diff --git a/dialect.go b/dialect.go index de72b79a..e879588b 100644 --- a/dialect.go +++ b/dialect.go @@ -14,7 +14,7 @@ type Dialect interface { GetName() string // SetDB set db for dialect - SetDB(db *sql.DB) + SetDB(db SQLCommon) // BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1 BindVar(i int) string @@ -50,7 +50,7 @@ type Dialect interface { var dialectsMap = map[string]Dialect{} -func newDialect(name string, db *sql.DB) Dialect { +func newDialect(name string, db SQLCommon) Dialect { if value, ok := dialectsMap[name]; ok { dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect) dialect.SetDB(db) diff --git a/dialect_common.go b/dialect_common.go index 601afd4c..1554151c 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -1,7 +1,6 @@ package gorm import ( - "database/sql" "fmt" "reflect" "regexp" @@ -15,7 +14,7 @@ type DefaultForeignKeyNamer struct { } type commonDialect struct { - db *sql.DB + db SQLCommon DefaultForeignKeyNamer } @@ -27,7 +26,7 @@ func (commonDialect) GetName() string { return "common" } -func (s *commonDialect) SetDB(db *sql.DB) { +func (s *commonDialect) SetDB(db SQLCommon) { s.db = db } diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 7c685c9f..c3c81aa2 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -1,7 +1,6 @@ package mssql import ( - "database/sql" "fmt" "reflect" "strconv" @@ -24,7 +23,7 @@ func init() { } type mssql struct { - db *sql.DB + db gorm.SQLCommon gorm.DefaultForeignKeyNamer } @@ -32,7 +31,7 @@ func (mssql) GetName() string { return "mssql" } -func (s *mssql) SetDB(db *sql.DB) { +func (s *mssql) SetDB(db gorm.SQLCommon) { s.db = db } diff --git a/interface.go b/interface.go index 7b02aa66..55128f7f 100644 --- a/interface.go +++ b/interface.go @@ -2,7 +2,8 @@ package gorm import "database/sql" -type sqlCommon interface { +// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB. +type SQLCommon interface { Exec(query string, args ...interface{}) (sql.Result, error) Prepare(query string) (*sql.Stmt, error) Query(query string, args ...interface{}) (*sql.Rows, error) diff --git a/main.go b/main.go index 9ea10909..9ae560a1 100644 --- a/main.go +++ b/main.go @@ -16,7 +16,7 @@ type DB struct { RowsAffected int64 // single db - db sqlCommon + db SQLCommon blockGlobalUpdate bool logMode int logger logger @@ -47,7 +47,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { return nil, err } var source string - var dbSQL *sql.DB + var dbSQL SQLCommon switch value := args[0].(type) { case string: @@ -59,8 +59,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { source = args[1].(string) } dbSQL, err = sql.Open(driver, source) - case *sql.DB: - source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String() + case SQLCommon: dbSQL = value } @@ -90,21 +89,27 @@ func (s *DB) New() *DB { return clone } -// Close close current db connection +type closer interface { + Close() error +} + +// Close close current db connection. If database connection is not an io.Closer, returns an error. func (s *DB) Close() error { - if db, ok := s.parent.db.(*sql.DB); ok { + if db, ok := s.parent.db.(closer); ok { return db.Close() } return errors.New("can't close current db") } // DB get `*sql.DB` from current connection +// If the underlying database connection is not a *sql.DB, returns nil func (s *DB) DB() *sql.DB { - return s.db.(*sql.DB) + db, _ := s.db.(*sql.DB) + return db } // CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code. -func (s *DB) CommonDB() sqlCommon { +func (s *DB) CommonDB() SQLCommon { return s.db } @@ -449,7 +454,7 @@ func (s *DB) Begin() *DB { c := s.clone() if db, ok := c.db.(sqlDb); ok { tx, err := db.Begin() - c.db = interface{}(tx).(sqlCommon) + c.db = interface{}(tx).(SQLCommon) c.AddError(err) } else { c.AddError(ErrCantStartTransaction) diff --git a/scope.go b/scope.go index 45f7185f..86fd1d42 100644 --- a/scope.go +++ b/scope.go @@ -58,7 +58,7 @@ func (scope *Scope) NewDB() *DB { } // SQLDB return *sql.DB -func (scope *Scope) SQLDB() sqlCommon { +func (scope *Scope) SQLDB() SQLCommon { return scope.db.db } @@ -391,7 +391,7 @@ func (scope *Scope) InstanceGet(name string) (interface{}, bool) { func (scope *Scope) Begin() *Scope { if db, ok := scope.SQLDB().(sqlDb); ok { if tx, err := db.Begin(); err == nil { - scope.db.db = interface{}(tx).(sqlCommon) + scope.db.db = interface{}(tx).(SQLCommon) scope.InstanceSet("gorm:started_transaction", true) } }