diff --git a/association.go b/association.go index dfc79598..45d7367a 100644 --- a/association.go +++ b/association.go @@ -143,7 +143,7 @@ func (association *Association) Replace(values ...interface{}) *Association { var newPrimaryKeys [][]interface{} var associationForeignFieldNames []string - if relationship.Kind == "many2many" { + if relationship.Kind == "many_to_many" { // If many to many relations, get it from foreign key associationForeignFieldNames = relationship.AssociationForeignFieldNames } else { @@ -156,6 +156,7 @@ func (association *Association) Replace(values ...interface{}) *Association { } newPrimaryKeys = association.getPrimaryKeys(associationForeignFieldNames, field.Interface()) + if len(newPrimaryKeys) > 0 { sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(newPrimaryKeys)) newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...) diff --git a/model_struct.go b/model_struct.go index e98f0718..89e7a169 100644 --- a/model_struct.go +++ b/model_struct.go @@ -85,15 +85,14 @@ func (structField *StructField) clone() *StructField { } type Relationship struct { - Kind string - PolymorphicType string - PolymorphicDBName string - ForeignFieldNames []string - ForeignDBNames []string - AssociationForeignFieldNames []string - AssociationForeignStructFieldNames []string - AssociationForeignDBNames []string - JoinTableHandler JoinTableHandlerInterface + Kind string + PolymorphicType string + PolymorphicDBName string + ForeignFieldNames []string + ForeignDBNames []string + AssociationForeignFieldNames []string + AssociationForeignDBNames []string + JoinTableHandler JoinTableHandlerInterface } func (scope *Scope) GetModelStruct() *ModelStruct { @@ -263,7 +262,6 @@ func (scope *Scope) GetModelStruct() *ModelStruct { for _, name := range associationForeignKeys { if field, ok := toScope.FieldByName(name); ok { relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) - relationship.AssociationForeignStructFieldNames = append(relationship.AssociationForeignFieldNames, field.Name) joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) } diff --git a/multi_primary_keys_test.go b/multi_primary_keys_test.go index 9ca68d13..061822ef 100644 --- a/multi_primary_keys_test.go +++ b/multi_primary_keys_test.go @@ -1,8 +1,9 @@ package gorm_test import ( - "fmt" "os" + "reflect" + "sort" "testing" ) @@ -20,10 +21,21 @@ type Tag struct { Value string } +func compareTags(tags []Tag, contents []string) bool { + var tagContents []string + for _, tag := range tags { + tagContents = append(tagContents, tag.Value) + } + sort.Strings(tagContents) + sort.Strings(contents) + return reflect.DeepEqual(tagContents, contents) +} + func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" { - DB.Exec(fmt.Sprintf("drop table blog_tags;")) - DB.AutoMigrate(&Blog{}, &Tag{}) + DB.DropTable(&Blog{}, &Tag{}) + DB.DropTable("blog_tags") + DB.CreateTable(&Blog{}, &Tag{}) blog := Blog{ Locale: "ZH", Subject: "subject", @@ -35,12 +47,70 @@ func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { } DB.Save(&blog) - DB.Model(&blog).Association("Tags").Append([]Tag{{Locale: "ZH", Value: "tag3"}}) + if !compareTags(blog.Tags, []string{"tag1", "tag2"}) { + t.Errorf("Blog should has two tags") + } + + // Append + var tag3 = &Tag{Locale: "ZH", Value: "tag3"} + DB.Model(&blog).Association("Tags").Append([]*Tag{tag3}) + if !compareTags(blog.Tags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Blog should has three tags after Append") + } + + if DB.Model(&blog).Association("Tags").Count() != 3 { + t.Errorf("Blog should has three tags after Append") + } var tags []Tag DB.Model(&blog).Related(&tags, "Tags") - if len(tags) != 3 { - t.Errorf("should found 3 tags with blog") + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Should find 3 tags with Related") + } + + var blog1 Blog + DB.Preload("Tags").Find(&blog1) + if !compareTags(blog1.Tags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Preload many2many relations") + } + + // Replace + var tag5 = &Tag{Locale: "ZH", Value: "tag5"} + var tag6 = &Tag{Locale: "ZH", Value: "tag6"} + DB.Model(&blog).Association("Tags").Replace(tag5, tag6) + var tags2 []Tag + DB.Model(&blog).Related(&tags2, "Tags") + if !compareTags(tags2, []string{"tag5", "tag6"}) { + t.Errorf("Should find 2 tags after Replace") + } + + if DB.Model(&blog).Association("Tags").Count() != 2 { + t.Errorf("Blog should has three tags after Replace") + } + + // Delete + DB.Model(&blog).Association("Tags").Delete(tag5) + var tags3 []Tag + DB.Model(&blog).Related(&tags3, "Tags") + if !compareTags(tags3, []string{"tag6"}) { + t.Errorf("Should find 1 tags after Delete") + } + + if DB.Model(&blog).Association("Tags").Count() != 1 { + t.Errorf("Blog should has three tags after Delete") + } + + DB.Model(&blog).Association("Tags").Delete(tag3) + var tags4 []Tag + DB.Model(&blog).Related(&tags4, "Tags") + if !compareTags(tags4, []string{"tag6"}) { + t.Errorf("Tag should not be deleted when Delete with a unrelated tag") + } + + // Clear + DB.Model(&blog).Association("Tags").Clear() + if DB.Model(&blog).Association("Tags").Count() != 0 { + t.Errorf("All tags should be cleared") } } }