diff --git a/association.go b/association.go index bed89837..55dd7772 100644 --- a/association.go +++ b/association.go @@ -86,6 +86,14 @@ func (association *Association) Replace(values ...interface{}) error { case schema.BelongsTo: if len(values) == 0 { updateMap := map[string]interface{}{} + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface()) + } + case reflect.Struct: + rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface()) + } for _, ref := range rel.References { updateMap[ref.ForeignKey.DBName] = nil diff --git a/callbacks/associations.go b/callbacks/associations.go index 3c8c2a50..d19f7339 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -24,6 +24,13 @@ func SaveBeforeAssociations(db *gorm.DB) { if !ref.OwnPrimaryKey { pv, _ := ref.PrimaryKey.ValueOf(elem) ref.ForeignKey.Set(obj, pv) + + if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { + dest[ref.ForeignKey.DBName] = pv + if _, ok := dest[rel.Name]; ok { + dest[rel.Name] = elem.Interface() + } + } } } } diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index 1985aec2..1c1d6ade 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -37,6 +37,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { updateCallback := db.Callback().Update() updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) + updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue) updateCallback.Register("gorm:before_update", BeforeUpdate) updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations) updateCallback.Register("gorm:update", Update) diff --git a/callbacks/query.go b/callbacks/query.go index 91948031..e4e76665 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -37,6 +37,19 @@ func Query(db *gorm.DB) { func BuildQuerySQL(db *gorm.DB) { clauseSelect := clause.Select{} + if db.Statement.ReflectValue.Kind() == reflect.Struct { + var conds []clause.Expression + for _, primaryField := range db.Statement.Schema.PrimaryFields { + if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero { + conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v}) + } + } + + if len(conds) > 0 { + db.Statement.AddClause(clause.Where{Exprs: conds}) + } + } + if len(db.Statement.Selects) > 0 { for _, name := range db.Statement.Selects { if db.Statement.Schema == nil { diff --git a/callbacks/update.go b/callbacks/update.go index cbbcddf7..fda07676 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -9,6 +9,25 @@ import ( "github.com/jinzhu/gorm/schema" ) +func SetupUpdateReflectValue(db *gorm.DB) { + if db.Error == nil { + if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest { + db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model) + for db.Statement.ReflectValue.Kind() == reflect.Ptr { + db.Statement.ReflectValue = db.Statement.ReflectValue.Elem() + } + + if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { + for _, rel := range db.Statement.Schema.Relationships.BelongsTo { + if _, ok := dest[rel.Name]; ok { + rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name]) + } + } + } + } + } +} + func BeforeUpdate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { tx := db.Session(&gorm.Session{}) @@ -114,21 +133,20 @@ func AfterUpdate(db *gorm.DB) { func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { var ( selectColumns, restricted = SelectAndOmitColumns(stmt, false, true) - reflectModelValue = reflect.Indirect(reflect.ValueOf(stmt.Model)) assignValue func(field *schema.Field, value interface{}) ) - switch reflectModelValue.Kind() { + switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: assignValue = func(field *schema.Field, value interface{}) { - for i := 0; i < reflectModelValue.Len(); i++ { - field.Set(reflectModelValue.Index(i), value) + for i := 0; i < stmt.ReflectValue.Len(); i++ { + field.Set(stmt.ReflectValue.Index(i), value) } } case reflect.Struct: assignValue = func(field *schema.Field, value interface{}) { - if reflectModelValue.CanAddr() { - field.Set(reflectModelValue, value) + if stmt.ReflectValue.CanAddr() { + field.Set(stmt.ReflectValue, value) } } default: @@ -136,7 +154,12 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } - switch value := stmt.Dest.(type) { + updatingValue := reflect.ValueOf(stmt.Dest) + for updatingValue.Kind() == reflect.Ptr { + updatingValue = updatingValue.Elem() + } + + switch value := updatingValue.Interface().(type) { case map[string]interface{}: set = make([]clause.Assignment, 0, len(value)) @@ -148,8 +171,12 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { for _, k := range keys { if field := stmt.Schema.LookUpField(k); field != nil { - if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) + if field.DBName != "" { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) + assignValue(field, value[k]) + } + } else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) { assignValue(field, value[k]) } } else if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { @@ -167,13 +194,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } default: - switch stmt.ReflectValue.Kind() { + switch updatingValue.Kind() { case reflect.Struct: set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) for _, field := range stmt.Schema.FieldsByDBName { - if !field.PrimaryKey || (!stmt.ReflectValue.CanAddr() || stmt.Dest != stmt.Model) { + if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - value, isZero := field.ValueOf(stmt.ReflectValue) + value, isZero := field.ValueOf(updatingValue) if !stmt.DisableUpdateTime { if field.AutoUpdateTime > 0 { value = stmt.DB.NowFunc() @@ -187,7 +214,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } } else { - if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { + if value, isZero := field.ValueOf(updatingValue); !isZero { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) } } @@ -195,16 +222,15 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } - if !stmt.ReflectValue.CanAddr() || stmt.Dest != stmt.Model { - reflectValue := reflect.Indirect(reflect.ValueOf(stmt.Model)) - switch reflectValue.Kind() { + if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { + switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: var priamryKeyExprs []clause.Expression - for i := 0; i < reflectValue.Len(); i++ { + for i := 0; i < stmt.ReflectValue.Len(); i++ { var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) var notZero bool for idx, field := range stmt.Schema.PrimaryFields { - value, isZero := field.ValueOf(reflectValue.Index(i)) + value, isZero := field.ValueOf(stmt.ReflectValue.Index(i)) exprs[idx] = clause.Eq{Column: field.DBName, Value: value} notZero = notZero || !isZero } @@ -215,7 +241,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(priamryKeyExprs...)}}) case reflect.Struct: for _, field := range stmt.Schema.PrimaryFields { - if value, isZero := field.ValueOf(reflectValue); !isZero { + if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) } } diff --git a/schema/field.go b/schema/field.go index f52dd6a6..8a0f01bf 100644 --- a/schema/field.go +++ b/schema/field.go @@ -347,6 +347,8 @@ func (field *Field) setupValuerAndSetter() { if v.Type().Elem().Kind() == reflect.Struct { if !v.IsNil() { v = v.Elem() + } else { + return nil, true } } else { return nil, true diff --git a/tests/associations_test.go b/tests/associations_test.go index 89bbe142..3668b44b 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -8,7 +8,7 @@ import ( func AssertAssociationCount(t *testing.T, data interface{}, name string, result int64, reason string) { if count := DB.Model(data).Association(name).Count(); count != result { - t.Errorf("invalid %v count %v, expects: %v got %v", name, reason, result, count) + t.Fatalf("invalid %v count %v, expects: %v got %v", name, reason, result, count) } var newUser User @@ -20,7 +20,7 @@ func AssertAssociationCount(t *testing.T, data interface{}, name string, result if newUser.ID != 0 { if count := DB.Model(&newUser).Association(name).Count(); count != result { - t.Errorf("invalid %v count %v, expects: %v got %v", name, reason, result, count) + t.Fatalf("invalid %v count %v, expects: %v got %v", name, reason, result, count) } } } @@ -28,6 +28,6 @@ func AssertAssociationCount(t *testing.T, data interface{}, name string, result func TestInvalidAssociation(t *testing.T) { var user = *GetUser("invalid", Config{Company: true, Manager: true}) if err := DB.Model(&user).Association("Invalid").Find(&user.Company).Error; err == nil { - t.Errorf("should return errors for invalid association, but got nil") + t.Fatalf("should return errors for invalid association, but got nil") } } diff --git a/tests/delete_test.go b/tests/delete_test.go index 4288253f..e7076aa6 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -31,12 +31,14 @@ func TestDelete(t *testing.T) { } for _, user := range []User{users[0], users[2]} { + result = User{} if err := DB.Where("id = ?", user.ID).First(&result).Error; err != nil { t.Errorf("no error should returns when query %v, but got %v", user.ID, err) } } for _, user := range []User{users[0], users[2]} { + result = User{} if err := DB.Where("id = ?", user.ID).First(&result).Error; err != nil { t.Errorf("no error should returns when query %v, but got %v", user.ID, err) } diff --git a/tests/query_test.go b/tests/query_test.go index 6efadc8e..73b6dca3 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -264,10 +264,12 @@ func TestSearchWithEmptyChain(t *testing.T) { t.Errorf("Should not raise any error if searching with empty strings") } + result = User{} if DB.Where(&User{}).Where("name = ?", user.Name).First(&result).Error != nil { t.Errorf("Should not raise any error if searching with empty struct") } + result = User{} if DB.Where(map[string]interface{}{}).Where("name = ?", user.Name).First(&result).Error != nil { t.Errorf("Should not raise any error if searching with empty map") } @@ -319,6 +321,7 @@ func TestSearchWithMap(t *testing.T) { DB.First(&user, map[string]interface{}{"name": users[0].Name}) CheckUser(t, user, users[0]) + user = User{} DB.Where(map[string]interface{}{"name": users[1].Name}).First(&user) CheckUser(t, user, users[1]) diff --git a/tests/update_test.go b/tests/update_test.go index 869ce4cd..a5a62237 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -2,6 +2,8 @@ package tests_test import ( "errors" + "sort" + "strings" "testing" "time" @@ -218,3 +220,304 @@ func TestBlockGlobalUpdate(t *testing.T) { t.Errorf("should returns missing WHERE clause while updating error, got err %v", err) } } + +func TestSelectWithUpdate(t *testing.T) { + user := *GetUser("select_update", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Create(&user) + + var result User + DB.First(&result, user.ID) + + user2 := *GetUser("select_update_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + result.Name = user2.Name + result.Age = 50 + result.Account = user2.Account + result.Pets = user2.Pets + result.Toys = user2.Toys + result.Company = user2.Company + result.Manager = user2.Manager + result.Team = user2.Team + result.Languages = user2.Languages + result.Friends = user2.Friends + + DB.Select("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Save(&result) + + var result2 User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID) + + result.Languages = append(user.Languages, result.Languages...) + result.Toys = append(user.Toys, result.Toys...) + + sort.Slice(result.Languages, func(i, j int) bool { + return strings.Compare(result.Languages[i].Code, result.Languages[j].Code) > 0 + }) + + sort.Slice(result.Toys, func(i, j int) bool { + return result.Toys[i].ID < result.Toys[j].ID + }) + + sort.Slice(result2.Languages, func(i, j int) bool { + return strings.Compare(result2.Languages[i].Code, result2.Languages[j].Code) > 0 + }) + + sort.Slice(result2.Toys, func(i, j int) bool { + return result2.Toys[i].ID < result2.Toys[j].ID + }) + + AssertObjEqual(t, result2, result, "Name", "Account", "Toys", "Manager", "ManagerID", "Languages") +} + +func TestSelectWithUpdateWithMap(t *testing.T) { + user := *GetUser("select_update_map", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Create(&user) + + var result User + DB.First(&result, user.ID) + + user2 := *GetUser("select_update_map_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + updateValues := map[string]interface{}{ + "Name": user2.Name, + "Age": 50, + "Account": user2.Account, + "Pets": user2.Pets, + "Toys": user2.Toys, + "Company": user2.Company, + "Manager": user2.Manager, + "Team": user2.Team, + "Languages": user2.Languages, + "Friends": user2.Friends, + } + + DB.Model(&result).Select("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Updates(updateValues) + + var result2 User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID) + + result.Languages = append(user.Languages, result.Languages...) + result.Toys = append(user.Toys, result.Toys...) + + sort.Slice(result.Languages, func(i, j int) bool { + return strings.Compare(result.Languages[i].Code, result.Languages[j].Code) > 0 + }) + + sort.Slice(result.Toys, func(i, j int) bool { + return result.Toys[i].ID < result.Toys[j].ID + }) + + sort.Slice(result2.Languages, func(i, j int) bool { + return strings.Compare(result2.Languages[i].Code, result2.Languages[j].Code) > 0 + }) + + sort.Slice(result2.Toys, func(i, j int) bool { + return result2.Toys[i].ID < result2.Toys[j].ID + }) + + AssertObjEqual(t, result2, result, "Name", "Account", "Toys", "Manager", "ManagerID", "Languages") +} + +func TestOmitWithUpdate(t *testing.T) { + user := *GetUser("omit_update", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Create(&user) + + var result User + DB.First(&result, user.ID) + + user2 := *GetUser("omit_update_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + result.Name = user2.Name + result.Age = 50 + result.Account = user2.Account + result.Pets = user2.Pets + result.Toys = user2.Toys + result.Company = user2.Company + result.Manager = user2.Manager + result.Team = user2.Team + result.Languages = user2.Languages + result.Friends = user2.Friends + + DB.Omit("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Save(&result) + + var result2 User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID) + + result.Pets = append(user.Pets, result.Pets...) + result.Team = append(user.Team, result.Team...) + result.Friends = append(user.Friends, result.Friends...) + + sort.Slice(result.Pets, func(i, j int) bool { + return result.Pets[i].ID < result.Pets[j].ID + }) + sort.Slice(result.Team, func(i, j int) bool { + return result.Team[i].ID < result.Team[j].ID + }) + sort.Slice(result.Friends, func(i, j int) bool { + return result.Friends[i].ID < result.Friends[j].ID + }) + sort.Slice(result2.Pets, func(i, j int) bool { + return result2.Pets[i].ID < result2.Pets[j].ID + }) + sort.Slice(result2.Team, func(i, j int) bool { + return result2.Team[i].ID < result2.Team[j].ID + }) + sort.Slice(result2.Friends, func(i, j int) bool { + return result2.Friends[i].ID < result2.Friends[j].ID + }) + + AssertObjEqual(t, result2, result, "Age", "Pets", "Company", "CompanyID", "Team", "Friends") +} + +func TestOmitWithUpdateWithMap(t *testing.T) { + user := *GetUser("omit_update_map", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Create(&user) + + var result User + DB.First(&result, user.ID) + + user2 := *GetUser("omit_update_map_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + updateValues := map[string]interface{}{ + "Name": user2.Name, + "Age": 50, + "Account": user2.Account, + "Pets": user2.Pets, + "Toys": user2.Toys, + "Company": user2.Company, + "Manager": user2.Manager, + "Team": user2.Team, + "Languages": user2.Languages, + "Friends": user2.Friends, + } + + DB.Model(&result).Omit("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Updates(updateValues) + + var result2 User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID) + + result.Pets = append(user.Pets, result.Pets...) + result.Team = append(user.Team, result.Team...) + result.Friends = append(user.Friends, result.Friends...) + + sort.Slice(result.Pets, func(i, j int) bool { + return result.Pets[i].ID < result.Pets[j].ID + }) + sort.Slice(result.Team, func(i, j int) bool { + return result.Team[i].ID < result.Team[j].ID + }) + sort.Slice(result.Friends, func(i, j int) bool { + return result.Friends[i].ID < result.Friends[j].ID + }) + sort.Slice(result2.Pets, func(i, j int) bool { + return result2.Pets[i].ID < result2.Pets[j].ID + }) + sort.Slice(result2.Team, func(i, j int) bool { + return result2.Team[i].ID < result2.Team[j].ID + }) + sort.Slice(result2.Friends, func(i, j int) bool { + return result2.Friends[i].ID < result2.Friends[j].ID + }) + + AssertObjEqual(t, result2, result, "Age", "Pets", "Company", "CompanyID", "Team", "Friends") +} + +func TestSelectWithUpdateColumn(t *testing.T) { + user := *GetUser("select_with_update_column", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Create(&user) + + updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} + + var result User + DB.First(&result, user.ID) + DB.Model(&result).Select("Name").UpdateColumns(updateValues) + + var result2 User + DB.First(&result2, user.ID) + + if result2.Name == user.Name || result2.Age != user.Age { + t.Errorf("Should only update users with name column") + } +} + +func TestOmitWithUpdateColumn(t *testing.T) { + user := *GetUser("omit_with_update_column", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Create(&user) + + updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} + + var result User + DB.First(&result, user.ID) + DB.Model(&result).Omit("Name").UpdateColumns(updateValues) + + var result2 User + DB.First(&result2, user.ID) + + if result2.Name != user.Name || result2.Age == user.Age { + t.Errorf("Should only update users with name column") + } +} + +func TestUpdateColumnsSkipsAssociations(t *testing.T) { + user := *GetUser("update_column_skips_association", Config{}) + DB.Create(&user) + + // Update a single field of the user and verify that the changed address is not stored. + newAge := uint(100) + user.Account.Number = "new_account_number" + db := DB.Model(&user).UpdateColumns(User{Age: newAge}) + + if db.RowsAffected != 1 { + t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", db.RowsAffected) + } + + // Verify that Age now=`newAge`. + result := &User{} + result.ID = user.ID + DB.Preload("Account").First(result) + + if result.Age != newAge { + t.Errorf("Expected freshly queried user to have Age=%v but instead found Age=%v", newAge, result.Age) + } + + if result.Account.Number != user.Account.Number { + t.Errorf("account number should not been changed, expects: %v, got %v", user.Account.Number, result.Account.Number) + } +} + +func TestUpdatesWithBlankValues(t *testing.T) { + user := *GetUser("updates_with_blank_value", Config{}) + DB.Save(&user) + + var user2 User + user2.ID = user.ID + DB.Model(&user2).Updates(&User{Age: 100}) + + var result User + DB.First(&result, user.ID) + + if result.Name != user.Name || result.Age != 100 { + t.Errorf("user's name should not be updated") + } +} + +func TestUpdatesTableWithIgnoredValues(t *testing.T) { + type ElementWithIgnoredField struct { + Id int64 + Value string + IgnoredField int64 `gorm:"-"` + } + DB.Migrator().DropTable(&ElementWithIgnoredField{}) + DB.AutoMigrate(&ElementWithIgnoredField{}) + + elem := ElementWithIgnoredField{Value: "foo", IgnoredField: 10} + DB.Save(&elem) + + DB.Model(&ElementWithIgnoredField{}). + Where("id = ?", elem.Id). + Updates(&ElementWithIgnoredField{Value: "bar", IgnoredField: 100}) + + var result ElementWithIgnoredField + if err := DB.First(&result, elem.Id).Error; err != nil { + t.Errorf("error getting an element from database: %s", err.Error()) + } + + if result.IgnoredField != 0 { + t.Errorf("element's ignored field should not be updated") + } +} diff --git a/tests/utils.go b/tests/utils.go index 7cc6d2bc..97b5d5c8 100644 --- a/tests/utils.go +++ b/tests/utils.go @@ -3,6 +3,7 @@ package tests import ( "database/sql/driver" "fmt" + "go/ast" "reflect" "sort" "strconv" @@ -126,6 +127,37 @@ func AssertEqual(t *testing.T, got, expect interface{}) { return } + if reflect.ValueOf(got).Kind() == reflect.Slice { + if reflect.ValueOf(expect).Kind() == reflect.Slice { + if reflect.ValueOf(got).Len() == reflect.ValueOf(expect).Len() { + for i := 0; i < reflect.ValueOf(got).Len(); i++ { + name := fmt.Sprintf(reflect.ValueOf(got).Type().Name()+" #%v", i) + t.Run(name, func(t *testing.T) { + AssertEqual(t, reflect.ValueOf(got).Index(i).Interface(), reflect.ValueOf(expect).Index(i).Interface()) + }) + } + } else { + name := reflect.ValueOf(got).Type().Elem().Name() + t.Errorf("%v expects length: %v, got %v", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len()) + } + return + } + } + + if reflect.ValueOf(got).Kind() == reflect.Struct { + if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() { + for i := 0; i < reflect.ValueOf(got).NumField(); i++ { + if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) { + field := reflect.ValueOf(got).Field(i) + t.Run(fieldStruct.Name, func(t *testing.T) { + AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface()) + }) + } + } + return + } + } + if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) { got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface() isEqual()