Reworked CurrentDatabase API to return the name instead of `*gorm.DB'.

This commit is contained in:
Jay Taylor 2015-08-11 08:59:59 -07:00
parent 70725f9d77
commit beeb040c62
9 changed files with 33 additions and 29 deletions

View File

@ -71,9 +71,8 @@ func (commonDialect) Quote(key string) string {
func (c commonDialect) HasTable(scope *Scope, tableName string) bool { func (c commonDialect) HasTable(scope *Scope, tableName string) bool {
var ( var (
count int count int
databaseName string databaseName = c.CurrentDatabase(scope)
) )
c.CurrentDatabase(scope, &databaseName)
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, databaseName).Row().Scan(&count) scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, databaseName).Row().Scan(&count)
return count > 0 return count > 0
} }
@ -81,9 +80,8 @@ func (c commonDialect) HasTable(scope *Scope, tableName string) bool {
func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool { func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool {
var ( var (
count int count int
databaseName string databaseName = c.CurrentDatabase(scope)
) )
c.CurrentDatabase(scope, &databaseName)
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName).Row().Scan(&count) scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName).Row().Scan(&count)
return count > 0 return count > 0
} }
@ -98,6 +96,7 @@ func (commonDialect) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())) scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName()))
} }
func (commonDialect) CurrentDatabase(scope *Scope, name *string) { func (commonDialect) CurrentDatabase(scope *Scope) (name string) {
scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(name)) scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(&name))
return
} }

View File

@ -17,7 +17,7 @@ type Dialect interface {
HasColumn(scope *Scope, tableName string, columnName string) bool HasColumn(scope *Scope, tableName string, columnName string) bool
HasIndex(scope *Scope, tableName string, indexName string) bool HasIndex(scope *Scope, tableName string, indexName string) bool
RemoveIndex(scope *Scope, indexName string) RemoveIndex(scope *Scope, indexName string)
CurrentDatabase(scope *Scope, name *string) CurrentDatabase(scope *Scope) string
} }
func NewDialect(driver string) Dialect { func NewDialect(driver string) Dialect {

View File

@ -77,6 +77,7 @@ func (foundation) HasIndex(scope *Scope, tableName string, indexName string) boo
return count > 0 return count > 0
} }
func (foundation) CurrentDatabase(scope *Scope, name *string) { func (foundation) CurrentDatabase(scope *Scope) (name string) {
scope.Err(scope.NewDB().Raw("SELECT CURRENT_SCHEMA").Row().Scan(name)) scope.Err(scope.NewDB().Raw("SELECT CURRENT_SCHEMA").Row().Scan(&name))
return
} }

10
main.go
View File

@ -429,10 +429,12 @@ func (s *DB) RemoveIndex(indexName string) *DB {
return scope.db return scope.db
} }
func (s *DB) CurrentDatabase(name *string) *DB { func (s *DB) CurrentDatabase() string {
scope := s.clone().NewScope(s.Value) var (
s.dialect.CurrentDatabase(scope, name) scope = s.clone().NewScope(s.Value)
return scope.db name = s.dialect.CurrentDatabase(scope)
)
return name
} }
/* /*

View File

@ -53,9 +53,8 @@ func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
func (s mssql) HasTable(scope *Scope, tableName string) bool { func (s mssql) HasTable(scope *Scope, tableName string) bool {
var ( var (
count int count int
databaseName string databaseName = s.CurrentDatabase(scope)
) )
s.CurrentDatabase(scope, &databaseName)
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, databaseName).Row().Scan(&count) scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, databaseName).Row().Scan(&count)
return count > 0 return count > 0
} }
@ -63,9 +62,8 @@ func (s mssql) HasTable(scope *Scope, tableName string) bool {
func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool { func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool {
var ( var (
count int count int
databaseName string databaseName = s.CurrentDatabase(scope)
) )
s.CurrentDatabase(scope, &databaseName)
scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName).Row().Scan(&count) scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName).Row().Scan(&count)
return count > 0 return count > 0
} }
@ -76,6 +74,7 @@ func (mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
return count > 0 return count > 0
} }
func (mssql) CurrentDatabase(scope *Scope, name *string) { func (mssql) CurrentDatabase(scope *Scope) (name string) {
scope.Err(scope.NewDB().Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(name)) scope.Err(scope.NewDB().Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(&name))
return
} }

View File

@ -64,6 +64,7 @@ func (mysql) SelectFromDummyTable() string {
return "FROM DUAL" return "FROM DUAL"
} }
func (mysql) CurrentDatabase(scope *Scope, name *string) { func (mysql) CurrentDatabase(scope *Scope) (name string) {
scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(name)) scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(&name))
return
} }

View File

@ -85,8 +85,9 @@ func (postgres) HasIndex(scope *Scope, tableName string, indexName string) bool
return count > 0 return count > 0
} }
func (postgres) CurrentDatabase(scope *Scope, name *string) { func (postgres) CurrentDatabase(scope *Scope) (name string) {
scope.Err(scope.NewDB().Raw("SELECT CURRENT_DATABASE()").Row().Scan(name)) scope.Err(scope.NewDB().Raw("SELECT CURRENT_DATABASE()").Row().Scan(&name))
return
} }
var hstoreType = reflect.TypeOf(Hstore{}) var hstoreType = reflect.TypeOf(Hstore{})

View File

@ -582,12 +582,12 @@ func TestSelectWithArrayInput(t *testing.T) {
func TestCurrentDatabase(t *testing.T) { func TestCurrentDatabase(t *testing.T) {
DB.LogMode(true) DB.LogMode(true)
var name string databaseName := DB.CurrentDatabase()
if err := DB.CurrentDatabase(&name).Error; err != nil { if err := DB.Error; err != nil {
t.Errorf("Problem getting current db name: %s", err) t.Errorf("Problem getting current db name: %s", err)
} }
if name == "" { if databaseName == "" {
t.Errorf("Current db name returned empty; this should never happen!") t.Errorf("Current db name returned empty; this should never happen!")
} }
t.Logf("Got current db name: %v", name) t.Logf("Got current db name: %v", databaseName)
} }

View File

@ -62,7 +62,7 @@ func (sqlite3) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)) scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName))
} }
func (sqlite3) CurrentDatabase(scope *Scope, name *string) { func (sqlite3) CurrentDatabase(scope *Scope) (name string) {
var ( var (
ifaces = make([]interface{}, 3) ifaces = make([]interface{}, 3)
pointers = make([]*string, 3) pointers = make([]*string, 3)
@ -75,6 +75,7 @@ func (sqlite3) CurrentDatabase(scope *Scope, name *string) {
return return
} }
if pointers[1] != nil { if pointers[1] != nil {
*name = *pointers[1] name = *pointers[1]
} }
return
} }