mirror of https://github.com/go-gorm/gorm.git
Fix soft delete with OrCondition, close #3627
This commit is contained in:
parent
9dbef26feb
commit
9b2181199d
|
@ -26,17 +26,22 @@ func (where Where) Build(builder Builder) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
buildExprs(where.Exprs, builder, " AND ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildExprs(exprs []Expression, builder Builder, joinCond string) {
|
||||||
wrapInParentheses := false
|
wrapInParentheses := false
|
||||||
for idx, expr := range where.Exprs {
|
|
||||||
|
for idx, expr := range exprs {
|
||||||
if idx > 0 {
|
if idx > 0 {
|
||||||
if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 {
|
if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 {
|
||||||
builder.WriteString(" OR ")
|
builder.WriteString(" OR ")
|
||||||
} else {
|
} else {
|
||||||
builder.WriteString(" AND ")
|
builder.WriteString(joinCond)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(where.Exprs) > 1 {
|
if len(exprs) > 1 {
|
||||||
switch v := expr.(type) {
|
switch v := expr.(type) {
|
||||||
case OrConditions:
|
case OrConditions:
|
||||||
if len(v.Exprs) == 1 {
|
if len(v.Exprs) == 1 {
|
||||||
|
@ -97,19 +102,10 @@ type AndConditions struct {
|
||||||
func (and AndConditions) Build(builder Builder) {
|
func (and AndConditions) Build(builder Builder) {
|
||||||
if len(and.Exprs) > 1 {
|
if len(and.Exprs) > 1 {
|
||||||
builder.WriteByte('(')
|
builder.WriteByte('(')
|
||||||
}
|
buildExprs(and.Exprs, builder, " AND ")
|
||||||
for idx, c := range and.Exprs {
|
|
||||||
if idx > 0 {
|
|
||||||
if orConditions, ok := c.(OrConditions); ok && len(orConditions.Exprs) == 1 {
|
|
||||||
builder.WriteString(" OR ")
|
|
||||||
} else {
|
|
||||||
builder.WriteString(" AND ")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.Build(builder)
|
|
||||||
}
|
|
||||||
if len(and.Exprs) > 1 {
|
|
||||||
builder.WriteByte(')')
|
builder.WriteByte(')')
|
||||||
|
} else {
|
||||||
|
buildExprs(and.Exprs, builder, " AND ")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -127,15 +123,10 @@ type OrConditions struct {
|
||||||
func (or OrConditions) Build(builder Builder) {
|
func (or OrConditions) Build(builder Builder) {
|
||||||
if len(or.Exprs) > 1 {
|
if len(or.Exprs) > 1 {
|
||||||
builder.WriteByte('(')
|
builder.WriteByte('(')
|
||||||
}
|
buildExprs(or.Exprs, builder, " OR ")
|
||||||
for idx, c := range or.Exprs {
|
|
||||||
if idx > 0 {
|
|
||||||
builder.WriteString(" OR ")
|
|
||||||
}
|
|
||||||
c.Build(builder)
|
|
||||||
}
|
|
||||||
if len(or.Exprs) > 1 {
|
|
||||||
builder.WriteByte(')')
|
builder.WriteByte(')')
|
||||||
|
} else {
|
||||||
|
buildExprs(or.Exprs, builder, " OR ")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -154,6 +154,8 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) {
|
||||||
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
|
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if andCond, ok := expr.(clause.AndConditions); ok {
|
||||||
|
tx.assignInterfacesToValue(andCond.Exprs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
|
case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
|
||||||
|
|
|
@ -57,6 +57,19 @@ func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) {
|
||||||
|
|
||||||
func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
|
func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
|
||||||
if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok {
|
if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok {
|
||||||
|
if c, ok := stmt.Clauses["WHERE"]; ok {
|
||||||
|
if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) > 1 {
|
||||||
|
for _, expr := range where.Exprs {
|
||||||
|
if orCond, ok := expr.(clause.OrConditions); ok && len(orCond.Exprs) == 1 {
|
||||||
|
where.Exprs = []clause.Expression{clause.And(where.Exprs...)}
|
||||||
|
c.Expression = where
|
||||||
|
stmt.Clauses["WHERE"] = c
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{
|
stmt.AddClause(clause.Where{Exprs: []clause.Expression{
|
||||||
clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil},
|
clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil},
|
||||||
}})
|
}})
|
||||||
|
|
|
@ -69,7 +69,7 @@ func TestCount(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var count4 int64
|
var count4 int64
|
||||||
if err := DB.Debug().Table("users").Joins("LEFT JOIN companies on companies.name = users.name").Where("users.name = ?", user1.Name).Count(&count4).Error; err != nil || count4 != 1 {
|
if err := DB.Table("users").Joins("LEFT JOIN companies on companies.name = users.name").Where("users.name = ?", user1.Name).Count(&count4).Error; err != nil || count4 != 1 {
|
||||||
t.Errorf("count with join, got error: %v, count %v", err, count)
|
t.Errorf("count with join, got error: %v, count %v", err, count)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -198,17 +198,17 @@ func TestCombineStringConditions(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Find(&User{}).Statement.SQL.String()
|
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Find(&User{}).Statement.SQL.String()
|
||||||
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR \(c = .+ and d = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
if !regexp.MustCompile(`WHERE \(\(a = .+ or b = .+\) OR \(c = .+ and d = .+\)\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||||
t.Fatalf("invalid sql generated, got %v", sql)
|
t.Fatalf("invalid sql generated, got %v", sql)
|
||||||
}
|
}
|
||||||
|
|
||||||
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ?", "c").Find(&User{}).Statement.SQL.String()
|
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ?", "c").Find(&User{}).Statement.SQL.String()
|
||||||
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR c = .+ AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
if !regexp.MustCompile(`WHERE \(\(a = .+ or b = .+\) OR c = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||||
t.Fatalf("invalid sql generated, got %v", sql)
|
t.Fatalf("invalid sql generated, got %v", sql)
|
||||||
}
|
}
|
||||||
|
|
||||||
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Or("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String()
|
sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Or("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String()
|
||||||
if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR \(c = .+ and d = .+\) OR \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
if !regexp.MustCompile(`WHERE \(\(a = .+ or b = .+\) OR \(c = .+ and d = .+\) OR \(e = .+ and f = .+\)\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||||
t.Fatalf("invalid sql generated, got %v", sql)
|
t.Fatalf("invalid sql generated, got %v", sql)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue