Fixed the use of "or" to be " OR ", to account for words that contain "or" or "and" (e.g., 'score', 'band') in a sql statement as the name of a field.

This commit is contained in:
sammyrnycreal 2022-02-14 14:13:26 -05:00 committed by Jinzhu
parent 48ced75d1d
commit 5edc78116f
2 changed files with 61 additions and 17 deletions

View File

@ -4,6 +4,11 @@ import (
"strings" "strings"
) )
const (
AndWithSpace = " AND "
OrWithSpace = " OR "
)
// Where where clause // Where where clause
type Where struct { type Where struct {
Exprs []Expression Exprs []Expression
@ -26,7 +31,7 @@ func (where Where) Build(builder Builder) {
} }
} }
buildExprs(where.Exprs, builder, " AND ") buildExprs(where.Exprs, builder, AndWithSpace)
} }
func buildExprs(exprs []Expression, builder Builder, joinCond string) { func buildExprs(exprs []Expression, builder Builder, joinCond string) {
@ -35,7 +40,7 @@ func buildExprs(exprs []Expression, builder Builder, joinCond string) {
for idx, expr := range exprs { for idx, expr := range exprs {
if idx > 0 { if idx > 0 {
if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 {
builder.WriteString(" OR ") builder.WriteString(OrWithSpace)
} else { } else {
builder.WriteString(joinCond) builder.WriteString(joinCond)
} }
@ -46,23 +51,23 @@ func buildExprs(exprs []Expression, builder Builder, joinCond string) {
case OrConditions: case OrConditions:
if len(v.Exprs) == 1 { if len(v.Exprs) == 1 {
if e, ok := v.Exprs[0].(Expr); ok { if e, ok := v.Exprs[0].(Expr); ok {
sql := strings.ToLower(e.SQL) sql := strings.ToUpper(e.SQL)
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
} }
} }
case AndConditions: case AndConditions:
if len(v.Exprs) == 1 { if len(v.Exprs) == 1 {
if e, ok := v.Exprs[0].(Expr); ok { if e, ok := v.Exprs[0].(Expr); ok {
sql := strings.ToLower(e.SQL) sql := strings.ToUpper(e.SQL)
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
} }
} }
case Expr: case Expr:
sql := strings.ToLower(v.SQL) sql := strings.ToUpper(v.SQL)
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
case NamedExpr: case NamedExpr:
sql := strings.ToLower(v.SQL) sql := strings.ToUpper(v.SQL)
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
} }
} }
@ -110,10 +115,10 @@ type AndConditions struct {
func (and AndConditions) Build(builder Builder) { func (and AndConditions) Build(builder Builder) {
if len(and.Exprs) > 1 { if len(and.Exprs) > 1 {
builder.WriteByte('(') builder.WriteByte('(')
buildExprs(and.Exprs, builder, " AND ") buildExprs(and.Exprs, builder, AndWithSpace)
builder.WriteByte(')') builder.WriteByte(')')
} else { } else {
buildExprs(and.Exprs, builder, " AND ") buildExprs(and.Exprs, builder, AndWithSpace)
} }
} }
@ -131,10 +136,10 @@ type OrConditions struct {
func (or OrConditions) Build(builder Builder) { func (or OrConditions) Build(builder Builder) {
if len(or.Exprs) > 1 { if len(or.Exprs) > 1 {
builder.WriteByte('(') builder.WriteByte('(')
buildExprs(or.Exprs, builder, " OR ") buildExprs(or.Exprs, builder, OrWithSpace)
builder.WriteByte(')') builder.WriteByte(')')
} else { } else {
buildExprs(or.Exprs, builder, " OR ") buildExprs(or.Exprs, builder, OrWithSpace)
} }
} }
@ -156,7 +161,7 @@ func (not NotConditions) Build(builder Builder) {
for idx, c := range not.Exprs { for idx, c := range not.Exprs {
if idx > 0 { if idx > 0 {
builder.WriteString(" AND ") builder.WriteString(AndWithSpace)
} }
if negationBuilder, ok := c.(NegationExpressionBuilder); ok { if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
@ -165,8 +170,8 @@ func (not NotConditions) Build(builder Builder) {
builder.WriteString("NOT ") builder.WriteString("NOT ")
e, wrapInParentheses := c.(Expr) e, wrapInParentheses := c.(Expr)
if wrapInParentheses { if wrapInParentheses {
sql := strings.ToLower(e.SQL) sql := strings.ToUpper(e.SQL)
if wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or"); wrapInParentheses { if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses {
builder.WriteByte('(') builder.WriteByte('(')
} }
} }

View File

@ -66,6 +66,45 @@ func TestWhere(t *testing.T) {
"SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)", "SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)",
[]interface{}{18, "jinzhu"}, []interface{}{18, "jinzhu"},
}, },
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?",
[]interface{}{"1", 18, 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?",
[]interface{}{"1", 18, 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Or(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `score` <= ?",
[]interface{}{"1", 18, 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{
clause.And(clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}),
clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})),
},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `score` <= ?)",
[]interface{}{"1", 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"},
clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}))},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND NOT `score` <= ?)",
[]interface{}{"1", 100},
},
} }
for idx, result := range results { for idx, result := range results {