mirror of https://github.com/go-gorm/gorm.git
Add more updates test
This commit is contained in:
parent
dffc2713f0
commit
1559fe24e5
|
@ -86,6 +86,14 @@ func (association *Association) Replace(values ...interface{}) error {
|
|||
case schema.BelongsTo:
|
||||
if len(values) == 0 {
|
||||
updateMap := map[string]interface{}{}
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
|
||||
}
|
||||
|
||||
for _, ref := range rel.References {
|
||||
updateMap[ref.ForeignKey.DBName] = nil
|
||||
|
|
|
@ -24,6 +24,13 @@ func SaveBeforeAssociations(db *gorm.DB) {
|
|||
if !ref.OwnPrimaryKey {
|
||||
pv, _ := ref.PrimaryKey.ValueOf(elem)
|
||||
ref.ForeignKey.Set(obj, pv)
|
||||
|
||||
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
|
||||
dest[ref.ForeignKey.DBName] = pv
|
||||
if _, ok := dest[rel.Name]; ok {
|
||||
dest[rel.Name] = elem.Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,6 +37,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
|
|||
|
||||
updateCallback := db.Callback().Update()
|
||||
updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
||||
updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue)
|
||||
updateCallback.Register("gorm:before_update", BeforeUpdate)
|
||||
updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations)
|
||||
updateCallback.Register("gorm:update", Update)
|
||||
|
|
|
@ -37,6 +37,19 @@ func Query(db *gorm.DB) {
|
|||
func BuildQuerySQL(db *gorm.DB) {
|
||||
clauseSelect := clause.Select{}
|
||||
|
||||
if db.Statement.ReflectValue.Kind() == reflect.Struct {
|
||||
var conds []clause.Expression
|
||||
for _, primaryField := range db.Statement.Schema.PrimaryFields {
|
||||
if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v})
|
||||
}
|
||||
}
|
||||
|
||||
if len(conds) > 0 {
|
||||
db.Statement.AddClause(clause.Where{Exprs: conds})
|
||||
}
|
||||
}
|
||||
|
||||
if len(db.Statement.Selects) > 0 {
|
||||
for _, name := range db.Statement.Selects {
|
||||
if db.Statement.Schema == nil {
|
||||
|
|
|
@ -9,6 +9,25 @@ import (
|
|||
"github.com/jinzhu/gorm/schema"
|
||||
)
|
||||
|
||||
func SetupUpdateReflectValue(db *gorm.DB) {
|
||||
if db.Error == nil {
|
||||
if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest {
|
||||
db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
|
||||
for db.Statement.ReflectValue.Kind() == reflect.Ptr {
|
||||
db.Statement.ReflectValue = db.Statement.ReflectValue.Elem()
|
||||
}
|
||||
|
||||
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
|
||||
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
|
||||
if _, ok := dest[rel.Name]; ok {
|
||||
rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BeforeUpdate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
|
||||
tx := db.Session(&gorm.Session{})
|
||||
|
@ -114,21 +133,20 @@ func AfterUpdate(db *gorm.DB) {
|
|||
func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||
var (
|
||||
selectColumns, restricted = SelectAndOmitColumns(stmt, false, true)
|
||||
reflectModelValue = reflect.Indirect(reflect.ValueOf(stmt.Model))
|
||||
assignValue func(field *schema.Field, value interface{})
|
||||
)
|
||||
|
||||
switch reflectModelValue.Kind() {
|
||||
switch stmt.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
assignValue = func(field *schema.Field, value interface{}) {
|
||||
for i := 0; i < reflectModelValue.Len(); i++ {
|
||||
field.Set(reflectModelValue.Index(i), value)
|
||||
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
||||
field.Set(stmt.ReflectValue.Index(i), value)
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
assignValue = func(field *schema.Field, value interface{}) {
|
||||
if reflectModelValue.CanAddr() {
|
||||
field.Set(reflectModelValue, value)
|
||||
if stmt.ReflectValue.CanAddr() {
|
||||
field.Set(stmt.ReflectValue, value)
|
||||
}
|
||||
}
|
||||
default:
|
||||
|
@ -136,7 +154,12 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||
}
|
||||
}
|
||||
|
||||
switch value := stmt.Dest.(type) {
|
||||
updatingValue := reflect.ValueOf(stmt.Dest)
|
||||
for updatingValue.Kind() == reflect.Ptr {
|
||||
updatingValue = updatingValue.Elem()
|
||||
}
|
||||
|
||||
switch value := updatingValue.Interface().(type) {
|
||||
case map[string]interface{}:
|
||||
set = make([]clause.Assignment, 0, len(value))
|
||||
|
||||
|
@ -148,8 +171,12 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||
|
||||
for _, k := range keys {
|
||||
if field := stmt.Schema.LookUpField(k); field != nil {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]})
|
||||
if field.DBName != "" {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]})
|
||||
assignValue(field, value[k])
|
||||
}
|
||||
} else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) {
|
||||
assignValue(field, value[k])
|
||||
}
|
||||
} else if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
||||
|
@ -167,13 +194,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||
}
|
||||
}
|
||||
default:
|
||||
switch stmt.ReflectValue.Kind() {
|
||||
switch updatingValue.Kind() {
|
||||
case reflect.Struct:
|
||||
set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
|
||||
for _, field := range stmt.Schema.FieldsByDBName {
|
||||
if !field.PrimaryKey || (!stmt.ReflectValue.CanAddr() || stmt.Dest != stmt.Model) {
|
||||
if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
value, isZero := field.ValueOf(stmt.ReflectValue)
|
||||
value, isZero := field.ValueOf(updatingValue)
|
||||
if !stmt.DisableUpdateTime {
|
||||
if field.AutoUpdateTime > 0 {
|
||||
value = stmt.DB.NowFunc()
|
||||
|
@ -187,7 +214,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
|
||||
if value, isZero := field.ValueOf(updatingValue); !isZero {
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
|
||||
}
|
||||
}
|
||||
|
@ -195,16 +222,15 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||
}
|
||||
}
|
||||
|
||||
if !stmt.ReflectValue.CanAddr() || stmt.Dest != stmt.Model {
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(stmt.Model))
|
||||
switch reflectValue.Kind() {
|
||||
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
||||
switch stmt.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
var priamryKeyExprs []clause.Expression
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
||||
var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields))
|
||||
var notZero bool
|
||||
for idx, field := range stmt.Schema.PrimaryFields {
|
||||
value, isZero := field.ValueOf(reflectValue.Index(i))
|
||||
value, isZero := field.ValueOf(stmt.ReflectValue.Index(i))
|
||||
exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
|
||||
notZero = notZero || !isZero
|
||||
}
|
||||
|
@ -215,7 +241,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(priamryKeyExprs...)}})
|
||||
case reflect.Struct:
|
||||
for _, field := range stmt.Schema.PrimaryFields {
|
||||
if value, isZero := field.ValueOf(reflectValue); !isZero {
|
||||
if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -347,6 +347,8 @@ func (field *Field) setupValuerAndSetter() {
|
|||
if v.Type().Elem().Kind() == reflect.Struct {
|
||||
if !v.IsNil() {
|
||||
v = v.Elem()
|
||||
} else {
|
||||
return nil, true
|
||||
}
|
||||
} else {
|
||||
return nil, true
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
|
||||
func AssertAssociationCount(t *testing.T, data interface{}, name string, result int64, reason string) {
|
||||
if count := DB.Model(data).Association(name).Count(); count != result {
|
||||
t.Errorf("invalid %v count %v, expects: %v got %v", name, reason, result, count)
|
||||
t.Fatalf("invalid %v count %v, expects: %v got %v", name, reason, result, count)
|
||||
}
|
||||
|
||||
var newUser User
|
||||
|
@ -20,7 +20,7 @@ func AssertAssociationCount(t *testing.T, data interface{}, name string, result
|
|||
|
||||
if newUser.ID != 0 {
|
||||
if count := DB.Model(&newUser).Association(name).Count(); count != result {
|
||||
t.Errorf("invalid %v count %v, expects: %v got %v", name, reason, result, count)
|
||||
t.Fatalf("invalid %v count %v, expects: %v got %v", name, reason, result, count)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -28,6 +28,6 @@ func AssertAssociationCount(t *testing.T, data interface{}, name string, result
|
|||
func TestInvalidAssociation(t *testing.T) {
|
||||
var user = *GetUser("invalid", Config{Company: true, Manager: true})
|
||||
if err := DB.Model(&user).Association("Invalid").Find(&user.Company).Error; err == nil {
|
||||
t.Errorf("should return errors for invalid association, but got nil")
|
||||
t.Fatalf("should return errors for invalid association, but got nil")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,12 +31,14 @@ func TestDelete(t *testing.T) {
|
|||
}
|
||||
|
||||
for _, user := range []User{users[0], users[2]} {
|
||||
result = User{}
|
||||
if err := DB.Where("id = ?", user.ID).First(&result).Error; err != nil {
|
||||
t.Errorf("no error should returns when query %v, but got %v", user.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, user := range []User{users[0], users[2]} {
|
||||
result = User{}
|
||||
if err := DB.Where("id = ?", user.ID).First(&result).Error; err != nil {
|
||||
t.Errorf("no error should returns when query %v, but got %v", user.ID, err)
|
||||
}
|
||||
|
|
|
@ -264,10 +264,12 @@ func TestSearchWithEmptyChain(t *testing.T) {
|
|||
t.Errorf("Should not raise any error if searching with empty strings")
|
||||
}
|
||||
|
||||
result = User{}
|
||||
if DB.Where(&User{}).Where("name = ?", user.Name).First(&result).Error != nil {
|
||||
t.Errorf("Should not raise any error if searching with empty struct")
|
||||
}
|
||||
|
||||
result = User{}
|
||||
if DB.Where(map[string]interface{}{}).Where("name = ?", user.Name).First(&result).Error != nil {
|
||||
t.Errorf("Should not raise any error if searching with empty map")
|
||||
}
|
||||
|
@ -319,6 +321,7 @@ func TestSearchWithMap(t *testing.T) {
|
|||
DB.First(&user, map[string]interface{}{"name": users[0].Name})
|
||||
CheckUser(t, user, users[0])
|
||||
|
||||
user = User{}
|
||||
DB.Where(map[string]interface{}{"name": users[1].Name}).First(&user)
|
||||
CheckUser(t, user, users[1])
|
||||
|
||||
|
|
|
@ -2,6 +2,8 @@ package tests_test
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -218,3 +220,304 @@ func TestBlockGlobalUpdate(t *testing.T) {
|
|||
t.Errorf("should returns missing WHERE clause while updating error, got err %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectWithUpdate(t *testing.T) {
|
||||
user := *GetUser("select_update", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4})
|
||||
DB.Create(&user)
|
||||
|
||||
var result User
|
||||
DB.First(&result, user.ID)
|
||||
|
||||
user2 := *GetUser("select_update_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4})
|
||||
result.Name = user2.Name
|
||||
result.Age = 50
|
||||
result.Account = user2.Account
|
||||
result.Pets = user2.Pets
|
||||
result.Toys = user2.Toys
|
||||
result.Company = user2.Company
|
||||
result.Manager = user2.Manager
|
||||
result.Team = user2.Team
|
||||
result.Languages = user2.Languages
|
||||
result.Friends = user2.Friends
|
||||
|
||||
DB.Select("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Save(&result)
|
||||
|
||||
var result2 User
|
||||
DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID)
|
||||
|
||||
result.Languages = append(user.Languages, result.Languages...)
|
||||
result.Toys = append(user.Toys, result.Toys...)
|
||||
|
||||
sort.Slice(result.Languages, func(i, j int) bool {
|
||||
return strings.Compare(result.Languages[i].Code, result.Languages[j].Code) > 0
|
||||
})
|
||||
|
||||
sort.Slice(result.Toys, func(i, j int) bool {
|
||||
return result.Toys[i].ID < result.Toys[j].ID
|
||||
})
|
||||
|
||||
sort.Slice(result2.Languages, func(i, j int) bool {
|
||||
return strings.Compare(result2.Languages[i].Code, result2.Languages[j].Code) > 0
|
||||
})
|
||||
|
||||
sort.Slice(result2.Toys, func(i, j int) bool {
|
||||
return result2.Toys[i].ID < result2.Toys[j].ID
|
||||
})
|
||||
|
||||
AssertObjEqual(t, result2, result, "Name", "Account", "Toys", "Manager", "ManagerID", "Languages")
|
||||
}
|
||||
|
||||
func TestSelectWithUpdateWithMap(t *testing.T) {
|
||||
user := *GetUser("select_update_map", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4})
|
||||
DB.Create(&user)
|
||||
|
||||
var result User
|
||||
DB.First(&result, user.ID)
|
||||
|
||||
user2 := *GetUser("select_update_map_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4})
|
||||
updateValues := map[string]interface{}{
|
||||
"Name": user2.Name,
|
||||
"Age": 50,
|
||||
"Account": user2.Account,
|
||||
"Pets": user2.Pets,
|
||||
"Toys": user2.Toys,
|
||||
"Company": user2.Company,
|
||||
"Manager": user2.Manager,
|
||||
"Team": user2.Team,
|
||||
"Languages": user2.Languages,
|
||||
"Friends": user2.Friends,
|
||||
}
|
||||
|
||||
DB.Model(&result).Select("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Updates(updateValues)
|
||||
|
||||
var result2 User
|
||||
DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID)
|
||||
|
||||
result.Languages = append(user.Languages, result.Languages...)
|
||||
result.Toys = append(user.Toys, result.Toys...)
|
||||
|
||||
sort.Slice(result.Languages, func(i, j int) bool {
|
||||
return strings.Compare(result.Languages[i].Code, result.Languages[j].Code) > 0
|
||||
})
|
||||
|
||||
sort.Slice(result.Toys, func(i, j int) bool {
|
||||
return result.Toys[i].ID < result.Toys[j].ID
|
||||
})
|
||||
|
||||
sort.Slice(result2.Languages, func(i, j int) bool {
|
||||
return strings.Compare(result2.Languages[i].Code, result2.Languages[j].Code) > 0
|
||||
})
|
||||
|
||||
sort.Slice(result2.Toys, func(i, j int) bool {
|
||||
return result2.Toys[i].ID < result2.Toys[j].ID
|
||||
})
|
||||
|
||||
AssertObjEqual(t, result2, result, "Name", "Account", "Toys", "Manager", "ManagerID", "Languages")
|
||||
}
|
||||
|
||||
func TestOmitWithUpdate(t *testing.T) {
|
||||
user := *GetUser("omit_update", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4})
|
||||
DB.Create(&user)
|
||||
|
||||
var result User
|
||||
DB.First(&result, user.ID)
|
||||
|
||||
user2 := *GetUser("omit_update_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4})
|
||||
result.Name = user2.Name
|
||||
result.Age = 50
|
||||
result.Account = user2.Account
|
||||
result.Pets = user2.Pets
|
||||
result.Toys = user2.Toys
|
||||
result.Company = user2.Company
|
||||
result.Manager = user2.Manager
|
||||
result.Team = user2.Team
|
||||
result.Languages = user2.Languages
|
||||
result.Friends = user2.Friends
|
||||
|
||||
DB.Omit("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Save(&result)
|
||||
|
||||
var result2 User
|
||||
DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID)
|
||||
|
||||
result.Pets = append(user.Pets, result.Pets...)
|
||||
result.Team = append(user.Team, result.Team...)
|
||||
result.Friends = append(user.Friends, result.Friends...)
|
||||
|
||||
sort.Slice(result.Pets, func(i, j int) bool {
|
||||
return result.Pets[i].ID < result.Pets[j].ID
|
||||
})
|
||||
sort.Slice(result.Team, func(i, j int) bool {
|
||||
return result.Team[i].ID < result.Team[j].ID
|
||||
})
|
||||
sort.Slice(result.Friends, func(i, j int) bool {
|
||||
return result.Friends[i].ID < result.Friends[j].ID
|
||||
})
|
||||
sort.Slice(result2.Pets, func(i, j int) bool {
|
||||
return result2.Pets[i].ID < result2.Pets[j].ID
|
||||
})
|
||||
sort.Slice(result2.Team, func(i, j int) bool {
|
||||
return result2.Team[i].ID < result2.Team[j].ID
|
||||
})
|
||||
sort.Slice(result2.Friends, func(i, j int) bool {
|
||||
return result2.Friends[i].ID < result2.Friends[j].ID
|
||||
})
|
||||
|
||||
AssertObjEqual(t, result2, result, "Age", "Pets", "Company", "CompanyID", "Team", "Friends")
|
||||
}
|
||||
|
||||
func TestOmitWithUpdateWithMap(t *testing.T) {
|
||||
user := *GetUser("omit_update_map", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4})
|
||||
DB.Create(&user)
|
||||
|
||||
var result User
|
||||
DB.First(&result, user.ID)
|
||||
|
||||
user2 := *GetUser("omit_update_map_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4})
|
||||
updateValues := map[string]interface{}{
|
||||
"Name": user2.Name,
|
||||
"Age": 50,
|
||||
"Account": user2.Account,
|
||||
"Pets": user2.Pets,
|
||||
"Toys": user2.Toys,
|
||||
"Company": user2.Company,
|
||||
"Manager": user2.Manager,
|
||||
"Team": user2.Team,
|
||||
"Languages": user2.Languages,
|
||||
"Friends": user2.Friends,
|
||||
}
|
||||
|
||||
DB.Model(&result).Omit("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Updates(updateValues)
|
||||
|
||||
var result2 User
|
||||
DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID)
|
||||
|
||||
result.Pets = append(user.Pets, result.Pets...)
|
||||
result.Team = append(user.Team, result.Team...)
|
||||
result.Friends = append(user.Friends, result.Friends...)
|
||||
|
||||
sort.Slice(result.Pets, func(i, j int) bool {
|
||||
return result.Pets[i].ID < result.Pets[j].ID
|
||||
})
|
||||
sort.Slice(result.Team, func(i, j int) bool {
|
||||
return result.Team[i].ID < result.Team[j].ID
|
||||
})
|
||||
sort.Slice(result.Friends, func(i, j int) bool {
|
||||
return result.Friends[i].ID < result.Friends[j].ID
|
||||
})
|
||||
sort.Slice(result2.Pets, func(i, j int) bool {
|
||||
return result2.Pets[i].ID < result2.Pets[j].ID
|
||||
})
|
||||
sort.Slice(result2.Team, func(i, j int) bool {
|
||||
return result2.Team[i].ID < result2.Team[j].ID
|
||||
})
|
||||
sort.Slice(result2.Friends, func(i, j int) bool {
|
||||
return result2.Friends[i].ID < result2.Friends[j].ID
|
||||
})
|
||||
|
||||
AssertObjEqual(t, result2, result, "Age", "Pets", "Company", "CompanyID", "Team", "Friends")
|
||||
}
|
||||
|
||||
func TestSelectWithUpdateColumn(t *testing.T) {
|
||||
user := *GetUser("select_with_update_column", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4})
|
||||
DB.Create(&user)
|
||||
|
||||
updateValues := map[string]interface{}{"Name": "new_name", "Age": 50}
|
||||
|
||||
var result User
|
||||
DB.First(&result, user.ID)
|
||||
DB.Model(&result).Select("Name").UpdateColumns(updateValues)
|
||||
|
||||
var result2 User
|
||||
DB.First(&result2, user.ID)
|
||||
|
||||
if result2.Name == user.Name || result2.Age != user.Age {
|
||||
t.Errorf("Should only update users with name column")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOmitWithUpdateColumn(t *testing.T) {
|
||||
user := *GetUser("omit_with_update_column", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4})
|
||||
DB.Create(&user)
|
||||
|
||||
updateValues := map[string]interface{}{"Name": "new_name", "Age": 50}
|
||||
|
||||
var result User
|
||||
DB.First(&result, user.ID)
|
||||
DB.Model(&result).Omit("Name").UpdateColumns(updateValues)
|
||||
|
||||
var result2 User
|
||||
DB.First(&result2, user.ID)
|
||||
|
||||
if result2.Name != user.Name || result2.Age == user.Age {
|
||||
t.Errorf("Should only update users with name column")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateColumnsSkipsAssociations(t *testing.T) {
|
||||
user := *GetUser("update_column_skips_association", Config{})
|
||||
DB.Create(&user)
|
||||
|
||||
// Update a single field of the user and verify that the changed address is not stored.
|
||||
newAge := uint(100)
|
||||
user.Account.Number = "new_account_number"
|
||||
db := DB.Model(&user).UpdateColumns(User{Age: newAge})
|
||||
|
||||
if db.RowsAffected != 1 {
|
||||
t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", db.RowsAffected)
|
||||
}
|
||||
|
||||
// Verify that Age now=`newAge`.
|
||||
result := &User{}
|
||||
result.ID = user.ID
|
||||
DB.Preload("Account").First(result)
|
||||
|
||||
if result.Age != newAge {
|
||||
t.Errorf("Expected freshly queried user to have Age=%v but instead found Age=%v", newAge, result.Age)
|
||||
}
|
||||
|
||||
if result.Account.Number != user.Account.Number {
|
||||
t.Errorf("account number should not been changed, expects: %v, got %v", user.Account.Number, result.Account.Number)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdatesWithBlankValues(t *testing.T) {
|
||||
user := *GetUser("updates_with_blank_value", Config{})
|
||||
DB.Save(&user)
|
||||
|
||||
var user2 User
|
||||
user2.ID = user.ID
|
||||
DB.Model(&user2).Updates(&User{Age: 100})
|
||||
|
||||
var result User
|
||||
DB.First(&result, user.ID)
|
||||
|
||||
if result.Name != user.Name || result.Age != 100 {
|
||||
t.Errorf("user's name should not be updated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdatesTableWithIgnoredValues(t *testing.T) {
|
||||
type ElementWithIgnoredField struct {
|
||||
Id int64
|
||||
Value string
|
||||
IgnoredField int64 `gorm:"-"`
|
||||
}
|
||||
DB.Migrator().DropTable(&ElementWithIgnoredField{})
|
||||
DB.AutoMigrate(&ElementWithIgnoredField{})
|
||||
|
||||
elem := ElementWithIgnoredField{Value: "foo", IgnoredField: 10}
|
||||
DB.Save(&elem)
|
||||
|
||||
DB.Model(&ElementWithIgnoredField{}).
|
||||
Where("id = ?", elem.Id).
|
||||
Updates(&ElementWithIgnoredField{Value: "bar", IgnoredField: 100})
|
||||
|
||||
var result ElementWithIgnoredField
|
||||
if err := DB.First(&result, elem.Id).Error; err != nil {
|
||||
t.Errorf("error getting an element from database: %s", err.Error())
|
||||
}
|
||||
|
||||
if result.IgnoredField != 0 {
|
||||
t.Errorf("element's ignored field should not be updated")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package tests
|
|||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
|
@ -126,6 +127,37 @@ func AssertEqual(t *testing.T, got, expect interface{}) {
|
|||
return
|
||||
}
|
||||
|
||||
if reflect.ValueOf(got).Kind() == reflect.Slice {
|
||||
if reflect.ValueOf(expect).Kind() == reflect.Slice {
|
||||
if reflect.ValueOf(got).Len() == reflect.ValueOf(expect).Len() {
|
||||
for i := 0; i < reflect.ValueOf(got).Len(); i++ {
|
||||
name := fmt.Sprintf(reflect.ValueOf(got).Type().Name()+" #%v", i)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
AssertEqual(t, reflect.ValueOf(got).Index(i).Interface(), reflect.ValueOf(expect).Index(i).Interface())
|
||||
})
|
||||
}
|
||||
} else {
|
||||
name := reflect.ValueOf(got).Type().Elem().Name()
|
||||
t.Errorf("%v expects length: %v, got %v", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len())
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if reflect.ValueOf(got).Kind() == reflect.Struct {
|
||||
if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() {
|
||||
for i := 0; i < reflect.ValueOf(got).NumField(); i++ {
|
||||
if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) {
|
||||
field := reflect.ValueOf(got).Field(i)
|
||||
t.Run(fieldStruct.Name, func(t *testing.T) {
|
||||
AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface())
|
||||
})
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) {
|
||||
got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface()
|
||||
isEqual()
|
||||
|
|
Loading…
Reference in New Issue