forked from mirror/gorm
Fix soft delete with OrCondition, close #3627
This commit is contained in:
parent
9dbef26feb
commit
9b2181199d
|
@ -26,17 +26,22 @@ func (where Where) Build(builder Builder) {
|
|||
}
|
||||
}
|
||||
|
||||
buildExprs(where.Exprs, builder, " AND ")
|
||||
}
|
||||
|
||||
func buildExprs(exprs []Expression, builder Builder, joinCond string) {
|
||||
wrapInParentheses := false
|
||||
for idx, expr := range where.Exprs {
|
||||
|
||||
for idx, expr := range exprs {
|
||||
if idx > 0 {
|
||||
if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 {
|
||||
builder.WriteString(" OR ")
|
||||
} else {
|
||||
builder.WriteString(" AND ")
|
||||
builder.WriteString(joinCond)
|
||||
}
|
||||
}
|
||||
|
||||
if len(where.Exprs) > 1 {
|
||||
if len(exprs) > 1 {
|
||||
switch v := expr.(type) {
|
||||
case OrConditions:
|
||||
if len(v.Exprs) == 1 {
|
||||
|
@ -97,19 +102,10 @@ type AndConditions struct {
|
|||
func (and AndConditions) Build(builder Builder) {
|
||||
if len(and.Exprs) > 1 {
|
||||
builder.WriteByte('(')
|
||||
}
|
||||
for idx, c := range and.Exprs {
|
||||
if idx > 0 {
|
||||
if orConditions, ok := c.(OrConditions); ok && len(orConditions.Exprs) == 1 {
|
||||
builder.WriteString(" OR ")
|
||||
} else {
|
||||
builder.WriteString(" AND ")
|
||||
}
|
||||
}
|
||||
c.Build(builder)
|
||||
}
|
||||
if len(and.Exprs) > 1 {
|
||||
buildExprs(and.Exprs, builder, " AND ")
|
||||
builder.WriteByte(')')
|
||||
} else {
|
||||
buildExprs(and.Exprs, builder, " AND ")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -127,15 +123,10 @@ type OrConditions struct {
|
|||
func (or OrConditions) Build(builder Builder) {
|
||||
if len(or.Exprs) > 1 {
|
||||
builder.WriteByte('(')
|
||||
}
|
||||
for idx, c := range or.Exprs {
|
||||
if idx > 0 {
|
||||
builder.WriteString(" OR ")
|
||||
}
|
||||
c.Build(builder)
|
||||
}
|
||||
if len(or.Exprs) > 1 {
|
||||
buildExprs(or.Exprs, builder, " OR ")
|
||||
builder.WriteByte(')')
|
||||
} else {
|
||||
buildExprs(or.Exprs, builder, " OR ")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -154,6 +154,8 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) {
|
|||
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
|
||||
}
|
||||
}
|
||||
} else if andCond, ok := expr.(clause.AndConditions); ok {
|
||||
tx.assignInterfacesToValue(andCond.Exprs)
|
||||
}
|
||||
}
|
||||
case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
|
||||
|
|
|
@ -57,6 +57,19 @@ func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) {
|
|||
|
||||
func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
|
||||
if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok {
|
||||
if c, ok := stmt.Clauses["WHERE"]; ok {
|
||||
if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) > 1 {
|
||||
for _, expr := range where.Exprs {
|
||||
if orCond, ok := expr.(clause.OrConditions); ok && len(orCond.Exprs) == 1 {
|
||||
where.Exprs = []clause.Expression{clause.And(where.Exprs...)}
|
||||
c.Expression = where
|
||||
stmt.Clauses["WHERE"] = c
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{
|
||||
clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil},
|
||||
}})
|
||||
|
|
|
@ -69,7 +69,7 @@ func TestCount(t *testing.T) {
|
|||
}
|
||||
|
||||
var count4 int64
|
||||
if err := DB.Debug().Table("users").Joins("LEFT JOIN companies on companies.name = users.name").Where("users.name = ?", user1.Name).Count(&count4).Error; err != nil || count4 != 1 {
|
||||
if err := DB.Table("users").Joins("LEFT JOIN companies on companies.name = users.name").Where("users.name = ?", user1.Name).Count(&count4).Error; err != nil || count4 != 1 {
|
||||
t.Errorf("count with join, got error: %v, count %v", err, count)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -198,17 +198,17 @@ func TestCombineStringConditions(t *testing.T) {
|
|||
}
|
||||
|
||||
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Find(&User{}).Statement.SQL.String()
|
||||
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR \(c = .+ and d = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||
if !regexp.MustCompile(`WHERE \(\(a = .+ or b = .+\) OR \(c = .+ and d = .+\)\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||
t.Fatalf("invalid sql generated, got %v", sql)
|
||||
}
|
||||
|
||||
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ?", "c").Find(&User{}).Statement.SQL.String()
|
||||
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR c = .+ AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||
if !regexp.MustCompile(`WHERE \(\(a = .+ or b = .+\) OR c = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||
t.Fatalf("invalid sql generated, got %v", sql)
|
||||
}
|
||||
|
||||
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Or("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String()
|
||||
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR \(c = .+ and d = .+\) OR \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||
if !regexp.MustCompile(`WHERE \(\(a = .+ or b = .+\) OR \(c = .+ and d = .+\) OR \(e = .+ and f = .+\)\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||
t.Fatalf("invalid sql generated, got %v", sql)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue