diff --git a/.gitignore b/.gitignore index e1b9ecea..45505cc9 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ documents coverage.txt _book .idea +vendor \ No newline at end of file diff --git a/finisher_api.go b/finisher_api.go index f994ec31..5d49ddf9 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -255,6 +255,7 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { } } } + // FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions) func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { queryTx := db.Limit(1).Order(clause.OrderByColumn{ @@ -603,6 +604,8 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) } else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok { tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + } else if beginner, ok := tx.Statement.ConnPool.(TxConnPoolBeginner); ok { + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) } else { err = ErrInvalidTransaction } diff --git a/interfaces.go b/interfaces.go index 44a85cb5..ed7112f2 100644 --- a/interfaces.go +++ b/interfaces.go @@ -50,12 +50,25 @@ type ConnPoolBeginner interface { BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) } +// TxConnPoolBeginner tx conn pool beginner +type TxConnPoolBeginner interface { + BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) +} + // TxCommitter tx committer type TxCommitter interface { Commit() error Rollback() error } +// Tx sql.Tx interface +type Tx interface { + ConnPool + Commit() error + Rollback() error + StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt +} + // Valuer gorm valuer interface type Valuer interface { GormValue(context.Context, *DB) clause.Expr diff --git a/prepare_stmt.go b/prepare_stmt.go index 88bec4e9..94282fad 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -73,6 +73,9 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn if beginner, ok := db.ConnPool.(TxBeginner); ok { tx, err := beginner.BeginTx(ctx, opt) return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err + } else if beginner, ok := db.ConnPool.(TxConnPoolBeginner); ok { + tx, err := beginner.BeginTx(ctx, opt) + return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err } return nil, ErrInvalidTransaction } @@ -115,7 +118,7 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg } type PreparedStmtTX struct { - *sql.Tx + Tx PreparedStmtDB *PreparedStmtDB } @@ -151,7 +154,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { - rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...) + rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...) if err != nil { tx.PreparedStmtDB.Mux.Lock() defer tx.PreparedStmtDB.Mux.Unlock() diff --git a/tests/connpool_test.go b/tests/connpool_test.go new file mode 100644 index 00000000..3713ad7c --- /dev/null +++ b/tests/connpool_test.go @@ -0,0 +1,181 @@ +package tests_test + +import ( + "context" + "database/sql" + "log" + "os" + "reflect" + "testing" + "time" + + "gorm.io/driver/mysql" + "gorm.io/gorm" + "gorm.io/gorm/logger" + . "gorm.io/gorm/utils/tests" +) + +type wrapperTx struct { + *sql.Tx + conn *wrapperConnPool +} + +func (c *wrapperTx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + c.conn.got = append(c.conn.got, query) + return c.Tx.PrepareContext(ctx, query) +} + +func (c *wrapperTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + c.conn.got = append(c.conn.got, query) + return c.Tx.ExecContext(ctx, query, args...) +} + +func (c *wrapperTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + c.conn.got = append(c.conn.got, query) + return c.Tx.QueryContext(ctx, query, args...) +} + +func (c *wrapperTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + c.conn.got = append(c.conn.got, query) + return c.Tx.QueryRowContext(ctx, query, args...) +} + +type wrapperConnPool struct { + db *sql.DB + got []string + expect []string +} + +func (c *wrapperConnPool) Ping() error { + return c.db.Ping() +} + +// If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction. +// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { +// return c.db.BeginTx(ctx, opts) +// } +// You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries. +func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.Tx, error) { + tx, err := c.db.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + return &wrapperTx{Tx: tx, conn: c}, nil +} + +func (c *wrapperConnPool) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + c.got = append(c.got, query) + return c.db.PrepareContext(ctx, query) +} + +func (c *wrapperConnPool) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + c.got = append(c.got, query) + return c.db.ExecContext(ctx, query, args...) +} + +func (c *wrapperConnPool) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + c.got = append(c.got, query) + return c.db.QueryContext(ctx, query, args...) +} + +func (c *wrapperConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + c.got = append(c.got, query) + return c.db.QueryRowContext(ctx, query, args...) +} + +func TestConnPoolWrapper(t *testing.T) { + dialect := os.Getenv("GORM_DIALECT") + if dialect != "mysql" { + t.SkipNow() + } + + dbDSN := os.Getenv("GORM_DSN") + if dbDSN == "" { + dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" + } + nativeDB, err := sql.Open("mysql", dbDSN) + if err != nil { + t.Fatalf("Should open db success, but got %v", err) + } + + conn := &wrapperConnPool{ + db: nativeDB, + expect: []string{ + "SELECT VERSION()", + "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + }, + } + + defer func() { + if !reflect.DeepEqual(conn.got, conn.expect) { + t.Errorf("expect %#v but got %#v", conn.expect, conn.got) + } + }() + + l := logger.New(log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{ + SlowThreshold: 200 * time.Millisecond, + LogLevel: logger.Info, + IgnoreRecordNotFoundError: false, + Colorful: true, + }) + + db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}), &gorm.Config{Logger: l}) + if err != nil { + t.Fatalf("Should open db success, but got %v", err) + } + + tx := db.Begin() + user := *GetUser("transaction", Config{}) + + if err = tx.Save(&user).Error; err != nil { + t.Fatalf("No error should raise, but got %v", err) + } + + if err = tx.First(&User{}, "name = ?", "transaction").Error; err != nil { + t.Fatalf("Should find saved record, but got %v", err) + } + + user1 := *GetUser("transaction1-1", Config{}) + + if err = tx.Save(&user1).Error; err != nil { + t.Fatalf("No error should raise, but got %v", err) + } + + if err = tx.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should find saved record, but got %v", err) + } + + if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil { + t.Fatalf("Should return the underlying sql.Tx") + } + + tx.Rollback() + + if err = db.First(&User{}, "name = ?", "transaction").Error; err == nil { + t.Fatalf("Should not find record after rollback, but got %v", err) + } + + txDB := db.Where("fake_name = ?", "fake_name") + tx2 := txDB.Session(&gorm.Session{NewDB: true}).Begin() + user2 := *GetUser("transaction-2", Config{}) + if err = tx2.Save(&user2).Error; err != nil { + t.Fatalf("No error should raise, but got %v", err) + } + + if err = tx2.First(&User{}, "name = ?", "transaction-2").Error; err != nil { + t.Fatalf("Should find saved record, but got %v", err) + } + + tx2.Commit() + + if err = db.First(&User{}, "name = ?", "transaction-2").Error; err != nil { + t.Fatalf("Should be able to find committed record, but got %v", err) + } +}