diff --git a/callbacks.go b/callbacks.go index 315eea17..f2ee0ea5 100644 --- a/callbacks.go +++ b/callbacks.go @@ -77,12 +77,23 @@ func (p *processor) Execute(db *DB) { stmt = db.Statement ) + // call scopes + for len(stmt.scopes) > 0 { + scopes := stmt.scopes + stmt.scopes = nil + for _, scope := range scopes { + db = scope(db) + } + } + + // assign model values if stmt.Model == nil { stmt.Model = stmt.Dest } else if stmt.Dest == nil { stmt.Dest = stmt.Model } + // parse model values if stmt.Model != nil { if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) { if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" { @@ -93,6 +104,7 @@ func (p *processor) Execute(db *DB) { } } + // assign stmt.ReflectValue if stmt.Dest != nil { stmt.ReflectValue = reflect.ValueOf(stmt.Dest) for stmt.ReflectValue.Kind() == reflect.Ptr { @@ -108,15 +120,6 @@ func (p *processor) Execute(db *DB) { } } - // call scopes - for len(stmt.scopes) > 0 { - scopes := stmt.scopes - stmt.scopes = nil - for _, scope := range scopes { - db = scope(db) - } - } - for _, f := range p.fns { f(db) } diff --git a/chainable_api.go b/chainable_api.go index 12db6830..e17d9bb2 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -103,7 +103,7 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { Distinct: db.Statement.Distinct, Expression: clause.NamedExpr{SQL: v, Vars: args}, }) - } else { + } else { tx.Statement.Selects = []string{v} for _, arg := range args { diff --git a/tests/count_test.go b/tests/count_test.go index ffe675d9..0fef82f7 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -121,4 +121,12 @@ func TestCount(t *testing.T) { }) AssertEqual(t, users, expects) + + var count9 int64 + if err := DB.Debug().Scopes(func(tx *gorm.DB) *gorm.DB { + fmt.Println("kdkdkdkdk") + return tx.Table("users") + }).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Count(&count9).Find(&users).Error; err != nil || count9 != 3 { + t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err)) + } } diff --git a/tests/go.mod b/tests/go.mod index 7743e63a..d4b0c975 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -10,8 +10,8 @@ require ( gorm.io/driver/mysql v1.0.5 gorm.io/driver/postgres v1.0.8 gorm.io/driver/sqlite v1.1.4 - gorm.io/driver/sqlserver v1.0.6 - gorm.io/gorm v1.21.3 + gorm.io/driver/sqlserver v1.0.7 + gorm.io/gorm v1.21.4 ) replace gorm.io/gorm => ../