From 4555796b62fa679f3397d5201759e387f7d88a0c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 8 Jun 2020 22:32:35 +0800 Subject: [PATCH] Refactor Execute callbacks --- callbacks.go | 50 ++++++++++++++++++++++++------------------------- finisher_api.go | 16 +++++++--------- 2 files changed, 31 insertions(+), 35 deletions(-) diff --git a/callbacks.go b/callbacks.go index e6cf29af..5e7933af 100644 --- a/callbacks.go +++ b/callbacks.go @@ -73,26 +73,26 @@ func (cs *callbacks) Raw() *processor { func (p *processor) Execute(db *DB) { curTime := time.Now() + stmt := db.Statement db.RowsAffected = 0 - if stmt := db.Statement; stmt != nil { - if stmt.Model == nil { - stmt.Model = stmt.Dest - } - if stmt.Model != nil { - if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) { - db.AddError(err) - } - } + if stmt.Model == nil { + stmt.Model = stmt.Dest + } - if stmt.Dest != nil { - stmt.ReflectValue = reflect.Indirect(reflect.ValueOf(stmt.Dest)) - for stmt.ReflectValue.Kind() == reflect.Ptr { - stmt.ReflectValue = stmt.ReflectValue.Elem() - } - if !stmt.ReflectValue.IsValid() { - db.AddError(fmt.Errorf("invalid value")) - } + if stmt.Model != nil { + if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) { + db.AddError(err) + } + } + + if stmt.Dest != nil { + stmt.ReflectValue = reflect.ValueOf(stmt.Dest) + for stmt.ReflectValue.Kind() == reflect.Ptr { + stmt.ReflectValue = stmt.ReflectValue.Elem() + } + if !stmt.ReflectValue.IsValid() { + db.AddError(fmt.Errorf("invalid value")) } } @@ -100,16 +100,14 @@ func (p *processor) Execute(db *DB) { f(db) } - if stmt := db.Statement; stmt != nil { - db.Logger.Trace(stmt.Context, curTime, func() (string, int64) { - return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected - }, db.Error) + db.Logger.Trace(stmt.Context, curTime, func() (string, int64) { + return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected + }, db.Error) - if !stmt.DB.DryRun { - stmt.SQL.Reset() - stmt.Vars = nil - stmt.NamedVars = nil - } + if !stmt.DB.DryRun { + stmt.SQL.Reset() + stmt.Vars = nil + stmt.NamedVars = nil } } diff --git a/finisher_api.go b/finisher_api.go index 84890b51..fc21e490 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -51,7 +51,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { // First find first record that match given conditions, order by primary key func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { - tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ + tx = db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) if len(conds) > 0 { @@ -65,7 +65,7 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { // Take return a record that match given conditions, the order will depend on the database implementation func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { - tx = db.getInstance().Limit(1) + tx = db.Limit(1) if len(conds) > 0 { tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)}) } @@ -77,7 +77,7 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { // Last find last record that match given conditions, order by primary key func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { - tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ + tx = db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Desc: true, }) @@ -120,8 +120,7 @@ func (tx *DB) assignExprsToValue(exprs []clause.Expression) { } func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { - tx = db.getInstance() - if tx = tx.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) { + if tx = db.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) { if c, ok := tx.Statement.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok { tx.assignExprsToValue(where.Exprs) @@ -145,8 +144,7 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { } func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { - tx = db.getInstance() - if err := tx.First(dest, conds...).Error; errors.Is(err, ErrRecordNotFound) { + if tx = db.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) { tx.Error = nil if c, ok := tx.Statement.Clauses["WHERE"]; ok { @@ -168,7 +166,7 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { } return tx.Create(dest) - } else if len(tx.Statement.assigns) > 0 { + } else if len(db.Statement.assigns) > 0 { exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) assigns := map[string]interface{}{} for _, expr := range exprs { @@ -186,7 +184,7 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { return tx.Model(dest).Updates(assigns) } - return + return db } // Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update