diff --git a/callbacks.go b/callbacks.go index f835e504..c060ea70 100644 --- a/callbacks.go +++ b/callbacks.go @@ -2,7 +2,6 @@ package gorm import ( "context" - "database/sql" "errors" "fmt" "reflect" @@ -16,13 +15,12 @@ import ( func initializeCallbacks(db *DB) *callbacks { return &callbacks{ processors: map[string]*processor{ - "create": {db: db}, - "query": {db: db}, - "update": {db: db}, - "delete": {db: db}, - "row": {db: db}, - "raw": {db: db}, - "transaction": {db: db}, + "create": {db: db}, + "query": {db: db}, + "update": {db: db}, + "delete": {db: db}, + "row": {db: db}, + "raw": {db: db}, }, } } @@ -74,29 +72,6 @@ func (cs *callbacks) Raw() *processor { return cs.processors["raw"] } -func (cs *callbacks) Transaction() *processor { - return cs.processors["transaction"] -} - -func (p *processor) Begin(tx *DB, opt *sql.TxOptions) *DB { - var err error - - switch beginner := tx.Statement.ConnPool.(type) { - case TxBeginner: - tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) - case ConnPoolBeginner: - tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) - default: - err = ErrInvalidTransaction - } - - if err != nil { - _ = tx.AddError(err) - } - - return tx -} - func (p *processor) Execute(db *DB) *DB { // call scopes for len(db.Statement.scopes) > 0 { diff --git a/finisher_api.go b/finisher_api.go index 3e406c1c..7a3f27ba 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -619,13 +619,27 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { // clone statement tx = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1}) opt *sql.TxOptions + err error ) if len(opts) > 0 { opt = opts[0] } - return tx.callbacks.Transaction().Begin(tx, opt) + switch beginner := tx.Statement.ConnPool.(type) { + case TxBeginner: + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + case ConnPoolBeginner: + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + default: + err = ErrInvalidTransaction + } + + if err != nil { + tx.AddError(err) + } + + return tx } // Commit commit a transaction