mirror of https://github.com/go-gorm/gorm.git
Fix combine conditions when using string conditions, close #3358
This commit is contained in:
parent
dbaa6b0ec3
commit
680dda2c15
|
@ -1,5 +1,9 @@
|
|||
package clause
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Where where clause
|
||||
type Where struct {
|
||||
Exprs []Expression
|
||||
|
@ -22,6 +26,7 @@ func (where Where) Build(builder Builder) {
|
|||
}
|
||||
}
|
||||
|
||||
wrapInParentheses := false
|
||||
for idx, expr := range where.Exprs {
|
||||
if idx > 0 {
|
||||
if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 {
|
||||
|
@ -31,7 +36,36 @@ func (where Where) Build(builder Builder) {
|
|||
}
|
||||
}
|
||||
|
||||
if len(where.Exprs) > 1 {
|
||||
switch v := expr.(type) {
|
||||
case OrConditions:
|
||||
if len(v.Exprs) == 1 {
|
||||
if e, ok := v.Exprs[0].(Expr); ok {
|
||||
sql := strings.ToLower(e.SQL)
|
||||
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
|
||||
}
|
||||
}
|
||||
case AndConditions:
|
||||
if len(v.Exprs) == 1 {
|
||||
if e, ok := v.Exprs[0].(Expr); ok {
|
||||
sql := strings.ToLower(e.SQL)
|
||||
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
|
||||
}
|
||||
}
|
||||
case Expr:
|
||||
sql := strings.ToLower(v.SQL)
|
||||
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
|
||||
}
|
||||
}
|
||||
|
||||
if wrapInParentheses {
|
||||
builder.WriteString(`(`)
|
||||
expr.Build(builder)
|
||||
builder.WriteString(`)`)
|
||||
wrapInParentheses = false
|
||||
} else {
|
||||
expr.Build(builder)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -50,6 +84,8 @@ func (where Where) MergeClause(clause *Clause) {
|
|||
func And(exprs ...Expression) Expression {
|
||||
if len(exprs) == 0 {
|
||||
return nil
|
||||
} else if len(exprs) == 1 {
|
||||
return exprs[0]
|
||||
}
|
||||
return AndConditions{Exprs: exprs}
|
||||
}
|
||||
|
@ -118,6 +154,7 @@ func (not NotConditions) Build(builder Builder) {
|
|||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte('(')
|
||||
}
|
||||
|
||||
for idx, c := range not.Exprs {
|
||||
if idx > 0 {
|
||||
builder.WriteString(" AND ")
|
||||
|
@ -127,9 +164,22 @@ func (not NotConditions) Build(builder Builder) {
|
|||
negationBuilder.NegationBuild(builder)
|
||||
} else {
|
||||
builder.WriteString("NOT ")
|
||||
e, wrapInParentheses := c.(Expr)
|
||||
if wrapInParentheses {
|
||||
sql := strings.ToLower(e.SQL)
|
||||
if wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or"); wrapInParentheses {
|
||||
builder.WriteByte('(')
|
||||
}
|
||||
}
|
||||
|
||||
c.Build(builder)
|
||||
|
||||
if wrapInParentheses {
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package tests_test
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
|
@ -188,3 +189,56 @@ func TestGroupConditions(t *testing.T) {
|
|||
t.Errorf("expects: %v, got %v", expects, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCombineStringConditions(t *testing.T) {
|
||||
dryRunDB := DB.Session(&gorm.Session{DryRun: true})
|
||||
sql := dryRunDB.Where("a = ? or b = ?", "a", "b").Find(&User{}).Statement.SQL.String()
|
||||
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||
t.Fatalf("invalid sql generated, got %v", sql)
|
||||
}
|
||||
|
||||
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Find(&User{}).Statement.SQL.String()
|
||||
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR \(c = .+ and d = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||
t.Fatalf("invalid sql generated, got %v", sql)
|
||||
}
|
||||
|
||||
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ?", "c").Find(&User{}).Statement.SQL.String()
|
||||
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR c = .+ AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||
t.Fatalf("invalid sql generated, got %v", sql)
|
||||
}
|
||||
|
||||
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Or("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String()
|
||||
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR \(c = .+ and d = .+\) OR \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||
t.Fatalf("invalid sql generated, got %v", sql)
|
||||
}
|
||||
|
||||
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ? and d = ?", "c", "d").Not("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String()
|
||||
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND \(c = .+ and d = .+\) AND NOT \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||
t.Fatalf("invalid sql generated, got %v", sql)
|
||||
}
|
||||
|
||||
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ?", "c").Not("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String()
|
||||
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND c = .+ AND NOT \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||
t.Fatalf("invalid sql generated, got %v", sql)
|
||||
}
|
||||
|
||||
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ? and d = ?", "c", "d").Not("e = ?", "e").Find(&User{}).Statement.SQL.String()
|
||||
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND \(c = .+ and d = .+\) AND NOT e = .+ AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||
t.Fatalf("invalid sql generated, got %v", sql)
|
||||
}
|
||||
|
||||
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String()
|
||||
if !regexp.MustCompile(`WHERE a = .+ or b = .+$`).MatchString(sql) {
|
||||
t.Fatalf("invalid sql generated, got %v", sql)
|
||||
}
|
||||
|
||||
sql = dryRunDB.Or("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String()
|
||||
if !regexp.MustCompile(`WHERE a = .+ or b = .+$`).MatchString(sql) {
|
||||
t.Fatalf("invalid sql generated, got %v", sql)
|
||||
}
|
||||
|
||||
sql = dryRunDB.Not("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String()
|
||||
if !regexp.MustCompile(`WHERE NOT \(a = .+ or b = .+\)$`).MatchString(sql) {
|
||||
t.Fatalf("invalid sql generated, got %v", sql)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue