diff --git a/association.go b/association.go index c179a148..ff1e155f 100644 --- a/association.go +++ b/association.go @@ -366,6 +366,11 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ if clear && len(values) == 0 { for i := 0; i < reflectValue.Len(); i++ { association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) + for _, ref := range association.Relationship.References { + if !ref.OwnPrimaryKey { + ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()) + } + } } break } @@ -382,6 +387,11 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ case reflect.Struct: if clear && len(values) == 0 { association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) + for _, ref := range association.Relationship.References { + if !ref.OwnPrimaryKey { + ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + } + } } for idx, value := range values { @@ -392,10 +402,12 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ _, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue) } - if hasZero { - association.DB.Save(reflectValue.Addr().Interface()) - } else { - association.DB.Select(selectedColumns).Save(reflectValue.Addr().Interface()) + if len(values) > 0 { + if hasZero { + association.DB.Create(reflectValue.Addr().Interface()) + } else { + association.DB.Select(selectedColumns).Model(nil).Save(reflectValue.Addr().Interface()) + } } for _, assignBack := range assignBacks { diff --git a/callbacks/update.go b/callbacks/update.go index be9fe30a..6a59e487 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -173,10 +173,28 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } if stmt.Dest != stmt.Model { - reflectValue := reflect.ValueOf(stmt.Model) - for _, field := range stmt.Schema.PrimaryFields { - if value, isZero := field.ValueOf(reflectValue); !isZero { - stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) + reflectValue := reflect.Indirect(reflect.ValueOf(stmt.Model)) + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + var priamryKeyExprs []clause.Expression + for i := 0; i < reflectValue.Len(); i++ { + var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) + var notZero bool + for idx, field := range stmt.Schema.PrimaryFields { + value, isZero := field.ValueOf(reflectValue.Index(i)) + exprs[idx] = clause.Eq{Column: field.DBName, Value: value} + notZero = notZero || !isZero + } + if notZero { + priamryKeyExprs = append(priamryKeyExprs, clause.And(exprs...)) + } + } + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(priamryKeyExprs...)}}) + case reflect.Struct: + for _, field := range stmt.Schema.PrimaryFields { + if value, isZero := field.ValueOf(reflectValue); !isZero { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) + } } } } diff --git a/errors.go b/errors.go index a990cc4a..4f2bd4fa 100644 --- a/errors.go +++ b/errors.go @@ -19,4 +19,6 @@ var ( ErrMissingWhereClause = errors.New("missing WHERE clause while deleting") // ErrUnsupportedRelation unsupported relations ErrUnsupportedRelation = errors.New("unsupported relations") + // ErrPtrStructSupported only ptr of struct supported + ErrPtrStructSupported = errors.New("only ptr of struct supported") ) diff --git a/finisher_api.go b/finisher_api.go index 6a787576..c64ecdda 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -23,12 +23,17 @@ func (db *DB) Save(value interface{}) (tx *DB) { if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))} - reflectValue := reflect.ValueOf(value) - for idx, pf := range tx.Statement.Schema.PrimaryFields { - if pv, isZero := pf.ValueOf(reflectValue); isZero { - tx.callbacks.Create().Execute(tx) - where.Exprs[idx] = clause.Eq{Column: pf.DBName, Value: pv} - return + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + tx.AddError(ErrPtrStructSupported) + case reflect.Struct: + for idx, pf := range tx.Statement.Schema.PrimaryFields { + if pv, isZero := pf.ValueOf(reflectValue); isZero { + tx.callbacks.Create().Execute(tx) + where.Exprs[idx] = clause.Eq{Column: pf.DBName, Value: pv} + return + } } } diff --git a/tests/associations_test.go b/tests/associations_test.go index 159f7f3a..77a5ce47 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -6,7 +6,26 @@ import ( . "github.com/jinzhu/gorm/tests" ) -func TestAssociationForBelongsTo(t *testing.T) { +func AssertAssociationCount(t *testing.T, data interface{}, name string, result int64, reason string) { + if count := DB.Model(data).Association(name).Count(); count != result { + t.Errorf("invalid %v count %v, expects: %v got %v", name, reason, result, count) + } + + var newUser User + if user, ok := data.(User); ok { + DB.Find(&newUser, "id = ?", user.ID) + } else if user, ok := data.(*User); ok { + DB.Find(&newUser, "id = ?", user.ID) + } + + if newUser.ID != 0 { + if count := DB.Model(&newUser).Association(name).Count(); count != result { + t.Errorf("invalid %v count %v, expects: %v got %v", name, reason, result, count) + } + } +} + +func TestBelongsToAssociation(t *testing.T) { var user = *GetUser("belongs-to", Config{Company: true, Manager: true}) if err := DB.Create(&user).Error; err != nil { @@ -24,13 +43,8 @@ func TestAssociationForBelongsTo(t *testing.T) { CheckUser(t, user2, user) // Count - if count := DB.Model(&user).Association("Company").Count(); count != 1 { - t.Errorf("invalid company count, got %v", count) - } - - if count := DB.Model(&user).Association("Manager").Count(); count != 1 { - t.Errorf("invalid manager count, got %v", count) - } + AssertAssociationCount(t, user, "Company", 1, "") + AssertAssociationCount(t, user, "Manager", 1, "") // Append var company = Company{Name: "company-belongs-to-append"} @@ -58,6 +72,9 @@ func TestAssociationForBelongsTo(t *testing.T) { user.ManagerID = &manager.ID CheckUser(t, user2, user) + AssertAssociationCount(t, user2, "Company", 1, "AfterAppend") + AssertAssociationCount(t, user2, "Manager", 1, "AfterAppend") + // Replace var company2 = Company{Name: "company-belongs-to-replace"} var manager2 = GetUser("manager-belongs-to-replace", Config{}) @@ -84,40 +101,31 @@ func TestAssociationForBelongsTo(t *testing.T) { user.ManagerID = &manager2.ID CheckUser(t, user2, user) + AssertAssociationCount(t, user2, "Company", 1, "AfterReplace") + AssertAssociationCount(t, user2, "Manager", 1, "AfterReplace") + // Delete if err := DB.Model(&user2).Association("Company").Delete(&Company{}); err != nil { t.Fatalf("Error happened when delete Company, got %v", err) } - - if count := DB.Model(&user2).Association("Company").Count(); count != 1 { - t.Errorf("Invalid company count after delete non-existing association, got %v", count) - } + AssertAssociationCount(t, user2, "Company", 1, "after delete non-existing data") if err := DB.Model(&user2).Association("Company").Delete(&company2); err != nil { t.Fatalf("Error happened when delete Company, got %v", err) } - - if count := DB.Model(&user2).Association("Company").Count(); count != 0 { - t.Errorf("Invalid company count after delete, got %v", count) - } + AssertAssociationCount(t, user2, "Company", 0, "after delete") if err := DB.Model(&user2).Association("Manager").Delete(&User{}); err != nil { t.Fatalf("Error happened when delete Manager, got %v", err) } - - if count := DB.Model(&user2).Association("Manager").Count(); count != 1 { - t.Errorf("Invalid manager count after delete non-existing association, got %v", count) - } + AssertAssociationCount(t, user2, "Manager", 1, "after delete non-existing data") if err := DB.Model(&user2).Association("Manager").Delete(manager2); err != nil { t.Fatalf("Error happened when delete Manager, got %v", err) } + AssertAssociationCount(t, user2, "Manager", 0, "after delete") - if count := DB.Model(&user2).Association("Manager").Count(); count != 0 { - t.Errorf("Invalid manager count after delete, got %v", count) - } - - // Prepare Data + // Prepare Data for Clear if err := DB.Model(&user2).Association("Company").Append(&company); err != nil { t.Fatalf("Error happened when append Company, got %v", err) } @@ -126,13 +134,8 @@ func TestAssociationForBelongsTo(t *testing.T) { t.Fatalf("Error happened when append Manager, got %v", err) } - if count := DB.Model(&user2).Association("Company").Count(); count != 1 { - t.Errorf("Invalid company count after append, got %v", count) - } - - if count := DB.Model(&user2).Association("Manager").Count(); count != 1 { - t.Errorf("Invalid manager count after append, got %v", count) - } + AssertAssociationCount(t, user2, "Company", 1, "after prepare data") + AssertAssociationCount(t, user2, "Manager", 1, "after prepare data") // Clear if err := DB.Model(&user2).Association("Company").Clear(); err != nil { @@ -143,11 +146,43 @@ func TestAssociationForBelongsTo(t *testing.T) { t.Errorf("Error happened when clear Manager, got %v", err) } - if count := DB.Model(&user2).Association("Company").Count(); count != 0 { - t.Errorf("Invalid company count after clear, got %v", count) + AssertAssociationCount(t, user2, "Company", 0, "after clear") + AssertAssociationCount(t, user2, "Manager", 0, "after clear") +} + +func TestBelongsToAssociationForSlice(t *testing.T) { + var users = []User{ + *GetUser("slice-belongs-to-1", Config{Company: true, Manager: true}), + *GetUser("slice-belongs-to-2", Config{Company: true, Manager: false}), + *GetUser("slice-belongs-to-3", Config{Company: true, Manager: true}), } - if count := DB.Model(&user2).Association("Manager").Count(); count != 0 { - t.Errorf("Invalid manager count after clear, got %v", count) + DB.Create(&users) + + AssertAssociationCount(t, users, "Company", 3, "") + AssertAssociationCount(t, users, "Manager", 2, "") + + // Find + var companies []Company + if DB.Model(users).Association("Company").Find(&companies); len(companies) != 3 { + t.Errorf("companies count should be %v, but got %v", 3, len(companies)) } + + var managers []User + if DB.Model(users).Association("Manager").Find(&managers); len(managers) != 2 { + t.Errorf("managers count should be %v, but got %v", 2, len(managers)) + } + + // Append + + // Replace + + // Delete + + // Clear + DB.Model(&users).Association("Company").Clear() + AssertAssociationCount(t, users, "Company", 0, "After Clear") + + DB.Model(&users).Association("Manager").Clear() + AssertAssociationCount(t, users, "Manager", 0, "After Clear") }