This commit is contained in:
Jinzhu 2017-02-05 13:31:31 +08:00
parent 23abd03a95
commit 1558522aaa
1 changed files with 50 additions and 50 deletions

100
main.go
View File

@ -11,21 +11,23 @@ import (
// DB contains information for current db connection // DB contains information for current db connection
type DB struct { type DB struct {
Value interface{} Value interface{}
Error error Error error
RowsAffected int64 RowsAffected int64
callbacks *Callback
// single db
db sqlCommon db sqlCommon
parent *DB blockGlobalUpdate bool
search *search
logMode int logMode int
logger logger logger logger
dialect Dialect search *search
singularTable bool
source string
values map[string]interface{} values map[string]interface{}
joinTableHandlers map[string]JoinTableHandler
blockGlobalUpdate bool // global db
parent *DB
callbacks *Callback
dialect Dialect
singularTable bool
} }
// Open initialize a new db connection, need to import driver first, e.g: // Open initialize a new db connection, need to import driver first, e.g:
@ -39,16 +41,13 @@ type DB struct {
// // import _ "github.com/jinzhu/gorm/dialects/postgres" // // import _ "github.com/jinzhu/gorm/dialects/postgres"
// // import _ "github.com/jinzhu/gorm/dialects/sqlite" // // import _ "github.com/jinzhu/gorm/dialects/sqlite"
// // import _ "github.com/jinzhu/gorm/dialects/mssql" // // import _ "github.com/jinzhu/gorm/dialects/mssql"
func Open(dialect string, args ...interface{}) (*DB, error) { func Open(dialect string, args ...interface{}) (db *DB, err error) {
var db DB
var err error
if len(args) == 0 { if len(args) == 0 {
err = errors.New("invalid database source") err = errors.New("invalid database source")
return nil, err return nil, err
} }
var source string var source string
var dbSQL sqlCommon var dbSQL *sql.DB
switch value := args[0].(type) { switch value := args[0].(type) {
case string: case string:
@ -60,44 +59,27 @@ func Open(dialect string, args ...interface{}) (*DB, error) {
source = args[1].(string) source = args[1].(string)
} }
dbSQL, err = sql.Open(driver, source) dbSQL, err = sql.Open(driver, source)
case sqlCommon: case *sql.DB:
source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String() source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
dbSQL = value dbSQL = value
} }
db = DB{ db = &DB{
dialect: newDialect(dialect, dbSQL.(*sql.DB)),
logger: defaultLogger,
callbacks: DefaultCallback,
source: source,
values: map[string]interface{}{},
db: dbSQL, db: dbSQL,
logger: defaultLogger,
values: map[string]interface{}{},
callbacks: DefaultCallback,
dialect: newDialect(dialect, dbSQL),
} }
db.parent = &db db.parent = db
if err == nil { if err == nil {
err = db.DB().Ping() // Send a ping to make sure the database connection is alive. // Send a ping to make sure the database connection is alive.
if err != nil { if err = db.DB().Ping(); err != nil {
db.DB().Close() db.DB().Close()
} }
} }
return
return &db, err
}
// Close close current db connection
func (s *DB) Close() error {
return s.parent.db.(*sql.DB).Close()
}
// DB get `*sql.DB` from current connection
func (s *DB) DB() *sql.DB {
return s.db.(*sql.DB)
}
// Dialect get dialect
func (s *DB) Dialect() Dialect {
return s.parent.dialect
} }
// New clone a new db connection without search conditions // New clone a new db connection without search conditions
@ -108,11 +90,17 @@ func (s *DB) New() *DB {
return clone return clone
} }
// NewScope create a scope for current operation // Close close current db connection
func (s *DB) NewScope(value interface{}) *Scope { func (s *DB) Close() error {
dbClone := s.clone() if db, ok := s.parent.db.(*sql.DB); ok {
dbClone.Value = value return db.Close()
return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value} }
return errors.New("can't close current db")
}
// DB get `*sql.DB` from current connection
func (s *DB) DB() *sql.DB {
return s.db.(*sql.DB)
} }
// CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code. // CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code.
@ -120,6 +108,11 @@ func (s *DB) CommonDB() sqlCommon {
return s.db return s.db
} }
// Dialect get dialect
func (s *DB) Dialect() Dialect {
return s.parent.dialect
}
// Callback return `Callbacks` container, you could add/change/delete callbacks with it // Callback return `Callbacks` container, you could add/change/delete callbacks with it
// db.Callback().Create().Register("update_created_at", updateCreated) // db.Callback().Create().Register("update_created_at", updateCreated)
// Refer https://jinzhu.github.io/gorm/development.html#callbacks // Refer https://jinzhu.github.io/gorm/development.html#callbacks
@ -161,6 +154,13 @@ func (s *DB) SingularTable(enable bool) {
s.parent.singularTable = enable s.parent.singularTable = enable
} }
// NewScope create a scope for current operation
func (s *DB) NewScope(value interface{}) *Scope {
dbClone := s.clone()
dbClone.Value = value
return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value}
}
// Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query // Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query
func (s *DB) Where(query interface{}, args ...interface{}) *DB { func (s *DB) Where(query interface{}, args ...interface{}) *DB {
return s.clone().search.Where(query, args...).db return s.clone().search.Where(query, args...).db
@ -691,7 +691,7 @@ func (s *DB) GetErrors() []error {
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// Private Methods For *gorm.DB // Private Methods For DB
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
func (s *DB) clone() *DB { func (s *DB) clone() *DB {
@ -721,7 +721,7 @@ func (s *DB) clone() *DB {
} }
func (s *DB) print(v ...interface{}) { func (s *DB) print(v ...interface{}) {
s.logger.(logger).Print(v...) s.logger.Print(v...)
} }
func (s *DB) log(v ...interface{}) { func (s *DB) log(v ...interface{}) {