mirror of https://github.com/go-gorm/gorm.git
Fix transaction on closed conn when using prepared statement, close #3380
This commit is contained in:
parent
3cd81ff646
commit
dd0d74fad0
|
@ -99,6 +99,20 @@ type PreparedStmtTX struct {
|
|||
PreparedStmtDB *PreparedStmtDB
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) Commit() error {
|
||||
if tx.Tx != nil {
|
||||
return tx.Tx.Commit()
|
||||
}
|
||||
return ErrInvalidTransaction
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) Rollback() error {
|
||||
if tx.Tx != nil {
|
||||
return tx.Tx.Rollback()
|
||||
}
|
||||
return ErrInvalidTransaction
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
|
||||
stmt, err := tx.PreparedStmtDB.prepare(ctx, query)
|
||||
if err == nil {
|
||||
|
|
|
@ -21,7 +21,7 @@ var DB *gorm.DB
|
|||
func init() {
|
||||
var err error
|
||||
if DB, err = OpenTestConnection(); err != nil {
|
||||
log.Printf("failed to connect database, got error %v\n", err)
|
||||
log.Printf("failed to connect database, got error %v", err)
|
||||
os.Exit(1)
|
||||
} else {
|
||||
sqlDB, err := DB.DB()
|
||||
|
@ -30,7 +30,7 @@ func init() {
|
|||
}
|
||||
|
||||
if err != nil {
|
||||
log.Printf("failed to connect database, got error %v\n", err)
|
||||
log.Printf("failed to connect database, got error %v", err)
|
||||
}
|
||||
|
||||
RunMigrations()
|
||||
|
|
|
@ -282,3 +282,24 @@ func TestNestedTransactionWithBlock(t *testing.T) {
|
|||
t.Fatalf("Should find saved record")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransactionOnClosedConn(t *testing.T) {
|
||||
DB, err := OpenTestConnection()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect database, got error %v", err)
|
||||
}
|
||||
rawDB, _ := DB.DB()
|
||||
rawDB.Close()
|
||||
|
||||
if err := DB.Transaction(func(tx *gorm.DB) error {
|
||||
return nil
|
||||
}); err == nil {
|
||||
t.Errorf("should returns error when commit with closed conn, got error %v", err)
|
||||
}
|
||||
|
||||
if err := DB.Session(&gorm.Session{PrepareStmt: true}).Transaction(func(tx *gorm.DB) error {
|
||||
return nil
|
||||
}); err == nil {
|
||||
t.Errorf("should returns error when commit with closed conn, got error %v", err)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue