Add DB.Tx() method to provice access to the underlying sql.Tx instance.

This commit is contained in:
Timothy Stranex 2014-03-16 18:24:32 +02:00
parent d232c69369
commit a336f51444
2 changed files with 17 additions and 0 deletions

12
main.go
View File

@ -28,10 +28,22 @@ func Open(driver, source string) (DB, error) {
return db, err return db, err
} }
// Return the underlying sql.DB instance.
//
// If called inside a transaction, it will panic.
// Use Tx() instead in this case.
func (s *DB) DB() *sql.DB { func (s *DB) DB() *sql.DB {
return s.db.(*sql.DB) return s.db.(*sql.DB)
} }
// Return the underlying sql.Tx instance.
//
// If called outside of a transaction, it will panic.
// Use DB() instead in this case.
func (s *DB) Tx() *sql.Tx {
return s.db.(*sql.Tx)
}
func (s *DB) Callback() *callback { func (s *DB) Callback() *callback {
s.parent.callback = s.parent.callback.clone() s.parent.callback = s.parent.callback.clone()
return s.parent.callback return s.parent.callback

View File

@ -1542,6 +1542,11 @@ func TestTransaction(t *testing.T) {
t.Errorf("Should find saved record, but got", err) t.Errorf("Should find saved record, but got", err)
} }
sql_tx := tx.Tx() // This shouldn't panic.
if sql_tx == nil {
t.Errorf("Should return the underlying sql.Tx, but got nil")
}
tx.Rollback() tx.Rollback()
if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil { if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil {