diff --git a/clause/where.go b/clause/where.go index a3774e1c..00b1a40e 100644 --- a/clause/where.go +++ b/clause/where.go @@ -26,17 +26,22 @@ func (where Where) Build(builder Builder) { } } + buildExprs(where.Exprs, builder, " AND ") +} + +func buildExprs(exprs []Expression, builder Builder, joinCond string) { wrapInParentheses := false - for idx, expr := range where.Exprs { + + for idx, expr := range exprs { if idx > 0 { if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { builder.WriteString(" OR ") } else { - builder.WriteString(" AND ") + builder.WriteString(joinCond) } } - if len(where.Exprs) > 1 { + if len(exprs) > 1 { switch v := expr.(type) { case OrConditions: if len(v.Exprs) == 1 { @@ -97,19 +102,10 @@ type AndConditions struct { func (and AndConditions) Build(builder Builder) { if len(and.Exprs) > 1 { builder.WriteByte('(') - } - for idx, c := range and.Exprs { - if idx > 0 { - if orConditions, ok := c.(OrConditions); ok && len(orConditions.Exprs) == 1 { - builder.WriteString(" OR ") - } else { - builder.WriteString(" AND ") - } - } - c.Build(builder) - } - if len(and.Exprs) > 1 { + buildExprs(and.Exprs, builder, " AND ") builder.WriteByte(')') + } else { + buildExprs(and.Exprs, builder, " AND ") } } @@ -127,15 +123,10 @@ type OrConditions struct { func (or OrConditions) Build(builder Builder) { if len(or.Exprs) > 1 { builder.WriteByte('(') - } - for idx, c := range or.Exprs { - if idx > 0 { - builder.WriteString(" OR ") - } - c.Build(builder) - } - if len(or.Exprs) > 1 { + buildExprs(or.Exprs, builder, " OR ") builder.WriteByte(')') + } else { + buildExprs(or.Exprs, builder, " OR ") } } diff --git a/finisher_api.go b/finisher_api.go index 63061553..2951fdef 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -154,6 +154,8 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) } } + } else if andCond, ok := expr.(clause.AndConditions); ok { + tx.assignInterfacesToValue(andCond.Exprs) } } case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}: diff --git a/soft_delete.go b/soft_delete.go index b15a8148..b3280ff7 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -57,6 +57,19 @@ func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) { func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) { if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok { + if c, ok := stmt.Clauses["WHERE"]; ok { + if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) > 1 { + for _, expr := range where.Exprs { + if orCond, ok := expr.(clause.OrConditions); ok && len(orCond.Exprs) == 1 { + where.Exprs = []clause.Expression{clause.And(where.Exprs...)} + c.Expression = where + stmt.Clauses["WHERE"] = c + break + } + } + } + } + stmt.AddClause(clause.Where{Exprs: []clause.Expression{ clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil}, }}) diff --git a/tests/count_test.go b/tests/count_test.go index 216fa3a1..0d348227 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -69,7 +69,7 @@ func TestCount(t *testing.T) { } var count4 int64 - if err := DB.Debug().Table("users").Joins("LEFT JOIN companies on companies.name = users.name").Where("users.name = ?", user1.Name).Count(&count4).Error; err != nil || count4 != 1 { + if err := DB.Table("users").Joins("LEFT JOIN companies on companies.name = users.name").Where("users.name = ?", user1.Name).Count(&count4).Error; err != nil || count4 != 1 { t.Errorf("count with join, got error: %v, count %v", err, count) } } diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index c0176fc3..acb08130 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -198,17 +198,17 @@ func TestCombineStringConditions(t *testing.T) { } sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Find(&User{}).Statement.SQL.String() - if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR \(c = .+ and d = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + if !regexp.MustCompile(`WHERE \(\(a = .+ or b = .+\) OR \(c = .+ and d = .+\)\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { t.Fatalf("invalid sql generated, got %v", sql) } sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ?", "c").Find(&User{}).Statement.SQL.String() - if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR c = .+ AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + if !regexp.MustCompile(`WHERE \(\(a = .+ or b = .+\) OR c = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { t.Fatalf("invalid sql generated, got %v", sql) } sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Or("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String() - if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR \(c = .+ and d = .+\) OR \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + if !regexp.MustCompile(`WHERE \(\(a = .+ or b = .+\) OR \(c = .+ and d = .+\) OR \(e = .+ and f = .+\)\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { t.Fatalf("invalid sql generated, got %v", sql) }