From af01854d3ecae994322b18d71cafdec114de9d81 Mon Sep 17 00:00:00 2001 From: Tyler Stillwater Date: Mon, 10 Jun 2019 06:33:20 -0600 Subject: [PATCH] Add BeginTx for parity with sql.DB.BeginTx (#2227) --- interface.go | 6 +++++- main.go | 10 ++++++++-- main_test.go | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/interface.go b/interface.go index 55128f7f..fe649231 100644 --- a/interface.go +++ b/interface.go @@ -1,6 +1,9 @@ package gorm -import "database/sql" +import ( + "context" + "database/sql" +) // SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB. type SQLCommon interface { @@ -12,6 +15,7 @@ type SQLCommon interface { type sqlDb interface { Begin() (*sql.Tx, error) + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) } type sqlTx interface { diff --git a/main.go b/main.go index 906b7f41..994d1618 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package gorm import ( + "context" "database/sql" "errors" "fmt" @@ -503,11 +504,16 @@ func (s *DB) Debug() *DB { return s.clone().LogMode(true) } -// Begin begin a transaction +// Begin begins a transaction func (s *DB) Begin() *DB { + return s.BeginTx(context.Background(), &sql.TxOptions{}) +} + +// BeginTX begins a transaction with options +func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB { c := s.clone() if db, ok := c.db.(sqlDb); ok && db != nil { - tx, err := db.Begin() + tx, err := db.BeginTx(ctx, opts) c.db = interface{}(tx).(SQLCommon) c.dialect.SetDB(c.db) diff --git a/main_test.go b/main_test.go index ee038cac..81ecf0fe 100644 --- a/main_test.go +++ b/main_test.go @@ -1,6 +1,7 @@ package gorm_test import ( + "context" "database/sql" "database/sql/driver" "fmt" @@ -471,6 +472,40 @@ func TestTransaction_NoErrorOnRollbackAfterCommit(t *testing.T) { } } +func TestTransactionReadonly(t *testing.T) { + dialect := os.Getenv("GORM_DIALECT") + if dialect == "" { + dialect = "sqlite" + } + switch dialect { + case "mssql", "sqlite": + t.Skipf("%s does not support readonly transactions\n", dialect) + } + + tx := DB.Begin() + u := User{Name: "transcation"} + if err := tx.Save(&u).Error; err != nil { + t.Errorf("No error should raise") + } + tx.Commit() + + tx = DB.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}) + if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil { + t.Errorf("Should find saved record") + } + + if sqlTx, ok := tx.CommonDB().(*sql.Tx); !ok || sqlTx == nil { + t.Errorf("Should return the underlying sql.Tx") + } + + u = User{Name: "transcation-2"} + if err := tx.Save(&u).Error; err == nil { + t.Errorf("Error should have been raised in a readonly transaction") + } + + tx.Rollback() +} + func TestRow(t *testing.T) { user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")} user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")}