mirror of https://github.com/go-gorm/gorm.git
Fix call hook methods when updating with struct
This commit is contained in:
parent
d011ebe7af
commit
4009ec5816
|
@ -8,7 +8,7 @@ import (
|
||||||
|
|
||||||
func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) {
|
func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) {
|
||||||
tx := db.Session(&gorm.Session{})
|
tx := db.Session(&gorm.Session{})
|
||||||
if called := fc(db.Statement.Dest, tx); !called {
|
if called := fc(db.Statement.ReflectValue.Interface(), tx); !called {
|
||||||
switch db.Statement.ReflectValue.Kind() {
|
switch db.Statement.ReflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
db.Statement.CurDestIndex = 0
|
db.Statement.CurDestIndex = 0
|
||||||
|
|
36
statement.go
36
statement.go
|
@ -451,6 +451,27 @@ func (stmt *Statement) SetColumn(name string, value interface{}) {
|
||||||
v[name] = value
|
v[name] = value
|
||||||
} else if stmt.Schema != nil {
|
} else if stmt.Schema != nil {
|
||||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||||
|
destValue := reflect.ValueOf(stmt.Dest)
|
||||||
|
for destValue.Kind() == reflect.Ptr {
|
||||||
|
destValue = destValue.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if stmt.ReflectValue != destValue {
|
||||||
|
if !destValue.CanAddr() {
|
||||||
|
destValueCanAddr := reflect.New(destValue.Type())
|
||||||
|
destValueCanAddr.Elem().Set(destValue)
|
||||||
|
stmt.Dest = destValueCanAddr.Interface()
|
||||||
|
destValue = destValueCanAddr.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
switch destValue.Kind() {
|
||||||
|
case reflect.Struct:
|
||||||
|
field.Set(destValue, value)
|
||||||
|
default:
|
||||||
|
stmt.AddError(ErrInvalidData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
switch stmt.ReflectValue.Kind() {
|
switch stmt.ReflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value)
|
field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value)
|
||||||
|
@ -467,11 +488,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}) {
|
||||||
|
|
||||||
// Changed check model changed or not when updating
|
// Changed check model changed or not when updating
|
||||||
func (stmt *Statement) Changed(fields ...string) bool {
|
func (stmt *Statement) Changed(fields ...string) bool {
|
||||||
modelValue := reflect.ValueOf(stmt.Model)
|
modelValue := stmt.ReflectValue
|
||||||
for modelValue.Kind() == reflect.Ptr {
|
|
||||||
modelValue = modelValue.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
switch modelValue.Kind() {
|
switch modelValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex)
|
modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex)
|
||||||
|
@ -488,8 +505,13 @@ func (stmt *Statement) Changed(fields ...string) bool {
|
||||||
return !utils.AssertEqual(fv, fieldValue)
|
return !utils.AssertEqual(fv, fieldValue)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
changedValue, _ := field.ValueOf(stmt.ReflectValue)
|
destValue := reflect.ValueOf(stmt.Dest)
|
||||||
return !utils.AssertEqual(changedValue, fieldValue)
|
for destValue.Kind() == reflect.Ptr {
|
||||||
|
destValue = destValue.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
changedValue, zero := field.ValueOf(destValue)
|
||||||
|
return !zero && !utils.AssertEqual(changedValue, fieldValue)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
|
|
|
@ -10,7 +10,7 @@ require (
|
||||||
gorm.io/driver/postgres v1.0.5
|
gorm.io/driver/postgres v1.0.5
|
||||||
gorm.io/driver/sqlite v1.1.3
|
gorm.io/driver/sqlite v1.1.3
|
||||||
gorm.io/driver/sqlserver v1.0.5
|
gorm.io/driver/sqlserver v1.0.5
|
||||||
gorm.io/gorm v1.20.4
|
gorm.io/gorm v1.20.5
|
||||||
)
|
)
|
||||||
|
|
||||||
replace gorm.io/gorm => ../
|
replace gorm.io/gorm => ../
|
||||||
|
|
|
@ -354,10 +354,20 @@ func TestSetColumn(t *testing.T) {
|
||||||
|
|
||||||
AssertEqual(t, result, product)
|
AssertEqual(t, result, product)
|
||||||
|
|
||||||
// Code changed, price not selected, price should not change
|
// Select to change Code, but nothing updated, price should not change
|
||||||
DB.Model(&product).Select("code").Updates(map[string]interface{}{"name": "L1214"})
|
DB.Model(&product).Select("code").Updates(Product3{Name: "L1214", Code: "L1213"})
|
||||||
|
|
||||||
if product.Price != 220 || product.Code != "L1213" {
|
if product.Price != 220 || product.Code != "L1213" || product.Name != "Product New3" {
|
||||||
|
t.Errorf("invalid data after update, got %+v", product)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&product).Updates(Product3{Code: "L1214"})
|
||||||
|
if product.Price != 270 || product.Code != "L1214" {
|
||||||
|
t.Errorf("invalid data after update, got %+v", product)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&product).UpdateColumns(Product3{Code: "L1215"})
|
||||||
|
if product.Price != 270 || product.Code != "L1215" {
|
||||||
t.Errorf("invalid data after update, got %+v", product)
|
t.Errorf("invalid data after update, got %+v", product)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue