mirror of https://github.com/go-gorm/gorm.git
Add BeginTx for parity with sql.DB.BeginTx (#2227)
This commit is contained in:
parent
ac78f05986
commit
af01854d3e
|
@ -1,6 +1,9 @@
|
|||
package gorm
|
||||
|
||||
import "database/sql"
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB.
|
||||
type SQLCommon interface {
|
||||
|
@ -12,6 +15,7 @@ type SQLCommon interface {
|
|||
|
||||
type sqlDb interface {
|
||||
Begin() (*sql.Tx, error)
|
||||
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
|
||||
}
|
||||
|
||||
type sqlTx interface {
|
||||
|
|
10
main.go
10
main.go
|
@ -1,6 +1,7 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -503,11 +504,16 @@ func (s *DB) Debug() *DB {
|
|||
return s.clone().LogMode(true)
|
||||
}
|
||||
|
||||
// Begin begin a transaction
|
||||
// Begin begins a transaction
|
||||
func (s *DB) Begin() *DB {
|
||||
return s.BeginTx(context.Background(), &sql.TxOptions{})
|
||||
}
|
||||
|
||||
// BeginTX begins a transaction with options
|
||||
func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB {
|
||||
c := s.clone()
|
||||
if db, ok := c.db.(sqlDb); ok && db != nil {
|
||||
tx, err := db.Begin()
|
||||
tx, err := db.BeginTx(ctx, opts)
|
||||
c.db = interface{}(tx).(SQLCommon)
|
||||
|
||||
c.dialect.SetDB(c.db)
|
||||
|
|
35
main_test.go
35
main_test.go
|
@ -1,6 +1,7 @@
|
|||
package gorm_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
|
@ -471,6 +472,40 @@ func TestTransaction_NoErrorOnRollbackAfterCommit(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestTransactionReadonly(t *testing.T) {
|
||||
dialect := os.Getenv("GORM_DIALECT")
|
||||
if dialect == "" {
|
||||
dialect = "sqlite"
|
||||
}
|
||||
switch dialect {
|
||||
case "mssql", "sqlite":
|
||||
t.Skipf("%s does not support readonly transactions\n", dialect)
|
||||
}
|
||||
|
||||
tx := DB.Begin()
|
||||
u := User{Name: "transcation"}
|
||||
if err := tx.Save(&u).Error; err != nil {
|
||||
t.Errorf("No error should raise")
|
||||
}
|
||||
tx.Commit()
|
||||
|
||||
tx = DB.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true})
|
||||
if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil {
|
||||
t.Errorf("Should find saved record")
|
||||
}
|
||||
|
||||
if sqlTx, ok := tx.CommonDB().(*sql.Tx); !ok || sqlTx == nil {
|
||||
t.Errorf("Should return the underlying sql.Tx")
|
||||
}
|
||||
|
||||
u = User{Name: "transcation-2"}
|
||||
if err := tx.Save(&u).Error; err == nil {
|
||||
t.Errorf("Error should have been raised in a readonly transaction")
|
||||
}
|
||||
|
||||
tx.Rollback()
|
||||
}
|
||||
|
||||
func TestRow(t *testing.T) {
|
||||
user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")}
|
||||
user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")}
|
||||
|
|
Loading…
Reference in New Issue