Add BeginTx for parity with sql.DB.BeginTx (#2227)

This commit is contained in:
Tyler Stillwater 2019-06-10 06:33:20 -06:00 committed by Jinzhu
parent cf9b85ed90
commit fec06da6a3
3 changed files with 48 additions and 3 deletions

View File

@ -1,6 +1,9 @@
package gorm
import "database/sql"
import (
"context"
"database/sql"
)
// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB.
type SQLCommon interface {
@ -12,6 +15,7 @@ type SQLCommon interface {
type sqlDb interface {
Begin() (*sql.Tx, error)
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}
type sqlTx interface {

10
main.go
View File

@ -1,6 +1,7 @@
package gorm
import (
"context"
"database/sql"
"errors"
"fmt"
@ -503,11 +504,16 @@ func (s *DB) Debug() *DB {
return s.clone().LogMode(true)
}
// Begin begin a transaction
// Begin begins a transaction
func (s *DB) Begin() *DB {
return s.BeginTx(context.Background(), &sql.TxOptions{})
}
// BeginTX begins a transaction with options
func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB {
c := s.clone()
if db, ok := c.db.(sqlDb); ok && db != nil {
tx, err := db.Begin()
tx, err := db.BeginTx(ctx, opts)
c.db = interface{}(tx).(SQLCommon)
c.dialect.SetDB(c.db)

View File

@ -1,6 +1,7 @@
package gorm_test
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
@ -471,6 +472,40 @@ func TestTransaction_NoErrorOnRollbackAfterCommit(t *testing.T) {
}
}
func TestTransactionReadonly(t *testing.T) {
dialect := os.Getenv("GORM_DIALECT")
if dialect == "" {
dialect = "sqlite"
}
switch dialect {
case "mssql", "sqlite":
t.Skipf("%s does not support readonly transactions\n", dialect)
}
tx := DB.Begin()
u := User{Name: "transcation"}
if err := tx.Save(&u).Error; err != nil {
t.Errorf("No error should raise")
}
tx.Commit()
tx = DB.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true})
if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil {
t.Errorf("Should find saved record")
}
if sqlTx, ok := tx.CommonDB().(*sql.Tx); !ok || sqlTx == nil {
t.Errorf("Should return the underlying sql.Tx")
}
u = User{Name: "transcation-2"}
if err := tx.Save(&u).Error; err == nil {
t.Errorf("Error should have been raised in a readonly transaction")
}
tx.Rollback()
}
func TestRow(t *testing.T) {
user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")}
user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")}