From 70a7bd52ca2bbf64443b7227524e4600997ea1b9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 10 Sep 2020 21:46:18 +0800 Subject: [PATCH] Support delete associations with Select when deleting --- callbacks/callbacks.go | 1 + callbacks/delete.go | 53 ++++++++++++++++++++++++++++++++++++++ tests/delete_test.go | 54 +++++++++++++++++++++++++++++++++++++++ tests/joins_table_test.go | 18 +++++++++++++ utils/utils.go | 2 +- 5 files changed, 127 insertions(+), 1 deletion(-) diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index 0a12468c..dda4b046 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -31,6 +31,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { deleteCallback := db.Callback().Delete() deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) deleteCallback.Register("gorm:before_delete", BeforeDelete) + deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations) deleteCallback.Register("gorm:delete", Delete) deleteCallback.Register("gorm:after_delete", AfterDelete) deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) diff --git a/callbacks/delete.go b/callbacks/delete.go index e95117a1..510dfae4 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -21,6 +21,59 @@ func BeforeDelete(db *gorm.DB) { } } +func DeleteBeforeAssociations(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil { + selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) + + if restricted { + for column, v := range selectColumns { + if v { + if rel, ok := db.Statement.Schema.Relationships.Relations[column]; ok { + switch rel.Type { + case schema.HasOne, schema.HasMany: + queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) + modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() + tx := db.Session(&gorm.Session{}).Model(modelValue) + if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { + return + } + case schema.Many2Many: + var ( + queryConds []clause.Expression + foreignFields []*schema.Field + relForeignKeys []string + modelValue = reflect.New(rel.JoinTable.ModelType).Interface() + table = rel.JoinTable.Table + tx = db.Session(&gorm.Session{}).Model(modelValue).Table(table) + ) + + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.PrimaryKey) + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + } else if ref.PrimaryValue != "" { + queryConds = append(queryConds, clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + }) + } + } + + _, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields) + column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) + queryConds = append(queryConds, clause.IN{Column: column, Values: values}) + + if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { + return + } + } + } + } + } + } + } +} + func Delete(db *gorm.DB) { if db.Error == nil { if db.Statement.Schema != nil && !db.Statement.Unscoped { diff --git a/tests/delete_test.go b/tests/delete_test.go index 17299677..4945e837 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -5,6 +5,7 @@ import ( "testing" "gorm.io/gorm" + "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -127,3 +128,56 @@ func TestBlockGlobalDelete(t *testing.T) { t.Errorf("should returns no error while enable global update, but got err %v", err) } } + +func TestDeleteWithAssociations(t *testing.T) { + user := GetUser("delete_with_associations", Config{Account: true, Pets: 2, Toys: 4, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 1}) + + if err := DB.Create(user).Error; err != nil { + t.Fatalf("failed to create user, got error %v", err) + } + + if err := DB.Select(clause.Associations).Delete(&user).Error; err != nil { + t.Fatalf("failed to delete user, got error %v", err) + } + + for key, value := range map[string]int64{"Account": 1, "Pets": 2, "Toys": 4, "Company": 1, "Manager": 1, "Team": 1, "Languages": 0, "Friends": 0} { + if count := DB.Unscoped().Model(&user).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } + + for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 1, "Manager": 1, "Team": 0, "Languages": 0, "Friends": 0} { + if count := DB.Model(&user).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } +} + +func TestDeleteSliceWithAssociations(t *testing.T) { + users := []User{ + *GetUser("delete_slice_with_associations1", Config{Account: true, Pets: 4, Toys: 1, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 4}), + *GetUser("delete_slice_with_associations2", Config{Account: true, Pets: 3, Toys: 2, Company: true, Manager: true, Team: 2, Languages: 2, Friends: 3}), + *GetUser("delete_slice_with_associations3", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 2}), + *GetUser("delete_slice_with_associations4", Config{Account: true, Pets: 1, Toys: 4, Company: true, Manager: true, Team: 4, Languages: 4, Friends: 1}), + } + + if err := DB.Create(users).Error; err != nil { + t.Fatalf("failed to create user, got error %v", err) + } + + if err := DB.Select(clause.Associations).Delete(&users).Error; err != nil { + t.Fatalf("failed to delete user, got error %v", err) + } + + for key, value := range map[string]int64{"Account": 4, "Pets": 10, "Toys": 10, "Company": 4, "Manager": 4, "Team": 10, "Languages": 0, "Friends": 0} { + if count := DB.Unscoped().Model(&users).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } + + for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 4, "Manager": 4, "Team": 0, "Languages": 0, "Friends": 0} { + if count := DB.Model(&users).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } +} diff --git a/tests/joins_table_test.go b/tests/joins_table_test.go index b8c1be77..084c2f2c 100644 --- a/tests/joins_table_test.go +++ b/tests/joins_table_test.go @@ -5,12 +5,14 @@ import ( "time" "gorm.io/gorm" + "gorm.io/gorm/clause" ) type Person struct { ID int Name string Addresses []Address `gorm:"many2many:person_addresses;"` + DeletedAt gorm.DeletedAt } type Address struct { @@ -95,4 +97,20 @@ func TestOverrideJoinTable(t *testing.T) { if DB.Unscoped().Model(&person).Association("Addresses").Count() != 0 { t.Fatalf("address should be deleted when clear with unscoped") } + + address2_1 := Address{Name: "address 2-1"} + address2_2 := Address{Name: "address 2-2"} + person2 := Person{Name: "person_2", Addresses: []Address{address2_1, address2_2}} + DB.Create(&person2) + if err := DB.Select(clause.Associations).Delete(&person2).Error; err != nil { + t.Fatalf("failed to delete person, got error: %v", err) + } + + if count := DB.Unscoped().Model(&person2).Association("Addresses").Count(); count != 2 { + t.Errorf("person's addresses expects 2, got %v", count) + } + + if count := DB.Model(&person2).Association("Addresses").Count(); count != 0 { + t.Errorf("person's addresses expects 2, got %v", count) + } } diff --git a/utils/utils.go b/utils/utils.go index 905001a5..ecba7fb9 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -30,7 +30,7 @@ func FileWithLineNum() string { } func IsValidDBNameChar(c rune) bool { - return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' + return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@' } func CheckTruth(val interface{}) bool {