forked from mirror/gorm
Fix create new db sessions in scopes
This commit is contained in:
parent
a480bd8545
commit
92c3ba9dcc
|
@ -72,7 +72,7 @@ func (cs *callbacks) Raw() *processor {
|
|||
return cs.processors["raw"]
|
||||
}
|
||||
|
||||
func (p *processor) Execute(db *DB) {
|
||||
func (p *processor) Execute(db *DB) *DB {
|
||||
// call scopes
|
||||
for len(db.Statement.scopes) > 0 {
|
||||
scopes := db.Statement.scopes
|
||||
|
@ -142,6 +142,8 @@ func (p *processor) Execute(db *DB) {
|
|||
if resetBuildClauses {
|
||||
stmt.BuildClauses = nil
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func (p *processor) Get(name string) func(*DB) {
|
||||
|
|
|
@ -21,8 +21,7 @@ func (db *DB) Create(value interface{}) (tx *DB) {
|
|||
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = value
|
||||
tx.callbacks.Create().Execute(tx)
|
||||
return
|
||||
return tx.callbacks.Create().Execute(tx)
|
||||
}
|
||||
|
||||
// CreateInBatches insert the value in batches into database
|
||||
|
@ -64,7 +63,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
|
|||
default:
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = value
|
||||
tx.callbacks.Create().Execute(tx)
|
||||
tx = tx.callbacks.Create().Execute(tx)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -80,13 +79,12 @@ func (db *DB) Save(value interface{}) (tx *DB) {
|
|||
if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
|
||||
tx = tx.Clauses(clause.OnConflict{UpdateAll: true})
|
||||
}
|
||||
tx.callbacks.Create().Execute(tx.InstanceSet("gorm:update_track_time", true))
|
||||
tx = tx.callbacks.Create().Execute(tx.InstanceSet("gorm:update_track_time", true))
|
||||
case reflect.Struct:
|
||||
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
|
||||
for _, pf := range tx.Statement.Schema.PrimaryFields {
|
||||
if _, isZero := pf.ValueOf(reflectValue); isZero {
|
||||
tx.callbacks.Create().Execute(tx)
|
||||
return
|
||||
return tx.callbacks.Create().Execute(tx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -99,7 +97,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
|
|||
tx.Statement.Selects = append(tx.Statement.Selects, "*")
|
||||
}
|
||||
|
||||
tx.callbacks.Update().Execute(tx)
|
||||
tx = tx.callbacks.Update().Execute(tx)
|
||||
|
||||
if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate {
|
||||
result := reflect.New(tx.Statement.Schema.ModelType).Interface()
|
||||
|
@ -124,8 +122,7 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||
}
|
||||
tx.Statement.RaiseErrorOnNotFound = true
|
||||
tx.Statement.Dest = dest
|
||||
tx.callbacks.Query().Execute(tx)
|
||||
return
|
||||
return tx.callbacks.Query().Execute(tx)
|
||||
}
|
||||
|
||||
// Take return a record that match given conditions, the order will depend on the database implementation
|
||||
|
@ -138,8 +135,7 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||
}
|
||||
tx.Statement.RaiseErrorOnNotFound = true
|
||||
tx.Statement.Dest = dest
|
||||
tx.callbacks.Query().Execute(tx)
|
||||
return
|
||||
return tx.callbacks.Query().Execute(tx)
|
||||
}
|
||||
|
||||
// Last find last record that match given conditions, order by primary key
|
||||
|
@ -155,8 +151,7 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||
}
|
||||
tx.Statement.RaiseErrorOnNotFound = true
|
||||
tx.Statement.Dest = dest
|
||||
tx.callbacks.Query().Execute(tx)
|
||||
return
|
||||
return tx.callbacks.Query().Execute(tx)
|
||||
}
|
||||
|
||||
// Find find records that match given conditions
|
||||
|
@ -168,8 +163,7 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||
}
|
||||
}
|
||||
tx.Statement.Dest = dest
|
||||
tx.callbacks.Query().Execute(tx)
|
||||
return
|
||||
return tx.callbacks.Query().Execute(tx)
|
||||
}
|
||||
|
||||
// FindInBatches find records in batches
|
||||
|
@ -334,32 +328,28 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||
func (db *DB) Update(column string, value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = map[string]interface{}{column: value}
|
||||
tx.callbacks.Update().Execute(tx)
|
||||
return
|
||||
return tx.callbacks.Update().Execute(tx)
|
||||
}
|
||||
|
||||
// Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
|
||||
func (db *DB) Updates(values interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = values
|
||||
tx.callbacks.Update().Execute(tx)
|
||||
return
|
||||
return tx.callbacks.Update().Execute(tx)
|
||||
}
|
||||
|
||||
func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = map[string]interface{}{column: value}
|
||||
tx.Statement.SkipHooks = true
|
||||
tx.callbacks.Update().Execute(tx)
|
||||
return
|
||||
return tx.callbacks.Update().Execute(tx)
|
||||
}
|
||||
|
||||
func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = values
|
||||
tx.Statement.SkipHooks = true
|
||||
tx.callbacks.Update().Execute(tx)
|
||||
return
|
||||
return tx.callbacks.Update().Execute(tx)
|
||||
}
|
||||
|
||||
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
|
||||
|
@ -371,8 +361,7 @@ func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
|
|||
}
|
||||
}
|
||||
tx.Statement.Dest = value
|
||||
tx.callbacks.Delete().Execute(tx)
|
||||
return
|
||||
return tx.callbacks.Delete().Execute(tx)
|
||||
}
|
||||
|
||||
func (db *DB) Count(count *int64) (tx *DB) {
|
||||
|
@ -428,7 +417,7 @@ func (db *DB) Count(count *int64) (tx *DB) {
|
|||
}
|
||||
|
||||
tx.Statement.Dest = count
|
||||
tx.callbacks.Query().Execute(tx)
|
||||
tx = tx.callbacks.Query().Execute(tx)
|
||||
if tx.RowsAffected != 1 {
|
||||
*count = tx.RowsAffected
|
||||
}
|
||||
|
@ -437,7 +426,7 @@ func (db *DB) Count(count *int64) (tx *DB) {
|
|||
|
||||
func (db *DB) Row() *sql.Row {
|
||||
tx := db.getInstance().InstanceSet("rows", false)
|
||||
tx.callbacks.Row().Execute(tx)
|
||||
tx = tx.callbacks.Row().Execute(tx)
|
||||
row, ok := tx.Statement.Dest.(*sql.Row)
|
||||
if !ok && tx.DryRun {
|
||||
db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error())
|
||||
|
@ -447,7 +436,7 @@ func (db *DB) Row() *sql.Row {
|
|||
|
||||
func (db *DB) Rows() (*sql.Rows, error) {
|
||||
tx := db.getInstance().InstanceSet("rows", true)
|
||||
tx.callbacks.Row().Execute(tx)
|
||||
tx = tx.callbacks.Row().Execute(tx)
|
||||
rows, ok := tx.Statement.Dest.(*sql.Rows)
|
||||
if !ok && tx.DryRun && tx.Error == nil {
|
||||
tx.Error = ErrDryRunModeUnsupported
|
||||
|
@ -505,8 +494,7 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
|
|||
})
|
||||
}
|
||||
tx.Statement.Dest = dest
|
||||
tx.callbacks.Query().Execute(tx)
|
||||
return
|
||||
return tx.callbacks.Query().Execute(tx)
|
||||
}
|
||||
|
||||
func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
|
||||
|
@ -644,6 +632,5 @@ func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
|
|||
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
|
||||
}
|
||||
|
||||
tx.callbacks.Raw().Execute(tx)
|
||||
return
|
||||
return tx.callbacks.Raw().Execute(tx)
|
||||
}
|
||||
|
|
|
@ -54,4 +54,12 @@ func TestScopes(t *testing.T) {
|
|||
if db.Find(&User{}).Statement.Table != "custom_table" {
|
||||
t.Errorf("failed to call Scopes")
|
||||
}
|
||||
|
||||
result := DB.Scopes(NameIn1And2, func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Session(&gorm.Session{})
|
||||
}).Find(&users1)
|
||||
|
||||
if result.RowsAffected != 2 {
|
||||
t.Errorf("Should found two users's name in 1, 2, but got %v", result.RowsAffected)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue