Fix soft delete with OrCondition, close #3627

This commit is contained in:
Jinzhu 2020-10-19 14:49:42 +08:00
parent 9dbef26feb
commit 9b2181199d
5 changed files with 33 additions and 27 deletions

View File

@ -26,17 +26,22 @@ func (where Where) Build(builder Builder) {
} }
} }
buildExprs(where.Exprs, builder, " AND ")
}
func buildExprs(exprs []Expression, builder Builder, joinCond string) {
wrapInParentheses := false wrapInParentheses := false
for idx, expr := range where.Exprs {
for idx, expr := range exprs {
if idx > 0 { if idx > 0 {
if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 {
builder.WriteString(" OR ") builder.WriteString(" OR ")
} else { } else {
builder.WriteString(" AND ") builder.WriteString(joinCond)
} }
} }
if len(where.Exprs) > 1 { if len(exprs) > 1 {
switch v := expr.(type) { switch v := expr.(type) {
case OrConditions: case OrConditions:
if len(v.Exprs) == 1 { if len(v.Exprs) == 1 {
@ -97,19 +102,10 @@ type AndConditions struct {
func (and AndConditions) Build(builder Builder) { func (and AndConditions) Build(builder Builder) {
if len(and.Exprs) > 1 { if len(and.Exprs) > 1 {
builder.WriteByte('(') builder.WriteByte('(')
} buildExprs(and.Exprs, builder, " AND ")
for idx, c := range and.Exprs {
if idx > 0 {
if orConditions, ok := c.(OrConditions); ok && len(orConditions.Exprs) == 1 {
builder.WriteString(" OR ")
} else {
builder.WriteString(" AND ")
}
}
c.Build(builder)
}
if len(and.Exprs) > 1 {
builder.WriteByte(')') builder.WriteByte(')')
} else {
buildExprs(and.Exprs, builder, " AND ")
} }
} }
@ -127,15 +123,10 @@ type OrConditions struct {
func (or OrConditions) Build(builder Builder) { func (or OrConditions) Build(builder Builder) {
if len(or.Exprs) > 1 { if len(or.Exprs) > 1 {
builder.WriteByte('(') builder.WriteByte('(')
} buildExprs(or.Exprs, builder, " OR ")
for idx, c := range or.Exprs {
if idx > 0 {
builder.WriteString(" OR ")
}
c.Build(builder)
}
if len(or.Exprs) > 1 {
builder.WriteByte(')') builder.WriteByte(')')
} else {
buildExprs(or.Exprs, builder, " OR ")
} }
} }

View File

@ -154,6 +154,8 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) {
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
} }
} }
} else if andCond, ok := expr.(clause.AndConditions); ok {
tx.assignInterfacesToValue(andCond.Exprs)
} }
} }
case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}: case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:

View File

@ -57,6 +57,19 @@ func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) {
func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) { func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok { if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok {
if c, ok := stmt.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) > 1 {
for _, expr := range where.Exprs {
if orCond, ok := expr.(clause.OrConditions); ok && len(orCond.Exprs) == 1 {
where.Exprs = []clause.Expression{clause.And(where.Exprs...)}
c.Expression = where
stmt.Clauses["WHERE"] = c
break
}
}
}
}
stmt.AddClause(clause.Where{Exprs: []clause.Expression{ stmt.AddClause(clause.Where{Exprs: []clause.Expression{
clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil}, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil},
}}) }})

View File

@ -69,7 +69,7 @@ func TestCount(t *testing.T) {
} }
var count4 int64 var count4 int64
if err := DB.Debug().Table("users").Joins("LEFT JOIN companies on companies.name = users.name").Where("users.name = ?", user1.Name).Count(&count4).Error; err != nil || count4 != 1 { if err := DB.Table("users").Joins("LEFT JOIN companies on companies.name = users.name").Where("users.name = ?", user1.Name).Count(&count4).Error; err != nil || count4 != 1 {
t.Errorf("count with join, got error: %v, count %v", err, count) t.Errorf("count with join, got error: %v, count %v", err, count)
} }
} }

View File

@ -198,17 +198,17 @@ func TestCombineStringConditions(t *testing.T) {
} }
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Find(&User{}).Statement.SQL.String() sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR \(c = .+ and d = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { if !regexp.MustCompile(`WHERE \(\(a = .+ or b = .+\) OR \(c = .+ and d = .+\)\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql) t.Fatalf("invalid sql generated, got %v", sql)
} }
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ?", "c").Find(&User{}).Statement.SQL.String() sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ?", "c").Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR c = .+ AND .users.\..deleted_at. IS NULL`).MatchString(sql) { if !regexp.MustCompile(`WHERE \(\(a = .+ or b = .+\) OR c = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql) t.Fatalf("invalid sql generated, got %v", sql)
} }
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Or("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String() sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Or("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR \(c = .+ and d = .+\) OR \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { if !regexp.MustCompile(`WHERE \(\(a = .+ or b = .+\) OR \(c = .+ and d = .+\) OR \(e = .+ and f = .+\)\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql) t.Fatalf("invalid sql generated, got %v", sql)
} }