From ce72988e964bdeedb03282dd398d1066eccba0b1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 26 Feb 2015 12:35:33 +0800 Subject: [PATCH] Refactoring API for plugin system --- callback_create.go | 6 +++--- callback_query.go | 2 +- common_dialect.go | 4 ++-- main.go | 23 +++++++++++++++++++---- mssql.go | 4 ++-- mysql.go | 4 ++-- postgres.go | 4 ++-- scope.go | 38 +++++++++++++++----------------------- scope_private.go | 4 ++-- sqlite3.go | 4 ++-- 10 files changed, 50 insertions(+), 43 deletions(-) diff --git a/callback_create.go b/callback_create.go index 6e32922f..1c41c03f 100644 --- a/callback_create.go +++ b/callback_create.go @@ -56,7 +56,7 @@ func Create(scope *Scope) { // execute create sql if scope.Dialect().SupportLastInsertId() { - if result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { + if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { id, err := result.LastInsertId() if scope.Err(err) == nil { scope.db.RowsAffected, _ = result.RowsAffected() @@ -67,10 +67,10 @@ func Create(scope *Scope) { } } else { if primaryField == nil { - if results, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); err != nil { + if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err != nil { scope.db.RowsAffected, _ = results.RowsAffected() } - } else if scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())) == nil { + } else if scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())) == nil { scope.db.RowsAffected = 1 } } diff --git a/callback_query.go b/callback_query.go index b9ca58e2..e7e388b4 100644 --- a/callback_query.go +++ b/callback_query.go @@ -40,7 +40,7 @@ func Query(scope *Scope) { scope.prepareQuerySql() if !scope.HasError() { - rows, err := scope.DB().Query(scope.Sql, scope.SqlVars...) + rows, err := scope.SqlDB().Query(scope.Sql, scope.SqlVars...) scope.db.RowsAffected = 0 if scope.Err(err) != nil { diff --git a/common_dialect.go b/common_dialect.go index 7013df06..e1a54c29 100644 --- a/common_dialect.go +++ b/common_dialect.go @@ -90,7 +90,7 @@ func (s *commonDialect) HasTable(scope *Scope, tableName string) bool { newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v AND table_schema = %v", newScope.AddToVars(tableName), newScope.AddToVars(s.databaseName(scope)))) - newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) + newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) return count > 0 } @@ -102,7 +102,7 @@ func (s *commonDialect) HasColumn(scope *Scope, tableName string, columnName str newScope.AddToVars(tableName), newScope.AddToVars(columnName), )) - newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) + newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) return count > 0 } diff --git a/main.go b/main.go index 49eeec2d..f0787d20 100644 --- a/main.go +++ b/main.go @@ -84,10 +84,25 @@ func (s *DB) DB() *sql.DB { return s.db.(*sql.DB) } -func (s *DB) New() *DB { - clone := s.clone() - clone.search = nil - return clone +// NewScope create scope for callbacks, including DB's search information +func (db *DB) NewScope(value interface{}) *Scope { + dbClone := db.clone() + dbClone.Value = value + return &Scope{db: dbClone, Search: dbClone.search, Value: value} +} + +func (s *DB) FreshDB() *DB { + newDB := &DB{ + dialect: s.dialect, + logger: s.logger, + callback: s.parent.callback.clone(), + source: s.source, + values: map[string]interface{}{}, + db: s.db, + ModelStructs: map[reflect.Type]*ModelStruct{}, + } + newDB.parent = newDB + return newDB } // CommonDB Return the underlying sql.DB or sql.Tx instance. diff --git a/mssql.go b/mssql.go index 1d1562ea..30f9b824 100644 --- a/mssql.go +++ b/mssql.go @@ -92,7 +92,7 @@ func (s *mssql) HasTable(scope *Scope, tableName string) bool { newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v AND table_catalog = %v", newScope.AddToVars(tableName), newScope.AddToVars(s.databaseName(scope)))) - newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) + newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) return count > 0 } @@ -104,7 +104,7 @@ func (s *mssql) HasColumn(scope *Scope, tableName string, columnName string) boo newScope.AddToVars(tableName), newScope.AddToVars(columnName), )) - newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) + newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) return count > 0 } diff --git a/mysql.go b/mysql.go index 7d9758a1..ec575336 100644 --- a/mysql.go +++ b/mysql.go @@ -90,7 +90,7 @@ func (s *mysql) HasTable(scope *Scope, tableName string) bool { newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v AND table_schema = %v", newScope.AddToVars(tableName), newScope.AddToVars(s.databaseName(scope)))) - newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) + newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) return count > 0 } @@ -102,7 +102,7 @@ func (s *mysql) HasColumn(scope *Scope, tableName string, columnName string) boo newScope.AddToVars(tableName), newScope.AddToVars(columnName), )) - newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) + newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) return count > 0 } diff --git a/postgres.go b/postgres.go index 3654ddd7..1ffbe4aa 100644 --- a/postgres.go +++ b/postgres.go @@ -83,7 +83,7 @@ func (s *postgres) HasTable(scope *Scope, tableName string) bool { var count int newScope := scope.New(nil) newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v and table_type = 'BASE TABLE'", newScope.AddToVars(tableName))) - newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) + newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) return count > 0 } @@ -94,7 +94,7 @@ func (s *postgres) HasColumn(scope *Scope, tableName string, columnName string) newScope.AddToVars(tableName), newScope.AddToVars(columnName), )) - newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) + newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) return count > 0 } diff --git a/scope.go b/scope.go index 1191a637..f88bf6dd 100644 --- a/scope.go +++ b/scope.go @@ -33,13 +33,6 @@ func (scope *Scope) IndirectValue() reflect.Value { return *scope.indirectValue } -// NewScope create scope for callbacks, including DB's search information -func (db *DB) NewScope(value interface{}) *Scope { - dbClone := db.clone() - dbClone.Value = value - return &Scope{db: dbClone, Search: dbClone.search, Value: value} -} - func (scope *Scope) NeedPtr() *Scope { reflectKind := reflect.ValueOf(scope.Value).Kind() if !((reflectKind == reflect.Invalid) || (reflectKind == reflect.Ptr)) { @@ -52,16 +45,21 @@ func (scope *Scope) NeedPtr() *Scope { // New create a new Scope without search information func (scope *Scope) New(value interface{}) *Scope { - return &Scope{db: scope.db, Search: &search{}, Value: value} + return &Scope{db: scope.NewDB(), Search: &search{}, Value: value} } // NewDB create a new DB without search information func (scope *Scope) NewDB() *DB { - return scope.db.New() + if scope.db != nil { + db := scope.db.clone() + db.search = nil + return db + } + return nil } -// DB get *sql.DB -func (scope *Scope) DB() sqlCommon { +// SqlDB return *sql.DB +func (scope *Scope) SqlDB() sqlCommon { return scope.db.db } @@ -73,9 +71,8 @@ func (scope *Scope) SkipLeft() { // Quote used to quote database column name according to database dialect func (scope *Scope) Quote(str string) string { if strings.Index(str, ".") != -1 { - strs := strings.Split(str, ".") newStrs := []string{} - for _, str := range strs { + for _, str := range strings.Split(str, ".") { newStrs = append(newStrs, scope.Dialect().Quote(str)) } return strings.Join(newStrs, ".") @@ -176,13 +173,13 @@ func (scope *Scope) CallMethod(name string, checkError bool) { case func(s *Scope): f(scope) case func(s *DB): - f(scope.db.New()) + f(scope.NewDB()) case func() error: scope.Err(f()) case func(s *Scope) error: scope.Err(f(scope)) case func(s *DB) error: - scope.Err(f(scope.db.New())) + scope.Err(f(scope.NewDB())) default: scope.Err(fmt.Errorf("unsupported function %v", name)) } @@ -229,12 +226,7 @@ func (scope *Scope) QuotedTableName() string { return scope.Search.TableName } - keys := strings.Split(scope.TableName(), ".") - for i, v := range keys { - keys[i] = scope.Quote(v) - } - return strings.Join(keys, ".") - + return scope.Quote(scope.TableName()) } // CombinedConditionSql get combined condition sql @@ -263,7 +255,7 @@ func (scope *Scope) Exec() *Scope { defer scope.Trace(NowFunc()) if !scope.HasError() { - if result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { + if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { if count, err := result.RowsAffected(); err == nil { scope.db.RowsAffected = count } @@ -308,7 +300,7 @@ func (scope *Scope) Trace(t time.Time) { // Begin start a transaction func (scope *Scope) Begin() *Scope { - if db, ok := scope.DB().(sqlDb); ok { + if db, ok := scope.SqlDB().(sqlDb); ok { if tx, err := db.Begin(); err == nil { scope.db.db = interface{}(tx).(sqlCommon) scope.InstanceSet("gorm:started_transaction", true) diff --git a/scope_private.go b/scope_private.go index 6636ee83..e831bf4f 100644 --- a/scope_private.go +++ b/scope_private.go @@ -338,13 +338,13 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore func (scope *Scope) row() *sql.Row { defer scope.Trace(NowFunc()) scope.prepareQuerySql() - return scope.DB().QueryRow(scope.Sql, scope.SqlVars...) + return scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...) } func (scope *Scope) rows() (*sql.Rows, error) { defer scope.Trace(NowFunc()) scope.prepareQuerySql() - return scope.DB().Query(scope.Sql, scope.SqlVars...) + return scope.SqlDB().Query(scope.Sql, scope.SqlVars...) } func (scope *Scope) initialize() *Scope { diff --git a/sqlite3.go b/sqlite3.go index c92e2cdb..7eb071d0 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -70,13 +70,13 @@ func (s *sqlite3) Quote(key string) string { func (s *sqlite3) HasTable(scope *Scope, tableName string) bool { var count int - scope.DB().QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='%v';", tableName)).Scan(&count) + scope.SqlDB().QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='%v';", tableName)).Scan(&count) return count > 0 } func (s *sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool { var count int - scope.DB().QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = '%v' AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%');\n", tableName, columnName, columnName, columnName, columnName)).Scan(&count) + scope.SqlDB().QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = '%v' AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%');\n", tableName, columnName, columnName, columnName, columnName)).Scan(&count) return count > 0 }