From 5ccd76f76cf21722289615333a0b2a8615d95ed9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 23 Feb 2020 23:28:35 +0800 Subject: [PATCH] Setup Transaction --- association.go | 4 ++++ callbacks/query.go | 5 +++-- finisher_api.go | 56 +++++++++++++++++++++++++++++++++------------- interfaces.go | 9 ++++++++ logger/logger.go | 1 + 5 files changed, 57 insertions(+), 18 deletions(-) diff --git a/association.go b/association.go index 17f8f4a5..14bc54b6 100644 --- a/association.go +++ b/association.go @@ -3,3 +3,7 @@ package gorm // Association Mode contains some helper methods to handle relationship things easily. type Association struct { } + +func (db *DB) Association(column string) *Association { + return nil +} diff --git a/callbacks/query.go b/callbacks/query.go index d8785057..baacbd24 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -11,12 +11,13 @@ func Query(db *gorm.DB) { if db.Statement.SQL.String() == "" { db.Statement.AddClauseIfNotExists(clause.Select{}) db.Statement.AddClauseIfNotExists(clause.From{}) - db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") } - _, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + rows, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) db.AddError(err) + _ = rows + // scan rows } func Preload(db *gorm.DB) { diff --git a/finisher_api.go b/finisher_api.go index 2c5d4f65..72c3d2aa 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -23,6 +23,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { // First find first record that match given conditions, order by primary key func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { + // TODO handle where tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Desc: true, @@ -35,12 +36,18 @@ func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { // Take return a record that match given conditions, the order will depend on the database implementation func (db *DB) Take(out interface{}, where ...interface{}) (tx *DB) { tx = db.getInstance() + tx.Statement.Dest = out + tx.callbacks.Query().Execute(tx) return } // Last find last record that match given conditions, order by primary key func (db *DB) Last(out interface{}, where ...interface{}) (tx *DB) { - tx = db.getInstance() + tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + }) + tx.Statement.Dest = out + tx.callbacks.Query().Execute(tx) return } @@ -88,21 +95,12 @@ func (db *DB) Delete(value interface{}, where ...interface{}) (tx *DB) { return } -func (db *DB) Related(value interface{}, foreignKeys ...string) (tx *DB) { - tx = db.getInstance() - return -} - //Preloads only preloads relations, don`t touch out func (db *DB) Preloads(out interface{}) (tx *DB) { tx = db.getInstance() return } -func (db *DB) Association(column string) *Association { - return nil -} - func (db *DB) Count(value interface{}) (tx *DB) { tx = db.getInstance() return @@ -130,6 +128,7 @@ func (db *DB) ScanRows(rows *sql.Rows, result interface{}) error { return nil } +// Transaction start a transaction as a block, return error will rollback, otherwise to commit. func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { panicked := true tx := db.Begin(opts...) @@ -150,21 +149,46 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er return } +// Begin begins a transaction func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) { tx = db.getInstance() + if beginner, ok := tx.DB.(TxBeginner); ok { + var opt *sql.TxOptions + var err error + if len(opts) > 0 { + opt = opts[0] + } + + if tx.DB, err = beginner.BeginTx(db.Context, opt); err != nil { + tx.AddError(err) + } + } else { + tx.AddError(ErrInvalidTransaction) + } return } -func (db *DB) Commit() (tx *DB) { - tx = db.getInstance() - return +// Commit commit a transaction +func (db *DB) Commit() *DB { + if comminter, ok := db.DB.(TxCommiter); ok && comminter != nil { + db.AddError(comminter.Commit()) + } else { + db.AddError(ErrInvalidTransaction) + } + return db } -func (db *DB) Rollback() (tx *DB) { - tx = db.getInstance() - return +// Rollback rollback a transaction +func (db *DB) Rollback() *DB { + if comminter, ok := db.DB.(TxCommiter); ok && comminter != nil { + db.AddError(comminter.Rollback()) + } else { + db.AddError(ErrInvalidTransaction) + } + return db } +// Exec execute raw sql func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.SQL = strings.Builder{} diff --git a/interfaces.go b/interfaces.go index 21563b7d..f0d14dd8 100644 --- a/interfaces.go +++ b/interfaces.go @@ -25,6 +25,15 @@ type CommonDB interface { QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row } +type TxBeginner interface { + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) +} + +type TxCommiter interface { + Commit() error + Rollback() error +} + type BeforeCreateInterface interface { BeforeCreate(*DB) } diff --git a/logger/logger.go b/logger/logger.go index 568ddd57..d3b97b9d 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -53,6 +53,7 @@ type Interface interface { var Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ SlowThreshold: 100 * time.Millisecond, + LogLevel: Warn, Colorful: true, })