From 42448cb5d6f68cf2c3a61bf68b217c5bcfb5ab30 Mon Sep 17 00:00:00 2001 From: Timothy Stranex Date: Mon, 17 Mar 2014 12:08:44 +0200 Subject: [PATCH] Add DB.CommonDB() instead of DB.Tx(), as discussed in the PR thread. --- main.go | 15 +++++---------- main_test.go | 7 +++---- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/main.go b/main.go index f205b6c3..942880b5 100644 --- a/main.go +++ b/main.go @@ -28,20 +28,15 @@ func Open(driver, source string) (DB, error) { return db, err } -// Return the underlying sql.DB instance. -// -// If called inside a transaction, it will panic. -// Use Tx() instead in this case. func (s *DB) DB() *sql.DB { return s.db.(*sql.DB) } -// Return the underlying sql.Tx instance. -// -// If called outside of a transaction, it will panic. -// Use DB() instead in this case. -func (s *DB) Tx() *sql.Tx { - return s.db.(*sql.Tx) +// Return the underlying sql.DB or sql.Tx instance. +// Use of this method is discouraged. It's mainly intended to allow +// coexistence with legacy non-GORM code. +func (s *DB) CommonDB() sqlCommon { + return s.db } func (s *DB) Callback() *callback { diff --git a/main_test.go b/main_test.go index 14de6701..32d9ac62 100644 --- a/main_test.go +++ b/main_test.go @@ -1542,10 +1542,9 @@ func TestTransaction(t *testing.T) { t.Errorf("Should find saved record, but got", err) } - sql_tx := tx.Tx() // This shouldn't panic. - if sql_tx == nil { - t.Errorf("Should return the underlying sql.Tx, but got nil") - } + if sql_tx, ok := tx.CommonDB().(*sql.Tx); !ok || sql_tx == nil { + t.Errorf("Should return the underlying sql.Tx") + } tx.Rollback()