From 3e2c4fc446f0a601124771d012ee167cbb0d7de0 Mon Sep 17 00:00:00 2001 From: tsuba3 Date: Tue, 5 Mar 2024 11:23:51 +0900 Subject: [PATCH] Fix regression in db.Not introduced in v1.25.6. (#6844) * Fix regression in db.Not introduced in 940358e. * Fix --- clause/where.go | 73 +++++++++++++++++++++++++++++++++----------- clause/where_test.go | 8 +++++ tests/query_test.go | 5 +++ 3 files changed, 69 insertions(+), 17 deletions(-) diff --git a/clause/where.go b/clause/where.go index 46d0b319..9ac78578 100644 --- a/clause/where.go +++ b/clause/where.go @@ -21,11 +21,11 @@ func (where Where) Name() string { // Build build where clause func (where Where) Build(builder Builder) { - if len(where.Exprs) == 1 { - if andCondition, ok := where.Exprs[0].(AndConditions); ok { - where.Exprs = andCondition.Exprs - } - } + if len(where.Exprs) == 1 { + if andCondition, ok := where.Exprs[0].(AndConditions); ok { + where.Exprs = andCondition.Exprs + } + } // Switch position if the first query expression is a single Or condition for idx, expr := range where.Exprs { @@ -166,19 +166,58 @@ type NotConditions struct { } func (not NotConditions) Build(builder Builder) { - if len(not.Exprs) > 1 { - builder.WriteByte('(') + anyNegationBuilder := false + for _, c := range not.Exprs { + if _, ok := c.(NegationExpressionBuilder); ok { + anyNegationBuilder = true + break + } } - for idx, c := range not.Exprs { - if idx > 0 { - builder.WriteString(AndWithSpace) + if anyNegationBuilder { + if len(not.Exprs) > 1 { + builder.WriteByte('(') } - if negationBuilder, ok := c.(NegationExpressionBuilder); ok { - negationBuilder.NegationBuild(builder) - } else { - builder.WriteString("NOT ") + for idx, c := range not.Exprs { + if idx > 0 { + builder.WriteString(AndWithSpace) + } + + if negationBuilder, ok := c.(NegationExpressionBuilder); ok { + negationBuilder.NegationBuild(builder) + } else { + builder.WriteString("NOT ") + e, wrapInParentheses := c.(Expr) + if wrapInParentheses { + sql := strings.ToUpper(e.SQL) + if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses { + builder.WriteByte('(') + } + } + + c.Build(builder) + + if wrapInParentheses { + builder.WriteByte(')') + } + } + } + + if len(not.Exprs) > 1 { + builder.WriteByte(')') + } + } else { + builder.WriteString("NOT ") + if len(not.Exprs) > 1 { + builder.WriteByte('(') + } + + for idx, c := range not.Exprs { + if idx > 0 { + builder.WriteString(AndWithSpace) + } + e, wrapInParentheses := c.(Expr) if wrapInParentheses { sql := strings.ToUpper(e.SQL) @@ -193,9 +232,9 @@ func (not NotConditions) Build(builder Builder) { builder.WriteByte(')') } } - } - if len(not.Exprs) > 1 { - builder.WriteByte(')') + if len(not.Exprs) > 1 { + builder.WriteByte(')') + } } } diff --git a/clause/where_test.go b/clause/where_test.go index aa9d06eb..7d5aca1f 100644 --- a/clause/where_test.go +++ b/clause/where_test.go @@ -105,6 +105,14 @@ func TestWhere(t *testing.T) { "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND NOT `score` <= ?)", []interface{}{"1", 100}, }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}}, + clause.Expr{SQL: "`age` <= ?", Vars: []interface{}{60}})}, + }}, + "SELECT * FROM `users` WHERE NOT (`score` <= ? AND `age` <= ?)", + []interface{}{100, 60}, + }, } for idx, result := range results { diff --git a/tests/query_test.go b/tests/query_test.go index cadf7164..e780e3bf 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -554,6 +554,11 @@ func TestNot(t *testing.T) { if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*users.*..*name.* <> .+ AND .*users.*..*age.* <> .+").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } + + result = dryDB.Not(DB.Where("manager IS NULL").Where("age >= ?", 20)).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE NOT \\(manager IS NULL AND age >= .+\\) AND .users.\\..deleted_at. IS NULL").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } } func TestNotWithAllFields(t *testing.T) {