From 45f1a9505168d7be2398830632d74e173fb2af3f Mon Sep 17 00:00:00 2001 From: Russ Egan Date: Tue, 14 Mar 2017 16:32:38 -0400 Subject: [PATCH] Replace all use of *sql.DB with sqlCommon Exporting sqlCommon as SQLCommon. This allows passing alternate implementations of the database connection, or wrapping the connection with middleware. This change didn't change any usages of the database variables. All usages were already only using the functions defined in SQLCommon. This does cause a breaking change in Dialect, since *sql.DB was referenced in the interface. --- dialect.go | 4 ++-- dialect_common.go | 5 ++--- dialects/mssql/mssql.go | 5 ++--- interface.go | 3 ++- main.go | 23 ++++++++++++++--------- scope.go | 4 ++-- 6 files changed, 24 insertions(+), 20 deletions(-) diff --git a/dialect.go b/dialect.go index de72b79a..e879588b 100644 --- a/dialect.go +++ b/dialect.go @@ -14,7 +14,7 @@ type Dialect interface { GetName() string // SetDB set db for dialect - SetDB(db *sql.DB) + SetDB(db SQLCommon) // BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1 BindVar(i int) string @@ -50,7 +50,7 @@ type Dialect interface { var dialectsMap = map[string]Dialect{} -func newDialect(name string, db *sql.DB) Dialect { +func newDialect(name string, db SQLCommon) Dialect { if value, ok := dialectsMap[name]; ok { dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect) dialect.SetDB(db) diff --git a/dialect_common.go b/dialect_common.go index 601afd4c..1554151c 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -1,7 +1,6 @@ package gorm import ( - "database/sql" "fmt" "reflect" "regexp" @@ -15,7 +14,7 @@ type DefaultForeignKeyNamer struct { } type commonDialect struct { - db *sql.DB + db SQLCommon DefaultForeignKeyNamer } @@ -27,7 +26,7 @@ func (commonDialect) GetName() string { return "common" } -func (s *commonDialect) SetDB(db *sql.DB) { +func (s *commonDialect) SetDB(db SQLCommon) { s.db = db } diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 7c685c9f..c3c81aa2 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -1,7 +1,6 @@ package mssql import ( - "database/sql" "fmt" "reflect" "strconv" @@ -24,7 +23,7 @@ func init() { } type mssql struct { - db *sql.DB + db gorm.SQLCommon gorm.DefaultForeignKeyNamer } @@ -32,7 +31,7 @@ func (mssql) GetName() string { return "mssql" } -func (s *mssql) SetDB(db *sql.DB) { +func (s *mssql) SetDB(db gorm.SQLCommon) { s.db = db } diff --git a/interface.go b/interface.go index 7b02aa66..55128f7f 100644 --- a/interface.go +++ b/interface.go @@ -2,7 +2,8 @@ package gorm import "database/sql" -type sqlCommon interface { +// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB. +type SQLCommon interface { Exec(query string, args ...interface{}) (sql.Result, error) Prepare(query string) (*sql.Stmt, error) Query(query string, args ...interface{}) (*sql.Rows, error) diff --git a/main.go b/main.go index 9ea10909..9ae560a1 100644 --- a/main.go +++ b/main.go @@ -16,7 +16,7 @@ type DB struct { RowsAffected int64 // single db - db sqlCommon + db SQLCommon blockGlobalUpdate bool logMode int logger logger @@ -47,7 +47,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { return nil, err } var source string - var dbSQL *sql.DB + var dbSQL SQLCommon switch value := args[0].(type) { case string: @@ -59,8 +59,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { source = args[1].(string) } dbSQL, err = sql.Open(driver, source) - case *sql.DB: - source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String() + case SQLCommon: dbSQL = value } @@ -90,21 +89,27 @@ func (s *DB) New() *DB { return clone } -// Close close current db connection +type closer interface { + Close() error +} + +// Close close current db connection. If database connection is not an io.Closer, returns an error. func (s *DB) Close() error { - if db, ok := s.parent.db.(*sql.DB); ok { + if db, ok := s.parent.db.(closer); ok { return db.Close() } return errors.New("can't close current db") } // DB get `*sql.DB` from current connection +// If the underlying database connection is not a *sql.DB, returns nil func (s *DB) DB() *sql.DB { - return s.db.(*sql.DB) + db, _ := s.db.(*sql.DB) + return db } // CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code. -func (s *DB) CommonDB() sqlCommon { +func (s *DB) CommonDB() SQLCommon { return s.db } @@ -449,7 +454,7 @@ func (s *DB) Begin() *DB { c := s.clone() if db, ok := c.db.(sqlDb); ok { tx, err := db.Begin() - c.db = interface{}(tx).(sqlCommon) + c.db = interface{}(tx).(SQLCommon) c.AddError(err) } else { c.AddError(ErrCantStartTransaction) diff --git a/scope.go b/scope.go index 45f7185f..86fd1d42 100644 --- a/scope.go +++ b/scope.go @@ -58,7 +58,7 @@ func (scope *Scope) NewDB() *DB { } // SQLDB return *sql.DB -func (scope *Scope) SQLDB() sqlCommon { +func (scope *Scope) SQLDB() SQLCommon { return scope.db.db } @@ -391,7 +391,7 @@ func (scope *Scope) InstanceGet(name string) (interface{}, bool) { func (scope *Scope) Begin() *Scope { if db, ok := scope.SQLDB().(sqlDb); ok { if tx, err := db.Begin(); err == nil { - scope.db.db = interface{}(tx).(sqlCommon) + scope.db.db = interface{}(tx).(SQLCommon) scope.InstanceSet("gorm:started_transaction", true) } }