From e66a059b823218ec6d7efc765f67d788bb900f75 Mon Sep 17 00:00:00 2001 From: black-06 Date: Sat, 18 Feb 2023 09:20:29 +0800 Subject: [PATCH] fix: update panic if model is not ptr (#6037) * fix: update panic if model is not ptr * fix: update panic if model is not ptr * fix: update panic if model is not ptr * fix: raise an error if the value is not addressable * fix: return --- callbacks/callmethod.go | 13 +++++++++-- callbacks/update.go | 4 +++- schema/utils.go | 2 +- tests/hooks_test.go | 52 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 67 insertions(+), 4 deletions(-) diff --git a/callbacks/callmethod.go b/callbacks/callmethod.go index bcaa03f3..fb900037 100644 --- a/callbacks/callmethod.go +++ b/callbacks/callmethod.go @@ -13,11 +13,20 @@ func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) { case reflect.Slice, reflect.Array: db.Statement.CurDestIndex = 0 for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - fc(reflect.Indirect(db.Statement.ReflectValue.Index(i)).Addr().Interface(), tx) + if value := reflect.Indirect(db.Statement.ReflectValue.Index(i)); value.CanAddr() { + fc(value.Addr().Interface(), tx) + } else { + db.AddError(gorm.ErrInvalidValue) + return + } db.Statement.CurDestIndex++ } case reflect.Struct: - fc(db.Statement.ReflectValue.Addr().Interface(), tx) + if db.Statement.ReflectValue.CanAddr() { + fc(db.Statement.ReflectValue.Addr().Interface(), tx) + } else { + db.AddError(gorm.ErrInvalidValue) + } } } } diff --git a/callbacks/update.go b/callbacks/update.go index b596df9a..fe6f0994 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -137,7 +137,9 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { case reflect.Slice, reflect.Array: assignValue = func(field *schema.Field, value interface{}) { for i := 0; i < stmt.ReflectValue.Len(); i++ { - field.Set(stmt.Context, stmt.ReflectValue.Index(i), value) + if stmt.ReflectValue.CanAddr() { + field.Set(stmt.Context, stmt.ReflectValue.Index(i), value) + } } } case reflect.Struct: diff --git a/schema/utils.go b/schema/utils.go index acf1a739..65d012e5 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -133,7 +133,7 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, for i := 0; i < reflectValue.Len(); i++ { elem := reflectValue.Index(i) elemKey := elem.Interface() - if elem.Kind() != reflect.Ptr { + if elem.Kind() != reflect.Ptr && elem.CanAddr() { elemKey = elem.Addr().Interface() } diff --git a/tests/hooks_test.go b/tests/hooks_test.go index 8e964fd8..0753dd0b 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -514,3 +514,55 @@ func TestFailedToSaveAssociationShouldRollback(t *testing.T) { t.Fatalf("AfterFind should not be called times:%d", productWithItem.Item.AfterFindCallTimes) } } + +type Product5 struct { + gorm.Model + Name string +} + +var beforeUpdateCall int + +func (p *Product5) BeforeUpdate(*gorm.DB) error { + beforeUpdateCall = beforeUpdateCall + 1 + return nil +} + +func TestUpdateCallbacks(t *testing.T) { + DB.Migrator().DropTable(&Product5{}) + DB.AutoMigrate(&Product5{}) + + p := Product5{Name: "unique_code"} + DB.Model(&Product5{}).Create(&p) + + err := DB.Model(&Product5{}).Where("id", p.ID).Update("name", "update_name_1").Error + if err != nil { + t.Fatalf("should update success, but got err %v", err) + } + if beforeUpdateCall != 1 { + t.Fatalf("before update should be called") + } + + err = DB.Model(Product5{}).Where("id", p.ID).Update("name", "update_name_2").Error + if !errors.Is(err, gorm.ErrInvalidValue) { + t.Fatalf("should got RecordNotFound, but got %v", err) + } + if beforeUpdateCall != 1 { + t.Fatalf("before update should not be called") + } + + err = DB.Model([1]*Product5{&p}).Update("name", "update_name_3").Error + if err != nil { + t.Fatalf("should update success, but got err %v", err) + } + if beforeUpdateCall != 2 { + t.Fatalf("before update should be called") + } + + err = DB.Model([1]Product5{p}).Update("name", "update_name_4").Error + if !errors.Is(err, gorm.ErrInvalidValue) { + t.Fatalf("should got RecordNotFound, but got %v", err) + } + if beforeUpdateCall != 2 { + t.Fatalf("before update should not be called") + } +}