Replace all use of *sql.DB with sqlCommon

Exporting sqlCommon as SQLCommon.

This allows passing alternate implementations of the database connection, or wrapping the connection with middleware.  This change didn't change any usages of the database variables.  All usages were already only using the functions defined in SQLCommon.

This does cause a breaking change in Dialect, since *sql.DB was referenced in the interface.
This commit is contained in:
Russ Egan 2017-03-14 16:32:38 -04:00
parent 5409931a1b
commit 45f1a95051
6 changed files with 24 additions and 20 deletions

View File

@ -14,7 +14,7 @@ type Dialect interface {
GetName() string GetName() string
// SetDB set db for dialect // SetDB set db for dialect
SetDB(db *sql.DB) SetDB(db SQLCommon)
// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1 // BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
BindVar(i int) string BindVar(i int) string
@ -50,7 +50,7 @@ type Dialect interface {
var dialectsMap = map[string]Dialect{} var dialectsMap = map[string]Dialect{}
func newDialect(name string, db *sql.DB) Dialect { func newDialect(name string, db SQLCommon) Dialect {
if value, ok := dialectsMap[name]; ok { if value, ok := dialectsMap[name]; ok {
dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect) dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect)
dialect.SetDB(db) dialect.SetDB(db)

View File

@ -1,7 +1,6 @@
package gorm package gorm
import ( import (
"database/sql"
"fmt" "fmt"
"reflect" "reflect"
"regexp" "regexp"
@ -15,7 +14,7 @@ type DefaultForeignKeyNamer struct {
} }
type commonDialect struct { type commonDialect struct {
db *sql.DB db SQLCommon
DefaultForeignKeyNamer DefaultForeignKeyNamer
} }
@ -27,7 +26,7 @@ func (commonDialect) GetName() string {
return "common" return "common"
} }
func (s *commonDialect) SetDB(db *sql.DB) { func (s *commonDialect) SetDB(db SQLCommon) {
s.db = db s.db = db
} }

View File

@ -1,7 +1,6 @@
package mssql package mssql
import ( import (
"database/sql"
"fmt" "fmt"
"reflect" "reflect"
"strconv" "strconv"
@ -24,7 +23,7 @@ func init() {
} }
type mssql struct { type mssql struct {
db *sql.DB db gorm.SQLCommon
gorm.DefaultForeignKeyNamer gorm.DefaultForeignKeyNamer
} }
@ -32,7 +31,7 @@ func (mssql) GetName() string {
return "mssql" return "mssql"
} }
func (s *mssql) SetDB(db *sql.DB) { func (s *mssql) SetDB(db gorm.SQLCommon) {
s.db = db s.db = db
} }

View File

@ -2,7 +2,8 @@ package gorm
import "database/sql" import "database/sql"
type sqlCommon interface { // SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB.
type SQLCommon interface {
Exec(query string, args ...interface{}) (sql.Result, error) Exec(query string, args ...interface{}) (sql.Result, error)
Prepare(query string) (*sql.Stmt, error) Prepare(query string) (*sql.Stmt, error)
Query(query string, args ...interface{}) (*sql.Rows, error) Query(query string, args ...interface{}) (*sql.Rows, error)

23
main.go
View File

@ -16,7 +16,7 @@ type DB struct {
RowsAffected int64 RowsAffected int64
// single db // single db
db sqlCommon db SQLCommon
blockGlobalUpdate bool blockGlobalUpdate bool
logMode int logMode int
logger logger logger logger
@ -47,7 +47,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
return nil, err return nil, err
} }
var source string var source string
var dbSQL *sql.DB var dbSQL SQLCommon
switch value := args[0].(type) { switch value := args[0].(type) {
case string: case string:
@ -59,8 +59,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
source = args[1].(string) source = args[1].(string)
} }
dbSQL, err = sql.Open(driver, source) dbSQL, err = sql.Open(driver, source)
case *sql.DB: case SQLCommon:
source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
dbSQL = value dbSQL = value
} }
@ -90,21 +89,27 @@ func (s *DB) New() *DB {
return clone return clone
} }
// Close close current db connection type closer interface {
Close() error
}
// Close close current db connection. If database connection is not an io.Closer, returns an error.
func (s *DB) Close() error { func (s *DB) Close() error {
if db, ok := s.parent.db.(*sql.DB); ok { if db, ok := s.parent.db.(closer); ok {
return db.Close() return db.Close()
} }
return errors.New("can't close current db") return errors.New("can't close current db")
} }
// DB get `*sql.DB` from current connection // DB get `*sql.DB` from current connection
// If the underlying database connection is not a *sql.DB, returns nil
func (s *DB) DB() *sql.DB { func (s *DB) DB() *sql.DB {
return s.db.(*sql.DB) db, _ := s.db.(*sql.DB)
return db
} }
// CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code. // CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code.
func (s *DB) CommonDB() sqlCommon { func (s *DB) CommonDB() SQLCommon {
return s.db return s.db
} }
@ -449,7 +454,7 @@ func (s *DB) Begin() *DB {
c := s.clone() c := s.clone()
if db, ok := c.db.(sqlDb); ok { if db, ok := c.db.(sqlDb); ok {
tx, err := db.Begin() tx, err := db.Begin()
c.db = interface{}(tx).(sqlCommon) c.db = interface{}(tx).(SQLCommon)
c.AddError(err) c.AddError(err)
} else { } else {
c.AddError(ErrCantStartTransaction) c.AddError(ErrCantStartTransaction)

View File

@ -58,7 +58,7 @@ func (scope *Scope) NewDB() *DB {
} }
// SQLDB return *sql.DB // SQLDB return *sql.DB
func (scope *Scope) SQLDB() sqlCommon { func (scope *Scope) SQLDB() SQLCommon {
return scope.db.db return scope.db.db
} }
@ -391,7 +391,7 @@ func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
func (scope *Scope) Begin() *Scope { func (scope *Scope) Begin() *Scope {
if db, ok := scope.SQLDB().(sqlDb); ok { if db, ok := scope.SQLDB().(sqlDb); ok {
if tx, err := db.Begin(); err == nil { if tx, err := db.Begin(); err == nil {
scope.db.db = interface{}(tx).(sqlCommon) scope.db.db = interface{}(tx).(SQLCommon)
scope.InstanceSet("gorm:started_transaction", true) scope.InstanceSet("gorm:started_transaction", true)
} }
} }