fix cond in scopes (#6152)

* fix cond in scopes

* replace quote

* fix execute scopes
This commit is contained in:
black-06 2023-04-11 12:01:23 +08:00 committed by GitHub
parent ccc3cb758a
commit 4b0da0e97a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 90 additions and 15 deletions

View File

@ -75,11 +75,7 @@ func (cs *callbacks) Raw() *processor {
func (p *processor) Execute(db *DB) *DB { func (p *processor) Execute(db *DB) *DB {
// call scopes // call scopes
for len(db.Statement.scopes) > 0 { for len(db.Statement.scopes) > 0 {
scopes := db.Statement.scopes db = db.executeScopes()
db.Statement.scopes = nil
for _, scope := range scopes {
db = scope(db)
}
} }
var ( var (

View File

@ -366,6 +366,36 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
return tx return tx
} }
func (db *DB) executeScopes() (tx *DB) {
tx = db.getInstance()
scopes := db.Statement.scopes
if len(scopes) == 0 {
return tx
}
tx.Statement.scopes = nil
conditions := make([]clause.Interface, 0, 4)
if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil {
conditions = append(conditions, cs.Expression.(clause.Interface))
cs.Expression = nil
tx.Statement.Clauses["WHERE"] = cs
}
for _, scope := range scopes {
tx = scope(tx)
if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil {
conditions = append(conditions, cs.Expression.(clause.Interface))
cs.Expression = nil
tx.Statement.Clauses["WHERE"] = cs
}
}
for _, condition := range conditions {
tx.Statement.AddClause(condition)
}
return tx
}
// Preload preload associations with given conditions // Preload preload associations with given conditions
// //
// // get all users, and preload all non-cancelled orders // // get all users, and preload all non-cancelled orders

View File

@ -13,11 +13,7 @@ func (db *DB) Migrator() Migrator {
// apply scopes to migrator // apply scopes to migrator
for len(tx.Statement.scopes) > 0 { for len(tx.Statement.scopes) > 0 {
scopes := tx.Statement.scopes tx = tx.executeScopes()
tx.Statement.scopes = nil
for _, scope := range scopes {
tx = scope(tx)
}
} }
return tx.Dialector.Migrator(tx.Session(&Session{})) return tx.Dialector.Migrator(tx.Session(&Session{}))

View File

@ -324,11 +324,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
case clause.Expression: case clause.Expression:
conds = append(conds, v) conds = append(conds, v)
case *DB: case *DB:
for _, scope := range v.Statement.scopes { v.executeScopes()
v = scope(v)
}
if cs, ok := v.Statement.Clauses["WHERE"]; ok { if cs, ok := v.Statement.Clauses["WHERE"]; ok && cs.Expression != nil {
if where, ok := cs.Expression.(clause.Where); ok { if where, ok := cs.Expression.(clause.Where); ok {
if len(where.Exprs) == 1 { if len(where.Exprs) == 1 {
if orConds, ok := where.Exprs[0].(clause.OrConditions); ok { if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
@ -336,9 +334,13 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
} }
} }
conds = append(conds, clause.And(where.Exprs...)) conds = append(conds, clause.And(where.Exprs...))
} else if cs.Expression != nil { } else {
conds = append(conds, cs.Expression) conds = append(conds, cs.Expression)
} }
if v.Statement == stmt {
cs.Expression = nil
stmt.Statement.Clauses["WHERE"] = cs
}
} }
case map[interface{}]interface{}: case map[interface{}]interface{}:
for i, j := range v { for i, j := range v {

View File

@ -72,3 +72,54 @@ func TestScopes(t *testing.T) {
t.Errorf("select max(id)") t.Errorf("select max(id)")
} }
} }
func TestComplexScopes(t *testing.T) {
tests := []struct {
name string
queryFn func(tx *gorm.DB) *gorm.DB
expected string
}{
{
name: "depth_1",
queryFn: func(tx *gorm.DB) *gorm.DB {
return tx.Scopes(
func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
func(d *gorm.DB) *gorm.DB { return d.Where(d.Or("b = 2").Or("c = 3")) },
).Find(&Language{})
},
expected: `SELECT * FROM "languages" WHERE a = 1 AND (b = 2 OR c = 3)`,
}, {
name: "depth_1_pre_cond",
queryFn: func(tx *gorm.DB) *gorm.DB {
return tx.Where("z = 0").Scopes(
func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
func(d *gorm.DB) *gorm.DB { return d.Or(d.Where("b = 2").Or("c = 3")) },
).Find(&Language{})
},
expected: `SELECT * FROM "languages" WHERE z = 0 AND a = 1 OR (b = 2 OR c = 3)`,
}, {
name: "depth_2",
queryFn: func(tx *gorm.DB) *gorm.DB {
return tx.Scopes(
func(d *gorm.DB) *gorm.DB { return d.Model(&Language{}) },
func(d *gorm.DB) *gorm.DB {
return d.
Or(d.Scopes(
func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
func(d *gorm.DB) *gorm.DB { return d.Where("b = 2") },
)).
Or("c = 3")
},
func(d *gorm.DB) *gorm.DB { return d.Where("d = 4") },
).Find(&Language{})
},
expected: `SELECT * FROM "languages" WHERE d = 4 OR c = 3 OR (a = 1 AND b = 2)`,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
assertEqualSQL(t, test.expected, DB.ToSQL(test.queryFn))
})
}
}