diff --git a/callbacks/callmethod.go b/callbacks/callmethod.go index a0e9b0e7..0160f354 100644 --- a/callbacks/callmethod.go +++ b/callbacks/callmethod.go @@ -11,8 +11,10 @@ func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) { if called := fc(db.Statement.Dest, tx); !called { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: + db.Statement.CurDestIndex = 0 for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - fc(db.Statement.ReflectValue.Index(i).Addr().Interface(), tx) + fc(reflect.Indirect(db.Statement.ReflectValue.Index(i)).Addr().Interface(), tx) + db.Statement.CurDestIndex++ } case reflect.Struct: fc(db.Statement.ReflectValue.Addr().Interface(), tx) diff --git a/statement.go b/statement.go index 164ddbd7..e65a064f 100644 --- a/statement.go +++ b/statement.go @@ -38,6 +38,7 @@ type Statement struct { SQL strings.Builder Vars []interface{} NamedVars []sql.NamedArg + CurDestIndex int attrs []interface{} assigns []interface{} } @@ -379,7 +380,12 @@ func (stmt *Statement) SetColumn(name string, value interface{}) { v[name] = value } else if stmt.Schema != nil { if field := stmt.Schema.LookUpField(name); field != nil { - field.Set(stmt.ReflectValue, value) + switch stmt.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value) + case reflect.Struct: + field.Set(stmt.ReflectValue, value) + } } else { stmt.AddError(ErrInvalidField) } @@ -395,17 +401,20 @@ func (stmt *Statement) Changed(fields ...string) bool { modelValue = modelValue.Elem() } + switch modelValue.Kind() { + case reflect.Slice, reflect.Array: + modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex) + } + selectColumns, restricted := stmt.SelectAndOmitColumns(false, true) changed := func(field *schema.Field) bool { - fieldValue, isZero := field.ValueOf(modelValue) + fieldValue, _ := field.ValueOf(modelValue) if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, ok := stmt.Dest.(map[string]interface{}); ok { if fv, ok := v[field.Name]; ok { return !utils.AssertEqual(fv, fieldValue) } else if fv, ok := v[field.DBName]; ok { return !utils.AssertEqual(fv, fieldValue) - } else if isZero { - return true } } else { changedValue, _ := field.ValueOf(stmt.ReflectValue) diff --git a/tests/hooks_test.go b/tests/hooks_test.go index 8f8c60f5..ed5ee746 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -366,3 +366,51 @@ func TestSetColumn(t *testing.T) { AssertEqual(t, result2, product) } + +func TestHooksForSlice(t *testing.T) { + products := []*Product3{ + {Name: "Product-1", Price: 100}, + {Name: "Product-2", Price: 200}, + {Name: "Product-3", Price: 300}, + } + + DB.Create(&products) + + for idx, value := range []int64{200, 300, 400} { + if products[idx].Price != value { + t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products[idx].Price) + } + } + + DB.Model(&products).Update("Name", "product-name") + + // will set all product's price to last product's price + 10 + for idx, value := range []int64{410, 410, 410} { + if products[idx].Price != value { + t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products[idx].Price) + } + } + + products2 := []Product3{ + {Name: "Product-1", Price: 100}, + {Name: "Product-2", Price: 200}, + {Name: "Product-3", Price: 300}, + } + + DB.Create(&products2) + + for idx, value := range []int64{200, 300, 400} { + if products2[idx].Price != value { + t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products2[idx].Price) + } + } + + DB.Model(&products2).Update("Name", "product-name") + + // will set all product's price to last product's price + 10 + for idx, value := range []int64{410, 410, 410} { + if products2[idx].Price != value { + t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products2[idx].Price) + } + } +}