diff --git a/main.go b/main.go index 7ba904be..9ea10909 100644 --- a/main.go +++ b/main.go @@ -11,21 +11,23 @@ import ( // DB contains information for current db connection type DB struct { - Value interface{} - Error error - RowsAffected int64 - callbacks *Callback + Value interface{} + Error error + RowsAffected int64 + + // single db db sqlCommon - parent *DB - search *search + blockGlobalUpdate bool logMode int logger logger - dialect Dialect - singularTable bool - source string + search *search values map[string]interface{} - joinTableHandlers map[string]JoinTableHandler - blockGlobalUpdate bool + + // global db + parent *DB + callbacks *Callback + dialect Dialect + singularTable bool } // Open initialize a new db connection, need to import driver first, e.g: @@ -39,16 +41,13 @@ type DB struct { // // import _ "github.com/jinzhu/gorm/dialects/postgres" // // import _ "github.com/jinzhu/gorm/dialects/sqlite" // // import _ "github.com/jinzhu/gorm/dialects/mssql" -func Open(dialect string, args ...interface{}) (*DB, error) { - var db DB - var err error - +func Open(dialect string, args ...interface{}) (db *DB, err error) { if len(args) == 0 { err = errors.New("invalid database source") return nil, err } var source string - var dbSQL sqlCommon + var dbSQL *sql.DB switch value := args[0].(type) { case string: @@ -60,44 +59,27 @@ func Open(dialect string, args ...interface{}) (*DB, error) { source = args[1].(string) } dbSQL, err = sql.Open(driver, source) - case sqlCommon: + case *sql.DB: source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String() dbSQL = value } - db = DB{ - dialect: newDialect(dialect, dbSQL.(*sql.DB)), - logger: defaultLogger, - callbacks: DefaultCallback, - source: source, - values: map[string]interface{}{}, + db = &DB{ db: dbSQL, + logger: defaultLogger, + values: map[string]interface{}{}, + callbacks: DefaultCallback, + dialect: newDialect(dialect, dbSQL), } - db.parent = &db + db.parent = db if err == nil { - err = db.DB().Ping() // Send a ping to make sure the database connection is alive. - if err != nil { + // Send a ping to make sure the database connection is alive. + if err = db.DB().Ping(); err != nil { db.DB().Close() } } - - return &db, err -} - -// Close close current db connection -func (s *DB) Close() error { - return s.parent.db.(*sql.DB).Close() -} - -// DB get `*sql.DB` from current connection -func (s *DB) DB() *sql.DB { - return s.db.(*sql.DB) -} - -// Dialect get dialect -func (s *DB) Dialect() Dialect { - return s.parent.dialect + return } // New clone a new db connection without search conditions @@ -108,11 +90,17 @@ func (s *DB) New() *DB { return clone } -// NewScope create a scope for current operation -func (s *DB) NewScope(value interface{}) *Scope { - dbClone := s.clone() - dbClone.Value = value - return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value} +// Close close current db connection +func (s *DB) Close() error { + if db, ok := s.parent.db.(*sql.DB); ok { + return db.Close() + } + return errors.New("can't close current db") +} + +// DB get `*sql.DB` from current connection +func (s *DB) DB() *sql.DB { + return s.db.(*sql.DB) } // CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code. @@ -120,6 +108,11 @@ func (s *DB) CommonDB() sqlCommon { return s.db } +// Dialect get dialect +func (s *DB) Dialect() Dialect { + return s.parent.dialect +} + // Callback return `Callbacks` container, you could add/change/delete callbacks with it // db.Callback().Create().Register("update_created_at", updateCreated) // Refer https://jinzhu.github.io/gorm/development.html#callbacks @@ -161,6 +154,13 @@ func (s *DB) SingularTable(enable bool) { s.parent.singularTable = enable } +// NewScope create a scope for current operation +func (s *DB) NewScope(value interface{}) *Scope { + dbClone := s.clone() + dbClone.Value = value + return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value} +} + // Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query func (s *DB) Where(query interface{}, args ...interface{}) *DB { return s.clone().search.Where(query, args...).db @@ -691,7 +691,7 @@ func (s *DB) GetErrors() []error { } //////////////////////////////////////////////////////////////////////////////// -// Private Methods For *gorm.DB +// Private Methods For DB //////////////////////////////////////////////////////////////////////////////// func (s *DB) clone() *DB { @@ -721,7 +721,7 @@ func (s *DB) clone() *DB { } func (s *DB) print(v ...interface{}) { - s.logger.(logger).Print(v...) + s.logger.Print(v...) } func (s *DB) log(v ...interface{}) {