diff --git a/callbacks/create.go b/callbacks/create.go index 0b88e263..99140612 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -12,31 +12,31 @@ func BeforeCreate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { tx := db.Session(&gorm.Session{}) callMethod := func(value interface{}) bool { - var ok bool + var called bool if db.Statement.Schema.BeforeSave { if i, ok := value.(gorm.BeforeSaveInterface); ok { - ok = true + called = true db.AddError(i.BeforeSave(tx)) } } if db.Statement.Schema.BeforeCreate { if i, ok := value.(gorm.BeforeCreateInterface); ok { - ok = true + called = true db.AddError(i.BeforeCreate(tx)) } } - return ok + return called } if ok := callMethod(db.Statement.Dest); !ok { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - callMethod(db.Statement.ReflectValue.Index(i).Interface()) + callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface()) } case reflect.Struct: - callMethod(db.Statement.ReflectValue.Interface()) + callMethod(db.Statement.ReflectValue.Addr().Interface()) } } } @@ -184,31 +184,31 @@ func AfterCreate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { tx := db.Session(&gorm.Session{}) callMethod := func(value interface{}) bool { - var ok bool + var called bool if db.Statement.Schema.AfterSave { if i, ok := value.(gorm.AfterSaveInterface); ok { - ok = true + called = true db.AddError(i.AfterSave(tx)) } } if db.Statement.Schema.AfterCreate { if i, ok := value.(gorm.AfterCreateInterface); ok { - ok = true + called = true db.AddError(i.AfterCreate(tx)) } } - return ok + return called } if ok := callMethod(db.Statement.Dest); !ok { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - callMethod(db.Statement.ReflectValue.Index(i).Interface()) + callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface()) } case reflect.Struct: - callMethod(db.Statement.ReflectValue.Interface()) + callMethod(db.Statement.ReflectValue.Addr().Interface()) } } } diff --git a/callbacks/delete.go b/callbacks/delete.go index b8691ff9..f1a49c11 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -25,10 +25,10 @@ func BeforeDelete(db *gorm.DB) { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - callMethod(db.Statement.ReflectValue.Index(i).Interface()) + callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface()) } case reflect.Struct: - callMethod(db.Statement.ReflectValue.Interface()) + callMethod(db.Statement.ReflectValue.Addr().Interface()) } } } @@ -101,10 +101,10 @@ func AfterDelete(db *gorm.DB) { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - callMethod(db.Statement.ReflectValue.Index(i).Interface()) + callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface()) } case reflect.Struct: - callMethod(db.Statement.ReflectValue.Interface()) + callMethod(db.Statement.ReflectValue.Addr().Interface()) } } } diff --git a/callbacks/query.go b/callbacks/query.go index 16202187..b6667414 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -203,10 +203,10 @@ func AfterQuery(db *gorm.DB) { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - callMethod(db.Statement.ReflectValue.Index(i).Interface()) + callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface()) } case reflect.Struct: - callMethod(db.Statement.ReflectValue.Interface()) + callMethod(db.Statement.ReflectValue.Addr().Interface()) } } } diff --git a/callbacks/update.go b/callbacks/update.go index 2589370f..9c922956 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -29,34 +29,34 @@ func SetupUpdateReflectValue(db *gorm.DB) { } func BeforeUpdate(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { tx := db.Session(&gorm.Session{}) callMethod := func(value interface{}) bool { - var ok bool + var called bool if db.Statement.Schema.BeforeSave { if i, ok := value.(gorm.BeforeSaveInterface); ok { - ok = true + called = true db.AddError(i.BeforeSave(tx)) } } if db.Statement.Schema.BeforeUpdate { if i, ok := value.(gorm.BeforeUpdateInterface); ok { - ok = true + called = true db.AddError(i.BeforeUpdate(tx)) } } - return ok + return called } if ok := callMethod(db.Statement.Dest); !ok { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - callMethod(db.Statement.ReflectValue.Index(i).Interface()) + callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface()) } case reflect.Struct: - callMethod(db.Statement.ReflectValue.Interface()) + callMethod(db.Statement.ReflectValue.Addr().Interface()) } } } @@ -98,34 +98,34 @@ func Update(db *gorm.DB) { } func AfterUpdate(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { tx := db.Session(&gorm.Session{}) callMethod := func(value interface{}) bool { - var ok bool + var called bool if db.Statement.Schema.AfterSave { if i, ok := value.(gorm.AfterSaveInterface); ok { - ok = true + called = true db.AddError(i.AfterSave(tx)) } } if db.Statement.Schema.AfterUpdate { if i, ok := value.(gorm.AfterUpdateInterface); ok { - ok = true + called = true db.AddError(i.AfterUpdate(tx)) } } - return ok + return called } if ok := callMethod(db.Statement.Dest); !ok { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - callMethod(db.Statement.ReflectValue.Index(i).Interface()) + callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface()) } case reflect.Struct: - callMethod(db.Statement.ReflectValue.Interface()) + callMethod(db.Statement.ReflectValue.Addr().Interface()) } } } @@ -191,7 +191,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } - if !stmt.DisableUpdateTime && stmt.Schema != nil { + if !stmt.UpdatingColumn && stmt.Schema != nil { for _, field := range stmt.Schema.FieldsByDBName { if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { now := stmt.DB.NowFunc() @@ -215,7 +215,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { value, isZero := field.ValueOf(updatingValue) - if !stmt.DisableUpdateTime { + if !stmt.UpdatingColumn { if field.AutoUpdateTime > 0 { if field.AutoUpdateTime == schema.UnixNanosecond { value = stmt.DB.NowFunc().UnixNano() diff --git a/finisher_api.go b/finisher_api.go index d6de7aa3..e94fd095 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -207,7 +207,7 @@ func (db *DB) Updates(values interface{}) (tx *DB) { func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = map[string]interface{}{column: value} - tx.Statement.DisableUpdateTime = true + tx.Statement.UpdatingColumn = true tx.callbacks.Update().Execute(tx) return } @@ -215,7 +215,7 @@ func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { func (db *DB) UpdateColumns(values interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = values - tx.Statement.DisableUpdateTime = true + tx.Statement.UpdatingColumn = true tx.callbacks.Update().Execute(tx) return } diff --git a/statement.go b/statement.go index 755d93ac..e3f324b9 100644 --- a/statement.go +++ b/statement.go @@ -33,7 +33,7 @@ type Statement struct { Schema *schema.Schema Context context.Context RaiseErrorOnNotFound bool - DisableUpdateTime bool + UpdatingColumn bool SQL strings.Builder Vars []interface{} NamedVars []sql.NamedArg diff --git a/tests/hooks_test.go b/tests/hooks_test.go index e2850c27..c74e8f10 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -3,9 +3,11 @@ package tests_test import ( "errors" "reflect" + "strings" "testing" "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" ) type Product struct { @@ -98,7 +100,7 @@ func TestRunCallbacks(t *testing.T) { DB.Save(&p) if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) { - t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes()) + t.Fatalf("Callbacks should be invoked successfully, %v", p.GetCallTimes()) } DB.Where("Code = ?", "unique_code").First(&p) @@ -114,7 +116,7 @@ func TestRunCallbacks(t *testing.T) { var products []Product DB.Find(&products, "code = ?", "unique_code") - if products[0].AfterFindCallTimes != 1 { + if products[0].AfterFindCallTimes != 2 { t.Fatalf("AfterFind callbacks should work with slice, called %v", products[0].AfterFindCallTimes) } @@ -198,3 +200,88 @@ func TestCallbacksWithErrors(t *testing.T) { t.Fatalf("Record shouldn't be deleted because of an error happened in after delete callback") } } + +type Product2 struct { + gorm.Model + Name string + Code string + Price int64 + Owner string +} + +func (s Product2) BeforeCreate(tx *gorm.DB) (err error) { + if !strings.HasSuffix(s.Name, "_clone") { + newProduft := s + newProduft.Price *= 2 + newProduft.Name += "_clone" + err = tx.Create(&newProduft).Error + } + + if s.Name == "Invalid" { + return errors.New("invalid") + } + + return nil +} + +func (s *Product2) BeforeUpdate(tx *gorm.DB) (err error) { + tx.Statement.Where("owner != ?", "admin") + return +} + +func TestUseDBInHooks(t *testing.T) { + DB.Migrator().DropTable(&Product2{}) + DB.AutoMigrate(&Product2{}) + + product := Product2{Name: "Invalid", Price: 100} + + if err := DB.Create(&product).Error; err == nil { + t.Fatalf("should returns error %v when creating product, but got nil", err) + } + + product2 := Product2{Name: "Nice", Price: 100} + + if err := DB.Create(&product2).Error; err != nil { + t.Fatalf("Failed to create product, got error: %v", err) + } + + var result Product2 + if err := DB.First(&result, "name = ?", "Nice").Error; err != nil { + t.Fatalf("Failed to query product, got error: %v", err) + } + + var resultClone Product2 + if err := DB.First(&resultClone, "name = ?", "Nice_clone").Error; err != nil { + t.Fatalf("Failed to find cloned product, got error: %v", err) + } + + result.Price *= 2 + result.Name += "_clone" + AssertObjEqual(t, result, resultClone, "Price", "Name") + + DB.Model(&result).Update("Price", 500) + var result2 Product2 + DB.First(&result2, "name = ?", "Nice") + + if result2.Price != 500 { + t.Errorf("Failed to update product's price, expects: %v, got %v", 500, result2.Price) + } + + product3 := Product2{Name: "Nice2", Price: 600, Owner: "admin"} + if err := DB.Create(&product3).Error; err != nil { + t.Fatalf("Failed to create product, got error: %v", err) + } + + var result3 Product2 + if err := DB.First(&result3, "name = ?", "Nice2").Error; err != nil { + t.Fatalf("Failed to query product, got error: %v", err) + } + + DB.Model(&result3).Update("Price", 800) + var result4 Product2 + DB.First(&result4, "name = ?", "Nice2") + + if result4.Price != 600 { + t.Errorf("Admin product's price should not be changed, expects: %v, got %v", 600, result4.Price) + } +}