forked from mirror/gorm
Merge pull request #1384 from ansel1/master
Replace all use of *sql.DB with sqlCommon
This commit is contained in:
commit
8b058a707f
|
@ -14,7 +14,7 @@ type Dialect interface {
|
||||||
GetName() string
|
GetName() string
|
||||||
|
|
||||||
// SetDB set db for dialect
|
// SetDB set db for dialect
|
||||||
SetDB(db *sql.DB)
|
SetDB(db SQLCommon)
|
||||||
|
|
||||||
// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
|
// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
|
||||||
BindVar(i int) string
|
BindVar(i int) string
|
||||||
|
@ -50,7 +50,7 @@ type Dialect interface {
|
||||||
|
|
||||||
var dialectsMap = map[string]Dialect{}
|
var dialectsMap = map[string]Dialect{}
|
||||||
|
|
||||||
func newDialect(name string, db *sql.DB) Dialect {
|
func newDialect(name string, db SQLCommon) Dialect {
|
||||||
if value, ok := dialectsMap[name]; ok {
|
if value, ok := dialectsMap[name]; ok {
|
||||||
dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect)
|
dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect)
|
||||||
dialect.SetDB(db)
|
dialect.SetDB(db)
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
@ -15,7 +14,7 @@ type DefaultForeignKeyNamer struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type commonDialect struct {
|
type commonDialect struct {
|
||||||
db *sql.DB
|
db SQLCommon
|
||||||
DefaultForeignKeyNamer
|
DefaultForeignKeyNamer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -27,7 +26,7 @@ func (commonDialect) GetName() string {
|
||||||
return "common"
|
return "common"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *commonDialect) SetDB(db *sql.DB) {
|
func (s *commonDialect) SetDB(db SQLCommon) {
|
||||||
s.db = db
|
s.db = db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package mssql
|
package mssql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
@ -24,7 +23,7 @@ func init() {
|
||||||
}
|
}
|
||||||
|
|
||||||
type mssql struct {
|
type mssql struct {
|
||||||
db *sql.DB
|
db gorm.SQLCommon
|
||||||
gorm.DefaultForeignKeyNamer
|
gorm.DefaultForeignKeyNamer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -32,7 +31,7 @@ func (mssql) GetName() string {
|
||||||
return "mssql"
|
return "mssql"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mssql) SetDB(db *sql.DB) {
|
func (s *mssql) SetDB(db gorm.SQLCommon) {
|
||||||
s.db = db
|
s.db = db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,8 @@ package gorm
|
||||||
|
|
||||||
import "database/sql"
|
import "database/sql"
|
||||||
|
|
||||||
type sqlCommon interface {
|
// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB.
|
||||||
|
type SQLCommon interface {
|
||||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||||
Prepare(query string) (*sql.Stmt, error)
|
Prepare(query string) (*sql.Stmt, error)
|
||||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||||
|
|
23
main.go
23
main.go
|
@ -16,7 +16,7 @@ type DB struct {
|
||||||
RowsAffected int64
|
RowsAffected int64
|
||||||
|
|
||||||
// single db
|
// single db
|
||||||
db sqlCommon
|
db SQLCommon
|
||||||
blockGlobalUpdate bool
|
blockGlobalUpdate bool
|
||||||
logMode int
|
logMode int
|
||||||
logger logger
|
logger logger
|
||||||
|
@ -47,7 +47,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var source string
|
var source string
|
||||||
var dbSQL *sql.DB
|
var dbSQL SQLCommon
|
||||||
|
|
||||||
switch value := args[0].(type) {
|
switch value := args[0].(type) {
|
||||||
case string:
|
case string:
|
||||||
|
@ -59,8 +59,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
|
||||||
source = args[1].(string)
|
source = args[1].(string)
|
||||||
}
|
}
|
||||||
dbSQL, err = sql.Open(driver, source)
|
dbSQL, err = sql.Open(driver, source)
|
||||||
case *sql.DB:
|
case SQLCommon:
|
||||||
source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
|
|
||||||
dbSQL = value
|
dbSQL = value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -90,21 +89,27 @@ func (s *DB) New() *DB {
|
||||||
return clone
|
return clone
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close close current db connection
|
type closer interface {
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close close current db connection. If database connection is not an io.Closer, returns an error.
|
||||||
func (s *DB) Close() error {
|
func (s *DB) Close() error {
|
||||||
if db, ok := s.parent.db.(*sql.DB); ok {
|
if db, ok := s.parent.db.(closer); ok {
|
||||||
return db.Close()
|
return db.Close()
|
||||||
}
|
}
|
||||||
return errors.New("can't close current db")
|
return errors.New("can't close current db")
|
||||||
}
|
}
|
||||||
|
|
||||||
// DB get `*sql.DB` from current connection
|
// DB get `*sql.DB` from current connection
|
||||||
|
// If the underlying database connection is not a *sql.DB, returns nil
|
||||||
func (s *DB) DB() *sql.DB {
|
func (s *DB) DB() *sql.DB {
|
||||||
return s.db.(*sql.DB)
|
db, _ := s.db.(*sql.DB)
|
||||||
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
// CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code.
|
// CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code.
|
||||||
func (s *DB) CommonDB() sqlCommon {
|
func (s *DB) CommonDB() SQLCommon {
|
||||||
return s.db
|
return s.db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -449,7 +454,7 @@ func (s *DB) Begin() *DB {
|
||||||
c := s.clone()
|
c := s.clone()
|
||||||
if db, ok := c.db.(sqlDb); ok {
|
if db, ok := c.db.(sqlDb); ok {
|
||||||
tx, err := db.Begin()
|
tx, err := db.Begin()
|
||||||
c.db = interface{}(tx).(sqlCommon)
|
c.db = interface{}(tx).(SQLCommon)
|
||||||
c.AddError(err)
|
c.AddError(err)
|
||||||
} else {
|
} else {
|
||||||
c.AddError(ErrCantStartTransaction)
|
c.AddError(ErrCantStartTransaction)
|
||||||
|
|
4
scope.go
4
scope.go
|
@ -58,7 +58,7 @@ func (scope *Scope) NewDB() *DB {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SQLDB return *sql.DB
|
// SQLDB return *sql.DB
|
||||||
func (scope *Scope) SQLDB() sqlCommon {
|
func (scope *Scope) SQLDB() SQLCommon {
|
||||||
return scope.db.db
|
return scope.db.db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -391,7 +391,7 @@ func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
|
||||||
func (scope *Scope) Begin() *Scope {
|
func (scope *Scope) Begin() *Scope {
|
||||||
if db, ok := scope.SQLDB().(sqlDb); ok {
|
if db, ok := scope.SQLDB().(sqlDb); ok {
|
||||||
if tx, err := db.Begin(); err == nil {
|
if tx, err := db.Begin(); err == nil {
|
||||||
scope.db.db = interface{}(tx).(sqlCommon)
|
scope.db.db = interface{}(tx).(SQLCommon)
|
||||||
scope.InstanceSet("gorm:started_transaction", true)
|
scope.InstanceSet("gorm:started_transaction", true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue