forked from mirror/gorm
Merge pull request #1384 from ansel1/master
Replace all use of *sql.DB with sqlCommon
This commit is contained in:
commit
8b058a707f
|
@ -14,7 +14,7 @@ type Dialect interface {
|
|||
GetName() string
|
||||
|
||||
// SetDB set db for dialect
|
||||
SetDB(db *sql.DB)
|
||||
SetDB(db SQLCommon)
|
||||
|
||||
// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
|
||||
BindVar(i int) string
|
||||
|
@ -50,7 +50,7 @@ type Dialect interface {
|
|||
|
||||
var dialectsMap = map[string]Dialect{}
|
||||
|
||||
func newDialect(name string, db *sql.DB) Dialect {
|
||||
func newDialect(name string, db SQLCommon) Dialect {
|
||||
if value, ok := dialectsMap[name]; ok {
|
||||
dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect)
|
||||
dialect.SetDB(db)
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
|
@ -15,7 +14,7 @@ type DefaultForeignKeyNamer struct {
|
|||
}
|
||||
|
||||
type commonDialect struct {
|
||||
db *sql.DB
|
||||
db SQLCommon
|
||||
DefaultForeignKeyNamer
|
||||
}
|
||||
|
||||
|
@ -27,7 +26,7 @@ func (commonDialect) GetName() string {
|
|||
return "common"
|
||||
}
|
||||
|
||||
func (s *commonDialect) SetDB(db *sql.DB) {
|
||||
func (s *commonDialect) SetDB(db SQLCommon) {
|
||||
s.db = db
|
||||
}
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package mssql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
|
@ -24,7 +23,7 @@ func init() {
|
|||
}
|
||||
|
||||
type mssql struct {
|
||||
db *sql.DB
|
||||
db gorm.SQLCommon
|
||||
gorm.DefaultForeignKeyNamer
|
||||
}
|
||||
|
||||
|
@ -32,7 +31,7 @@ func (mssql) GetName() string {
|
|||
return "mssql"
|
||||
}
|
||||
|
||||
func (s *mssql) SetDB(db *sql.DB) {
|
||||
func (s *mssql) SetDB(db gorm.SQLCommon) {
|
||||
s.db = db
|
||||
}
|
||||
|
||||
|
|
|
@ -2,7 +2,8 @@ package gorm
|
|||
|
||||
import "database/sql"
|
||||
|
||||
type sqlCommon interface {
|
||||
// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB.
|
||||
type SQLCommon interface {
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
Prepare(query string) (*sql.Stmt, error)
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
|
|
23
main.go
23
main.go
|
@ -16,7 +16,7 @@ type DB struct {
|
|||
RowsAffected int64
|
||||
|
||||
// single db
|
||||
db sqlCommon
|
||||
db SQLCommon
|
||||
blockGlobalUpdate bool
|
||||
logMode int
|
||||
logger logger
|
||||
|
@ -47,7 +47,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
|
|||
return nil, err
|
||||
}
|
||||
var source string
|
||||
var dbSQL *sql.DB
|
||||
var dbSQL SQLCommon
|
||||
|
||||
switch value := args[0].(type) {
|
||||
case string:
|
||||
|
@ -59,8 +59,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
|
|||
source = args[1].(string)
|
||||
}
|
||||
dbSQL, err = sql.Open(driver, source)
|
||||
case *sql.DB:
|
||||
source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
|
||||
case SQLCommon:
|
||||
dbSQL = value
|
||||
}
|
||||
|
||||
|
@ -90,21 +89,27 @@ func (s *DB) New() *DB {
|
|||
return clone
|
||||
}
|
||||
|
||||
// Close close current db connection
|
||||
type closer interface {
|
||||
Close() error
|
||||
}
|
||||
|
||||
// Close close current db connection. If database connection is not an io.Closer, returns an error.
|
||||
func (s *DB) Close() error {
|
||||
if db, ok := s.parent.db.(*sql.DB); ok {
|
||||
if db, ok := s.parent.db.(closer); ok {
|
||||
return db.Close()
|
||||
}
|
||||
return errors.New("can't close current db")
|
||||
}
|
||||
|
||||
// DB get `*sql.DB` from current connection
|
||||
// If the underlying database connection is not a *sql.DB, returns nil
|
||||
func (s *DB) DB() *sql.DB {
|
||||
return s.db.(*sql.DB)
|
||||
db, _ := s.db.(*sql.DB)
|
||||
return db
|
||||
}
|
||||
|
||||
// CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code.
|
||||
func (s *DB) CommonDB() sqlCommon {
|
||||
func (s *DB) CommonDB() SQLCommon {
|
||||
return s.db
|
||||
}
|
||||
|
||||
|
@ -449,7 +454,7 @@ func (s *DB) Begin() *DB {
|
|||
c := s.clone()
|
||||
if db, ok := c.db.(sqlDb); ok {
|
||||
tx, err := db.Begin()
|
||||
c.db = interface{}(tx).(sqlCommon)
|
||||
c.db = interface{}(tx).(SQLCommon)
|
||||
c.AddError(err)
|
||||
} else {
|
||||
c.AddError(ErrCantStartTransaction)
|
||||
|
|
4
scope.go
4
scope.go
|
@ -58,7 +58,7 @@ func (scope *Scope) NewDB() *DB {
|
|||
}
|
||||
|
||||
// SQLDB return *sql.DB
|
||||
func (scope *Scope) SQLDB() sqlCommon {
|
||||
func (scope *Scope) SQLDB() SQLCommon {
|
||||
return scope.db.db
|
||||
}
|
||||
|
||||
|
@ -391,7 +391,7 @@ func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
|
|||
func (scope *Scope) Begin() *Scope {
|
||||
if db, ok := scope.SQLDB().(sqlDb); ok {
|
||||
if tx, err := db.Begin(); err == nil {
|
||||
scope.db.db = interface{}(tx).(sqlCommon)
|
||||
scope.db.db = interface{}(tx).(SQLCommon)
|
||||
scope.InstanceSet("gorm:started_transaction", true)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue