From 4009ec58163b97294633edc19f5d792546cd612c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 27 Oct 2020 18:14:36 +0800 Subject: [PATCH] Fix call hook methods when updating with struct --- callbacks/callmethod.go | 2 +- statement.go | 36 +++++++++++++++++++++++++++++------- tests/go.mod | 2 +- tests/hooks_test.go | 16 +++++++++++++--- 4 files changed, 44 insertions(+), 12 deletions(-) diff --git a/callbacks/callmethod.go b/callbacks/callmethod.go index 0160f354..b81fc915 100644 --- a/callbacks/callmethod.go +++ b/callbacks/callmethod.go @@ -8,7 +8,7 @@ import ( func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) { tx := db.Session(&gorm.Session{}) - if called := fc(db.Statement.Dest, tx); !called { + if called := fc(db.Statement.ReflectValue.Interface(), tx); !called { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: db.Statement.CurDestIndex = 0 diff --git a/statement.go b/statement.go index 567df869..82ebdd91 100644 --- a/statement.go +++ b/statement.go @@ -451,6 +451,27 @@ func (stmt *Statement) SetColumn(name string, value interface{}) { v[name] = value } else if stmt.Schema != nil { if field := stmt.Schema.LookUpField(name); field != nil { + destValue := reflect.ValueOf(stmt.Dest) + for destValue.Kind() == reflect.Ptr { + destValue = destValue.Elem() + } + + if stmt.ReflectValue != destValue { + if !destValue.CanAddr() { + destValueCanAddr := reflect.New(destValue.Type()) + destValueCanAddr.Elem().Set(destValue) + stmt.Dest = destValueCanAddr.Interface() + destValue = destValueCanAddr.Elem() + } + + switch destValue.Kind() { + case reflect.Struct: + field.Set(destValue, value) + default: + stmt.AddError(ErrInvalidData) + } + } + switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value) @@ -467,11 +488,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}) { // Changed check model changed or not when updating func (stmt *Statement) Changed(fields ...string) bool { - modelValue := reflect.ValueOf(stmt.Model) - for modelValue.Kind() == reflect.Ptr { - modelValue = modelValue.Elem() - } - + modelValue := stmt.ReflectValue switch modelValue.Kind() { case reflect.Slice, reflect.Array: modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex) @@ -488,8 +505,13 @@ func (stmt *Statement) Changed(fields ...string) bool { return !utils.AssertEqual(fv, fieldValue) } } else { - changedValue, _ := field.ValueOf(stmt.ReflectValue) - return !utils.AssertEqual(changedValue, fieldValue) + destValue := reflect.ValueOf(stmt.Dest) + for destValue.Kind() == reflect.Ptr { + destValue = destValue.Elem() + } + + changedValue, zero := field.ValueOf(destValue) + return !zero && !utils.AssertEqual(changedValue, fieldValue) } } return false diff --git a/tests/go.mod b/tests/go.mod index 3fa011f1..55495de3 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -10,7 +10,7 @@ require ( gorm.io/driver/postgres v1.0.5 gorm.io/driver/sqlite v1.1.3 gorm.io/driver/sqlserver v1.0.5 - gorm.io/gorm v1.20.4 + gorm.io/gorm v1.20.5 ) replace gorm.io/gorm => ../ diff --git a/tests/hooks_test.go b/tests/hooks_test.go index 3612857b..d8b1770e 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -354,10 +354,20 @@ func TestSetColumn(t *testing.T) { AssertEqual(t, result, product) - // Code changed, price not selected, price should not change - DB.Model(&product).Select("code").Updates(map[string]interface{}{"name": "L1214"}) + // Select to change Code, but nothing updated, price should not change + DB.Model(&product).Select("code").Updates(Product3{Name: "L1214", Code: "L1213"}) - if product.Price != 220 || product.Code != "L1213" { + if product.Price != 220 || product.Code != "L1213" || product.Name != "Product New3" { + t.Errorf("invalid data after update, got %+v", product) + } + + DB.Model(&product).Updates(Product3{Code: "L1214"}) + if product.Price != 270 || product.Code != "L1214" { + t.Errorf("invalid data after update, got %+v", product) + } + + DB.Model(&product).UpdateColumns(Product3{Code: "L1215"}) + if product.Price != 270 || product.Code != "L1215" { t.Errorf("invalid data after update, got %+v", product) }