Test Hooks For Slice

This commit is contained in:
Jinzhu 2020-06-30 22:47:21 +08:00
parent 66dcd7e3ca
commit 3e4dbde920
3 changed files with 64 additions and 5 deletions

View File

@ -11,8 +11,10 @@ func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) {
if called := fc(db.Statement.Dest, tx); !called { if called := fc(db.Statement.Dest, tx); !called {
switch db.Statement.ReflectValue.Kind() { switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
db.Statement.CurDestIndex = 0
for i := 0; i < db.Statement.ReflectValue.Len(); i++ { for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
fc(db.Statement.ReflectValue.Index(i).Addr().Interface(), tx) fc(reflect.Indirect(db.Statement.ReflectValue.Index(i)).Addr().Interface(), tx)
db.Statement.CurDestIndex++
} }
case reflect.Struct: case reflect.Struct:
fc(db.Statement.ReflectValue.Addr().Interface(), tx) fc(db.Statement.ReflectValue.Addr().Interface(), tx)

View File

@ -38,6 +38,7 @@ type Statement struct {
SQL strings.Builder SQL strings.Builder
Vars []interface{} Vars []interface{}
NamedVars []sql.NamedArg NamedVars []sql.NamedArg
CurDestIndex int
attrs []interface{} attrs []interface{}
assigns []interface{} assigns []interface{}
} }
@ -379,7 +380,12 @@ func (stmt *Statement) SetColumn(name string, value interface{}) {
v[name] = value v[name] = value
} else if stmt.Schema != nil { } else if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(name); field != nil { if field := stmt.Schema.LookUpField(name); field != nil {
switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value)
case reflect.Struct:
field.Set(stmt.ReflectValue, value) field.Set(stmt.ReflectValue, value)
}
} else { } else {
stmt.AddError(ErrInvalidField) stmt.AddError(ErrInvalidField)
} }
@ -395,17 +401,20 @@ func (stmt *Statement) Changed(fields ...string) bool {
modelValue = modelValue.Elem() modelValue = modelValue.Elem()
} }
switch modelValue.Kind() {
case reflect.Slice, reflect.Array:
modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex)
}
selectColumns, restricted := stmt.SelectAndOmitColumns(false, true) selectColumns, restricted := stmt.SelectAndOmitColumns(false, true)
changed := func(field *schema.Field) bool { changed := func(field *schema.Field) bool {
fieldValue, isZero := field.ValueOf(modelValue) fieldValue, _ := field.ValueOf(modelValue)
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if v, ok := stmt.Dest.(map[string]interface{}); ok { if v, ok := stmt.Dest.(map[string]interface{}); ok {
if fv, ok := v[field.Name]; ok { if fv, ok := v[field.Name]; ok {
return !utils.AssertEqual(fv, fieldValue) return !utils.AssertEqual(fv, fieldValue)
} else if fv, ok := v[field.DBName]; ok { } else if fv, ok := v[field.DBName]; ok {
return !utils.AssertEqual(fv, fieldValue) return !utils.AssertEqual(fv, fieldValue)
} else if isZero {
return true
} }
} else { } else {
changedValue, _ := field.ValueOf(stmt.ReflectValue) changedValue, _ := field.ValueOf(stmt.ReflectValue)

View File

@ -366,3 +366,51 @@ func TestSetColumn(t *testing.T) {
AssertEqual(t, result2, product) AssertEqual(t, result2, product)
} }
func TestHooksForSlice(t *testing.T) {
products := []*Product3{
{Name: "Product-1", Price: 100},
{Name: "Product-2", Price: 200},
{Name: "Product-3", Price: 300},
}
DB.Create(&products)
for idx, value := range []int64{200, 300, 400} {
if products[idx].Price != value {
t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products[idx].Price)
}
}
DB.Model(&products).Update("Name", "product-name")
// will set all product's price to last product's price + 10
for idx, value := range []int64{410, 410, 410} {
if products[idx].Price != value {
t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products[idx].Price)
}
}
products2 := []Product3{
{Name: "Product-1", Price: 100},
{Name: "Product-2", Price: 200},
{Name: "Product-3", Price: 300},
}
DB.Create(&products2)
for idx, value := range []int64{200, 300, 400} {
if products2[idx].Price != value {
t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products2[idx].Price)
}
}
DB.Model(&products2).Update("Name", "product-name")
// will set all product's price to last product's price + 10
for idx, value := range []int64{410, 410, 410} {
if products2[idx].Price != value {
t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products2[idx].Price)
}
}
}