mirror of https://github.com/go-gorm/gorm.git
Refactoring API for plugin system
This commit is contained in:
parent
087b7083ad
commit
ce72988e96
|
@ -56,7 +56,7 @@ func Create(scope *Scope) {
|
|||
|
||||
// execute create sql
|
||||
if scope.Dialect().SupportLastInsertId() {
|
||||
if result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
||||
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
||||
id, err := result.LastInsertId()
|
||||
if scope.Err(err) == nil {
|
||||
scope.db.RowsAffected, _ = result.RowsAffected()
|
||||
|
@ -67,10 +67,10 @@ func Create(scope *Scope) {
|
|||
}
|
||||
} else {
|
||||
if primaryField == nil {
|
||||
if results, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); err != nil {
|
||||
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err != nil {
|
||||
scope.db.RowsAffected, _ = results.RowsAffected()
|
||||
}
|
||||
} else if scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())) == nil {
|
||||
} else if scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())) == nil {
|
||||
scope.db.RowsAffected = 1
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ func Query(scope *Scope) {
|
|||
scope.prepareQuerySql()
|
||||
|
||||
if !scope.HasError() {
|
||||
rows, err := scope.DB().Query(scope.Sql, scope.SqlVars...)
|
||||
rows, err := scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
|
||||
scope.db.RowsAffected = 0
|
||||
|
||||
if scope.Err(err) != nil {
|
||||
|
|
|
@ -90,7 +90,7 @@ func (s *commonDialect) HasTable(scope *Scope, tableName string) bool {
|
|||
newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v AND table_schema = %v",
|
||||
newScope.AddToVars(tableName),
|
||||
newScope.AddToVars(s.databaseName(scope))))
|
||||
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
|
||||
newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
|
@ -102,7 +102,7 @@ func (s *commonDialect) HasColumn(scope *Scope, tableName string, columnName str
|
|||
newScope.AddToVars(tableName),
|
||||
newScope.AddToVars(columnName),
|
||||
))
|
||||
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
|
||||
newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
|
|
23
main.go
23
main.go
|
@ -84,10 +84,25 @@ func (s *DB) DB() *sql.DB {
|
|||
return s.db.(*sql.DB)
|
||||
}
|
||||
|
||||
func (s *DB) New() *DB {
|
||||
clone := s.clone()
|
||||
clone.search = nil
|
||||
return clone
|
||||
// NewScope create scope for callbacks, including DB's search information
|
||||
func (db *DB) NewScope(value interface{}) *Scope {
|
||||
dbClone := db.clone()
|
||||
dbClone.Value = value
|
||||
return &Scope{db: dbClone, Search: dbClone.search, Value: value}
|
||||
}
|
||||
|
||||
func (s *DB) FreshDB() *DB {
|
||||
newDB := &DB{
|
||||
dialect: s.dialect,
|
||||
logger: s.logger,
|
||||
callback: s.parent.callback.clone(),
|
||||
source: s.source,
|
||||
values: map[string]interface{}{},
|
||||
db: s.db,
|
||||
ModelStructs: map[reflect.Type]*ModelStruct{},
|
||||
}
|
||||
newDB.parent = newDB
|
||||
return newDB
|
||||
}
|
||||
|
||||
// CommonDB Return the underlying sql.DB or sql.Tx instance.
|
||||
|
|
4
mssql.go
4
mssql.go
|
@ -92,7 +92,7 @@ func (s *mssql) HasTable(scope *Scope, tableName string) bool {
|
|||
newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v AND table_catalog = %v",
|
||||
newScope.AddToVars(tableName),
|
||||
newScope.AddToVars(s.databaseName(scope))))
|
||||
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
|
||||
newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
|
@ -104,7 +104,7 @@ func (s *mssql) HasColumn(scope *Scope, tableName string, columnName string) boo
|
|||
newScope.AddToVars(tableName),
|
||||
newScope.AddToVars(columnName),
|
||||
))
|
||||
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
|
||||
newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
|
|
4
mysql.go
4
mysql.go
|
@ -90,7 +90,7 @@ func (s *mysql) HasTable(scope *Scope, tableName string) bool {
|
|||
newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v AND table_schema = %v",
|
||||
newScope.AddToVars(tableName),
|
||||
newScope.AddToVars(s.databaseName(scope))))
|
||||
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
|
||||
newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
|
@ -102,7 +102,7 @@ func (s *mysql) HasColumn(scope *Scope, tableName string, columnName string) boo
|
|||
newScope.AddToVars(tableName),
|
||||
newScope.AddToVars(columnName),
|
||||
))
|
||||
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
|
||||
newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
|
|
|
@ -83,7 +83,7 @@ func (s *postgres) HasTable(scope *Scope, tableName string) bool {
|
|||
var count int
|
||||
newScope := scope.New(nil)
|
||||
newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v and table_type = 'BASE TABLE'", newScope.AddToVars(tableName)))
|
||||
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
|
||||
newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
|
@ -94,7 +94,7 @@ func (s *postgres) HasColumn(scope *Scope, tableName string, columnName string)
|
|||
newScope.AddToVars(tableName),
|
||||
newScope.AddToVars(columnName),
|
||||
))
|
||||
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
|
||||
newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
|
|
38
scope.go
38
scope.go
|
@ -33,13 +33,6 @@ func (scope *Scope) IndirectValue() reflect.Value {
|
|||
return *scope.indirectValue
|
||||
}
|
||||
|
||||
// NewScope create scope for callbacks, including DB's search information
|
||||
func (db *DB) NewScope(value interface{}) *Scope {
|
||||
dbClone := db.clone()
|
||||
dbClone.Value = value
|
||||
return &Scope{db: dbClone, Search: dbClone.search, Value: value}
|
||||
}
|
||||
|
||||
func (scope *Scope) NeedPtr() *Scope {
|
||||
reflectKind := reflect.ValueOf(scope.Value).Kind()
|
||||
if !((reflectKind == reflect.Invalid) || (reflectKind == reflect.Ptr)) {
|
||||
|
@ -52,16 +45,21 @@ func (scope *Scope) NeedPtr() *Scope {
|
|||
|
||||
// New create a new Scope without search information
|
||||
func (scope *Scope) New(value interface{}) *Scope {
|
||||
return &Scope{db: scope.db, Search: &search{}, Value: value}
|
||||
return &Scope{db: scope.NewDB(), Search: &search{}, Value: value}
|
||||
}
|
||||
|
||||
// NewDB create a new DB without search information
|
||||
func (scope *Scope) NewDB() *DB {
|
||||
return scope.db.New()
|
||||
if scope.db != nil {
|
||||
db := scope.db.clone()
|
||||
db.search = nil
|
||||
return db
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DB get *sql.DB
|
||||
func (scope *Scope) DB() sqlCommon {
|
||||
// SqlDB return *sql.DB
|
||||
func (scope *Scope) SqlDB() sqlCommon {
|
||||
return scope.db.db
|
||||
}
|
||||
|
||||
|
@ -73,9 +71,8 @@ func (scope *Scope) SkipLeft() {
|
|||
// Quote used to quote database column name according to database dialect
|
||||
func (scope *Scope) Quote(str string) string {
|
||||
if strings.Index(str, ".") != -1 {
|
||||
strs := strings.Split(str, ".")
|
||||
newStrs := []string{}
|
||||
for _, str := range strs {
|
||||
for _, str := range strings.Split(str, ".") {
|
||||
newStrs = append(newStrs, scope.Dialect().Quote(str))
|
||||
}
|
||||
return strings.Join(newStrs, ".")
|
||||
|
@ -176,13 +173,13 @@ func (scope *Scope) CallMethod(name string, checkError bool) {
|
|||
case func(s *Scope):
|
||||
f(scope)
|
||||
case func(s *DB):
|
||||
f(scope.db.New())
|
||||
f(scope.NewDB())
|
||||
case func() error:
|
||||
scope.Err(f())
|
||||
case func(s *Scope) error:
|
||||
scope.Err(f(scope))
|
||||
case func(s *DB) error:
|
||||
scope.Err(f(scope.db.New()))
|
||||
scope.Err(f(scope.NewDB()))
|
||||
default:
|
||||
scope.Err(fmt.Errorf("unsupported function %v", name))
|
||||
}
|
||||
|
@ -229,12 +226,7 @@ func (scope *Scope) QuotedTableName() string {
|
|||
return scope.Search.TableName
|
||||
}
|
||||
|
||||
keys := strings.Split(scope.TableName(), ".")
|
||||
for i, v := range keys {
|
||||
keys[i] = scope.Quote(v)
|
||||
}
|
||||
return strings.Join(keys, ".")
|
||||
|
||||
return scope.Quote(scope.TableName())
|
||||
}
|
||||
|
||||
// CombinedConditionSql get combined condition sql
|
||||
|
@ -263,7 +255,7 @@ func (scope *Scope) Exec() *Scope {
|
|||
defer scope.Trace(NowFunc())
|
||||
|
||||
if !scope.HasError() {
|
||||
if result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
||||
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
||||
if count, err := result.RowsAffected(); err == nil {
|
||||
scope.db.RowsAffected = count
|
||||
}
|
||||
|
@ -308,7 +300,7 @@ func (scope *Scope) Trace(t time.Time) {
|
|||
|
||||
// Begin start a transaction
|
||||
func (scope *Scope) Begin() *Scope {
|
||||
if db, ok := scope.DB().(sqlDb); ok {
|
||||
if db, ok := scope.SqlDB().(sqlDb); ok {
|
||||
if tx, err := db.Begin(); err == nil {
|
||||
scope.db.db = interface{}(tx).(sqlCommon)
|
||||
scope.InstanceSet("gorm:started_transaction", true)
|
||||
|
|
|
@ -338,13 +338,13 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore
|
|||
func (scope *Scope) row() *sql.Row {
|
||||
defer scope.Trace(NowFunc())
|
||||
scope.prepareQuerySql()
|
||||
return scope.DB().QueryRow(scope.Sql, scope.SqlVars...)
|
||||
return scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...)
|
||||
}
|
||||
|
||||
func (scope *Scope) rows() (*sql.Rows, error) {
|
||||
defer scope.Trace(NowFunc())
|
||||
scope.prepareQuerySql()
|
||||
return scope.DB().Query(scope.Sql, scope.SqlVars...)
|
||||
return scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
|
||||
}
|
||||
|
||||
func (scope *Scope) initialize() *Scope {
|
||||
|
|
|
@ -70,13 +70,13 @@ func (s *sqlite3) Quote(key string) string {
|
|||
|
||||
func (s *sqlite3) HasTable(scope *Scope, tableName string) bool {
|
||||
var count int
|
||||
scope.DB().QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='%v';", tableName)).Scan(&count)
|
||||
scope.SqlDB().QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='%v';", tableName)).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s *sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
||||
var count int
|
||||
scope.DB().QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = '%v' AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%');\n", tableName, columnName, columnName, columnName, columnName)).Scan(&count)
|
||||
scope.SqlDB().QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = '%v' AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%');\n", tableName, columnName, columnName, columnName, columnName)).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue