Avoid panic for invalid transaction, close #3271

This commit is contained in:
Jinzhu 2020-08-17 12:16:42 +08:00
parent 6834c25cec
commit 2a716e04e6
2 changed files with 24 additions and 2 deletions

View File

@ -445,7 +445,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
// Commit commit a transaction // Commit commit a transaction
func (db *DB) Commit() *DB { func (db *DB) Commit() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
db.AddError(committer.Commit()) db.AddError(committer.Commit())
} else { } else {
db.AddError(ErrInvalidTransaction) db.AddError(ErrInvalidTransaction)
@ -456,7 +456,9 @@ func (db *DB) Commit() *DB {
// Rollback rollback a transaction // Rollback rollback a transaction
func (db *DB) Rollback() *DB { func (db *DB) Rollback() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
if !reflect.ValueOf(committer).IsNil() {
db.AddError(committer.Rollback()) db.AddError(committer.Rollback())
}
} else { } else {
db.AddError(ErrInvalidTransaction) db.AddError(ErrInvalidTransaction)
} }

View File

@ -1,6 +1,7 @@
package tests_test package tests_test
import ( import (
"context"
"errors" "errors"
"testing" "testing"
@ -57,6 +58,25 @@ func TestTransaction(t *testing.T) {
} }
} }
func TestCancelTransaction(t *testing.T) {
ctx := context.Background()
ctx, cancelFunc := context.WithCancel(ctx)
cancelFunc()
user := *GetUser("cancel_transaction", Config{})
DB.Create(&user)
err := DB.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
var result User
tx.First(&result, user.ID)
return nil
})
if err == nil {
t.Fatalf("Transaction should get error when using cancelled context")
}
}
func TestTransactionWithBlock(t *testing.T) { func TestTransactionWithBlock(t *testing.T) {
assertPanic := func(f func()) { assertPanic := func(f func()) {
defer func() { defer func() {