diff --git a/clause/where.go b/clause/where.go index 6399a2d5..f7cd3318 100644 --- a/clause/where.go +++ b/clause/where.go @@ -66,7 +66,11 @@ func (and AndConditions) Build(builder Builder) { } for idx, c := range and.Exprs { if idx > 0 { - builder.WriteString(" AND ") + if orConditions, ok := c.(OrConditions); ok && len(orConditions.Exprs) == 1 { + builder.WriteString(" OR ") + } else { + builder.WriteString(" AND ") + } } c.Build(builder) } diff --git a/clause/where_test.go b/clause/where_test.go index 95bba820..2fa11d76 100644 --- a/clause/where_test.go +++ b/clause/where_test.go @@ -53,6 +53,12 @@ func TestWhere(t *testing.T) { }}, "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `name` <> ? AND (`score` <= ? OR `name` LIKE ?)", []interface{}{"1", 18, "jinzhu", 100, "%linus%"}, }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.And(clause.Eq{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}))}, + }}, + "SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)", []interface{}{18, "jinzhu"}, + }, } for idx, result := range results { diff --git a/statement.go b/statement.go index e3c882ee..7cc01bb8 100644 --- a/statement.go +++ b/statement.go @@ -245,6 +245,14 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c switch v := arg.(type) { case clause.Expression: conds = append(conds, v) + case *DB: + if cs, ok := v.Statement.Clauses["WHERE"]; ok { + if where, ok := cs.Expression.(clause.Where); ok { + conds = append(conds, clause.And(where.Exprs...)) + } else if cs.Expression != nil { + conds = append(conds, cs.Expression) + } + } case map[interface{}]interface{}: for i, j := range v { conds = append(conds, clause.Eq{Column: i, Value: j}) diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index a60514c9..b78c2484 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "strings" "testing" "gorm.io/gorm" @@ -138,3 +139,27 @@ func TestDryRun(t *testing.T) { t.Errorf("Failed to generate sql, got %v", stmt2.SQL.String()) } } + +func TestGroupConditions(t *testing.T) { + type Pizza struct { + ID uint + Name string + Size string + } + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + + stmt := dryRunDB.Where( + DB.Where("pizza = ?", "pepperoni").Where(DB.Where("size = ?", "small").Or("size = ?", "medium")), + ).Or( + DB.Where("pizza = ?", "hawaiian").Where("size = ?", "xlarge"), + ).Find(&Pizza{}).Statement + + execStmt := dryRunDB.Exec("WHERE (pizza = ? AND (size = ? OR size = ?)) OR (pizza = ? AND size = ?)", "pepperoni", "small", "medium", "hawaiian", "xlarge").Statement + + result := DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + expects := DB.Dialector.Explain(execStmt.SQL.String(), execStmt.Vars...) + + if !strings.HasSuffix(result, expects) { + t.Errorf("expects: %v, got %v", expects, result) + } +}