mirror of https://github.com/go-gorm/gorm.git
Support delete associations with Select when deleting
This commit is contained in:
parent
53caa85cf4
commit
70a7bd52ca
|
@ -31,6 +31,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
|
||||||
deleteCallback := db.Callback().Delete()
|
deleteCallback := db.Callback().Delete()
|
||||||
deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
||||||
deleteCallback.Register("gorm:before_delete", BeforeDelete)
|
deleteCallback.Register("gorm:before_delete", BeforeDelete)
|
||||||
|
deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations)
|
||||||
deleteCallback.Register("gorm:delete", Delete)
|
deleteCallback.Register("gorm:delete", Delete)
|
||||||
deleteCallback.Register("gorm:after_delete", AfterDelete)
|
deleteCallback.Register("gorm:after_delete", AfterDelete)
|
||||||
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||||
|
|
|
@ -21,6 +21,59 @@ func BeforeDelete(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DeleteBeforeAssociations(db *gorm.DB) {
|
||||||
|
if db.Error == nil && db.Statement.Schema != nil {
|
||||||
|
selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)
|
||||||
|
|
||||||
|
if restricted {
|
||||||
|
for column, v := range selectColumns {
|
||||||
|
if v {
|
||||||
|
if rel, ok := db.Statement.Schema.Relationships.Relations[column]; ok {
|
||||||
|
switch rel.Type {
|
||||||
|
case schema.HasOne, schema.HasMany:
|
||||||
|
queryConds := rel.ToQueryConditions(db.Statement.ReflectValue)
|
||||||
|
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||||
|
tx := db.Session(&gorm.Session{}).Model(modelValue)
|
||||||
|
if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case schema.Many2Many:
|
||||||
|
var (
|
||||||
|
queryConds []clause.Expression
|
||||||
|
foreignFields []*schema.Field
|
||||||
|
relForeignKeys []string
|
||||||
|
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
|
||||||
|
table = rel.JoinTable.Table
|
||||||
|
tx = db.Session(&gorm.Session{}).Model(modelValue).Table(table)
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, ref := range rel.References {
|
||||||
|
if ref.OwnPrimaryKey {
|
||||||
|
foreignFields = append(foreignFields, ref.PrimaryKey)
|
||||||
|
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
|
||||||
|
} else if ref.PrimaryValue != "" {
|
||||||
|
queryConds = append(queryConds, clause.Eq{
|
||||||
|
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
|
||||||
|
Value: ref.PrimaryValue,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields)
|
||||||
|
column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues)
|
||||||
|
queryConds = append(queryConds, clause.IN{Column: column, Values: values})
|
||||||
|
|
||||||
|
if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func Delete(db *gorm.DB) {
|
func Delete(db *gorm.DB) {
|
||||||
if db.Error == nil {
|
if db.Error == nil {
|
||||||
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -127,3 +128,56 @@ func TestBlockGlobalDelete(t *testing.T) {
|
||||||
t.Errorf("should returns no error while enable global update, but got err %v", err)
|
t.Errorf("should returns no error while enable global update, but got err %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDeleteWithAssociations(t *testing.T) {
|
||||||
|
user := GetUser("delete_with_associations", Config{Account: true, Pets: 2, Toys: 4, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 1})
|
||||||
|
|
||||||
|
if err := DB.Create(user).Error; err != nil {
|
||||||
|
t.Fatalf("failed to create user, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Select(clause.Associations).Delete(&user).Error; err != nil {
|
||||||
|
t.Fatalf("failed to delete user, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range map[string]int64{"Account": 1, "Pets": 2, "Toys": 4, "Company": 1, "Manager": 1, "Team": 1, "Languages": 0, "Friends": 0} {
|
||||||
|
if count := DB.Unscoped().Model(&user).Association(key).Count(); count != value {
|
||||||
|
t.Errorf("user's %v expects: %v, got %v", key, value, count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 1, "Manager": 1, "Team": 0, "Languages": 0, "Friends": 0} {
|
||||||
|
if count := DB.Model(&user).Association(key).Count(); count != value {
|
||||||
|
t.Errorf("user's %v expects: %v, got %v", key, value, count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteSliceWithAssociations(t *testing.T) {
|
||||||
|
users := []User{
|
||||||
|
*GetUser("delete_slice_with_associations1", Config{Account: true, Pets: 4, Toys: 1, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 4}),
|
||||||
|
*GetUser("delete_slice_with_associations2", Config{Account: true, Pets: 3, Toys: 2, Company: true, Manager: true, Team: 2, Languages: 2, Friends: 3}),
|
||||||
|
*GetUser("delete_slice_with_associations3", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 2}),
|
||||||
|
*GetUser("delete_slice_with_associations4", Config{Account: true, Pets: 1, Toys: 4, Company: true, Manager: true, Team: 4, Languages: 4, Friends: 1}),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Create(users).Error; err != nil {
|
||||||
|
t.Fatalf("failed to create user, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Select(clause.Associations).Delete(&users).Error; err != nil {
|
||||||
|
t.Fatalf("failed to delete user, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range map[string]int64{"Account": 4, "Pets": 10, "Toys": 10, "Company": 4, "Manager": 4, "Team": 10, "Languages": 0, "Friends": 0} {
|
||||||
|
if count := DB.Unscoped().Model(&users).Association(key).Count(); count != value {
|
||||||
|
t.Errorf("user's %v expects: %v, got %v", key, value, count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 4, "Manager": 4, "Team": 0, "Languages": 0, "Friends": 0} {
|
||||||
|
if count := DB.Model(&users).Association(key).Count(); count != value {
|
||||||
|
t.Errorf("user's %v expects: %v, got %v", key, value, count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -5,12 +5,14 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Person struct {
|
type Person struct {
|
||||||
ID int
|
ID int
|
||||||
Name string
|
Name string
|
||||||
Addresses []Address `gorm:"many2many:person_addresses;"`
|
Addresses []Address `gorm:"many2many:person_addresses;"`
|
||||||
|
DeletedAt gorm.DeletedAt
|
||||||
}
|
}
|
||||||
|
|
||||||
type Address struct {
|
type Address struct {
|
||||||
|
@ -95,4 +97,20 @@ func TestOverrideJoinTable(t *testing.T) {
|
||||||
if DB.Unscoped().Model(&person).Association("Addresses").Count() != 0 {
|
if DB.Unscoped().Model(&person).Association("Addresses").Count() != 0 {
|
||||||
t.Fatalf("address should be deleted when clear with unscoped")
|
t.Fatalf("address should be deleted when clear with unscoped")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
address2_1 := Address{Name: "address 2-1"}
|
||||||
|
address2_2 := Address{Name: "address 2-2"}
|
||||||
|
person2 := Person{Name: "person_2", Addresses: []Address{address2_1, address2_2}}
|
||||||
|
DB.Create(&person2)
|
||||||
|
if err := DB.Select(clause.Associations).Delete(&person2).Error; err != nil {
|
||||||
|
t.Fatalf("failed to delete person, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if count := DB.Unscoped().Model(&person2).Association("Addresses").Count(); count != 2 {
|
||||||
|
t.Errorf("person's addresses expects 2, got %v", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
if count := DB.Model(&person2).Association("Addresses").Count(); count != 0 {
|
||||||
|
t.Errorf("person's addresses expects 2, got %v", count)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,7 +30,7 @@ func FileWithLineNum() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsValidDBNameChar(c rune) bool {
|
func IsValidDBNameChar(c rune) bool {
|
||||||
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$'
|
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@'
|
||||||
}
|
}
|
||||||
|
|
||||||
func CheckTruth(val interface{}) bool {
|
func CheckTruth(val interface{}) bool {
|
||||||
|
|
Loading…
Reference in New Issue