From 4b0da0e97a15979820790dd14023f47acc1848d0 Mon Sep 17 00:00:00 2001 From: black-06 Date: Tue, 11 Apr 2023 12:01:23 +0800 Subject: [PATCH] fix cond in scopes (#6152) * fix cond in scopes * replace quote * fix execute scopes --- callbacks.go | 6 +----- chainable_api.go | 30 ++++++++++++++++++++++++++ migrator.go | 6 +----- statement.go | 12 ++++++----- tests/scopes_test.go | 51 ++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 90 insertions(+), 15 deletions(-) diff --git a/callbacks.go b/callbacks.go index de979e45..ca6b6d50 100644 --- a/callbacks.go +++ b/callbacks.go @@ -75,11 +75,7 @@ func (cs *callbacks) Raw() *processor { func (p *processor) Execute(db *DB) *DB { // call scopes for len(db.Statement.scopes) > 0 { - scopes := db.Statement.scopes - db.Statement.scopes = nil - for _, scope := range scopes { - db = scope(db) - } + db = db.executeScopes() } var ( diff --git a/chainable_api.go b/chainable_api.go index a85235e0..19d405cc 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -366,6 +366,36 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { return tx } +func (db *DB) executeScopes() (tx *DB) { + tx = db.getInstance() + scopes := db.Statement.scopes + if len(scopes) == 0 { + return tx + } + tx.Statement.scopes = nil + + conditions := make([]clause.Interface, 0, 4) + if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil { + conditions = append(conditions, cs.Expression.(clause.Interface)) + cs.Expression = nil + tx.Statement.Clauses["WHERE"] = cs + } + + for _, scope := range scopes { + tx = scope(tx) + if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil { + conditions = append(conditions, cs.Expression.(clause.Interface)) + cs.Expression = nil + tx.Statement.Clauses["WHERE"] = cs + } + } + + for _, condition := range conditions { + tx.Statement.AddClause(condition) + } + return tx +} + // Preload preload associations with given conditions // // // get all users, and preload all non-cancelled orders diff --git a/migrator.go b/migrator.go index 9c7cc2c4..037afc35 100644 --- a/migrator.go +++ b/migrator.go @@ -13,11 +13,7 @@ func (db *DB) Migrator() Migrator { // apply scopes to migrator for len(tx.Statement.scopes) > 0 { - scopes := tx.Statement.scopes - tx.Statement.scopes = nil - for _, scope := range scopes { - tx = scope(tx) - } + tx = tx.executeScopes() } return tx.Dialector.Migrator(tx.Session(&Session{})) diff --git a/statement.go b/statement.go index bc959f0b..59c0b772 100644 --- a/statement.go +++ b/statement.go @@ -324,11 +324,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] case clause.Expression: conds = append(conds, v) case *DB: - for _, scope := range v.Statement.scopes { - v = scope(v) - } + v.executeScopes() - if cs, ok := v.Statement.Clauses["WHERE"]; ok { + if cs, ok := v.Statement.Clauses["WHERE"]; ok && cs.Expression != nil { if where, ok := cs.Expression.(clause.Where); ok { if len(where.Exprs) == 1 { if orConds, ok := where.Exprs[0].(clause.OrConditions); ok { @@ -336,9 +334,13 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } } conds = append(conds, clause.And(where.Exprs...)) - } else if cs.Expression != nil { + } else { conds = append(conds, cs.Expression) } + if v.Statement == stmt { + cs.Expression = nil + stmt.Statement.Clauses["WHERE"] = cs + } } case map[interface{}]interface{}: for i, j := range v { diff --git a/tests/scopes_test.go b/tests/scopes_test.go index ab3807ea..52c6b37b 100644 --- a/tests/scopes_test.go +++ b/tests/scopes_test.go @@ -72,3 +72,54 @@ func TestScopes(t *testing.T) { t.Errorf("select max(id)") } } + +func TestComplexScopes(t *testing.T) { + tests := []struct { + name string + queryFn func(tx *gorm.DB) *gorm.DB + expected string + }{ + { + name: "depth_1", + queryFn: func(tx *gorm.DB) *gorm.DB { + return tx.Scopes( + func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") }, + func(d *gorm.DB) *gorm.DB { return d.Where(d.Or("b = 2").Or("c = 3")) }, + ).Find(&Language{}) + }, + expected: `SELECT * FROM "languages" WHERE a = 1 AND (b = 2 OR c = 3)`, + }, { + name: "depth_1_pre_cond", + queryFn: func(tx *gorm.DB) *gorm.DB { + return tx.Where("z = 0").Scopes( + func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") }, + func(d *gorm.DB) *gorm.DB { return d.Or(d.Where("b = 2").Or("c = 3")) }, + ).Find(&Language{}) + }, + expected: `SELECT * FROM "languages" WHERE z = 0 AND a = 1 OR (b = 2 OR c = 3)`, + }, { + name: "depth_2", + queryFn: func(tx *gorm.DB) *gorm.DB { + return tx.Scopes( + func(d *gorm.DB) *gorm.DB { return d.Model(&Language{}) }, + func(d *gorm.DB) *gorm.DB { + return d. + Or(d.Scopes( + func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") }, + func(d *gorm.DB) *gorm.DB { return d.Where("b = 2") }, + )). + Or("c = 3") + }, + func(d *gorm.DB) *gorm.DB { return d.Where("d = 4") }, + ).Find(&Language{}) + }, + expected: `SELECT * FROM "languages" WHERE d = 4 OR c = 3 OR (a = 1 AND b = 2)`, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + assertEqualSQL(t, test.expected, DB.ToSQL(test.queryFn)) + }) + } +}