Fix combine conditions when using string conditions, close #3358

This commit is contained in:
Jinzhu 2020-09-02 20:09:51 +08:00
parent dbaa6b0ec3
commit 680dda2c15
2 changed files with 105 additions and 1 deletions

View File

@ -1,5 +1,9 @@
package clause package clause
import (
"strings"
)
// Where where clause // Where where clause
type Where struct { type Where struct {
Exprs []Expression Exprs []Expression
@ -22,6 +26,7 @@ func (where Where) Build(builder Builder) {
} }
} }
wrapInParentheses := false
for idx, expr := range where.Exprs { for idx, expr := range where.Exprs {
if idx > 0 { if idx > 0 {
if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 {
@ -31,7 +36,36 @@ func (where Where) Build(builder Builder) {
} }
} }
if len(where.Exprs) > 1 {
switch v := expr.(type) {
case OrConditions:
if len(v.Exprs) == 1 {
if e, ok := v.Exprs[0].(Expr); ok {
sql := strings.ToLower(e.SQL)
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
}
}
case AndConditions:
if len(v.Exprs) == 1 {
if e, ok := v.Exprs[0].(Expr); ok {
sql := strings.ToLower(e.SQL)
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
}
}
case Expr:
sql := strings.ToLower(v.SQL)
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
}
}
if wrapInParentheses {
builder.WriteString(`(`)
expr.Build(builder) expr.Build(builder)
builder.WriteString(`)`)
wrapInParentheses = false
} else {
expr.Build(builder)
}
} }
} }
@ -50,6 +84,8 @@ func (where Where) MergeClause(clause *Clause) {
func And(exprs ...Expression) Expression { func And(exprs ...Expression) Expression {
if len(exprs) == 0 { if len(exprs) == 0 {
return nil return nil
} else if len(exprs) == 1 {
return exprs[0]
} }
return AndConditions{Exprs: exprs} return AndConditions{Exprs: exprs}
} }
@ -118,6 +154,7 @@ func (not NotConditions) Build(builder Builder) {
if len(not.Exprs) > 1 { if len(not.Exprs) > 1 {
builder.WriteByte('(') builder.WriteByte('(')
} }
for idx, c := range not.Exprs { for idx, c := range not.Exprs {
if idx > 0 { if idx > 0 {
builder.WriteString(" AND ") builder.WriteString(" AND ")
@ -127,9 +164,22 @@ func (not NotConditions) Build(builder Builder) {
negationBuilder.NegationBuild(builder) negationBuilder.NegationBuild(builder)
} else { } else {
builder.WriteString("NOT ") builder.WriteString("NOT ")
e, wrapInParentheses := c.(Expr)
if wrapInParentheses {
sql := strings.ToLower(e.SQL)
if wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or"); wrapInParentheses {
builder.WriteByte('(')
}
}
c.Build(builder) c.Build(builder)
if wrapInParentheses {
builder.WriteByte(')')
} }
} }
}
if len(not.Exprs) > 1 { if len(not.Exprs) > 1 {
builder.WriteByte(')') builder.WriteByte(')')
} }

View File

@ -1,6 +1,7 @@
package tests_test package tests_test
import ( import (
"regexp"
"strings" "strings"
"testing" "testing"
@ -188,3 +189,56 @@ func TestGroupConditions(t *testing.T) {
t.Errorf("expects: %v, got %v", expects, result) t.Errorf("expects: %v, got %v", expects, result)
} }
} }
func TestCombineStringConditions(t *testing.T) {
dryRunDB := DB.Session(&gorm.Session{DryRun: true})
sql := dryRunDB.Where("a = ? or b = ?", "a", "b").Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR \(c = .+ and d = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ?", "c").Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR c = .+ AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Or("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR \(c = .+ and d = .+\) OR \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ? and d = ?", "c", "d").Not("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND \(c = .+ and d = .+\) AND NOT \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ?", "c").Not("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND c = .+ AND NOT \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ? and d = ?", "c", "d").Not("e = ?", "e").Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND \(c = .+ and d = .+\) AND NOT e = .+ AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE a = .+ or b = .+$`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
sql = dryRunDB.Or("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE a = .+ or b = .+$`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
sql = dryRunDB.Not("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`WHERE NOT \(a = .+ or b = .+\)$`).MatchString(sql) {
t.Fatalf("invalid sql generated, got %v", sql)
}
}