Refactoring API for plugin system

This commit is contained in:
Jinzhu 2015-02-26 12:35:33 +08:00
parent 087b7083ad
commit ce72988e96
10 changed files with 50 additions and 43 deletions

View File

@ -56,7 +56,7 @@ func Create(scope *Scope) {
// execute create sql // execute create sql
if scope.Dialect().SupportLastInsertId() { if scope.Dialect().SupportLastInsertId() {
if result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
id, err := result.LastInsertId() id, err := result.LastInsertId()
if scope.Err(err) == nil { if scope.Err(err) == nil {
scope.db.RowsAffected, _ = result.RowsAffected() scope.db.RowsAffected, _ = result.RowsAffected()
@ -67,10 +67,10 @@ func Create(scope *Scope) {
} }
} else { } else {
if primaryField == nil { if primaryField == nil {
if results, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); err != nil { if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err != nil {
scope.db.RowsAffected, _ = results.RowsAffected() scope.db.RowsAffected, _ = results.RowsAffected()
} }
} else if scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())) == nil { } else if scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())) == nil {
scope.db.RowsAffected = 1 scope.db.RowsAffected = 1
} }
} }

View File

@ -40,7 +40,7 @@ func Query(scope *Scope) {
scope.prepareQuerySql() scope.prepareQuerySql()
if !scope.HasError() { if !scope.HasError() {
rows, err := scope.DB().Query(scope.Sql, scope.SqlVars...) rows, err := scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
scope.db.RowsAffected = 0 scope.db.RowsAffected = 0
if scope.Err(err) != nil { if scope.Err(err) != nil {

View File

@ -90,7 +90,7 @@ func (s *commonDialect) HasTable(scope *Scope, tableName string) bool {
newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v AND table_schema = %v", newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v AND table_schema = %v",
newScope.AddToVars(tableName), newScope.AddToVars(tableName),
newScope.AddToVars(s.databaseName(scope)))) newScope.AddToVars(s.databaseName(scope))))
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
return count > 0 return count > 0
} }
@ -102,7 +102,7 @@ func (s *commonDialect) HasColumn(scope *Scope, tableName string, columnName str
newScope.AddToVars(tableName), newScope.AddToVars(tableName),
newScope.AddToVars(columnName), newScope.AddToVars(columnName),
)) ))
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
return count > 0 return count > 0
} }

23
main.go
View File

@ -84,10 +84,25 @@ func (s *DB) DB() *sql.DB {
return s.db.(*sql.DB) return s.db.(*sql.DB)
} }
func (s *DB) New() *DB { // NewScope create scope for callbacks, including DB's search information
clone := s.clone() func (db *DB) NewScope(value interface{}) *Scope {
clone.search = nil dbClone := db.clone()
return clone dbClone.Value = value
return &Scope{db: dbClone, Search: dbClone.search, Value: value}
}
func (s *DB) FreshDB() *DB {
newDB := &DB{
dialect: s.dialect,
logger: s.logger,
callback: s.parent.callback.clone(),
source: s.source,
values: map[string]interface{}{},
db: s.db,
ModelStructs: map[reflect.Type]*ModelStruct{},
}
newDB.parent = newDB
return newDB
} }
// CommonDB Return the underlying sql.DB or sql.Tx instance. // CommonDB Return the underlying sql.DB or sql.Tx instance.

View File

@ -92,7 +92,7 @@ func (s *mssql) HasTable(scope *Scope, tableName string) bool {
newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v AND table_catalog = %v", newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v AND table_catalog = %v",
newScope.AddToVars(tableName), newScope.AddToVars(tableName),
newScope.AddToVars(s.databaseName(scope)))) newScope.AddToVars(s.databaseName(scope))))
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
return count > 0 return count > 0
} }
@ -104,7 +104,7 @@ func (s *mssql) HasColumn(scope *Scope, tableName string, columnName string) boo
newScope.AddToVars(tableName), newScope.AddToVars(tableName),
newScope.AddToVars(columnName), newScope.AddToVars(columnName),
)) ))
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
return count > 0 return count > 0
} }

