diff --git a/association.go b/association.go index 45d7367a..1ba1d519 100644 --- a/association.go +++ b/association.go @@ -126,14 +126,6 @@ func (association *Association) Replace(values ...interface{}) *Association { } } else { // Relations - var foreignKeyMap = map[string]interface{}{} - for idx, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - if relationship.PolymorphicDBName != "" { newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) } @@ -164,8 +156,22 @@ func (association *Association) Replace(values ...interface{}) *Association { } if relationship.Kind == "many_to_many" { + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { + newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + } + } + association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { + var foreignKeyMap = map[string]interface{}{} + for idx, foreignKey := range relationship.ForeignDBNames { + foreignKeyMap[foreignKey] = nil + if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok { + newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + } + } + fieldValue := reflect.New(association.Field.Field.Type()).Interface() association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) } diff --git a/main.go b/main.go index 6db37dab..9fe6cf4e 100644 --- a/main.go +++ b/main.go @@ -384,6 +384,10 @@ func (s *DB) CreateTable(values ...interface{}) *DB { func (s *DB) DropTable(values ...interface{}) *DB { db := s.clone() for _, value := range values { + if tableName, ok := value.(string); ok { + db = db.Table(tableName) + } + db = db.NewScope(value).dropTable().db } return db diff --git a/model_struct.go b/model_struct.go index 15df1082..204da5e7 100644 --- a/model_struct.go +++ b/model_struct.go @@ -203,11 +203,11 @@ func (scope *Scope) GetModelStruct() *ModelStruct { ) if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { - foreignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ";") + foreignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",") } if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { - associationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ";") + associationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",") } for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr { @@ -343,11 +343,11 @@ func (scope *Scope) GetModelStruct() *ModelStruct { ) if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { - tagForeignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ";") + tagForeignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",") } if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { - tagAssociationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ";") + tagAssociationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",") } if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { diff --git a/multi_primary_keys_test.go b/multi_primary_keys_test.go index 27997fb7..ea80326e 100644 --- a/multi_primary_keys_test.go +++ b/multi_primary_keys_test.go @@ -14,6 +14,7 @@ type Blog struct { Body string Tags []Tag `gorm:"many2many:blog_tags;"` SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;AssociationForeignKey:id"` + LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;AssociationForeignKey:id"` } type Tag struct { @@ -135,9 +136,9 @@ func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { blog2 := Blog{ ID: blog.ID, - Locale: "ZH", + Locale: "EN", } - DB.Save(&blog2) + DB.Create(&blog2) if !compareTags(blog.SharedTags, []string{"tag1", "tag2"}) { t.Errorf("Blog should has two tags") @@ -233,3 +234,148 @@ func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { } } } + +func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { + if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" { + DB.DropTable(&Blog{}, &Tag{}) + DB.DropTable("locale_blog_tags") + DB.CreateTable(&Blog{}, &Tag{}) + blog := Blog{ + Locale: "ZH", + Subject: "subject", + Body: "body", + LocaleTags: []Tag{ + {Locale: "ZH", Value: "tag1"}, + {Locale: "ZH", Value: "tag2"}, + }, + } + DB.Save(&blog) + + blog2 := Blog{ + ID: blog.ID, + Locale: "EN", + } + DB.Create(&blog2) + + // Append + var tag3 = &Tag{Locale: "ZH", Value: "tag3"} + DB.Model(&blog).Association("LocaleTags").Append([]*Tag{tag3}) + if !compareTags(blog.LocaleTags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Blog should has three tags after Append") + } + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Errorf("Blog should has three tags after Append") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { + t.Errorf("EN Blog should has 0 tags after ZH Blog Append") + } + + var tags []Tag + DB.Model(&blog).Related(&tags, "LocaleTags") + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Should find 3 tags with Related") + } + + DB.Model(&blog2).Related(&tags, "LocaleTags") + if len(tags) != 0 { + t.Errorf("Should find 0 tags with Related for EN Blog") + } + + var blog1 Blog + DB.Preload("LocaleTags").Find(&blog1, "locale = ? AND id = ?", "ZH", blog.ID) + if !compareTags(blog1.LocaleTags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Preload many2many relations") + } + + var tag4 = &Tag{Locale: "ZH", Value: "tag4"} + DB.Model(&blog2).Association("LocaleTags").Append(tag4) + + DB.Model(&blog).Related(&tags, "LocaleTags") + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Should find 3 tags with Related for EN Blog") + } + + DB.Model(&blog2).Related(&tags, "LocaleTags") + if !compareTags(tags, []string{"tag4"}) { + t.Errorf("Should find 1 tags with Related for EN Blog") + } + + // Replace + var tag5 = &Tag{Locale: "ZH", Value: "tag5"} + var tag6 = &Tag{Locale: "ZH", Value: "tag6"} + DB.Model(&blog2).Association("LocaleTags").Replace(tag5, tag6) + + var tags2 []Tag + DB.Model(&blog).Related(&tags2, "LocaleTags") + if !compareTags(tags2, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("CN Blog's tags should not be changed after EN Blog Replace") + } + + var blog11 Blog + DB.Preload("LocaleTags").First(&blog11, "id = ? AND locale = ?", blog.ID, blog.Locale) + if !compareTags(blog11.LocaleTags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("CN Blog's tags should not be changed after EN Blog Replace") + } + + DB.Model(&blog2).Related(&tags2, "LocaleTags") + if !compareTags(tags2, []string{"tag5", "tag6"}) { + t.Errorf("Should find 2 tags after Replace") + } + + var blog21 Blog + DB.Preload("LocaleTags").First(&blog21, "id = ? AND locale = ?", blog2.ID, blog2.Locale) + if !compareTags(blog21.LocaleTags, []string{"tag5", "tag6"}) { + t.Errorf("EN Blog's tags should be changed after Replace") + } + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Errorf("ZH Blog should has three tags after Replace") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { + t.Errorf("EN Blog should has two tags after Replace") + } + + // Delete + DB.Model(&blog).Association("LocaleTags").Delete(tag5) + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Errorf("ZH Blog should has three tags after Delete with EN's tag") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { + t.Errorf("EN Blog should has two tags after ZH Blog Delete with EN's tag") + } + + DB.Model(&blog2).Association("LocaleTags").Delete(tag5) + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Errorf("ZH Blog should has three tags after EN Blog Delete with EN's tag") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 1 { + t.Errorf("EN Blog should has 1 tags after EN Blog Delete with EN's tag") + } + + // Clear + DB.Model(&blog2).Association("LocaleTags").Clear() + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Errorf("ZH Blog's tags should not be cleared when clear EN Blog's tags") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { + t.Errorf("EN Blog's tags should be cleared when clear EN Blog's tags") + } + + DB.Model(&blog).Association("LocaleTags").Clear() + if DB.Model(&blog).Association("LocaleTags").Count() != 0 { + t.Errorf("ZH Blog's tags should be cleared when clear ZH Blog's tags") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { + t.Errorf("EN Blog's tags should be cleared") + } + } +} diff --git a/preload.go b/preload.go index 9eaf0ab0..d12995f3 100644 --- a/preload.go +++ b/preload.go @@ -277,10 +277,10 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface } } - var associationForeignStructFieldNames []string - for _, dbName := range relation.AssociationForeignFieldNames { + var foreignFieldNames []string + for _, dbName := range relation.ForeignFieldNames { if field, ok := scope.FieldByName(dbName); ok { - associationForeignStructFieldNames = append(associationForeignStructFieldNames, field.Name) + foreignFieldNames = append(foreignFieldNames, field.Name) } } @@ -288,7 +288,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface objects := scope.IndirectValue() for j := 0; j < objects.Len(); j++ { object := reflect.Indirect(objects.Index(j)) - source := getRealValue(object, associationForeignStructFieldNames) + source := getRealValue(object, foreignFieldNames) field := object.FieldByName(field.Name) for _, link := range linkHash[toString(source)] { field.Set(reflect.Append(field, link)) @@ -296,7 +296,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface } } else { if object := scope.IndirectValue(); object.IsValid() { - source := getRealValue(object, associationForeignStructFieldNames) + source := getRealValue(object, foreignFieldNames) field := object.FieldByName(field.Name) for _, link := range linkHash[toString(source)] { field.Set(reflect.Append(field, link))