diff --git a/main.go b/main.go index 3b058231..6bd006d7 100644 --- a/main.go +++ b/main.go @@ -533,7 +533,9 @@ func (s *DB) Commit() *DB { func (s *DB) Rollback() *DB { var emptySQLTx *sql.Tx if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { - s.AddError(db.Rollback()) + if err := db.Rollback(); err != nil && err != sql.ErrTxDone { + s.AddError(err) + } } else { s.AddError(ErrInvalidTransaction) } diff --git a/main_test.go b/main_test.go index 14bf34ac..3d922dda 100644 --- a/main_test.go +++ b/main_test.go @@ -421,6 +421,22 @@ func TestTransaction(t *testing.T) { } } +func TestTransaction_NoErrorOnRollbackAfterCommit(t *testing.T) { + tx := DB.Begin() + u := User{Name: "transcation"} + if err := tx.Save(&u).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx.Commit().Error; err != nil { + t.Errorf("Commit should not raise error") + } + + if err := tx.Rollback().Error; err != nil { + t.Errorf("Rollback should not raise error") + } +} + func TestRow(t *testing.T) { user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")} user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")}