diff --git a/prepare_stmt.go b/prepare_stmt.go index de7e2a26..14a6aaec 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -99,6 +99,20 @@ type PreparedStmtTX struct { PreparedStmtDB *PreparedStmtDB } +func (tx *PreparedStmtTX) Commit() error { + if tx.Tx != nil { + return tx.Tx.Commit() + } + return ErrInvalidTransaction +} + +func (tx *PreparedStmtTX) Rollback() error { + if tx.Tx != nil { + return tx.Tx.Rollback() + } + return ErrInvalidTransaction +} + func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { stmt, err := tx.PreparedStmtDB.prepare(ctx, query) if err == nil { diff --git a/tests/tests_test.go b/tests/tests_test.go index 192160a0..cb73d267 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -21,7 +21,7 @@ var DB *gorm.DB func init() { var err error if DB, err = OpenTestConnection(); err != nil { - log.Printf("failed to connect database, got error %v\n", err) + log.Printf("failed to connect database, got error %v", err) os.Exit(1) } else { sqlDB, err := DB.DB() @@ -30,7 +30,7 @@ func init() { } if err != nil { - log.Printf("failed to connect database, got error %v\n", err) + log.Printf("failed to connect database, got error %v", err) } RunMigrations() diff --git a/tests/transaction_test.go b/tests/transaction_test.go index aea151d9..334600b8 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -282,3 +282,24 @@ func TestNestedTransactionWithBlock(t *testing.T) { t.Fatalf("Should find saved record") } } + +func TestTransactionOnClosedConn(t *testing.T) { + DB, err := OpenTestConnection() + if err != nil { + t.Fatalf("failed to connect database, got error %v", err) + } + rawDB, _ := DB.DB() + rawDB.Close() + + if err := DB.Transaction(func(tx *gorm.DB) error { + return nil + }); err == nil { + t.Errorf("should returns error when commit with closed conn, got error %v", err) + } + + if err := DB.Session(&gorm.Session{PrepareStmt: true}).Transaction(func(tx *gorm.DB) error { + return nil + }); err == nil { + t.Errorf("should returns error when commit with closed conn, got error %v", err) + } +}