View File

@ -90,7 +90,7 @@ func (s *mysql) HasTable(scope *Scope, tableName string) bool {
newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v AND table_schema = %v", newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v AND table_schema = %v",
newScope.AddToVars(tableName), newScope.AddToVars(tableName),
newScope.AddToVars(s.databaseName(scope)))) newScope.AddToVars(s.databaseName(scope))))
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
return count > 0 return count > 0
} }
@ -102,7 +102,7 @@ func (s *mysql) HasColumn(scope *Scope, tableName string, columnName string) boo
newScope.AddToVars(tableName), newScope.AddToVars(tableName),
newScope.AddToVars(columnName), newScope.AddToVars(columnName),
)) ))
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
return count > 0 return count > 0
} }

View File

@ -83,7 +83,7 @@ func (s *postgres) HasTable(scope *Scope, tableName string) bool {
var count int var count int
newScope := scope.New(nil) newScope := scope.New(nil)
newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v and table_type = 'BASE TABLE'", newScope.AddToVars(tableName))) newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v and table_type = 'BASE TABLE'", newScope.AddToVars(tableName)))
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
return count > 0 return count > 0
} }
@ -94,7 +94,7 @@ func (s *postgres) HasColumn(scope *Scope, tableName string, columnName string)
newScope.AddToVars(tableName), newScope.AddToVars(tableName),
newScope.AddToVars(columnName), newScope.AddToVars(columnName),
)) ))
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
return count > 0 return count > 0
} }

View File

