Support delete associations with Select when deleting

This commit is contained in:
Jinzhu 2020-09-10 21:46:18 +08:00
parent 53caa85cf4
commit 70a7bd52ca
5 changed files with 127 additions and 1 deletions

View File

@ -31,6 +31,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
deleteCallback := db.Callback().Delete()
deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
deleteCallback.Register("gorm:before_delete", BeforeDelete)
deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations)
deleteCallback.Register("gorm:delete", Delete)
deleteCallback.Register("gorm:after_delete", AfterDelete)
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)

View File

@ -21,6 +21,59 @@ func BeforeDelete(db *gorm.DB) {
}
}
func DeleteBeforeAssociations(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil {
selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)
if restricted {
for column, v := range selectColumns {
if v {
if rel, ok := db.Statement.Schema.Relationships.Relations[column]; ok {
switch rel.Type {
case schema.HasOne, schema.HasMany:
queryConds := rel.ToQueryConditions(db.Statement.ReflectValue)
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
tx := db.Session(&gorm.Session{}).Model(modelValue)
if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
return
}
case schema.Many2Many:
var (
queryConds []clause.Expression
foreignFields []*schema.Field
relForeignKeys []string
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
table = rel.JoinTable.Table
tx = db.Session(&gorm.Session{}).Model(modelValue).Table(table)
)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.PrimaryKey)
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
} else if ref.PrimaryValue != "" {
queryConds = append(queryConds, clause.Eq{
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
})
}
}
_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields)
column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues)
queryConds = append(queryConds, clause.IN{Column: column, Values: values})
if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
return
}
}
}
}
}
}
}
}
func Delete(db *gorm.DB) {
if db.Error == nil {
if db.Statement.Schema != nil && !db.Statement.Unscoped {

View File

@ -5,6 +5,7 @@ import (
"testing"
"gorm.io/gorm"
"gorm.io/gorm/clause"
. "gorm.io/gorm/utils/tests"
)
@ -127,3 +128,56 @@ func TestBlockGlobalDelete(t *testing.T) {
t.Errorf("should returns no error while enable global update, but got err %v", err)
}
}
func TestDeleteWithAssociations(t *testing.T) {
user := GetUser("delete_with_associations", Config{Account: true, Pets: 2, Toys: 4, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 1})
if err := DB.Create(user).Error; err != nil {
t.Fatalf("failed to create user, got error %v", err)
}
if err := DB.Select(clause.Associations).Delete(&user).Error; err != nil {
t.Fatalf("failed to delete user, got error %v", err)
}
for key, value := range map[string]int64{"Account": 1, "Pets": 2, "Toys": 4, "Company": 1, "Manager": 1, "Team": 1, "Languages": 0, "Friends": 0} {
if count := DB.Unscoped().Model(&user).Association(key).Count(); count != value {
t.Errorf("user's %v expects: %v, got %v", key, value, count)
}
}
for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 1, "Manager": 1, "Team": 0, "Languages": 0, "Friends": 0} {
if count := DB.Model(&user).Association(key).Count(); count != value {
t.Errorf("user's %v expects: %v, got %v", key, value, count)
}
}
}
func TestDeleteSliceWithAssociations(t *testing.T) {
users := []User{
*GetUser("delete_slice_with_associations1", Config{Account: true, Pets: 4, Toys: 1, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 4}),
*GetUser("delete_slice_with_associations2", Config{Account: true, Pets: 3, Toys: 2, Company: true, Manager: true, Team: 2, Languages: 2, Friends: 3}),
*GetUser("delete_slice_with_associations3", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 2}),
*GetUser("delete_slice_with_associations4", Config{Account: true, Pets: 1, Toys: 4, Company: true, Manager: true, Team: 4, Languages: 4, Friends: 1}),
}
if err := DB.Create(users).Error; err != nil {
t.Fatalf("failed to create user, got error %v", err)
}
if err := DB.Select(clause.Associations).Delete(&users).Error; err != nil {
t.Fatalf("failed to delete user, got error %v", err)
}
for key, value := range map[string]int64{"Account": 4, "Pets": 10, "Toys": 10, "Company": 4, "Manager": 4, "Team": 10, "Languages": 0, "Friends": 0} {
if count := DB.Unscoped().Model(&users).Association(key).Count(); count != value {
t.Errorf("user's %v expects: %v, got %v", key, value, count)
}
}
for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 4, "Manager": 4, "Team": 0, "Languages": 0, "Friends": 0} {
if count := DB.Model(&users).Association(key).Count(); count != value {
t.Errorf("user's %v expects: %v, got %v", key, value, count)
}
}
}

View File

@ -5,12 +5,14 @@ import (
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type Person struct {
ID int
Name string
Addresses []Address `gorm:"many2many:person_addresses;"`
DeletedAt gorm.DeletedAt
}
type Address struct {
@ -95,4 +97,20 @@ func TestOverrideJoinTable(t *testing.T) {
if DB.Unscoped().Model(&person).Association("Addresses").Count() != 0 {
t.Fatalf("address should be deleted when clear with unscoped")
}
address2_1 := Address{Name: "address 2-1"}
address2_2 := Address{Name: "address 2-2"}
person2 := Person{Name: "person_2", Addresses: []Address{address2_1, address2_2}}
DB.Create(&person2)
if err := DB.Select(clause.Associations).Delete(&person2).Error; err != nil {
t.Fatalf("failed to delete person, got error: %v", err)
}
if count := DB.Unscoped().Model(&person2).Association("Addresses").Count(); count != 2 {
t.Errorf("person's addresses expects 2, got %v", count)
}
if count := DB.Model(&person2).Association("Addresses").Count(); count != 0 {
t.Errorf("person's addresses expects 2, got %v", count)
}
}

View File

@ -30,7 +30,7 @@ func FileWithLineNum() string {
}
func IsValidDBNameChar(c rune) bool {
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$'
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@'
}
func CheckTruth(val interface{}) bool {