diff --git a/association.go b/association.go index ff1e155f..47ec500e 100644 --- a/association.go +++ b/association.go @@ -195,6 +195,8 @@ func (association *Association) Delete(values ...interface{}) error { foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) updateAttrs[ref.ForeignKey.DBName] = nil } + } else { + tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryKey}) } } @@ -208,6 +210,15 @@ func (association *Association) Delete(values ...interface{}) error { conds := rel.ToQueryConditions(reflectValue) tx.Model(modelValue).Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs) case schema.BelongsTo: + primaryKeys := []string{} + for _, field := range rel.Schema.PrimaryFields { + primaryKeys = append(primaryKeys, field.DBName) + } + _, queryValues := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields) + if column, values := schema.ToQueryValues(primaryKeys, queryValues); len(values) > 0 { + tx.Where(clause.IN{Column: column, Values: values}) + } + modelValue := reflect.New(rel.Schema.ModelType).Interface() tx.Model(modelValue).UpdateColumns(updateAttrs) case schema.Many2Many: @@ -353,7 +364,6 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } selectedColumns := []string{association.Relationship.Name} - hasZero := false for _, ref := range association.Relationship.References { if !ref.OwnPrimaryKey { selectedColumns = append(selectedColumns, ref.ForeignKey.Name) @@ -375,13 +385,16 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ break } association.Error = errors.New("invalid association values, length doesn't match") + return } for i := 0; i < reflectValue.Len(); i++ { appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear) - if !hasZero { - _, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue.Index(i)) + if len(values) > 0 { + // TODO support save slice data, sql with case + err := association.DB.Session(&Session{}).Select(selectedColumns).Model(nil).Save(reflectValue.Index(i).Addr().Interface()).Error + association.DB.AddError(err) } } case reflect.Struct: @@ -399,13 +412,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ appendToRelations(reflectValue, rv, clear && idx == 0) } - _, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue) - } - - if len(values) > 0 { - if hasZero { - association.DB.Create(reflectValue.Addr().Interface()) - } else { + if len(values) > 0 { association.DB.Select(selectedColumns).Model(nil).Save(reflectValue.Addr().Interface()) } } diff --git a/tests/associations_test.go b/tests/associations_test.go index 77a5ce47..c67e79c8 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -164,20 +164,49 @@ func TestBelongsToAssociationForSlice(t *testing.T) { // Find var companies []Company - if DB.Model(users).Association("Company").Find(&companies); len(companies) != 3 { + if DB.Model(&users).Association("Company").Find(&companies); len(companies) != 3 { t.Errorf("companies count should be %v, but got %v", 3, len(companies)) } var managers []User - if DB.Model(users).Association("Manager").Find(&managers); len(managers) != 2 { + if DB.Model(&users).Association("Manager").Find(&managers); len(managers) != 2 { t.Errorf("managers count should be %v, but got %v", 2, len(managers)) } // Append + DB.Model(&users).Association("Company").Append( + &Company{Name: "company-slice-append-1"}, + &Company{Name: "company-slice-append-2"}, + &Company{Name: "company-slice-append-3"}, + ) - // Replace + AssertAssociationCount(t, users, "Company", 3, "After Append") + + DB.Model(&users).Association("Manager").Append( + GetUser("manager-slice-belongs-to-1", Config{}), + GetUser("manager-slice-belongs-to-2", Config{}), + GetUser("manager-slice-belongs-to-3", Config{}), + ) + AssertAssociationCount(t, users, "Manager", 3, "After Append") + + if err := DB.Model(&users).Association("Manager").Append( + GetUser("manager-slice-belongs-to-test-1", Config{}), + ).Error; err == nil { + t.Errorf("unmatched length when update user's manager") + } + + // Replace -> same as append // Delete + if err := DB.Model(&users).Association("Company").Delete(&users[0].Company); err != nil { + t.Errorf("no error should happend when deleting company, but got %v", err) + } + + if users[0].CompanyID != nil || users[0].Company.ID != 0 { + t.Errorf("users[0]'s company should be deleted'") + } + + AssertAssociationCount(t, users, "Company", 2, "After Delete") // Clear DB.Model(&users).Association("Company").Clear() @@ -185,4 +214,22 @@ func TestBelongsToAssociationForSlice(t *testing.T) { DB.Model(&users).Association("Manager").Clear() AssertAssociationCount(t, users, "Manager", 0, "After Clear") + + // shared company + company := Company{Name: "shared"} + if err := DB.Model(&users[0]).Association("Company").Append(&company); err != nil { + t.Errorf("Error happened when append company to user, got %v", err) + } + + if err := DB.Model(&users[1]).Association("Company").Append(&company); err != nil { + t.Errorf("Error happened when append company to user, got %v", err) + } + + if users[0].CompanyID == nil || users[1].CompanyID == nil || *users[0].CompanyID != *users[1].CompanyID { + t.Errorf("user's company id should exists and equal, but its: %v, %v", users[0].CompanyID, users[1].CompanyID) + } + + DB.Model(&users[0]).Association("Company").Delete(&company) + AssertAssociationCount(t, users[0], "Company", 0, "After Delete") + AssertAssociationCount(t, users[1], "Company", 1, "After other user Delete") }