mirror of https://github.com/go-gorm/gorm.git
Add TxConnPoolBeginner and Tx interface
This commit is contained in:
parent
e2e802b837
commit
996b96e812
|
@ -3,3 +3,4 @@ documents
|
||||||
coverage.txt
|
coverage.txt
|
||||||
_book
|
_book
|
||||||
.idea
|
.idea
|
||||||
|
vendor
|
|
@ -255,6 +255,7 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions)
|
// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions)
|
||||||
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||||
queryTx := db.Limit(1).Order(clause.OrderByColumn{
|
queryTx := db.Limit(1).Order(clause.OrderByColumn{
|
||||||
|
@ -603,6 +604,8 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
|
||||||
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
||||||
} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok {
|
} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok {
|
||||||
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
||||||
|
} else if beginner, ok := tx.Statement.ConnPool.(TxConnPoolBeginner); ok {
|
||||||
|
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
||||||
} else {
|
} else {
|
||||||
err = ErrInvalidTransaction
|
err = ErrInvalidTransaction
|
||||||
}
|
}
|
||||||
|
|
|
@ -50,12 +50,25 @@ type ConnPoolBeginner interface {
|
||||||
BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error)
|
BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TxConnPoolBeginner tx conn pool beginner
|
||||||
|
type TxConnPoolBeginner interface {
|
||||||
|
BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error)
|
||||||
|
}
|
||||||
|
|
||||||
// TxCommitter tx committer
|
// TxCommitter tx committer
|
||||||
type TxCommitter interface {
|
type TxCommitter interface {
|
||||||
Commit() error
|
Commit() error
|
||||||
Rollback() error
|
Rollback() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tx sql.Tx interface
|
||||||
|
type Tx interface {
|
||||||
|
ConnPool
|
||||||
|
Commit() error
|
||||||
|
Rollback() error
|
||||||
|
StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
// Valuer gorm valuer interface
|
// Valuer gorm valuer interface
|
||||||
type Valuer interface {
|
type Valuer interface {
|
||||||
GormValue(context.Context, *DB) clause.Expr
|
GormValue(context.Context, *DB) clause.Expr
|
||||||
|
|
|
@ -73,6 +73,9 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn
|
||||||
if beginner, ok := db.ConnPool.(TxBeginner); ok {
|
if beginner, ok := db.ConnPool.(TxBeginner); ok {
|
||||||
tx, err := beginner.BeginTx(ctx, opt)
|
tx, err := beginner.BeginTx(ctx, opt)
|
||||||
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
|
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
|
||||||
|
} else if beginner, ok := db.ConnPool.(TxConnPoolBeginner); ok {
|
||||||
|
tx, err := beginner.BeginTx(ctx, opt)
|
||||||
|
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
|
||||||
}
|
}
|
||||||
return nil, ErrInvalidTransaction
|
return nil, ErrInvalidTransaction
|
||||||
}
|
}
|
||||||
|
@ -115,7 +118,7 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg
|
||||||
}
|
}
|
||||||
|
|
||||||
type PreparedStmtTX struct {
|
type PreparedStmtTX struct {
|
||||||
*sql.Tx
|
Tx
|
||||||
PreparedStmtDB *PreparedStmtDB
|
PreparedStmtDB *PreparedStmtDB
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -151,7 +154,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
|
||||||
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
|
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||||
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
|
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...)
|
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.PreparedStmtDB.Mux.Lock()
|
tx.PreparedStmtDB.Mux.Lock()
|
||||||
defer tx.PreparedStmtDB.Mux.Unlock()
|
defer tx.PreparedStmtDB.Mux.Unlock()
|
||||||
|
|
|
@ -0,0 +1,181 @@
|
||||||
|
package tests_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/driver/mysql"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
. "gorm.io/gorm/utils/tests"
|
||||||
|
)
|
||||||
|
|
||||||
|
type wrapperTx struct {
|
||||||
|
*sql.Tx
|
||||||
|
conn *wrapperConnPool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wrapperTx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
|
||||||
|
c.conn.got = append(c.conn.got, query)
|
||||||
|
return c.Tx.PrepareContext(ctx, query)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wrapperTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||||
|
c.conn.got = append(c.conn.got, query)
|
||||||
|
return c.Tx.ExecContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wrapperTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||||
|
c.conn.got = append(c.conn.got, query)
|
||||||
|
return c.Tx.QueryContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wrapperTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||||
|
c.conn.got = append(c.conn.got, query)
|
||||||
|
return c.Tx.QueryRowContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
type wrapperConnPool struct {
|
||||||
|
db *sql.DB
|
||||||
|
got []string
|
||||||
|
expect []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wrapperConnPool) Ping() error {
|
||||||
|
return c.db.Ping()
|
||||||
|
}
|
||||||
|
|
||||||
|
// If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction.
|
||||||
|
// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
|
||||||
|
// return c.db.BeginTx(ctx, opts)
|
||||||
|
// }
|
||||||
|
// You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries.
|
||||||
|
func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.Tx, error) {
|
||||||
|
tx, err := c.db.BeginTx(ctx, opts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &wrapperTx{Tx: tx, conn: c}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wrapperConnPool) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
|
||||||
|
c.got = append(c.got, query)
|
||||||
|
return c.db.PrepareContext(ctx, query)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wrapperConnPool) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||||
|
c.got = append(c.got, query)
|
||||||
|
return c.db.ExecContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wrapperConnPool) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||||
|
c.got = append(c.got, query)
|
||||||
|
return c.db.QueryContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wrapperConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||||
|
c.got = append(c.got, query)
|
||||||
|
return c.db.QueryRowContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnPoolWrapper(t *testing.T) {
|
||||||
|
dialect := os.Getenv("GORM_DIALECT")
|
||||||
|
if dialect != "mysql" {
|
||||||
|
t.SkipNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
dbDSN := os.Getenv("GORM_DSN")
|
||||||
|
if dbDSN == "" {
|
||||||
|
dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local"
|
||||||
|
}
|
||||||
|
nativeDB, err := sql.Open("mysql", dbDSN)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Should open db success, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := &wrapperConnPool{
|
||||||
|
db: nativeDB,
|
||||||
|
expect: []string{
|
||||||
|
"SELECT VERSION()",
|
||||||
|
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
|
||||||
|
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||||
|
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
|
||||||
|
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||||
|
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||||
|
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
|
||||||
|
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||||
|
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if !reflect.DeepEqual(conn.got, conn.expect) {
|
||||||
|
t.Errorf("expect %#v but got %#v", conn.expect, conn.got)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
l := logger.New(log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{
|
||||||
|
SlowThreshold: 200 * time.Millisecond,
|
||||||
|
LogLevel: logger.Info,
|
||||||
|
IgnoreRecordNotFoundError: false,
|
||||||
|
Colorful: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}), &gorm.Config{Logger: l})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Should open db success, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tx := db.Begin()
|
||||||
|
user := *GetUser("transaction", Config{})
|
||||||
|
|
||||||
|
if err = tx.Save(&user).Error; err != nil {
|
||||||
|
t.Fatalf("No error should raise, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = tx.First(&User{}, "name = ?", "transaction").Error; err != nil {
|
||||||
|
t.Fatalf("Should find saved record, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user1 := *GetUser("transaction1-1", Config{})
|
||||||
|
|
||||||
|
if err = tx.Save(&user1).Error; err != nil {
|
||||||
|
t.Fatalf("No error should raise, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = tx.First(&User{}, "name = ?", user1.Name).Error; err != nil {
|
||||||
|
t.Fatalf("Should find saved record, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil {
|
||||||
|
t.Fatalf("Should return the underlying sql.Tx")
|
||||||
|
}
|
||||||
|
|
||||||
|
tx.Rollback()
|
||||||
|
|
||||||
|
if err = db.First(&User{}, "name = ?", "transaction").Error; err == nil {
|
||||||
|
t.Fatalf("Should not find record after rollback, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
txDB := db.Where("fake_name = ?", "fake_name")
|
||||||
|
tx2 := txDB.Session(&gorm.Session{NewDB: true}).Begin()
|
||||||
|
user2 := *GetUser("transaction-2", Config{})
|
||||||
|
if err = tx2.Save(&user2).Error; err != nil {
|
||||||
|
t.Fatalf("No error should raise, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = tx2.First(&User{}, "name = ?", "transaction-2").Error; err != nil {
|
||||||
|
t.Fatalf("Should find saved record, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tx2.Commit()
|
||||||
|
|
||||||
|
if err = db.First(&User{}, "name = ?", "transaction-2").Error; err != nil {
|
||||||
|
t.Fatalf("Should be able to find committed record, but got %v", err)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue