Merge pull request #1384 from ansel1/master

Replace all use of *sql.DB with sqlCommon
This commit is contained in:
Jinzhu 2017-03-16 22:15:57 +08:00 committed by GitHub
commit 8b058a707f
6 changed files with 24 additions and 20 deletions

View File

@ -14,7 +14,7 @@ type Dialect interface {
GetName() string
// SetDB set db for dialect
SetDB(db *sql.DB)
SetDB(db SQLCommon)
// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
BindVar(i int) string
@ -50,7 +50,7 @@ type Dialect interface {
var dialectsMap = map[string]Dialect{}
func newDialect(name string, db *sql.DB) Dialect {
func newDialect(name string, db SQLCommon) Dialect {
if value, ok := dialectsMap[name]; ok {
dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect)
dialect.SetDB(db)

View File

@ -1,7 +1,6 @@
package gorm
import (
"database/sql"
"fmt"
"reflect"
"regexp"
@ -15,7 +14,7 @@ type DefaultForeignKeyNamer struct {
}
type commonDialect struct {
db *sql.DB
db SQLCommon
DefaultForeignKeyNamer
}
@ -27,7 +26,7 @@ func (commonDialect) GetName() string {
return "common"
}
func (s *commonDialect) SetDB(db *sql.DB) {
func (s *commonDialect) SetDB(db SQLCommon) {
s.db = db
}

View File

@ -1,7 +1,6 @@
package mssql
import (
"database/sql"
"fmt"
"reflect"
"strconv"
@ -24,7 +23,7 @@ func init() {
}
type mssql struct {
db *sql.DB
db gorm.SQLCommon
gorm.DefaultForeignKeyNamer
}
@ -32,7 +31,7 @@ func (mssql) GetName() string {
return "mssql"
}
func (s *mssql) SetDB(db *sql.DB) {
func (s *mssql) SetDB(db gorm.SQLCommon) {
s.db = db
}

View File

@ -2,7 +2,8 @@ package gorm
import "database/sql"
type sqlCommon interface {
// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB.
type SQLCommon interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Prepare(query string) (*sql.Stmt, error)
Query(query string, args ...interface{}) (*sql.Rows, error)

23
main.go
View File

@ -16,7 +16,7 @@ type DB struct {
RowsAffected int64
// single db
db sqlCommon
db SQLCommon
blockGlobalUpdate bool
logMode int
logger logger
@ -47,7 +47,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
return nil, err
}
var source string
var dbSQL *sql.DB
var dbSQL SQLCommon
switch value := args[0].(type) {
case string:
@ -59,8 +59,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
source = args[1].(string)
}
dbSQL, err = sql.Open(driver, source)
case *sql.DB:
source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
case SQLCommon:
dbSQL = value
}
@ -90,21 +89,27 @@ func (s *DB) New() *DB {
return clone
}
// Close close current db connection
type closer interface {
Close() error
}
// Close close current db connection. If database connection is not an io.Closer, returns an error.
func (s *DB) Close() error {
if db, ok := s.parent.db.(*sql.DB); ok {
if db, ok := s.parent.db.(closer); ok {
return db.Close()
}
return errors.New("can't close current db")
}
// DB get `*sql.DB` from current connection
// If the underlying database connection is not a *sql.DB, returns nil
func (s *DB) DB() *sql.DB {
return s.db.(*sql.DB)
db, _ := s.db.(*sql.DB)
return db
}
// CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code.
func (s *DB) CommonDB() sqlCommon {
func (s *DB) CommonDB() SQLCommon {
return s.db
}
@ -449,7 +454,7 @@ func (s *DB) Begin() *DB {
c := s.clone()
if db, ok := c.db.(sqlDb); ok {
tx, err := db.Begin()
c.db = interface{}(tx).(sqlCommon)
c.db = interface{}(tx).(SQLCommon)
c.AddError(err)
} else {
c.AddError(ErrCantStartTransaction)

View File

@ -58,7 +58,7 @@ func (scope *Scope) NewDB() *DB {
}
// SQLDB return *sql.DB
func (scope *Scope) SQLDB() sqlCommon {
func (scope *Scope) SQLDB() SQLCommon {
return scope.db.db
}
@ -391,7 +391,7 @@ func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
func (scope *Scope) Begin() *Scope {
if db, ok := scope.SQLDB().(sqlDb); ok {
if tx, err := db.Begin(); err == nil {
scope.db.db = interface{}(tx).(sqlCommon)
scope.db.db = interface{}(tx).(SQLCommon)
scope.InstanceSet("gorm:started_transaction", true)
}
}