Add more updates test

This commit is contained in:
Jinzhu 2020-06-01 19:41:33 +08:00
parent dffc2713f0
commit 1559fe24e5
11 changed files with 419 additions and 22 deletions

View File

@ -86,6 +86,14 @@ func (association *Association) Replace(values ...interface{}) error {
case schema.BelongsTo: case schema.BelongsTo:
if len(values) == 0 { if len(values) == 0 {
updateMap := map[string]interface{}{} updateMap := map[string]interface{}{}
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface())
}
case reflect.Struct:
rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
}
for _, ref := range rel.References { for _, ref := range rel.References {
updateMap[ref.ForeignKey.DBName] = nil updateMap[ref.ForeignKey.DBName] = nil

View File

@ -24,6 +24,13 @@ func SaveBeforeAssociations(db *gorm.DB) {
if !ref.OwnPrimaryKey { if !ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(elem) pv, _ := ref.PrimaryKey.ValueOf(elem)
ref.ForeignKey.Set(obj, pv) ref.ForeignKey.Set(obj, pv)
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
dest[ref.ForeignKey.DBName] = pv
if _, ok := dest[rel.Name]; ok {
dest[rel.Name] = elem.Interface()
}
}
} }
} }
} }

View File

@ -37,6 +37,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
updateCallback := db.Callback().Update() updateCallback := db.Callback().Update()
updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue)
updateCallback.Register("gorm:before_update", BeforeUpdate) updateCallback.Register("gorm:before_update", BeforeUpdate)
updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations) updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations)
updateCallback.Register("gorm:update", Update) updateCallback.Register("gorm:update", Update)

View File

@ -37,6 +37,19 @@ func Query(db *gorm.DB) {
func BuildQuerySQL(db *gorm.DB) { func BuildQuerySQL(db *gorm.DB) {
clauseSelect := clause.Select{} clauseSelect := clause.Select{}
if db.Statement.ReflectValue.Kind() == reflect.Struct {
var conds []clause.Expression
for _, primaryField := range db.Statement.Schema.PrimaryFields {
if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero {
conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v})
}
}
if len(conds) > 0 {
db.Statement.AddClause(clause.Where{Exprs: conds})
}
}
if len(db.Statement.Selects) > 0 { if len(db.Statement.Selects) > 0 {
for _, name := range db.Statement.Selects { for _, name := range db.Statement.Selects {
if db.Statement.Schema == nil { if db.Statement.Schema == nil {

View File

@ -9,6 +9,25 @@ import (
"github.com/jinzhu/gorm/schema" "github.com/jinzhu/gorm/schema"
) )
func SetupUpdateReflectValue(db *gorm.DB) {
if db.Error == nil {
if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest {
db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
for db.Statement.ReflectValue.Kind() == reflect.Ptr {
db.Statement.ReflectValue = db.Statement.ReflectValue.Elem()
}
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
if _, ok := dest[rel.Name]; ok {
rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name])
}
}
}
}
}
}
func BeforeUpdate(db *gorm.DB) { func BeforeUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
tx := db.Session(&gorm.Session{}) tx := db.Session(&gorm.Session{})
@ -114,21 +133,20 @@ func AfterUpdate(db *gorm.DB) {
func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
var ( var (
selectColumns, restricted = SelectAndOmitColumns(stmt, false, true) selectColumns, restricted = SelectAndOmitColumns(stmt, false, true)
reflectModelValue = reflect.Indirect(reflect.ValueOf(stmt.Model))
assignValue func(field *schema.Field, value interface{}) assignValue func(field *schema.Field, value interface{})
) )
switch reflectModelValue.Kind() { switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
assignValue = func(field *schema.Field, value interface{}) { assignValue = func(field *schema.Field, value interface{}) {
for i := 0; i < reflectModelValue.Len(); i++ { for i := 0; i < stmt.ReflectValue.Len(); i++ {
field.Set(reflectModelValue.Index(i), value) field.Set(stmt.ReflectValue.Index(i), value)
} }
} }
case reflect.Struct: case reflect.Struct:
assignValue = func(field *schema.Field, value interface{}) { assignValue = func(field *schema.Field, value interface{}) {
if reflectModelValue.CanAddr() { if stmt.ReflectValue.CanAddr() {
field.Set(reflectModelValue, value) field.Set(stmt.ReflectValue, value)
} }
} }
default: default:
@ -136,7 +154,12 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
} }
} }
switch value := stmt.Dest.(type) { updatingValue := reflect.ValueOf(stmt.Dest)
for updatingValue.Kind() == reflect.Ptr {
updatingValue = updatingValue.Elem()
}
switch value := updatingValue.Interface().(type) {
case map[string]interface{}: case map[string]interface{}:
set = make([]clause.Assignment, 0, len(value)) set = make([]clause.Assignment, 0, len(value))
@ -148,8 +171,12 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
for _, k := range keys { for _, k := range keys {
if field := stmt.Schema.LookUpField(k); field != nil { if field := stmt.Schema.LookUpField(k); field != nil {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if field.DBName != "" {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]})
assignValue(field, value[k])
}
} else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) {
assignValue(field, value[k]) assignValue(field, value[k])
} }
} else if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { } else if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
@ -167,13 +194,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
} }
} }
default: default:
switch stmt.ReflectValue.Kind() { switch updatingValue.Kind() {
case reflect.Struct: case reflect.Struct:
set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
for _, field := range stmt.Schema.FieldsByDBName { for _, field := range stmt.Schema.FieldsByDBName {
if !field.PrimaryKey || (!stmt.ReflectValue.CanAddr() || stmt.Dest != stmt.Model) { if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
value, isZero := field.ValueOf(stmt.ReflectValue) value, isZero := field.ValueOf(updatingValue)
if !stmt.DisableUpdateTime { if !stmt.DisableUpdateTime {
if field.AutoUpdateTime > 0 { if field.AutoUpdateTime > 0 {
value = stmt.DB.NowFunc() value = stmt.DB.NowFunc()
@ -187,7 +214,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
} }
} }
} else { } else {
if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { if value, isZero := field.ValueOf(updatingValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
} }
} }
@ -195,16 +222,15 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
} }
} }
if !stmt.ReflectValue.CanAddr() || stmt.Dest != stmt.Model { if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
reflectValue := reflect.Indirect(reflect.ValueOf(stmt.Model)) switch stmt.ReflectValue.Kind() {
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
var priamryKeyExprs []clause.Expression var priamryKeyExprs []clause.Expression
for i := 0; i < reflectValue.Len(); i++ { for i := 0; i < stmt.ReflectValue.Len(); i++ {
var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields))
var notZero bool var notZero bool
for idx, field := range stmt.Schema.PrimaryFields { for idx, field := range stmt.Schema.PrimaryFields {
value, isZero := field.ValueOf(reflectValue.Index(i)) value, isZero := field.ValueOf(stmt.ReflectValue.Index(i))
exprs[idx] = clause.Eq{Column: field.DBName, Value: value} exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
notZero = notZero || !isZero notZero = notZero || !isZero
} }
@ -215,7 +241,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(priamryKeyExprs...)}}) stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(priamryKeyExprs...)}})
case reflect.Struct: case reflect.Struct:
for _, field := range stmt.Schema.PrimaryFields { for _, field := range stmt.Schema.PrimaryFields {
if value, isZero := field.ValueOf(reflectValue); !isZero { if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
} }
} }

View File

@ -347,6 +347,8 @@ func (field *Field) setupValuerAndSetter() {
if v.Type().Elem().Kind() == reflect.Struct { if v.Type().Elem().Kind() == reflect.Struct {
if !v.IsNil() { if !v.IsNil() {
v = v.Elem() v = v.Elem()
} else {
return nil, true
} }
} else { } else {
return nil, true return nil, true

View File

@ -8,7 +8,7 @@ import (
func AssertAssociationCount(t *testing.T, data interface{}, name string, result int64, reason string) { func AssertAssociationCount(t *testing.T, data interface{}, name string, result int64, reason string) {
if count := DB.Model(data).Association(name).Count(); count != result { if count := DB.Model(data).Association(name).Count(); count != result {
t.Errorf("invalid %v count %v, expects: %v got %v", name, reason, result, count) t.Fatalf("invalid %v count %v, expects: %v got %v", name, reason, result, count)
} }
var newUser User var newUser User
@ -20,7 +20,7 @@ func AssertAssociationCount(t *testing.T, data interface{}, name string, result
if newUser.ID != 0 { if newUser.ID != 0 {
if count := DB.Model(&newUser).Association(name).Count(); count != result { if count := DB.Model(&newUser).Association(name).Count(); count != result {
t.Errorf("invalid %v count %v, expects: %v got %v", name, reason, result, count) t.Fatalf("invalid %v count %v, expects: %v got %v", name, reason, result, count)
} }
} }
} }
@ -28,6 +28,6 @@ func AssertAssociationCount(t *testing.T, data interface{}, name string, result
func TestInvalidAssociation(t *testing.T) { func TestInvalidAssociation(t *testing.T) {
var user = *GetUser("invalid", Config{Company: true, Manager: true}) var user = *GetUser("invalid", Config{Company: true, Manager: true})
if err := DB.Model(&user).Association("Invalid").Find(&user.Company).Error; err == nil { if err := DB.Model(&user).Association("Invalid").Find(&user.Company).Error; err == nil {
t.Errorf("should return errors for invalid association, but got nil") t.Fatalf("should return errors for invalid association, but got nil")
} }
} }

View File

@ -31,12 +31,14 @@ func TestDelete(t *testing.T) {
} }
for _, user := range []User{users[0], users[2]} { for _, user := range []User{users[0], users[2]} {
result = User{}
if err := DB.Where("id = ?", user.ID).First(&result).Error; err != nil { if err := DB.Where("id = ?", user.ID).First(&result).Error; err != nil {
t.Errorf("no error should returns when query %v, but got %v", user.ID, err) t.Errorf("no error should returns when query %v, but got %v", user.ID, err)
} }
} }
for _, user := range []User{users[0], users[2]} { for _, user := range []User{users[0], users[2]} {
result = User{}
if err := DB.Where("id = ?", user.ID).First(&result).Error; err != nil { if err := DB.Where("id = ?", user.ID).First(&result).Error; err != nil {
t.Errorf("no error should returns when query %v, but got %v", user.ID, err) t.Errorf("no error should returns when query %v, but got %v", user.ID, err)
} }

View File

@ -264,10 +264,12 @@ func TestSearchWithEmptyChain(t *testing.T) {
t.Errorf("Should not raise any error if searching with empty strings") t.Errorf("Should not raise any error if searching with empty strings")
} }
result = User{}
if DB.Where(&User{}).Where("name = ?", user.Name).First(&result).Error != nil { if DB.Where(&User{}).Where("name = ?", user.Name).First(&result).Error != nil {
t.Errorf("Should not raise any error if searching with empty struct") t.Errorf("Should not raise any error if searching with empty struct")
} }
result = User{}
if DB.Where(map[string]interface{}{}).Where("name = ?", user.Name).First(&result).Error != nil { if DB.Where(map[string]interface{}{}).Where("name = ?", user.Name).First(&result).Error != nil {
t.Errorf("Should not raise any error if searching with empty map") t.Errorf("Should not raise any error if searching with empty map")
} }
@ -319,6 +321,7 @@ func TestSearchWithMap(t *testing.T) {
DB.First(&user, map[string]interface{}{"name": users[0].Name}) DB.First(&user, map[string]interface{}{"name": users[0].Name})
CheckUser(t, user, users[0]) CheckUser(t, user, users[0])
user = User{}
DB.Where(map[string]interface{}{"name": users[1].Name}).First(&user) DB.Where(map[string]interface{}{"name": users[1].Name}).First(&user)
CheckUser(t, user, users[1]) CheckUser(t, user, users[1])

View File

@ -2,6 +2,8 @@ package tests_test
import ( import (
"errors" "errors"
"sort"
"strings"
"testing" "testing"
"time" "time"
@ -218,3 +220,304 @@ func TestBlockGlobalUpdate(t *testing.T) {
t.Errorf("should returns missing WHERE clause while updating error, got err %v", err) t.Errorf("should returns missing WHERE clause while updating error, got err %v", err)
} }
} }
func TestSelectWithUpdate(t *testing.T) {
user := *GetUser("select_update", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4})
DB.Create(&user)
var result User
DB.First(&result, user.ID)
user2 := *GetUser("select_update_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4})
result.Name = user2.Name
result.Age = 50
result.Account = user2.Account
result.Pets = user2.Pets
result.Toys = user2.Toys
result.Company = user2.Company
result.Manager = user2.Manager
result.Team = user2.Team
result.Languages = user2.Languages
result.Friends = user2.Friends
DB.Select("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Save(&result)
var result2 User
DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID)
result.Languages = append(user.Languages, result.Languages...)
result.Toys = append(user.Toys, result.Toys...)
sort.Slice(result.Languages, func(i, j int) bool {
return strings.Compare(result.Languages[i].Code, result.Languages[j].Code) > 0
})
sort.Slice(result.Toys, func(i, j int) bool {
return result.Toys[i].ID < result.Toys[j].ID
})
sort.Slice(result2.Languages, func(i, j int) bool {
return strings.Compare(result2.Languages[i].Code, result2.Languages[j].Code) > 0
})
sort.Slice(result2.Toys, func(i, j int) bool {
return result2.Toys[i].ID < result2.Toys[j].ID
})
AssertObjEqual(t, result2, result, "Name", "Account", "Toys", "Manager", "ManagerID", "Languages")
}
func TestSelectWithUpdateWithMap(t *testing.T) {
user := *GetUser("select_update_map", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4})
DB.Create(&user)
var result User
DB.First(&result, user.ID)
user2 := *GetUser("select_update_map_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4})
updateValues := map[string]interface{}{
"Name": user2.Name,
"Age": 50,
"Account": user2.Account,
"Pets": user2.Pets,
"Toys": user2.Toys,
"Company": user2.Company,
"Manager": user2.Manager,
"Team": user2.Team,
"Languages": user2.Languages,
"Friends": user2.Friends,
}
DB.Model(&result).Select("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Updates(updateValues)
var result2 User
DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID)
result.Languages = append(user.Languages, result.Languages...)
result.Toys = append(user.Toys, result.Toys...)
sort.Slice(result.Languages, func(i, j int) bool {
return strings.Compare(result.Languages[i].Code, result.Languages[j].Code) > 0
})
sort.Slice(result.Toys, func(i, j int) bool {
return result.Toys[i].ID < result.Toys[j].ID
})
sort.Slice(result2.Languages, func(i, j int) bool {
return strings.Compare(result2.Languages[i].Code, result2.Languages[j].Code) > 0
})
sort.Slice(result2.Toys, func(i, j int) bool {
return result2.Toys[i].ID < result2.Toys[j].ID
})
AssertObjEqual(t, result2, result, "Name", "Account", "Toys", "Manager", "ManagerID", "Languages")
}
func TestOmitWithUpdate(t *testing.T) {
user := *GetUser("omit_update", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4})
DB.Create(&user)
var result User
DB.First(&result, user.ID)
user2 := *GetUser("omit_update_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4})
result.Name = user2.Name
result.Age = 50
result.Account = user2.Account
result.Pets = user2.Pets
result.Toys = user2.Toys
result.Company = user2.Company
result.Manager = user2.Manager
result.Team = user2.Team
result.Languages = user2.Languages
result.Friends = user2.Friends
DB.Omit("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Save(&result)
var result2 User
DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID)
result.Pets = append(user.Pets, result.Pets...)
result.Team = append(user.Team, result.Team...)
result.Friends = append(user.Friends, result.Friends...)
sort.Slice(result.Pets, func(i, j int) bool {
return result.Pets[i].ID < result.Pets[j].ID
})
sort.Slice(result.Team, func(i, j int) bool {
return result.Team[i].ID < result.Team[j].ID
})
sort.Slice(result.Friends, func(i, j int) bool {
return result.Friends[i].ID < result.Friends[j].ID
})
sort.Slice(result2.Pets, func(i, j int) bool {
return result2.Pets[i].ID < result2.Pets[j].ID
})
sort.Slice(result2.Team, func(i, j int) bool {
return result2.Team[i].ID < result2.Team[j].ID
})
sort.Slice(result2.Friends, func(i, j int) bool {
return result2.Friends[i].ID < result2.Friends[j].ID
})
AssertObjEqual(t, result2, result, "Age", "Pets", "Company", "CompanyID", "Team", "Friends")
}
func TestOmitWithUpdateWithMap(t *testing.T) {
user := *GetUser("omit_update_map", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4})
DB.Create(&user)
var result User
DB.First(&result, user.ID)
user2 := *GetUser("omit_update_map_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4})
updateValues := map[string]interface{}{
"Name": user2.Name,
"Age": 50,
"Account": user2.Account,
"Pets": user2.Pets,
"Toys": user2.Toys,
"Company": user2.Company,
"Manager": user2.Manager,
"Team": user2.Team,
"Languages": user2.Languages,
"Friends": user2.Friends,
}
DB.Model(&result).Omit("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Updates(updateValues)
var result2 User
DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID)
result.Pets = append(user.Pets, result.Pets...)
result.Team = append(user.Team, result.Team...)
result.Friends = append(user.Friends, result.Friends...)
sort.Slice(result.Pets, func(i, j int) bool {
return result.Pets[i].ID < result.Pets[j].ID
})
sort.Slice(result.Team, func(i, j int) bool {
return result.Team[i].ID < result.Team[j].ID
})
sort.Slice(result.Friends, func(i, j int) bool {
return result.Friends[i].ID < result.Friends[j].ID
})
sort.Slice(result2.Pets, func(i, j int) bool {
return result2.Pets[i].ID < result2.Pets[j].ID
})
sort.Slice(result2.Team, func(i, j int) bool {
return result2.Team[i].ID < result2.Team[j].ID
})
sort.Slice(result2.Friends, func(i, j int) bool {
return result2.Friends[i].ID < result2.Friends[j].ID
})
AssertObjEqual(t, result2, result, "Age", "Pets", "Company", "CompanyID", "Team", "Friends")
}
func TestSelectWithUpdateColumn(t *testing.T) {
user := *GetUser("select_with_update_column", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4})
DB.Create(&user)
updateValues := map[string]interface{}{"Name": "new_name", "Age": 50}
var result User
DB.First(&result, user.ID)
DB.Model(&result).Select("Name").UpdateColumns(updateValues)
var result2 User
DB.First(&result2, user.ID)
if result2.Name == user.Name || result2.Age != user.Age {
t.Errorf("Should only update users with name column")
}
}
func TestOmitWithUpdateColumn(t *testing.T) {
user := *GetUser("omit_with_update_column", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4})
DB.Create(&user)
updateValues := map[string]interface{}{"Name": "new_name", "Age": 50}
var result User
DB.First(&result, user.ID)
DB.Model(&result).Omit("Name").UpdateColumns(updateValues)
var result2 User
DB.First(&result2, user.ID)
if result2.Name != user.Name || result2.Age == user.Age {
t.Errorf("Should only update users with name column")
}
}
func TestUpdateColumnsSkipsAssociations(t *testing.T) {
user := *GetUser("update_column_skips_association", Config{})
DB.Create(&user)
// Update a single field of the user and verify that the changed address is not stored.
newAge := uint(100)
user.Account.Number = "new_account_number"
db := DB.Model(&user).UpdateColumns(User{Age: newAge})
if db.RowsAffected != 1 {
t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", db.RowsAffected)
}
// Verify that Age now=`newAge`.
result := &User{}
result.ID = user.ID
DB.Preload("Account").First(result)
if result.Age != newAge {
t.Errorf("Expected freshly queried user to have Age=%v but instead found Age=%v", newAge, result.Age)
}
if result.Account.Number != user.Account.Number {
t.Errorf("account number should not been changed, expects: %v, got %v", user.Account.Number, result.Account.Number)
}
}
func TestUpdatesWithBlankValues(t *testing.T) {
user := *GetUser("updates_with_blank_value", Config{})
DB.Save(&user)
var user2 User
user2.ID = user.ID
DB.Model(&user2).Updates(&User{Age: 100})
var result User
DB.First(&result, user.ID)
if result.Name != user.Name || result.Age != 100 {
t.Errorf("user's name should not be updated")
}
}
func TestUpdatesTableWithIgnoredValues(t *testing.T) {
type ElementWithIgnoredField struct {
Id int64
Value string
IgnoredField int64 `gorm:"-"`
}
DB.Migrator().DropTable(&ElementWithIgnoredField{})
DB.AutoMigrate(&ElementWithIgnoredField{})
elem := ElementWithIgnoredField{Value: "foo", IgnoredField: 10}
DB.Save(&elem)
DB.Model(&ElementWithIgnoredField{}).
Where("id = ?", elem.Id).
Updates(&ElementWithIgnoredField{Value: "bar", IgnoredField: 100})
var result ElementWithIgnoredField
if err := DB.First(&result, elem.Id).Error; err != nil {
t.Errorf("error getting an element from database: %s", err.Error())
}
if result.IgnoredField != 0 {
t.Errorf("element's ignored field should not be updated")
}
}

View File

@ -3,6 +3,7 @@ package tests
import ( import (
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"go/ast"
"reflect" "reflect"
"sort" "sort"
"strconv" "strconv"
@ -126,6 +127,37 @@ func AssertEqual(t *testing.T, got, expect interface{}) {
return return
} }
if reflect.ValueOf(got).Kind() == reflect.Slice {
if reflect.ValueOf(expect).Kind() == reflect.Slice {
if reflect.ValueOf(got).Len() == reflect.ValueOf(expect).Len() {
for i := 0; i < reflect.ValueOf(got).Len(); i++ {
name := fmt.Sprintf(reflect.ValueOf(got).Type().Name()+" #%v", i)
t.Run(name, func(t *testing.T) {
AssertEqual(t, reflect.ValueOf(got).Index(i).Interface(), reflect.ValueOf(expect).Index(i).Interface())
})
}
} else {
name := reflect.ValueOf(got).Type().Elem().Name()
t.Errorf("%v expects length: %v, got %v", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len())
}
return
}
}
if reflect.ValueOf(got).Kind() == reflect.Struct {
if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() {
for i := 0; i < reflect.ValueOf(got).NumField(); i++ {
if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) {
field := reflect.ValueOf(got).Field(i)
t.Run(fieldStruct.Name, func(t *testing.T) {
AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface())
})
}
}
return
}
}
if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) { if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) {
got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface() got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface()
isEqual() isEqual()