diff --git a/association.go b/association.go index 47ec500e..c90258ec 100644 --- a/association.go +++ b/association.go @@ -97,28 +97,34 @@ func (association *Association) Replace(values ...interface{}) error { } case schema.HasOne, schema.HasMany: var ( - primaryFields []*schema.Field - foreignKeys []string - updateMap = map[string]interface{}{} - modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() + tx = association.DB + primaryFields []*schema.Field + foreignKeys []string + updateMap = map[string]interface{}{} + relPrimaryKeys = []string{} + relValues = schema.GetRelationsValues(reflectValue, []*schema.Relationship{rel}) + modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() ) - if rel.Type == schema.BelongsTo { - modelValue = reflect.New(rel.Schema.ModelType).Interface() + + for _, field := range rel.FieldSchema.PrimaryFields { + relPrimaryKeys = append(relPrimaryKeys, field.DBName) + } + if _, qvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(qvs) > 0 { + if column, values := schema.ToQueryValues(relPrimaryKeys, qvs); len(values) > 0 { + tx = tx.Not(clause.IN{Column: column, Values: values}) + } } for _, ref := range rel.References { if ref.OwnPrimaryKey { primaryFields = append(primaryFields, ref.PrimaryKey) - } else { foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) updateMap[ref.ForeignKey.DBName] = nil } } - - _, values := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) - if len(values) == 0 { - column, queryValues := schema.ToQueryValues(foreignKeys, values) - association.DB.Model(modelValue).Where(clause.IN{Column: column, Values: queryValues}).UpdateColumns(updateMap) + if _, qvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(qvs) > 0 { + column, values := schema.ToQueryValues(foreignKeys, qvs) + tx.Model(modelValue).Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap) } case schema.Many2Many: var primaryFields, relPrimaryFields []*schema.Field @@ -413,7 +419,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } if len(values) > 0 { - association.DB.Select(selectedColumns).Model(nil).Save(reflectValue.Addr().Interface()) + association.DB.Session(&Session{}).Select(selectedColumns).Model(nil).Save(reflectValue.Addr().Interface()) } } diff --git a/callbacks/associations.go b/callbacks/associations.go index 37addd60..2342f110 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -124,6 +124,8 @@ func SaveAfterAssociations(db *gorm.DB) { if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero { elems = reflect.Append(elems, rv) + } else { + db.Session(&gorm.Session{}).Save(rv.Interface()) } } } @@ -149,6 +151,8 @@ func SaveAfterAssociations(db *gorm.DB) { if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f); isZero { db.Session(&gorm.Session{}).Create(f.Interface()) + } else { + db.Session(&gorm.Session{}).Save(f.Interface()) } } } @@ -187,6 +191,8 @@ func SaveAfterAssociations(db *gorm.DB) { } else { elems = reflect.Append(elems, elem.Addr()) } + } else { + db.Session(&gorm.Session{}).Save(elem.Interface()) } } } diff --git a/callbacks/update.go b/callbacks/update.go index 6a59e487..f9b20981 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -45,7 +45,11 @@ func BeforeUpdate(db *gorm.DB) { func Update(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.Update{}) - db.Statement.AddClause(ConvertToAssignments(db.Statement)) + if set := ConvertToAssignments(db.Statement); len(set) != 0 { + db.Statement.AddClause(set) + } else { + return + } db.Statement.Build("UPDATE", "SET", "WHERE") result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) @@ -198,5 +202,6 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } } + return } diff --git a/clause/expression.go b/clause/expression.go index 8150f838..872736ce 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -55,9 +55,11 @@ func (in IN) NegationBuild(builder Builder) { switch len(in.Values) { case 0: case 1: + builder.WriteQuoted(in.Column) builder.WriteString(" <> ") builder.AddVar(builder, in.Values...) default: + builder.WriteQuoted(in.Column) builder.WriteString(" NOT IN (") builder.AddVar(builder, in.Values...) builder.WriteByte(')') diff --git a/schema/field.go b/schema/field.go index 9a5f1fc6..8b8b190d 100644 --- a/schema/field.go +++ b/schema/field.go @@ -603,32 +603,40 @@ func (field *Field) setupValuerAndSetter() { if _, ok := fieldValue.Interface().(sql.Scanner); ok { // struct scanner field.Set = func(value reflect.Value, v interface{}) (err error) { - reflectV := reflect.ValueOf(v) - if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) - } else if valuer, ok := v.(driver.Valuer); ok { - if v, err = valuer.Value(); err == nil { + if v == nil { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else { + reflectV := reflect.ValueOf(v) + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) + } else if valuer, ok := v.(driver.Valuer); ok { + if v, err = valuer.Value(); err == nil { + err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) + } + } else { err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) } - } else { - err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) } return } } else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { // pointer scanner field.Set = func(value reflect.Value, v interface{}) (err error) { - reflectV := reflect.ValueOf(v) - if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) - } else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { - field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) - } else if valuer, ok := v.(driver.Valuer); ok { - if v, err = valuer.Value(); err == nil { + if v == nil { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else { + reflectV := reflect.ValueOf(v) + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) + } else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { + field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) + } else if valuer, ok := v.(driver.Valuer); ok { + if v, err = valuer.Value(); err == nil { + err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v) + } + } else { err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v) } - } else { - err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v) } return } diff --git a/tests/associations_test.go b/tests/associations_test.go index c67e79c8..137b2c50 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -233,3 +233,79 @@ func TestBelongsToAssociationForSlice(t *testing.T) { AssertAssociationCount(t, users[0], "Company", 0, "After Delete") AssertAssociationCount(t, users[1], "Company", 1, "After other user Delete") } + +func TestHasOneAssociation(t *testing.T) { + var user = *GetUser("hasone", Config{Account: true}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + // Find + var user2 User + DB.Find(&user2, "id = ?", user.ID) + DB.Model(&user2).Association("Account").Find(&user2.Account) + CheckUser(t, user2, user) + + // Count + AssertAssociationCount(t, user, "Account", 1, "") + + // Append + var account = Account{Number: "account-has-one-append"} + + if err := DB.Model(&user2).Association("Account").Append(&account); err != nil { + t.Fatalf("Error happened when append account, got %v", err) + } + + if account.ID == 0 { + t.Fatalf("Account's ID should be created") + } + + user.Account = account + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Account", 1, "AfterAppend") + + // Replace + var account2 = Account{Number: "account-has-one-replace"} + + if err := DB.Model(&user2).Association("Account").Replace(&account2); err != nil { + t.Fatalf("Error happened when append Account, got %v", err) + } + + if account2.ID == 0 { + t.Fatalf("account2's ID should be created") + } + + user.Account = account2 + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Account", 1, "AfterReplace") + + // Delete + if err := DB.Model(&user2).Association("Account").Delete(&Company{}); err != nil { + t.Fatalf("Error happened when delete account, got %v", err) + } + AssertAssociationCount(t, user2, "Account", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Account").Delete(&account2); err != nil { + t.Fatalf("Error happened when delete Account, got %v", err) + } + AssertAssociationCount(t, user2, "Account", 0, "after delete") + + // Prepare Data for Clear + if err := DB.Model(&user2).Association("Account").Append(&account); err != nil { + t.Fatalf("Error happened when append Account, got %v", err) + } + + AssertAssociationCount(t, user2, "Account", 1, "after prepare data") + + // Clear + if err := DB.Model(&user2).Association("Account").Clear(); err != nil { + t.Errorf("Error happened when clear Account, got %v", err) + } + + AssertAssociationCount(t, user2, "Account", 0, "after clear") +}