Call scopes before parse model value, close #4209

This commit is contained in:
Jinzhu 2021-03-24 16:17:49 +08:00
parent 4d5cec8bdd
commit 704e53a774
4 changed files with 23 additions and 12 deletions

View File

@ -77,12 +77,23 @@ func (p *processor) Execute(db *DB) {
stmt = db.Statement
)
// call scopes
for len(stmt.scopes) > 0 {
scopes := stmt.scopes
stmt.scopes = nil
for _, scope := range scopes {
db = scope(db)
}
}
// assign model values
if stmt.Model == nil {
stmt.Model = stmt.Dest
} else if stmt.Dest == nil {
stmt.Dest = stmt.Model
}
// parse model values
if stmt.Model != nil {
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) {
if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" {
@ -93,6 +104,7 @@ func (p *processor) Execute(db *DB) {
}
}
// assign stmt.ReflectValue
if stmt.Dest != nil {
stmt.ReflectValue = reflect.ValueOf(stmt.Dest)
for stmt.ReflectValue.Kind() == reflect.Ptr {
@ -108,15 +120,6 @@ func (p *processor) Execute(db *DB) {
}
}
// call scopes
for len(stmt.scopes) > 0 {
scopes := stmt.scopes
stmt.scopes = nil
for _, scope := range scopes {
db = scope(db)
}
}
for _, f := range p.fns {
f(db)
}

View File

@ -121,4 +121,12 @@ func TestCount(t *testing.T) {
})
AssertEqual(t, users, expects)
var count9 int64
if err := DB.Debug().Scopes(func(tx *gorm.DB) *gorm.DB {
fmt.Println("kdkdkdkdk")
return tx.Table("users")
}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Count(&count9).Find(&users).Error; err != nil || count9 != 3 {
t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err))
}
}

View File

@ -10,8 +10,8 @@ require (
gorm.io/driver/mysql v1.0.5
gorm.io/driver/postgres v1.0.8
gorm.io/driver/sqlite v1.1.4
gorm.io/driver/sqlserver v1.0.6
gorm.io/gorm v1.21.3
gorm.io/driver/sqlserver v1.0.7
gorm.io/gorm v1.21.4
)
replace gorm.io/gorm => ../