Refactoring API for plugin system

This commit is contained in:
Jinzhu 2015-02-26 12:35:33 +08:00
parent 087b7083ad
commit ce72988e96
10 changed files with 50 additions and 43 deletions

View File

@ -56,7 +56,7 @@ func Create(scope *Scope) {
// execute create sql
if scope.Dialect().SupportLastInsertId() {
if result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
id, err := result.LastInsertId()
if scope.Err(err) == nil {
scope.db.RowsAffected, _ = result.RowsAffected()
@ -67,10 +67,10 @@ func Create(scope *Scope) {
}
} else {
if primaryField == nil {
if results, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); err != nil {
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err != nil {
scope.db.RowsAffected, _ = results.RowsAffected()
}
} else if scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())) == nil {
} else if scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())) == nil {
scope.db.RowsAffected = 1
}
}

View File

@ -40,7 +40,7 @@ func Query(scope *Scope) {
scope.prepareQuerySql()
if !scope.HasError() {
rows, err := scope.DB().Query(scope.Sql, scope.SqlVars...)
rows, err := scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
scope.db.RowsAffected = 0
if scope.Err(err) != nil {

View File

@ -90,7 +90,7 @@ func (s *commonDialect) HasTable(scope *Scope, tableName string) bool {
newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v AND table_schema = %v",
newScope.AddToVars(tableName),
newScope.AddToVars(s.databaseName(scope))))
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
return count > 0
}
@ -102,7 +102,7 @@ func (s *commonDialect) HasColumn(scope *Scope, tableName string, columnName str
newScope.AddToVars(tableName),
newScope.AddToVars(columnName),
))
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
return count > 0
}

23
main.go
View File

@ -84,10 +84,25 @@ func (s *DB) DB() *sql.DB {
return s.db.(*sql.DB)
}
func (s *DB) New() *DB {
clone := s.clone()
clone.search = nil
return clone
// NewScope create scope for callbacks, including DB's search information
func (db *DB) NewScope(value interface{}) *Scope {
dbClone := db.clone()
dbClone.Value = value
return &Scope{db: dbClone, Search: dbClone.search, Value: value}
}
func (s *DB) FreshDB() *DB {
newDB := &DB{
dialect: s.dialect,
logger: s.logger,
callback: s.parent.callback.clone(),
source: s.source,
values: map[string]interface{}{},
db: s.db,
ModelStructs: map[reflect.Type]*ModelStruct{},
}
newDB.parent = newDB
return newDB
}
// CommonDB Return the underlying sql.DB or sql.Tx instance.

View File

@ -92,7 +92,7 @@ func (s *mssql) HasTable(scope *Scope, tableName string) bool {
newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v AND table_catalog = %v",
newScope.AddToVars(tableName),
newScope.AddToVars(s.databaseName(scope))))
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
return count > 0
}
@ -104,7 +104,7 @@ func (s *mssql) HasColumn(scope *Scope, tableName string, columnName string) boo
newScope.AddToVars(tableName),
newScope.AddToVars(columnName),
))
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
return count > 0
}

View File

@ -90,7 +90,7 @@ func (s *mysql) HasTable(scope *Scope, tableName string) bool {
newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v AND table_schema = %v",
newScope.AddToVars(tableName),
newScope.AddToVars(s.databaseName(scope))))
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
return count > 0
}
@ -102,7 +102,7 @@ func (s *mysql) HasColumn(scope *Scope, tableName string, columnName string) boo
newScope.AddToVars(tableName),
newScope.AddToVars(columnName),
))
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
return count > 0
}

View File

@ -83,7 +83,7 @@ func (s *postgres) HasTable(scope *Scope, tableName string) bool {
var count int
newScope := scope.New(nil)
newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v and table_type = 'BASE TABLE'", newScope.AddToVars(tableName)))
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
return count > 0
}
@ -94,7 +94,7 @@ func (s *postgres) HasColumn(scope *Scope, tableName string, columnName string)
newScope.AddToVars(tableName),
newScope.AddToVars(columnName),
))
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
return count > 0
}

View File

