This commit is contained in:
Jinzhu 2017-02-05 13:31:31 +08:00
parent 23abd03a95
commit 1558522aaa
1 changed files with 50 additions and 50 deletions

100
main.go
View File

@ -11,21 +11,23 @@ import (
// DB contains information for current db connection
type DB struct {
Value interface{}
Error error
RowsAffected int64
callbacks *Callback
Value interface{}
Error error
RowsAffected int64
// single db
db sqlCommon
parent *DB
search *search
blockGlobalUpdate bool
logMode int
logger logger
dialect Dialect
singularTable bool
source string
search *search
values map[string]interface{}
joinTableHandlers map[string]JoinTableHandler
blockGlobalUpdate bool
// global db
parent *DB
callbacks *Callback
dialect Dialect
singularTable bool
}
// Open initialize a new db connection, need to import driver first, e.g:
@ -39,16 +41,13 @@ type DB struct {
// // import _ "github.com/jinzhu/gorm/dialects/postgres"
// // import _ "github.com/jinzhu/gorm/dialects/sqlite"
// // import _ "github.com/jinzhu/gorm/dialects/mssql"
func Open(dialect string, args ...interface{}) (*DB, error) {
var db DB
var err error
func Open(dialect string, args ...interface{}) (db *DB, err error) {
if len(args) == 0 {
err = errors.New("invalid database source")
return nil, err
}
var source string
var dbSQL sqlCommon
var dbSQL *sql.DB
switch value := args[0].(type) {
case string:
@ -60,44 +59,27 @@ func Open(dialect string, args ...interface{}) (*DB, error) {
source = args[1].(string)
}
dbSQL, err = sql.Open(driver, source)
case sqlCommon:
case *sql.DB:
source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
dbSQL = value
}
db = DB{
dialect: newDialect(dialect, dbSQL.(*sql.DB)),
logger: defaultLogger,
callbacks: DefaultCallback,
source: source,
values: map[string]interface{}{},
db = &DB{
db: dbSQL,
logger: defaultLogger,
values: map[string]interface{}{},
callbacks: DefaultCallback,
dialect: newDialect(dialect, dbSQL),
}
db.parent = &db
db.parent = db
if err == nil {
err = db.DB().Ping() // Send a ping to make sure the database connection is alive.
if err != nil {
// Send a ping to make sure the database connection is alive.
if err = db.DB().Ping(); err != nil {
db.DB().Close()
}
}
return &db, err
}
// Close close current db connection
func (s *DB) Close() error {
return s.parent.db.(*sql.DB).Close()
}
// DB get `*sql.DB` from current connection
func (s *DB) DB() *sql.DB {
return s.db.(*sql.DB)
}
// Dialect get dialect
func (s *DB) Dialect() Dialect {
return s.parent.dialect
return
}
// New clone a new db connection without search conditions
@ -108,11 +90,17 @@ func (s *DB) New() *DB {
return clone
}
// NewScope create a scope for current operation
func (s *DB) NewScope(value interface{}) *Scope {
dbClone := s.clone()
dbClone.Value = value
return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value}
// Close close current db connection
func (s *DB) Close() error {
if db, ok := s.parent.db.(*sql.DB); ok {
return db.Close()
}
return errors.New("can't close current db")
}
// DB get `*sql.DB` from current connection
func (s *DB) DB() *sql.DB {
return s.db.(*sql.DB)
}
// CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code.
@ -120,6 +108,11 @@ func (s *DB) CommonDB() sqlCommon {
return s.db
}
// Dialect get dialect
func (s *DB) Dialect() Dialect {
return s.parent.dialect
}
// Callback return `Callbacks` container, you could add/change/delete callbacks with it
// db.Callback().Create().Register("update_created_at", updateCreated)
// Refer https://jinzhu.github.io/gorm/development.html#callbacks
@ -161,6 +154,13 @@ func (s *DB) SingularTable(enable bool) {
s.parent.singularTable = enable
}
// NewScope create a scope for current operation
func (s *DB) NewScope(value interface{}) *Scope {
dbClone := s.clone()
dbClone.Value = value
return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value}
}
// Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query
func (s *DB) Where(query interface{}, args ...interface{}) *DB {
return s.clone().search.Where(query, args...).db
@ -691,7 +691,7 @@ func (s *DB) GetErrors() []error {
}
////////////////////////////////////////////////////////////////////////////////
// Private Methods For *gorm.DB
// Private Methods For DB
////////////////////////////////////////////////////////////////////////////////
func (s *DB) clone() *DB {
@ -721,7 +721,7 @@ func (s *DB) clone() *DB {
}
func (s *DB) print(v ...interface{}) {
s.logger.(logger).Print(v...)
s.logger.Print(v...)
}
func (s *DB) log(v ...interface{}) {