From 91a695893c4c5c5e830631fa58d63b9a26d50aed Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 24 May 2020 17:24:23 +0800 Subject: [PATCH] Test Association For BelongsTo --- association.go | 79 +++++++++++++++++------- callbacks/associations.go | 2 +- callbacks/helper.go | 2 +- callbacks/update.go | 30 +++++++-- gorm.go | 11 ++-- schema/field.go | 29 +++++---- schema/relationship.go | 1 + statement.go | 33 ++++++++++ tests/associations_test.go | 121 +++++++++++++++++++++++++++++++++++++ tests/count_test.go | 2 +- 10 files changed, 265 insertions(+), 45 deletions(-) diff --git a/association.go b/association.go index bd2a7cdd..c179a148 100644 --- a/association.go +++ b/association.go @@ -19,8 +19,10 @@ type Association struct { func (db *DB) Association(column string) *Association { association := &Association{DB: db} + table := db.Statement.Table if err := db.Statement.Parse(db.Statement.Model); err == nil { + db.Statement.Table = table association.Relationship = db.Statement.Schema.Relationships.Relations[column] if association.Relationship == nil { @@ -83,6 +85,16 @@ func (association *Association) Replace(values ...interface{}) error { rel := association.Relationship switch rel.Type { + case schema.BelongsTo: + if len(values) == 0 { + updateMap := map[string]interface{}{} + + for _, ref := range rel.References { + updateMap[ref.ForeignKey.DBName] = nil + } + + association.DB.UpdateColumns(updateMap) + } case schema.HasOne, schema.HasMany: var ( primaryFields []*schema.Field @@ -90,6 +102,9 @@ func (association *Association) Replace(values ...interface{}) error { updateMap = map[string]interface{}{} modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() ) + if rel.Type == schema.BelongsTo { + modelValue = reflect.New(rel.Schema.ModelType).Interface() + } for _, ref := range rel.References { if ref.OwnPrimaryKey { @@ -101,7 +116,7 @@ func (association *Association) Replace(values ...interface{}) error { } _, values := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) - if len(values) > 0 { + if len(values) == 0 { column, queryValues := schema.ToQueryValues(foreignKeys, values) association.DB.Model(modelValue).Where(clause.IN{Column: column, Values: queryValues}).UpdateColumns(updateMap) } @@ -158,13 +173,13 @@ func (association *Association) Replace(values ...interface{}) error { func (association *Association) Delete(values ...interface{}) error { if association.Error == nil { var ( - tx = association.DB - rel = association.Relationship - reflectValue = tx.Statement.ReflectValue - conds = rel.ToQueryConditions(reflectValue) - relFields []*schema.Field - foreignKeys []string - updateAttrs = map[string]interface{}{} + tx = association.DB + rel = association.Relationship + reflectValue = tx.Statement.ReflectValue + relFields []*schema.Field + foreignKeyFields []*schema.Field + foreignKeys []string + updateAttrs = map[string]interface{}{} ) for _, ref := range rel.References { @@ -174,6 +189,7 @@ func (association *Association) Delete(values ...interface{}) error { relFields = append(relFields, ref.ForeignKey) } else { relFields = append(relFields, ref.PrimaryKey) + foreignKeyFields = append(foreignKeyFields, ref.ForeignKey) } foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) @@ -189,11 +205,14 @@ func (association *Association) Delete(values ...interface{}) error { switch rel.Type { case schema.HasOne, schema.HasMany: modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() + conds := rel.ToQueryConditions(reflectValue) tx.Model(modelValue).Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs) case schema.BelongsTo: - tx.Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs) + modelValue := reflect.New(rel.Schema.ModelType).Interface() + tx.Model(modelValue).UpdateColumns(updateAttrs) case schema.Many2Many: modelValue := reflect.New(rel.JoinTable.ModelType).Interface() + conds := rel.ToQueryConditions(reflectValue) tx.Clauses(clause.Where{Exprs: conds}).Delete(modelValue) } @@ -216,13 +235,16 @@ func (association *Association) Delete(values ...interface{}) error { } } - rel.Field.Set(data, validFieldValues) + rel.Field.Set(data, validFieldValues.Interface()) case reflect.Struct: for idx, field := range relFields { - fieldValues[idx], _ = field.ValueOf(data) + fieldValues[idx], _ = field.ValueOf(fieldValue) } if _, ok := relValuesMap[utils.ToStringKey(fieldValues...)]; ok { - rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType)) + rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()) + for _, field := range foreignKeyFields { + field.Set(data, reflect.Zero(field.FieldType).Interface()) + } } } } @@ -275,7 +297,11 @@ func (association *Association) Count() (count int64) { } func (association *Association) saveAssociation(clear bool, values ...interface{}) { - reflectValue := association.DB.Statement.ReflectValue + var ( + reflectValue = association.DB.Statement.ReflectValue + assignBacks = [][2]reflect.Value{} + assignBack = association.Relationship.Field.FieldType.Kind() == reflect.Struct + ) appendToRelations := func(source, rv reflect.Value, clear bool) { switch association.Relationship.Type { @@ -283,10 +309,16 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ switch rv.Kind() { case reflect.Slice, reflect.Array: if rv.Len() > 0 { - association.Error = association.Relationship.Field.Set(source, rv.Index(0)) + association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface()) + if assignBack { + assignBacks = append(assignBacks, [2]reflect.Value{source, rv.Index(0)}) + } } case reflect.Struct: - association.Error = association.Relationship.Field.Set(source, rv) + association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface()) + if assignBack { + assignBacks = append(assignBacks, [2]reflect.Value{source, rv}) + } } case schema.HasMany, schema.Many2Many: elemType := association.Relationship.Field.IndirectFieldType.Elem() @@ -315,7 +347,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } if association.Error == nil { - association.Error = association.Relationship.Field.Set(source, fieldValue) + association.Error = association.Relationship.Field.Set(source, fieldValue.Addr().Interface()) } } } @@ -333,7 +365,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ if len(values) != reflectValue.Len() { if clear && len(values) == 0 { for i := 0; i < reflectValue.Len(); i++ { - association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType)) + association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) } break } @@ -349,19 +381,24 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } case reflect.Struct: if clear && len(values) == 0 { - association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType)) + association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) } for idx, value := range values { - appendToRelations(reflectValue, reflect.Indirect(reflect.ValueOf(value)), clear && idx == 0) + rv := reflect.Indirect(reflect.ValueOf(value)) + appendToRelations(reflectValue, rv, clear && idx == 0) } _, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue) } if hasZero { - association.DB.Save(reflectValue.Interface()) + association.DB.Save(reflectValue.Addr().Interface()) } else { - association.DB.Select(selectedColumns).Save(reflectValue.Interface()) + association.DB.Select(selectedColumns).Save(reflectValue.Addr().Interface()) + } + + for _, assignBack := range assignBacks { + reflect.Indirect(assignBack[1]).Set(association.Relationship.Field.ReflectValueOf(assignBack[0])) } } diff --git a/callbacks/associations.go b/callbacks/associations.go index ef040b71..37addd60 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -73,8 +73,8 @@ func SaveBeforeAssociations(db *gorm.DB) { if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero { db.Session(&gorm.Session{}).Create(rv.Interface()) - setupReferences(db.Statement.ReflectValue, rv) } + setupReferences(db.Statement.ReflectValue, rv) } } } diff --git a/callbacks/helper.go b/callbacks/helper.go index 43e90b8a..8da74690 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -22,7 +22,7 @@ func SelectAndOmitColumns(stmt *gorm.Statement, requireCreate, requireUpdate boo break } - if field := stmt.Schema.LookUpField(column); field != nil { + if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { results[field.DBName] = true } else { results[column] = true diff --git a/callbacks/update.go b/callbacks/update.go index 53c646e9..be9fe30a 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -7,6 +7,7 @@ import ( "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/clause" + "github.com/jinzhu/gorm/schema" ) func BeforeUpdate(db *gorm.DB) { @@ -91,8 +92,27 @@ func AfterUpdate(db *gorm.DB) { // ConvertToAssignments convert to update assignments func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { - selectColumns, restricted := SelectAndOmitColumns(stmt, false, true) - reflectModelValue := reflect.ValueOf(stmt.Model) + var ( + selectColumns, restricted = SelectAndOmitColumns(stmt, false, true) + reflectModelValue = reflect.Indirect(reflect.ValueOf(stmt.Model)) + assignValue func(field *schema.Field, value interface{}) + ) + + switch reflectModelValue.Kind() { + case reflect.Slice, reflect.Array: + assignValue = func(field *schema.Field, value interface{}) { + for i := 0; i < reflectModelValue.Len(); i++ { + field.Set(reflectModelValue.Index(i), value) + } + } + case reflect.Struct: + assignValue = func(field *schema.Field, value interface{}) { + field.Set(reflectModelValue, value) + } + default: + assignValue = func(field *schema.Field, value interface{}) { + } + } switch value := stmt.Dest.(type) { case map[string]interface{}: @@ -111,7 +131,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { value[k] = time.Now() } set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) - field.Set(reflectModelValue, value[k]) + assignValue(field, value[k]) } } else if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: value[k]}) @@ -122,7 +142,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { now := time.Now() set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) - field.Set(reflectModelValue, now) + assignValue(field, now) } } default: @@ -140,7 +160,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if ok || !isZero { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) - field.Set(reflectModelValue, value) + assignValue(field, value) } } } else { diff --git a/gorm.go b/gorm.go index f8c944af..1fa69383 100644 --- a/gorm.go +++ b/gorm.go @@ -105,11 +105,12 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { func (db *DB) Session(config *Session) *DB { var ( tx = db.getInstance() + stmt = tx.Statement.clone() txConfig = *tx.Config ) if config.Context != nil { - tx.Statement.Context = config.Context + stmt.Context = config.Context } if config.Logger != nil { @@ -120,9 +121,11 @@ func (db *DB) Session(config *Session) *DB { txConfig.NowFunc = config.NowFunc } - tx.Config = &txConfig - tx.clone = true - return tx + return &DB{ + Config: &txConfig, + Statement: stmt, + clone: true, + } } // WithContext change current instance db's context to ctx diff --git a/schema/field.go b/schema/field.go index 7b37733b..9a5f1fc6 100644 --- a/schema/field.go +++ b/schema/field.go @@ -372,19 +372,24 @@ func (field *Field) setupValuerAndSetter() { } recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) { - reflectV := reflect.ValueOf(v) - if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) - } else if valuer, ok := v.(driver.Valuer); ok { - if v, err = valuer.Value(); err == nil { - return setter(value, v) - } - } else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { - field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) - } else if reflectV.Kind() == reflect.Ptr { - return field.Set(value, reflectV.Elem().Interface()) + if v == nil { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) } else { - return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) + reflectV := reflect.ValueOf(v) + + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) + } else if valuer, ok := v.(driver.Valuer); ok { + if v, err = valuer.Value(); err == nil { + return setter(value, v) + } + } else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { + field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) + } else if reflectV.Kind() == reflect.Ptr { + return field.Set(value, reflectV.Elem().Interface()) + } else { + return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) + } } return err } diff --git a/schema/relationship.go b/schema/relationship.go index 59aaa7e4..d10bfe30 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -387,6 +387,7 @@ func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds [] _, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields) column, values := ToQueryValues(relForeignKeys, foreignValues) + conds = append(conds, clause.IN{Column: column, Values: values}) return } diff --git a/statement.go b/statement.go index 0abf7a7e..d37622dd 100644 --- a/statement.go +++ b/statement.go @@ -278,6 +278,39 @@ func (stmt *Statement) Parse(value interface{}) (err error) { return err } +func (stmt *Statement) clone() *Statement { + newStmt := &Statement{ + DB: stmt.DB, + Table: stmt.Table, + Model: stmt.Model, + Dest: stmt.Dest, + ReflectValue: stmt.ReflectValue, + Clauses: map[string]clause.Clause{}, + Selects: stmt.Selects, + Omits: stmt.Omits, + Joins: map[string][]interface{}{}, + Preloads: map[string][]interface{}{}, + ConnPool: stmt.ConnPool, + Schema: stmt.Schema, + Context: stmt.Context, + RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound, + } + + for k, c := range stmt.Clauses { + newStmt.Clauses[k] = c + } + + for k, p := range stmt.Preloads { + newStmt.Preloads[k] = p + } + + for k, j := range stmt.Joins { + newStmt.Joins[k] = j + } + + return newStmt +} + func (stmt *Statement) reinit() { // stmt.Table = "" // stmt.Model = nil diff --git a/tests/associations_test.go b/tests/associations_test.go index 845ee65e..159f7f3a 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -15,6 +15,7 @@ func TestAssociationForBelongsTo(t *testing.T) { CheckUser(t, user, user) + // Find var user2 User DB.Find(&user2, "id = ?", user.ID) DB.Model(&user2).Association("Company").Find(&user2.Company) @@ -22,6 +23,7 @@ func TestAssociationForBelongsTo(t *testing.T) { DB.Model(&user2).Association("Manager").Find(user2.Manager) CheckUser(t, user2, user) + // Count if count := DB.Model(&user).Association("Company").Count(); count != 1 { t.Errorf("invalid company count, got %v", count) } @@ -29,4 +31,123 @@ func TestAssociationForBelongsTo(t *testing.T) { if count := DB.Model(&user).Association("Manager").Count(); count != 1 { t.Errorf("invalid manager count, got %v", count) } + + // Append + var company = Company{Name: "company-belongs-to-append"} + var manager = GetUser("manager-belongs-to-append", Config{}) + + if err := DB.Model(&user2).Association("Company").Append(&company); err != nil { + t.Fatalf("Error happened when append Company, got %v", err) + } + + if company.ID == 0 { + t.Fatalf("Company's ID should be created") + } + + if err := DB.Model(&user2).Association("Manager").Append(manager); err != nil { + t.Fatalf("Error happened when append Manager, got %v", err) + } + + if manager.ID == 0 { + t.Fatalf("Manager's ID should be created") + } + + user.Company = company + user.Manager = manager + user.CompanyID = &company.ID + user.ManagerID = &manager.ID + CheckUser(t, user2, user) + + // Replace + var company2 = Company{Name: "company-belongs-to-replace"} + var manager2 = GetUser("manager-belongs-to-replace", Config{}) + + if err := DB.Model(&user2).Association("Company").Replace(&company2); err != nil { + t.Fatalf("Error happened when replace Company, got %v", err) + } + + if company2.ID == 0 { + t.Fatalf("Company's ID should be created") + } + + if err := DB.Model(&user2).Association("Manager").Replace(manager2); err != nil { + t.Fatalf("Error happened when replace Manager, got %v", err) + } + + if manager2.ID == 0 { + t.Fatalf("Manager's ID should be created") + } + + user.Company = company2 + user.Manager = manager2 + user.CompanyID = &company2.ID + user.ManagerID = &manager2.ID + CheckUser(t, user2, user) + + // Delete + if err := DB.Model(&user2).Association("Company").Delete(&Company{}); err != nil { + t.Fatalf("Error happened when delete Company, got %v", err) + } + + if count := DB.Model(&user2).Association("Company").Count(); count != 1 { + t.Errorf("Invalid company count after delete non-existing association, got %v", count) + } + + if err := DB.Model(&user2).Association("Company").Delete(&company2); err != nil { + t.Fatalf("Error happened when delete Company, got %v", err) + } + + if count := DB.Model(&user2).Association("Company").Count(); count != 0 { + t.Errorf("Invalid company count after delete, got %v", count) + } + + if err := DB.Model(&user2).Association("Manager").Delete(&User{}); err != nil { + t.Fatalf("Error happened when delete Manager, got %v", err) + } + + if count := DB.Model(&user2).Association("Manager").Count(); count != 1 { + t.Errorf("Invalid manager count after delete non-existing association, got %v", count) + } + + if err := DB.Model(&user2).Association("Manager").Delete(manager2); err != nil { + t.Fatalf("Error happened when delete Manager, got %v", err) + } + + if count := DB.Model(&user2).Association("Manager").Count(); count != 0 { + t.Errorf("Invalid manager count after delete, got %v", count) + } + + // Prepare Data + if err := DB.Model(&user2).Association("Company").Append(&company); err != nil { + t.Fatalf("Error happened when append Company, got %v", err) + } + + if err := DB.Model(&user2).Association("Manager").Append(manager); err != nil { + t.Fatalf("Error happened when append Manager, got %v", err) + } + + if count := DB.Model(&user2).Association("Company").Count(); count != 1 { + t.Errorf("Invalid company count after append, got %v", count) + } + + if count := DB.Model(&user2).Association("Manager").Count(); count != 1 { + t.Errorf("Invalid manager count after append, got %v", count) + } + + // Clear + if err := DB.Model(&user2).Association("Company").Clear(); err != nil { + t.Errorf("Error happened when clear Company, got %v", err) + } + + if err := DB.Model(&user2).Association("Manager").Clear(); err != nil { + t.Errorf("Error happened when clear Manager, got %v", err) + } + + if count := DB.Model(&user2).Association("Company").Count(); count != 0 { + t.Errorf("Invalid company count after clear, got %v", count) + } + + if count := DB.Model(&user2).Association("Manager").Count(); count != 0 { + t.Errorf("Invalid manager count after clear, got %v", count) + } } diff --git a/tests/count_test.go b/tests/count_test.go index 960db167..257959c3 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -33,7 +33,7 @@ func TestCount(t *testing.T) { var count3 int64 if err := DB.Model(&User{}).Where("name in ?", []string{user2.Name, user2.Name, user3.Name}).Group("id").Count(&count3).Error; err != nil { - t.Errorf("No error should happen when count with group, but got %v", err) + t.Errorf("Error happened when count with group, but got %v", err) } if count3 != 2 {