Merge pull request #1384 from ansel1/master

Replace all use of *sql.DB with sqlCommon
This commit is contained in:
Jinzhu 2017-03-16 22:15:57 +08:00 committed by GitHub
commit 8b058a707f
6 changed files with 24 additions and 20 deletions

View File

@ -14,7 +14,7 @@ type Dialect interface {
GetName() string GetName() string
// SetDB set db for dialect // SetDB set db for dialect
SetDB(db *sql.DB) SetDB(db SQLCommon)
// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1 // BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
BindVar(i int) string BindVar(i int) string
@ -50,7 +50,7 @@ type Dialect interface {
var dialectsMap = map[string]Dialect{} var dialectsMap = map[string]Dialect{}
func newDialect(name string, db *sql.DB) Dialect { func newDialect(name string, db SQLCommon) Dialect {
if value, ok := dialectsMap[name]; ok { if value, ok := dialectsMap[name]; ok {
dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect) dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect)
dialect.SetDB(db) dialect.SetDB(db)

View File

@ -1,7 +1,6 @@
package gorm package gorm
import ( import (
"database/sql"
"fmt" "fmt"
"reflect" "reflect"
"regexp" "regexp"
@ -15,7 +14,7 @@ type DefaultForeignKeyNamer struct {
} }
type commonDialect struct { type commonDialect struct {
db *sql.DB db SQLCommon
DefaultForeignKeyNamer DefaultForeignKeyNamer
} }
@ -27,7 +26,7 @@ func (commonDialect) GetName() string {
return "common" return "common"
} }
func (s *commonDialect) SetDB(db *sql.DB) { func (s *commonDialect) SetDB(db SQLCommon) {
s.db = db s.db = db
} }

View File

@ -1,7 +1,6 @@
package mssql package mssql
import ( import (
"database/sql"
"fmt" "fmt"
"reflect" "reflect"
"strconv" "strconv"
@ -24,7 +23,7 @@ func init() {
} }
type mssql struct { type mssql struct {
db *sql.DB db gorm.SQLCommon
gorm.DefaultForeignKeyNamer gorm.DefaultForeignKeyNamer
} }
@ -32,7 +31,7 @@ func (mssql) GetName() string {
return "mssql" return "mssql"
} }
func (s *mssql) SetDB(db *sql.DB) { func (s *mssql) SetDB(db gorm.SQLCommon) {
s.db = db s.db = db
} }

View File

@ -2,7 +2,8 @@ package gorm
import "database/sql" import "database/sql"
type sqlCommon interface { // SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB.
type SQLCommon interface {
Exec(query string, args ...interface{}) (sql.Result, error) Exec(query string, args ...interface{}) (sql.Result, error)
Prepare(query string) (*sql.Stmt, error) Prepare(query string) (*sql.Stmt, error)
Query(query string, args ...interface{}) (*sql.Rows, error) Query(query string, args ...interface{}) (*sql.Rows, error)

23
main.go
View File

@ -16,7 +16,7 @@ type DB struct {
RowsAffected int64 RowsAffected int64
// single db // single db
db sqlCommon db SQLCommon
blockGlobalUpdate bool blockGlobalUpdate bool
logMode int logMode int
logger logger logger logger
@ -47,7 +47,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
return nil, err return nil, err
} }
var source string var source string
var dbSQL *sql.DB var dbSQL SQLCommon
switch value := args[0].(type) { switch value := args[0].(type) {
case string: case string:
@ -59,8 +59,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
source = args[1].(string) source = args[1].(string)
} }
dbSQL, err = sql.Open(driver, source) dbSQL, err = sql.Open(driver, source)
case *sql.DB: case SQLCommon:
source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
dbSQL = value dbSQL = value
} }
@ -90,21 +89,27 @@ func (s *DB) New() *DB {
return clone return clone
} }
// Close close current db connection type closer interface {
Close() error
}
// Close close current db connection. If database connection is not an io.Closer, returns an error.
func (s *DB) Close() error { func (s *DB) Close() error {
if db, ok := s.parent.db.(*sql.DB); ok { if db, ok := s.parent.db.(closer); ok {
return db.Close() return db.Close()
} }
return errors.New("can't close current db") return errors.New("can't close current db")
} }
// DB get `*sql.DB` from current connection // DB get `*sql.DB` from current connection
// If the underlying database connection is not a *sql.DB, returns nil
func (s *DB) DB() *sql.DB { func (s *DB) DB() *sql.DB {
return s.db.(*sql.DB) db, _ := s.db.(*sql.DB)
return db
} }
// CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code. // CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code.
func (s *DB) CommonDB() sqlCommon { func (s *DB) CommonDB() SQLCommon {
return s.db return s.db
} }
@ -449,7 +454,7 @@ func (s *DB) Begin() *DB {
c := s.clone() c := s.clone()
if db, ok := c.db.(sqlDb); ok { if db, ok := c.db.(sqlDb); ok {
tx, err := db.Begin() tx, err := db.Begin()
c.db = interface{}(tx).(sqlCommon) c.db = interface{}(tx).(SQLCommon)
c.AddError(err) c.AddError(err)
} else { } else {
c.AddError(ErrCantStartTransaction) c.AddError(ErrCantStartTransaction)

View File

@ -58,7 +58,7 @@ func (scope *Scope) NewDB() *DB {
} }
// SQLDB return *sql.DB // SQLDB return *sql.DB
func (scope *Scope) SQLDB() sqlCommon { func (scope *Scope) SQLDB() SQLCommon {
return scope.db.db return scope.db.db
} }
@ -391,7 +391,7 @@ func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
func (scope *Scope) Begin() *Scope { func (scope *Scope) Begin() *Scope {
if db, ok := scope.SQLDB().(sqlDb); ok { if db, ok := scope.SQLDB().(sqlDb); ok {
if tx, err := db.Begin(); err == nil { if tx, err := db.Begin(); err == nil {
scope.db.db = interface{}(tx).(sqlCommon) scope.db.db = interface{}(tx).(SQLCommon)
scope.InstanceSet("gorm:started_transaction", true) scope.InstanceSet("gorm:started_transaction", true)
} }
} }