diff --git a/common_dialect.go b/common_dialect.go index 956121ff..19708e50 100644 --- a/common_dialect.go +++ b/common_dialect.go @@ -71,9 +71,8 @@ func (commonDialect) Quote(key string) string { func (c commonDialect) HasTable(scope *Scope, tableName string) bool { var ( count int - databaseName string + databaseName = c.CurrentDatabase(scope) ) - c.CurrentDatabase(scope, &databaseName) c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, databaseName) return count > 0 } @@ -81,9 +80,8 @@ func (c commonDialect) HasTable(scope *Scope, tableName string) bool { func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool { var ( count int - databaseName string + databaseName = c.CurrentDatabase(scope) ) - c.CurrentDatabase(scope, &databaseName) c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName) return count > 0 } @@ -110,6 +108,7 @@ func (commonDialect) RawScanString(scope *Scope, scanPtr *string, query string, scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr)) } -func (commonDialect) CurrentDatabase(scope *Scope, name *string) { - scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(name)) +func (commonDialect) CurrentDatabase(scope *Scope) (name string) { + scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(&name)) + return } diff --git a/dialect.go b/dialect.go index 5d17b545..926f8a11 100644 --- a/dialect.go +++ b/dialect.go @@ -17,7 +17,7 @@ type Dialect interface { HasColumn(scope *Scope, tableName string, columnName string) bool HasIndex(scope *Scope, tableName string, indexName string) bool RemoveIndex(scope *Scope, indexName string) - CurrentDatabase(scope *Scope, name *string) + CurrentDatabase(scope *Scope) string } func NewDialect(driver string) Dialect { diff --git a/foundation.go b/foundation.go index 0413360e..07968ad3 100644 --- a/foundation.go +++ b/foundation.go @@ -77,6 +77,7 @@ func (s foundation) HasIndex(scope *Scope, tableName string, indexName string) b return count > 0 } -func (s foundation) CurrentDatabase(scope *Scope, name *string) { - s.RawScanString(scope, name, "SELECT CURRENT_SCHEMA") +func (s foundation) CurrentDatabase(scope *Scope) (name string) { + s.RawScanString(scope, &name, "SELECT CURRENT_SCHEMA") + return } diff --git a/main.go b/main.go index 7e58e592..df3c29b4 100644 --- a/main.go +++ b/main.go @@ -434,10 +434,12 @@ func (s *DB) RemoveIndex(indexName string) *DB { return scope.db } -func (s *DB) CurrentDatabase(name *string) *DB { - scope := s.clone().NewScope(s.Value) - s.dialect.CurrentDatabase(scope, name) - return scope.db +func (s *DB) CurrentDatabase() string { + var ( + scope = s.clone().NewScope(s.Value) + name = s.dialect.CurrentDatabase(scope) + ) + return name } /* diff --git a/mssql.go b/mssql.go index 49d17fed..a9bd1e52 100644 --- a/mssql.go +++ b/mssql.go @@ -53,9 +53,8 @@ func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { func (s mssql) HasTable(scope *Scope, tableName string) bool { var ( count int - databaseName string + databaseName = s.CurrentDatabase(scope) ) - s.CurrentDatabase(scope, &databaseName) s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, databaseName) return count > 0 } @@ -63,9 +62,8 @@ func (s mssql) HasTable(scope *Scope, tableName string) bool { func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool { var ( count int - databaseName string + databaseName = s.CurrentDatabase(scope) ) - s.CurrentDatabase(scope, &databaseName) s.RawScanInt(scope, &count, "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName) return count > 0 } @@ -76,6 +74,7 @@ func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool { return count > 0 } -func (s mssql) CurrentDatabase(scope *Scope, name *string) { - s.RawScanString(scope, name, "SELECT DB_NAME() AS [Current Database]") +func (s mssql) CurrentDatabase(scope *Scope) (name string) { + s.RawScanString(scope, &name, "SELECT DB_NAME() AS [Current Database]") + return } diff --git a/mysql.go b/mysql.go index 7fadb5af..9e1d56d3 100644 --- a/mysql.go +++ b/mysql.go @@ -64,6 +64,7 @@ func (mysql) SelectFromDummyTable() string { return "FROM DUAL" } -func (s mysql) CurrentDatabase(scope *Scope, name *string) { - s.RawScanString(scope, name, "SELECT DATABASE()") +func (s mysql) CurrentDatabase(scope *Scope) (name string) { + s.RawScanString(scope, &name, "SELECT DATABASE()") + return } diff --git a/postgres.go b/postgres.go index 9d1d776c..357f9842 100644 --- a/postgres.go +++ b/postgres.go @@ -85,8 +85,9 @@ func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) boo return count > 0 } -func (s postgres) CurrentDatabase(scope *Scope, name *string) { - s.RawScanString(scope, name, "SELECT CURRENT_DATABASE()") +func (s postgres) CurrentDatabase(scope *Scope) (name string) { + s.RawScanString(scope, &name, "SELECT CURRENT_DATABASE()") + return } var hstoreType = reflect.TypeOf(Hstore{}) diff --git a/query_test.go b/query_test.go index bd13ec18..ff022dbb 100644 --- a/query_test.go +++ b/query_test.go @@ -582,12 +582,12 @@ func TestSelectWithArrayInput(t *testing.T) { func TestCurrentDatabase(t *testing.T) { DB.LogMode(true) - var name string - if err := DB.CurrentDatabase(&name).Error; err != nil { + databaseName := DB.CurrentDatabase() + if err := DB.Error; err != nil { t.Errorf("Problem getting current db name: %s", err) } - if name == "" { + if databaseName == "" { t.Errorf("Current db name returned empty; this should never happen!") } - t.Logf("Got current db name: %v", name) + t.Logf("Got current db name: %v", databaseName) } diff --git a/sqlite3.go b/sqlite3.go index 96183b77..4584642e 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -62,7 +62,7 @@ func (sqlite3) RemoveIndex(scope *Scope, indexName string) { scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error) } -func (sqlite3) CurrentDatabase(scope *Scope, name *string) { +func (sqlite3) CurrentDatabase(scope *Scope) (name string) { var ( ifaces = make([]interface{}, 3) pointers = make([]*string, 3) @@ -75,6 +75,7 @@ func (sqlite3) CurrentDatabase(scope *Scope, name *string) { return } if pointers[1] != nil { - *name = *pointers[1] + name = *pointers[1] } + return }