From 6a6dfdae72574e931ea4f0737637308ef2c34b8f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 26 Apr 2022 17:16:48 +0800 Subject: [PATCH] Refactor FirstOrCreate, FirstOrInit --- finisher_api.go | 24 ++++++++++++------------ tests/go.mod | 7 +++---- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 0bd8f7d9..663d532b 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -290,7 +290,7 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) - if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 { + if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 { if c, ok := tx.Statement.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok { tx.assignInterfacesToValue(where.Exprs) @@ -312,25 +312,26 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { // FirstOrCreate gets the first matched record or create a new one with given conditions (only works with struct, map conditions) func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { - queryTx := db.Limit(1).Order(clause.OrderByColumn{ + tx = db.getInstance() + queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) - if tx = queryTx.Find(dest, conds...); tx.Error == nil { - if tx.RowsAffected == 0 { - if c, ok := tx.Statement.Clauses["WHERE"]; ok { + if result := queryTx.Find(dest, conds...); result.Error == nil { + if result.RowsAffected == 0 { + if c, ok := result.Statement.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok { - tx.assignInterfacesToValue(where.Exprs) + result.assignInterfacesToValue(where.Exprs) } } // initialize with attrs, conds - if len(tx.Statement.attrs) > 0 { - tx.assignInterfacesToValue(tx.Statement.attrs...) + if len(db.Statement.attrs) > 0 { + result.assignInterfacesToValue(db.Statement.attrs...) } // initialize with attrs, conds - if len(tx.Statement.assigns) > 0 { - tx.assignInterfacesToValue(tx.Statement.assigns...) + if len(db.Statement.assigns) > 0 { + result.assignInterfacesToValue(db.Statement.assigns...) } return tx.Create(dest) @@ -351,8 +352,7 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { return tx.Model(dest).Updates(assigns) } else { - // can not use Find RowsAffected - tx.RowsAffected = 0 + tx.Error = result.Error } } return tx diff --git a/tests/go.mod b/tests/go.mod index 0a3f85f9..6a2cf22f 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,13 +7,12 @@ require ( github.com/google/uuid v1.3.0 github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.5 - github.com/mattn/go-sqlite3 v1.14.12 // indirect golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect gorm.io/driver/mysql v1.3.3 - gorm.io/driver/postgres v1.3.4 - gorm.io/driver/sqlite v1.3.1 + gorm.io/driver/postgres v1.3.5 + gorm.io/driver/sqlite v1.3.2 gorm.io/driver/sqlserver v1.3.2 - gorm.io/gorm v1.23.3 + gorm.io/gorm v1.23.4 ) replace gorm.io/gorm => ../