Add TxConnPoolBeginner and Tx interface

This commit is contained in:
lianghuan 2022-02-28 17:12:09 +08:00 committed by Jinzhu
parent e2e802b837
commit 996b96e812
5 changed files with 203 additions and 2 deletions

1
.gitignore vendored
View File

@ -3,3 +3,4 @@ documents
coverage.txt coverage.txt
_book _book
.idea .idea
vendor

View File

@ -255,6 +255,7 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) {
} }
} }
} }
// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions) // FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions)
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
queryTx := db.Limit(1).Order(clause.OrderByColumn{ queryTx := db.Limit(1).Order(clause.OrderByColumn{
@ -603,6 +604,8 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok { } else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok {
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
} else if beginner, ok := tx.Statement.ConnPool.(TxConnPoolBeginner); ok {
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
} else { } else {
err = ErrInvalidTransaction err = ErrInvalidTransaction
} }

View File

@ -50,12 +50,25 @@ type ConnPoolBeginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error)
} }
// TxConnPoolBeginner tx conn pool beginner
type TxConnPoolBeginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error)
}
// TxCommitter tx committer // TxCommitter tx committer
type TxCommitter interface { type TxCommitter interface {
Commit() error Commit() error
Rollback() error Rollback() error
} }
// Tx sql.Tx interface
type Tx interface {
ConnPool
Commit() error
Rollback() error
StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
}
// Valuer gorm valuer interface // Valuer gorm valuer interface
type Valuer interface { type Valuer interface {
GormValue(context.Context, *DB) clause.Expr GormValue(context.Context, *DB) clause.Expr

View File

@ -73,6 +73,9 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn
if beginner, ok := db.ConnPool.(TxBeginner); ok { if beginner, ok := db.ConnPool.(TxBeginner); ok {
tx, err := beginner.BeginTx(ctx, opt) tx, err := beginner.BeginTx(ctx, opt)
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
} else if beginner, ok := db.ConnPool.(TxConnPoolBeginner); ok {
tx, err := beginner.BeginTx(ctx, opt)
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
} }
return nil, ErrInvalidTransaction return nil, ErrInvalidTransaction
} }
@ -115,7 +118,7 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg
} }
type PreparedStmtTX struct { type PreparedStmtTX struct {
*sql.Tx Tx
PreparedStmtDB *PreparedStmtDB PreparedStmtDB *PreparedStmtDB
} }
@ -151,7 +154,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil { if err == nil {
rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...) rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
if err != nil { if err != nil {
tx.PreparedStmtDB.Mux.Lock() tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock() defer tx.PreparedStmtDB.Mux.Unlock()

181
tests/connpool_test.go Normal file
View File

@ -0,0 +1,181 @@
package tests_test
import (
"context"
"database/sql"
"log"
"os"
"reflect"
"testing"
"time"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/logger"
. "gorm.io/gorm/utils/tests"
)
type wrapperTx struct {
*sql.Tx
conn *wrapperConnPool
}
func (c *wrapperTx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
c.conn.got = append(c.conn.got, query)
return c.Tx.PrepareContext(ctx, query)
}
func (c *wrapperTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
c.conn.got = append(c.conn.got, query)
return c.Tx.ExecContext(ctx, query, args...)
}
func (c *wrapperTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
c.conn.got = append(c.conn.got, query)
return c.Tx.QueryContext(ctx, query, args...)
}
func (c *wrapperTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
c.conn.got = append(c.conn.got, query)
return c.Tx.QueryRowContext(ctx, query, args...)
}
type wrapperConnPool struct {
db *sql.DB
got []string
expect []string
}
func (c *wrapperConnPool) Ping() error {
return c.db.Ping()
}
// If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction.
// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
// return c.db.BeginTx(ctx, opts)
// }
// You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries.
func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.Tx, error) {
tx, err := c.db.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
return &wrapperTx{Tx: tx, conn: c}, nil
}
func (c *wrapperConnPool) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
c.got = append(c.got, query)
return c.db.PrepareContext(ctx, query)
}
func (c *wrapperConnPool) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
c.got = append(c.got, query)
return c.db.ExecContext(ctx, query, args...)
}
func (c *wrapperConnPool) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
c.got = append(c.got, query)
return c.db.QueryContext(ctx, query, args...)
}
func (c *wrapperConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
c.got = append(c.got, query)
return c.db.QueryRowContext(ctx, query, args...)
}
func TestConnPoolWrapper(t *testing.T) {
dialect := os.Getenv("GORM_DIALECT")
if dialect != "mysql" {
t.SkipNow()
}
dbDSN := os.Getenv("GORM_DSN")
if dbDSN == "" {
dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local"
}
nativeDB, err := sql.Open("mysql", dbDSN)
if err != nil {
t.Fatalf("Should open db success, but got %v", err)
}
conn := &wrapperConnPool{
db: nativeDB,
expect: []string{
"SELECT VERSION()",
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
},
}
defer func() {
if !reflect.DeepEqual(conn.got, conn.expect) {
t.Errorf("expect %#v but got %#v", conn.expect, conn.got)
}
}()
l := logger.New(log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{
SlowThreshold: 200 * time.Millisecond,
LogLevel: logger.Info,
IgnoreRecordNotFoundError: false,
Colorful: true,
})
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}), &gorm.Config{Logger: l})
if err != nil {
t.Fatalf("Should open db success, but got %v", err)
}
tx := db.Begin()
user := *GetUser("transaction", Config{})
if err = tx.Save(&user).Error; err != nil {
t.Fatalf("No error should raise, but got %v", err)
}
if err = tx.First(&User{}, "name = ?", "transaction").Error; err != nil {
t.Fatalf("Should find saved record, but got %v", err)
}
user1 := *GetUser("transaction1-1", Config{})
if err = tx.Save(&user1).Error; err != nil {
t.Fatalf("No error should raise, but got %v", err)
}
if err = tx.First(&User{}, "name = ?", user1.Name).Error; err != nil {
t.Fatalf("Should find saved record, but got %v", err)
}
if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil {
t.Fatalf("Should return the underlying sql.Tx")
}
tx.Rollback()
if err = db.First(&User{}, "name = ?", "transaction").Error; err == nil {
t.Fatalf("Should not find record after rollback, but got %v", err)
}
txDB := db.Where("fake_name = ?", "fake_name")
tx2 := txDB.Session(&gorm.Session{NewDB: true}).Begin()
user2 := *GetUser("transaction-2", Config{})
if err = tx2.Save(&user2).Error; err != nil {
t.Fatalf("No error should raise, but got %v", err)
}
if err = tx2.First(&User{}, "name = ?", "transaction-2").Error; err != nil {
t.Fatalf("Should find saved record, but got %v", err)
}
tx2.Commit()
if err = db.First(&User{}, "name = ?", "transaction-2").Error; err != nil {
t.Fatalf("Should be able to find committed record, but got %v", err)
}
}