diff --git a/association_test.go b/association_test.go index 90ad4298..52d2303f 100644 --- a/association_test.go +++ b/association_test.go @@ -191,12 +191,12 @@ func TestBelongsToOverrideForeignKey1(t *testing.T) { ProfileRefer int } - DB.AutoMigrate(&User{}) - DB.AutoMigrate(&Profile{}) - - var user = User{Model: gorm.Model{ID: 1}, ProfileRefer: 10} - if err := DB.Model(&user).Association("Profile").Find(&[]Profile{}).Error; err != nil { - t.Errorf("Override belongs to foreign key with tag") + if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { + if relation.Relationship.Kind != "belongs_to" || + !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileRefer"}) || + !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) { + t.Errorf("Override belongs to foreign key with tag") + } } } @@ -213,12 +213,12 @@ func TestBelongsToOverrideForeignKey2(t *testing.T) { ProfileID int } - DB.AutoMigrate(&User{}) - DB.AutoMigrate(&Profile{}) - - var user = User{Model: gorm.Model{ID: 1}, ProfileID: 10} - if err := DB.Model(&user).Association("Profile").Find(&[]Profile{}).Error; err != nil { - t.Errorf("Override belongs to foreign key with tag") + if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { + if relation.Relationship.Kind != "belongs_to" || + !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileID"}) || + !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) { + t.Errorf("Override belongs to foreign key with tag") + } } } @@ -368,6 +368,49 @@ func TestHasOne(t *testing.T) { } } +func TestHasOneOverrideForeignKey1(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profile Profile `gorm:"ForeignKey:UserRefer"` + } + + if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { + if relation.Relationship.Kind != "has_one" || + !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) || + !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) { + t.Errorf("Override belongs to foreign key with tag") + } + } +} + +func TestHasOneOverrideForeignKey2(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserID uint + } + + type User struct { + gorm.Model + Refer string + Profile Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"` + } + + if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { + if relation.Relationship.Kind != "has_one" || + !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) || + !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) { + t.Errorf("Override belongs to foreign key with tag") + } + } +} + func TestHasMany(t *testing.T) { post := Post{ Title: "post has many", @@ -507,6 +550,49 @@ func TestHasMany(t *testing.T) { } } +func TestHasManyOverrideForeignKey1(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profile []Profile `gorm:"ForeignKey:UserRefer"` + } + + if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { + if relation.Relationship.Kind != "has_many" || + !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) || + !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) { + t.Errorf("Override belongs to foreign key with tag") + } + } +} + +func TestHasManyOverrideForeignKey2(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserID uint + } + + type User struct { + gorm.Model + Refer string + Profile []Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"` + } + + if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { + if relation.Relationship.Kind != "has_many" || + !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) || + !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) { + t.Errorf("Override belongs to foreign key with tag") + } + } +} + func TestManyToMany(t *testing.T) { DB.Raw("delete from languages") var languages = []Language{{Name: "ZH"}, {Name: "EN"}} diff --git a/model_struct.go b/model_struct.go index a0d52352..6df615d1 100644 --- a/model_struct.go +++ b/model_struct.go @@ -298,7 +298,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if len(associationForeignKeys) == 0 { for _, foreignKey := range foreignKeys { if strings.HasPrefix(foreignKey, associationType) { - associationForeignKeys = append(associationForeignKeys, strings.TrimPrefix(foreignKey, associationType)) + associationForeignKey := strings.TrimPrefix(foreignKey, associationType) + if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { + associationForeignKeys = append(associationForeignKeys, associationForeignKey) + } } } if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { @@ -391,7 +394,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if len(associationForeignKeys) == 0 { for _, foreignKey := range foreignKeys { if strings.HasPrefix(foreignKey, associationType) { - associationForeignKeys = append(associationForeignKeys, strings.TrimPrefix(foreignKey, associationType)) + associationForeignKey := strings.TrimPrefix(foreignKey, associationType) + if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { + associationForeignKeys = append(associationForeignKeys, associationForeignKey) + } } } if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {