Fix call hook methods when updating with struct

This commit is contained in:
Jinzhu 2020-10-27 18:14:36 +08:00
parent d011ebe7af
commit 4009ec5816
4 changed files with 44 additions and 12 deletions

View File

@ -8,7 +8,7 @@ import (
func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) {
tx := db.Session(&gorm.Session{})
if called := fc(db.Statement.Dest, tx); !called {
if called := fc(db.Statement.ReflectValue.Interface(), tx); !called {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
db.Statement.CurDestIndex = 0

View File

@ -451,6 +451,27 @@ func (stmt *Statement) SetColumn(name string, value interface{}) {
v[name] = value
} else if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(name); field != nil {
destValue := reflect.ValueOf(stmt.Dest)
for destValue.Kind() == reflect.Ptr {
destValue = destValue.Elem()
}
if stmt.ReflectValue != destValue {
if !destValue.CanAddr() {
destValueCanAddr := reflect.New(destValue.Type())
destValueCanAddr.Elem().Set(destValue)
stmt.Dest = destValueCanAddr.Interface()
destValue = destValueCanAddr.Elem()
}
switch destValue.Kind() {
case reflect.Struct:
field.Set(destValue, value)
default:
stmt.AddError(ErrInvalidData)
}
}
switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value)
@ -467,11 +488,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}) {
// Changed check model changed or not when updating
func (stmt *Statement) Changed(fields ...string) bool {
modelValue := reflect.ValueOf(stmt.Model)
for modelValue.Kind() == reflect.Ptr {
modelValue = modelValue.Elem()
}
modelValue := stmt.ReflectValue
switch modelValue.Kind() {
case reflect.Slice, reflect.Array:
modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex)
@ -488,8 +505,13 @@ func (stmt *Statement) Changed(fields ...string) bool {
return !utils.AssertEqual(fv, fieldValue)
}
} else {
changedValue, _ := field.ValueOf(stmt.ReflectValue)
return !utils.AssertEqual(changedValue, fieldValue)
destValue := reflect.ValueOf(stmt.Dest)
for destValue.Kind() == reflect.Ptr {
destValue = destValue.Elem()
}
changedValue, zero := field.ValueOf(destValue)
return !zero && !utils.AssertEqual(changedValue, fieldValue)
}
}
return false

View File

@ -10,7 +10,7 @@ require (
gorm.io/driver/postgres v1.0.5
gorm.io/driver/sqlite v1.1.3
gorm.io/driver/sqlserver v1.0.5
gorm.io/gorm v1.20.4
gorm.io/gorm v1.20.5
)
replace gorm.io/gorm => ../

View File

@ -354,10 +354,20 @@ func TestSetColumn(t *testing.T) {
AssertEqual(t, result, product)
// Code changed, price not selected, price should not change
DB.Model(&product).Select("code").Updates(map[string]interface{}{"name": "L1214"})
// Select to change Code, but nothing updated, price should not change
DB.Model(&product).Select("code").Updates(Product3{Name: "L1214", Code: "L1213"})
if product.Price != 220 || product.Code != "L1213" {
if product.Price != 220 || product.Code != "L1213" || product.Name != "Product New3" {
t.Errorf("invalid data after update, got %+v", product)
}
DB.Model(&product).Updates(Product3{Code: "L1214"})
if product.Price != 270 || product.Code != "L1214" {
t.Errorf("invalid data after update, got %+v", product)
}
DB.Model(&product).UpdateColumns(Product3{Code: "L1215"})
if product.Price != 270 || product.Code != "L1215" {
t.Errorf("invalid data after update, got %+v", product)
}