diff --git a/finisher_api.go b/finisher_api.go index 857f9419..2e7e5f4e 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -186,7 +186,11 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { } func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { - if tx = db.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) { + queryTx := db.Limit(1).Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + }) + + if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 { if c, ok := tx.Statement.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok { tx.assignInterfacesToValue(where.Exprs) @@ -197,7 +201,6 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { if len(tx.Statement.attrs) > 0 { tx.assignInterfacesToValue(tx.Statement.attrs...) } - tx.Error = nil } // initialize with attrs, conds @@ -208,9 +211,11 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { } func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { - if tx = db.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) { - tx.Error = nil + queryTx := db.Limit(1).Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + }) + if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 { if c, ok := tx.Statement.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok { tx.assignInterfacesToValue(where.Exprs)