Fix call hook methods when updating with struct

This commit is contained in:
Jinzhu 2020-10-27 18:14:36 +08:00
parent d011ebe7af
commit 4009ec5816
4 changed files with 44 additions and 12 deletions

View File

@ -8,7 +8,7 @@ import (
func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) { func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) {
tx := db.Session(&gorm.Session{}) tx := db.Session(&gorm.Session{})
if called := fc(db.Statement.Dest, tx); !called { if called := fc(db.Statement.ReflectValue.Interface(), tx); !called {
switch db.Statement.ReflectValue.Kind() { switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
db.Statement.CurDestIndex = 0 db.Statement.CurDestIndex = 0

View File

@ -451,6 +451,27 @@ func (stmt *Statement) SetColumn(name string, value interface{}) {
v[name] = value v[name] = value
} else if stmt.Schema != nil { } else if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(name); field != nil { if field := stmt.Schema.LookUpField(name); field != nil {
destValue := reflect.ValueOf(stmt.Dest)
for destValue.Kind() == reflect.Ptr {
destValue = destValue.Elem()
}
if stmt.ReflectValue != destValue {
if !destValue.CanAddr() {
destValueCanAddr := reflect.New(destValue.Type())
destValueCanAddr.Elem().Set(destValue)
stmt.Dest = destValueCanAddr.Interface()
destValue = destValueCanAddr.Elem()
}
switch destValue.Kind() {
case reflect.Struct:
field.Set(destValue, value)
default:
stmt.AddError(ErrInvalidData)
}
}
switch stmt.ReflectValue.Kind() { switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value) field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value)
@ -467,11 +488,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}) {
// Changed check model changed or not when updating // Changed check model changed or not when updating
func (stmt *Statement) Changed(fields ...string) bool { func (stmt *Statement) Changed(fields ...string) bool {
modelValue := reflect.ValueOf(stmt.Model) modelValue := stmt.ReflectValue
for modelValue.Kind() == reflect.Ptr {
modelValue = modelValue.Elem()
}
switch modelValue.Kind() { switch modelValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex) modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex)
@ -488,8 +505,13 @@ func (stmt *Statement) Changed(fields ...string) bool {
return !utils.AssertEqual(fv, fieldValue) return !utils.AssertEqual(fv, fieldValue)
} }
} else { } else {
changedValue, _ := field.ValueOf(stmt.ReflectValue) destValue := reflect.ValueOf(stmt.Dest)
return !utils.AssertEqual(changedValue, fieldValue) for destValue.Kind() == reflect.Ptr {
destValue = destValue.Elem()
}
changedValue, zero := field.ValueOf(destValue)
return !zero && !utils.AssertEqual(changedValue, fieldValue)
} }
} }
return false return false

View File

@ -10,7 +10,7 @@ require (
gorm.io/driver/postgres v1.0.5 gorm.io/driver/postgres v1.0.5
gorm.io/driver/sqlite v1.1.3 gorm.io/driver/sqlite v1.1.3
gorm.io/driver/sqlserver v1.0.5 gorm.io/driver/sqlserver v1.0.5
gorm.io/gorm v1.20.4 gorm.io/gorm v1.20.5
) )
replace gorm.io/gorm => ../ replace gorm.io/gorm => ../

View File

@ -354,10 +354,20 @@ func TestSetColumn(t *testing.T) {
AssertEqual(t, result, product) AssertEqual(t, result, product)
// Code changed, price not selected, price should not change // Select to change Code, but nothing updated, price should not change
DB.Model(&product).Select("code").Updates(map[string]interface{}{"name": "L1214"}) DB.Model(&product).Select("code").Updates(Product3{Name: "L1214", Code: "L1213"})
if product.Price != 220 || product.Code != "L1213" { if product.Price != 220 || product.Code != "L1213" || product.Name != "Product New3" {
t.Errorf("invalid data after update, got %+v", product)
}
DB.Model(&product).Updates(Product3{Code: "L1214"})
if product.Price != 270 || product.Code != "L1214" {
t.Errorf("invalid data after update, got %+v", product)
}
DB.Model(&product).UpdateColumns(Product3{Code: "L1215"})
if product.Price != 270 || product.Code != "L1215" {
t.Errorf("invalid data after update, got %+v", product) t.Errorf("invalid data after update, got %+v", product)
} }