@ -33,13 +33,6 @@ func (scope *Scope) IndirectValue() reflect.Value {
return *scope.indirectValue
}
// NewScope create scope for callbacks, including DB's search information
func (db *DB) NewScope(value interface{}) *Scope {
dbClone := db.clone()
dbClone.Value = value
return &Scope{db: dbClone, Search: dbClone.search, Value: value}
}
func (scope *Scope) NeedPtr() *Scope {
reflectKind := reflect.ValueOf(scope.Value).Kind()
if !((reflectKind == reflect.Invalid) || (reflectKind == reflect.Ptr)) {
@ -52,16 +45,21 @@ func (scope *Scope) NeedPtr() *Scope {
// New create a new Scope without search information
func (scope *Scope) New(value interface{}) *Scope {
return &Scope{db: scope.db, Search: &search{}, Value: value}
return &Scope{db: scope.NewDB(), Search: &search{}, Value: value}
}
// NewDB create a new DB without search information
func (scope *Scope) NewDB() *DB {
return scope.db.New()
if scope.db != nil {
db := scope.db.clone()
db.search = nil
return db
}
return nil
}
// DB get *sql.DB
func (scope *Scope) DB() sqlCommon {
// SqlDB return *sql.DB
func (scope *Scope) SqlDB() sqlCommon {
return scope.db.db
}
@ -73,9 +71,8 @@ func (scope *Scope) SkipLeft() {
// Quote used to quote database column name according to database dialect
func (scope *Scope) Quote(str string) string {
if strings.Index(str, ".") != -1 {
strs := strings.Split(str, ".")
newStrs := []string{}
for _, str := range strs {
for _, str := range strings.Split(str, ".") {
newStrs = append(newStrs, scope.Dialect().Quote(str))
}
return strings.Join(newStrs, ".")
@ -176,13 +173,13 @@ func (scope *Scope) CallMethod(name string, checkError bool) {
case func(s *Scope):
f(scope)
case func(s *DB):
f(scope.db.New())
f(scope.NewDB())
case func() error:
scope.Err(f())
case func(s *Scope) error:
scope.Err(f(scope))
case func(s *DB) error:
scope.Err(f(scope.db.New()))
scope.Err(f(scope.NewDB()))
default:
scope.Err(fmt.Errorf("unsupported function %v", name))
}
@ -229,12 +226,7 @@ func (scope *Scope) QuotedTableName() string {
return scope.Search.TableName
}
keys := strings.Split(scope.TableName(), ".")
for i, v := range keys {
keys[i] = scope.Quote(v)
}
return strings.Join(keys, ".")
return scope.Quote(scope.TableName())
}
// CombinedConditionSql get combined condition sql
@ -263,7 +255,7 @@ func (scope *Scope) Exec() *Scope {
defer scope.Trace(NowFunc())
if !scope.HasError() {
if result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
if count, err := result.RowsAffected(); err == nil {
scope.db.RowsAffected = count
}
@ -308,7 +300,7 @@ func (scope *Scope) Trace(t time.Time) {
// Begin start a transaction
func (scope *Scope) Begin() *Scope {
if db, ok := scope.DB().(sqlDb); ok {
if db, ok := scope.SqlDB().(sqlDb); ok {
if tx, err := db.Begin(); err == nil {
scope.db.db = interface{}(tx).(sqlCommon)
scope.InstanceSet("gorm:started_transaction", true)

View File

@ -338,13 +338,13 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore
func (scope *Scope) row() *sql.Row {
defer scope.Trace(NowFunc())
scope.prepareQuerySql()
return scope.DB().QueryRow(scope.Sql, scope.SqlVars...)
return scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...)
}
func (scope *Scope) rows() (*sql.Rows, error) {
defer scope.Trace(NowFunc())
scope.prepareQuerySql()
return scope.DB().Query(scope.Sql, scope.SqlVars...)
return scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
}
func (scope *Scope) initialize() *Scope {

View File

@ -70,13 +70,13 @@ func (s *sqlite3) Quote(key string) string {
func (s *sqlite3) HasTable(scope *Scope, tableName string) bool {
var count int
scope.DB().QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='%v';", tableName)).Scan(&count)
scope.SqlDB().QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='%v';", tableName)).Scan(&count)
return count > 0
}
func (s *sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int
scope.DB().QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = '%v' AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%');\n", tableName, columnName, columnName, columnName, columnName)).Scan(&count)
scope.SqlDB().QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = '%v' AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%');\n", tableName, columnName, columnName, columnName, columnName)).Scan(&count)
return count > 0
}