Refactor Execute callbacks

This commit is contained in:
Jinzhu 2020-06-08 22:32:35 +08:00
parent 9f19378304
commit 4555796b62
2 changed files with 31 additions and 35 deletions

View File

@ -73,26 +73,26 @@ func (cs *callbacks) Raw() *processor {
func (p *processor) Execute(db *DB) { func (p *processor) Execute(db *DB) {
curTime := time.Now() curTime := time.Now()
stmt := db.Statement
db.RowsAffected = 0 db.RowsAffected = 0
if stmt := db.Statement; stmt != nil {
if stmt.Model == nil {
stmt.Model = stmt.Dest
}
if stmt.Model != nil { if stmt.Model == nil {
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) { stmt.Model = stmt.Dest
db.AddError(err) }
}
}
if stmt.Dest != nil { if stmt.Model != nil {
stmt.ReflectValue = reflect.Indirect(reflect.ValueOf(stmt.Dest)) if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) {
for stmt.ReflectValue.Kind() == reflect.Ptr { db.AddError(err)
stmt.ReflectValue = stmt.ReflectValue.Elem() }
} }
if !stmt.ReflectValue.IsValid() {
db.AddError(fmt.Errorf("invalid value")) if stmt.Dest != nil {
} stmt.ReflectValue = reflect.ValueOf(stmt.Dest)
for stmt.ReflectValue.Kind() == reflect.Ptr {
stmt.ReflectValue = stmt.ReflectValue.Elem()
}
if !stmt.ReflectValue.IsValid() {
db.AddError(fmt.Errorf("invalid value"))
} }
} }
@ -100,16 +100,14 @@ func (p *processor) Execute(db *DB) {
f(db) f(db)
} }
if stmt := db.Statement; stmt != nil { db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
db.Logger.Trace(stmt.Context, curTime, func() (string, int64) { return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected }, db.Error)
}, db.Error)
if !stmt.DB.DryRun { if !stmt.DB.DryRun {
stmt.SQL.Reset() stmt.SQL.Reset()
stmt.Vars = nil stmt.Vars = nil
stmt.NamedVars = nil stmt.NamedVars = nil
}
} }
} }

View File

@ -51,7 +51,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
// First find first record that match given conditions, order by primary key // First find first record that match given conditions, order by primary key
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ tx = db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
}) })
if len(conds) > 0 { if len(conds) > 0 {
@ -65,7 +65,7 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
// Take return a record that match given conditions, the order will depend on the database implementation // Take return a record that match given conditions, the order will depend on the database implementation
func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance().Limit(1) tx = db.Limit(1)
if len(conds) > 0 { if len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)}) tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)})
} }
@ -77,7 +77,7 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
// Last find last record that match given conditions, order by primary key // Last find last record that match given conditions, order by primary key
func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ tx = db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
Desc: true, Desc: true,
}) })
@ -120,8 +120,7 @@ func (tx *DB) assignExprsToValue(exprs []clause.Expression) {
} }
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance() if tx = db.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) {
if tx = tx.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) {
if c, ok := tx.Statement.Clauses["WHERE"]; ok { if c, ok := tx.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok { if where, ok := c.Expression.(clause.Where); ok {
tx.assignExprsToValue(where.Exprs) tx.assignExprsToValue(where.Exprs)
@ -145,8 +144,7 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
} }
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance() if tx = db.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) {
if err := tx.First(dest, conds...).Error; errors.Is(err, ErrRecordNotFound) {
tx.Error = nil tx.Error = nil
if c, ok := tx.Statement.Clauses["WHERE"]; ok { if c, ok := tx.Statement.Clauses["WHERE"]; ok {
@ -168,7 +166,7 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
} }
return tx.Create(dest) return tx.Create(dest)
} else if len(tx.Statement.assigns) > 0 { } else if len(db.Statement.assigns) > 0 {
exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...)
assigns := map[string]interface{}{} assigns := map[string]interface{}{}
for _, expr := range exprs { for _, expr := range exprs {
@ -186,7 +184,7 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
return tx.Model(dest).Updates(assigns) return tx.Model(dest).Updates(assigns)
} }
return return db
} }
// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update // Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update