forked from mirror/gorm
Test Hooks For Slice
This commit is contained in:
parent
66dcd7e3ca
commit
3e4dbde920
|
@ -11,8 +11,10 @@ func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) {
|
||||||
if called := fc(db.Statement.Dest, tx); !called {
|
if called := fc(db.Statement.Dest, tx); !called {
|
||||||
switch db.Statement.ReflectValue.Kind() {
|
switch db.Statement.ReflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
|
db.Statement.CurDestIndex = 0
|
||||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||||
fc(db.Statement.ReflectValue.Index(i).Addr().Interface(), tx)
|
fc(reflect.Indirect(db.Statement.ReflectValue.Index(i)).Addr().Interface(), tx)
|
||||||
|
db.Statement.CurDestIndex++
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
fc(db.Statement.ReflectValue.Addr().Interface(), tx)
|
fc(db.Statement.ReflectValue.Addr().Interface(), tx)
|
||||||
|
|
17
statement.go
17
statement.go
|
@ -38,6 +38,7 @@ type Statement struct {
|
||||||
SQL strings.Builder
|
SQL strings.Builder
|
||||||
Vars []interface{}
|
Vars []interface{}
|
||||||
NamedVars []sql.NamedArg
|
NamedVars []sql.NamedArg
|
||||||
|
CurDestIndex int
|
||||||
attrs []interface{}
|
attrs []interface{}
|
||||||
assigns []interface{}
|
assigns []interface{}
|
||||||
}
|
}
|
||||||
|
@ -379,7 +380,12 @@ func (stmt *Statement) SetColumn(name string, value interface{}) {
|
||||||
v[name] = value
|
v[name] = value
|
||||||
} else if stmt.Schema != nil {
|
} else if stmt.Schema != nil {
|
||||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||||
field.Set(stmt.ReflectValue, value)
|
switch stmt.ReflectValue.Kind() {
|
||||||
|
case reflect.Slice, reflect.Array:
|
||||||
|
field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value)
|
||||||
|
case reflect.Struct:
|
||||||
|
field.Set(stmt.ReflectValue, value)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
stmt.AddError(ErrInvalidField)
|
stmt.AddError(ErrInvalidField)
|
||||||
}
|
}
|
||||||
|
@ -395,17 +401,20 @@ func (stmt *Statement) Changed(fields ...string) bool {
|
||||||
modelValue = modelValue.Elem()
|
modelValue = modelValue.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
switch modelValue.Kind() {
|
||||||
|
case reflect.Slice, reflect.Array:
|
||||||
|
modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex)
|
||||||
|
}
|
||||||
|
|
||||||
selectColumns, restricted := stmt.SelectAndOmitColumns(false, true)
|
selectColumns, restricted := stmt.SelectAndOmitColumns(false, true)
|
||||||
changed := func(field *schema.Field) bool {
|
changed := func(field *schema.Field) bool {
|
||||||
fieldValue, isZero := field.ValueOf(modelValue)
|
fieldValue, _ := field.ValueOf(modelValue)
|
||||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||||
if v, ok := stmt.Dest.(map[string]interface{}); ok {
|
if v, ok := stmt.Dest.(map[string]interface{}); ok {
|
||||||
if fv, ok := v[field.Name]; ok {
|
if fv, ok := v[field.Name]; ok {
|
||||||
return !utils.AssertEqual(fv, fieldValue)
|
return !utils.AssertEqual(fv, fieldValue)
|
||||||
} else if fv, ok := v[field.DBName]; ok {
|
} else if fv, ok := v[field.DBName]; ok {
|
||||||
return !utils.AssertEqual(fv, fieldValue)
|
return !utils.AssertEqual(fv, fieldValue)
|
||||||
} else if isZero {
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
changedValue, _ := field.ValueOf(stmt.ReflectValue)
|
changedValue, _ := field.ValueOf(stmt.ReflectValue)
|
||||||
|
|
|
@ -366,3 +366,51 @@ func TestSetColumn(t *testing.T) {
|
||||||
|
|
||||||
AssertEqual(t, result2, product)
|
AssertEqual(t, result2, product)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHooksForSlice(t *testing.T) {
|
||||||
|
products := []*Product3{
|
||||||
|
{Name: "Product-1", Price: 100},
|
||||||
|
{Name: "Product-2", Price: 200},
|
||||||
|
{Name: "Product-3", Price: 300},
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Create(&products)
|
||||||
|
|
||||||
|
for idx, value := range []int64{200, 300, 400} {
|
||||||
|
if products[idx].Price != value {
|
||||||
|
t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products[idx].Price)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&products).Update("Name", "product-name")
|
||||||
|
|
||||||
|
// will set all product's price to last product's price + 10
|
||||||
|
for idx, value := range []int64{410, 410, 410} {
|
||||||
|
if products[idx].Price != value {
|
||||||
|
t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products[idx].Price)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
products2 := []Product3{
|
||||||
|
{Name: "Product-1", Price: 100},
|
||||||
|
{Name: "Product-2", Price: 200},
|
||||||
|
{Name: "Product-3", Price: 300},
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Create(&products2)
|
||||||
|
|
||||||
|
for idx, value := range []int64{200, 300, 400} {
|
||||||
|
if products2[idx].Price != value {
|
||||||
|
t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products2[idx].Price)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&products2).Update("Name", "product-name")
|
||||||
|
|
||||||
|
// will set all product's price to last product's price + 10
|
||||||
|
for idx, value := range []int64{410, 410, 410} {
|
||||||
|
if products2[idx].Price != value {
|
||||||
|
t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products2[idx].Price)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue