diff --git a/association.go b/association.go index 4bc7269f..1d3777f5 100644 --- a/association.go +++ b/association.go @@ -82,85 +82,65 @@ func (association *Association) Append(values ...interface{}) *Association { } func (association *Association) Replace(values ...interface{}) *Association { - relationship := association.Field.Relationship - scope := association.Scope - field := association.Field.Field + var ( + relationship = association.Field.Relationship + scope = association.Scope + field = association.Field.Field + newDB = scope.NewDB() + ) - // get old primary keys - oldPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, field.Interface()) - - // append new values + // Append new values association.Field.Set(reflect.Zero(association.Field.Field.Type())) association.Append(values...) - // get new primary keys - var newPrimaryKeys [][]interface{} - if len(values) > 0 { - newPrimaryKeys = association.getPrimaryKeys(relationship.AssociationForeignFieldNames, field.Interface()) - } - - var addedPrimaryKeys = [][]interface{}{} - for _, newKey := range newPrimaryKeys { - hasEqual := false - for _, oldKey := range oldPrimaryKeys { - if equalAsString(newKey, oldKey) { - hasEqual = true - break - } - } - if !hasEqual { - addedPrimaryKeys = append(addedPrimaryKeys, newKey) - } - } - - for _, primaryKey := range association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...) { - addedPrimaryKeys = append(addedPrimaryKeys, primaryKey) - } - - query := scope.NewDB() - var foreignKeyMap = map[string]interface{}{} - + // Belongs To if relationship.Kind == "belongs_to" { - for idx, foreignKey := range relationship.AssociationForeignDBNames { - if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok { - query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + // Set foreign key to be null only when clearing value + if len(values) == 0 { + // Set foreign key to be nil + var foreignKeyMap = map[string]interface{}{} + for _, foreignKey := range relationship.ForeignDBNames { + foreignKeyMap[foreignKey] = nil } + association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error) } - - for _, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - } - - if len(addedPrimaryKeys) > 0 { - sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(addedPrimaryKeys)) - query = query.Where(sql, toQueryValues(addedPrimaryKeys)...) - } - - modelValue := scope.Value - // if replacing with a new value, don't reset current foreign key - // if clearing foreign value, then reset the foreign key to null - if len(values) > 0 { - modelValue = reflect.New(scope.GetModelStruct().ModelType).Interface() - } - association.setErr(query.Model(modelValue).UpdateColumn(foreignKeyMap).Error) } else { + // Relations + var foreignKeyMap = map[string]interface{}{} for idx, foreignKey := range relationship.ForeignDBNames { foreignKeyMap[foreignKey] = nil if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok { - query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) } } - if len(addedPrimaryKeys) > 0 { - sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(addedPrimaryKeys)) - query = query.Where(sql, toQueryValues(addedPrimaryKeys)...) + // Relations except new created + if len(values) > 0 { + var newPrimaryKeys [][]interface{} + var associationForeignFieldNames []string + + if relationship.Kind == "many2many" { + // If many to many relations, get it from foreign key + associationForeignFieldNames = relationship.AssociationForeignFieldNames + } else { + // If other relations, get real primary keys + for _, field := range scope.New(reflect.New(field.Type()).Interface()).Fields() { + if field.IsPrimaryKey { + associationForeignFieldNames = append(associationForeignFieldNames, field.Name) + } + } + } + + newPrimaryKeys = association.getPrimaryKeys(associationForeignFieldNames, field.Interface()) + sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(newPrimaryKeys)) + newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...) } if relationship.Kind == "many_to_many" { - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship)) + association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { fieldValue := reflect.New(association.Field.Field.Type()).Interface() - association.setErr(query.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) + association.setErr(newDB.Debug().Model(fieldValue).UpdateColumn(foreignKeyMap).Error) } } return association diff --git a/association_test.go b/association_test.go index 92854d89..37ed1bb8 100644 --- a/association_test.go +++ b/association_test.go @@ -8,12 +8,9 @@ import ( ) func TestBelongsTo(t *testing.T) { - DB.DropTable(Category{}, Post{}) - DB.CreateTable(Category{}, Post{}) - post := Post{ - Title: "post 1", - Body: "body 1", + Title: "post belongs to", + Body: "body belongs to", Category: Category{Name: "Category 1"}, MainCategory: Category{Name: "Main Category 1"}, } @@ -34,19 +31,19 @@ func TestBelongsTo(t *testing.T) { var category1 Category DB.Model(&post).Association("Category").Find(&category1) if category1.Name != "Category 1" { - t.Errorf("Query has one relations with Association") + t.Errorf("Query belongs to relations with Association") } var mainCategory1 Category DB.Model(&post).Association("MainCategory").Find(&mainCategory1) if mainCategory1.Name != "Main Category 1" { - t.Errorf("Query has one relations with Association") + t.Errorf("Query belongs to relations with Association") } var category11 Category DB.Model(&post).Related(&category11) if category11.Name != "Category 1" { - t.Errorf("Query has one relations with Related") + t.Errorf("Query belongs to relations with Related") } // Append @@ -114,12 +111,9 @@ func TestBelongsTo(t *testing.T) { } func TestHasMany(t *testing.T) { - DB.DropTable(Post{}, Comment{}) - DB.CreateTable(Post{}, Comment{}) - post := Post{ - Title: "post 1", - Body: "body 1", + Title: "post has many", + Body: "body has many", Comments: []*Comment{{Content: "Comment 1"}, {Content: "Comment 2"}}, } @@ -213,51 +207,69 @@ func TestHasMany(t *testing.T) { } } -func TestHasOneAndHasManyAssociation(t *testing.T) { - DB.DropTable(Category{}, Post{}, Comment{}) - DB.CreateTable(Category{}, Post{}, Comment{}) - - post := Post{ - Title: "post 1", - Body: "body 1", - Comments: []*Comment{{Content: "Comment 1"}, {Content: "Comment 2"}}, - Category: Category{Name: "Category 1"}, - MainCategory: Category{Name: "Main Category 1"}, +func TestHasOne(t *testing.T) { + user := User{ + Name: "has one", + CreditCard: CreditCard{Number: "411111111111"}, } - if err := DB.Save(&post).Error; err != nil { - t.Errorf("Got errors when save post", err.Error()) + if err := DB.Save(&user).Error; err != nil { + t.Errorf("Got errors when save user", err.Error()) } - if err := DB.First(&Category{}, "name = ?", "Category 1").Error; err != nil { - t.Errorf("Category should be saved", err.Error()) + if user.CreditCard.UserId.Int64 == 0 { + t.Errorf("CreditCard's foreign key should be updated") } - var p Post - DB.First(&p, post.Id) + // Query + var creditCard1 CreditCard + DB.Model(&user).Related(&creditCard1) - if post.CategoryId.Int64 == 0 || p.CategoryId.Int64 == 0 || post.MainCategoryId == 0 || p.MainCategoryId == 0 { - t.Errorf("Category Id should exist") + if creditCard1.Number != "411111111111" { + t.Errorf("Query has one relations with Related") } - if DB.First(&Comment{}, "content = ?", "Comment 1").Error != nil { - t.Errorf("Comment 1 should be saved") - } - if post.Comments[0].PostId == 0 { - t.Errorf("Comment Should have post id") + var creditCard11 CreditCard + DB.Model(&user).Association("CreditCard").Find(&creditCard11) + + if creditCard11.Number != "411111111111" { + t.Errorf("Query has one relations with Related") } - var comment Comment - if DB.First(&comment, "content = ?", "Comment 2").Error != nil { - t.Errorf("Comment 2 should be saved") + // Append + var creditcard2 = CreditCard{ + Number: "411111111112", + } + DB.Model(&user).Association("CreditCard").Append(&creditcard2) + + if creditcard2.ID == 0 { + t.Errorf("Creditcard should has ID when created with Append") } - if comment.PostId == 0 { - t.Errorf("Comment 2 Should have post id") + var creditcard21 CreditCard + DB.Model(&user).Related(&creditcard21) + if creditcard21.Number != "411111111112" { + t.Errorf("CreditCard should be updated with Append") } - comment3 := Comment{Content: "Comment 3", Post: Post{Title: "Title 3", Body: "Body 3"}} - DB.Save(&comment3) + // Replace + var creditcard3 = CreditCard{ + Number: "411111111113", + } + DB.Model(&user).Association("CreditCard").Replace(&creditcard3) + + if creditcard3.ID == 0 { + t.Errorf("Creditcard should has ID when created with Replace") + } + + var creditcard31 CreditCard + DB.Model(&user).Related(&creditcard31) + if creditcard31.Number != "411111111113" { + t.Errorf("CreditCard should be updated with Replace") + } + + // Delete + // Clear } func TestRelated(t *testing.T) {