@ -33,13 +33,6 @@ func (scope *Scope) IndirectValue() reflect.Value {
return *scope.indirectValue return *scope.indirectValue
} }
// NewScope create scope for callbacks, including DB's search information
func (db *DB) NewScope(value interface{}) *Scope {
dbClone := db.clone()
dbClone.Value = value
return &Scope{db: dbClone, Search: dbClone.search, Value: value}
}
func (scope *Scope) NeedPtr() *Scope { func (scope *Scope) NeedPtr() *Scope {
reflectKind := reflect.ValueOf(scope.Value).Kind() reflectKind := reflect.ValueOf(scope.Value).Kind()
if !((reflectKind == reflect.Invalid) || (reflectKind == reflect.Ptr)) { if !((reflectKind == reflect.Invalid) || (reflectKind == reflect.Ptr)) {
@ -52,16 +45,21 @@ func (scope *Scope) NeedPtr() *Scope {
// New create a new Scope without search information // New create a new Scope without search information
func (scope *Scope) New(value interface{}) *Scope { func (scope *Scope) New(value interface{}) *Scope {
return &Scope{db: scope.db, Search: &search{}, Value: value} return &Scope{db: scope.NewDB(), Search: &search{}, Value: value}
} }
// NewDB create a new DB without search information // NewDB create a new DB without search information
func (scope *Scope) NewDB() *DB { func (scope *Scope) NewDB() *DB {
return scope.db.New() if scope.db != nil {
db := scope.db.clone()
db.search = nil
return db
}
return nil
} }
// DB get *sql.DB // SqlDB return *sql.DB
func (scope *Scope) DB() sqlCommon { func (scope *Scope) SqlDB() sqlCommon {
return scope.db.db return scope.db.db
} }
@ -73,9 +71,8 @@ func (scope *Scope) SkipLeft() {
// Quote used to quote database column name according to database dialect // Quote used to quote database column name according to database dialect
func (scope *Scope) Quote(str string) string { func (scope *Scope) Quote(str string) string {
if strings.Index(str, ".") != -1 { if strings.Index(str, ".") != -1 {
strs := strings.Split(str, ".")
newStrs := []string{} newStrs := []string{}
for _, str := range strs { for _, str := range strings.Split(str, ".") {
newStrs = append(newStrs, scope.Dialect().Quote(str)) newStrs = append(newStrs, scope.Dialect().Quote(str))
} }
return strings.Join(newStrs, ".") return strings.Join(newStrs, ".")
@ -176,13 +173,13 @@ func (scope *Scope) CallMethod(name string, checkError bool) {
case func(s *Scope): case func(s *Scope):
f(scope) f(scope)
case func(s *DB): case func(s *DB):
f(scope.db.New()) f(scope.NewDB())
case func() error: case func() error:
scope.Err(f()) scope.Err(f())
case func(s *Scope) error: case func(s *Scope) error:
scope.Err(f(scope)) scope.Err(f(scope))
case func(s *DB) error: case func(s *DB) error:
scope.Err(f(scope.db.New())) scope.Err(f(scope.NewDB()))
default: default:
scope.Err(fmt.Errorf("unsupported function %v", name)) scope.Err(fmt.Errorf("unsupported function %v", name))
} }
@ -229,12 +226,7 @@ func (scope *Scope) QuotedTableName() string {
return scope.Search.TableName return scope.Search.TableName
} }
keys := strings.Split(scope.TableName(), ".") return scope.Quote(scope.TableName())
for i, v := range keys {
keys[i] = scope.Quote(v)
}
return strings.Join(keys, ".")
} }
// CombinedConditionSql get combined condition sql // CombinedConditionSql get combined condition sql
@ -263,7 +255,7 @@ func (scope *Scope) Exec() *Scope {
defer scope.Trace(NowFunc()) defer scope.Trace(NowFunc())
if !scope.HasError() { if !scope.HasError() {
if result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
if count, err := result.RowsAffected(); err == nil { if count, err := result.RowsAffected(); err == nil {
scope.db.RowsAffected = count scope.db.RowsAffected = count
} }
@ -308,7 +300,7 @@ func (scope *Scope) Trace(t time.Time) {
// Begin start a transaction // Begin start a transaction
func (scope *Scope) Begin() *Scope { func (scope *Scope) Begin() *Scope {
if db, ok := scope.DB().(sqlDb); ok { if db, ok := scope.SqlDB().(sqlDb); ok {
if tx, err := db.Begin(); err == nil { if tx, err := db.Begin(); err == nil {
scope.db.db = interface{}(tx).(sqlCommon) scope.db.db = interface{}(tx).(sqlCommon)
scope.InstanceSet("gorm:started_transaction", true) scope.InstanceSet("gorm:started_transaction", true)

View File

@ -338,13 +338,13 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore
func (scope *Scope) row() *sql.Row { func (scope *Scope) row() *sql.Row {
defer scope.Trace(NowFunc()) defer scope.Trace(NowFunc())
scope.prepareQuerySql() scope.prepareQuerySql()
return scope.DB().QueryRow(scope.Sql, scope.SqlVars...) return scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...)
} }
func (scope *Scope) rows() (*sql.Rows, error) { func (scope *Scope) rows() (*sql.Rows, error) {
defer scope.Trace(NowFunc()) defer scope.Trace(NowFunc())
scope.prepareQuerySql() scope.prepareQuerySql()
return scope.DB().Query(scope.Sql, scope.SqlVars...) return scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
} }
func (scope *Scope) initialize() *Scope { func (scope *Scope) initialize() *Scope {

View File

@ -70,13 +70,13 @@ func (s *sqlite3) Quote(key string) string {
func (s *sqlite3) HasTable(scope *Scope, tableName string) bool { func (s *sqlite3) HasTable(scope *Scope, tableName string) bool {
var count int var count int
scope.DB().QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='%v';", tableName)).Scan(&count) scope.SqlDB().QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='%v';", tableName)).Scan(&count)
return count > 0 return count > 0
} }
func (s *sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool { func (s *sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int var count int
scope.DB().QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = '%v' AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%');\n", tableName, columnName, columnName, columnName, columnName)).Scan(&count) scope.SqlDB().QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = '%v' AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%');\n", tableName, columnName, columnName, columnName, columnName)).Scan(&count)
return count > 0 return count > 0
} }