From 70725f9d774ea9436916c98b4dc1b7dfbe6f43fc Mon Sep 17 00:00:00 2001 From: Jay Taylor Date: Thu, 6 Aug 2015 12:37:26 -0700 Subject: [PATCH 1/2] `CurrentDatabase' determines current dbname by querying the database. Preserves the gorm-style query API. --- common_dialect.go | 30 ++++++++++++++++-------------- dialect.go | 1 + foundation.go | 4 ++++ main.go | 6 ++++++ mssql.go | 32 ++++++++++++++++---------------- mysql.go | 4 ++++ postgres.go | 4 ++++ query_test.go | 12 ++++++++++++ sqlite3.go | 17 +++++++++++++++++ 9 files changed, 80 insertions(+), 30 deletions(-) diff --git a/common_dialect.go b/common_dialect.go index 3b646869..d7546ede 100644 --- a/common_dialect.go +++ b/common_dialect.go @@ -3,7 +3,6 @@ package gorm import ( "fmt" "reflect" - "strings" "time" ) @@ -69,24 +68,23 @@ func (commonDialect) Quote(key string) string { return fmt.Sprintf(`"%s"`, key) } -func (commonDialect) databaseName(scope *Scope) string { - from := strings.LastIndex(scope.db.parent.source, "/") + 1 - to := strings.LastIndex(scope.db.parent.source, "?") - if to == -1 { - to = len(scope.db.parent.source) - } - return scope.db.parent.source[from:to] -} - func (c commonDialect) HasTable(scope *Scope, tableName string) bool { - var count int - scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, c.databaseName(scope)).Row().Scan(&count) + var ( + count int + databaseName string + ) + c.CurrentDatabase(scope, &databaseName) + scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, databaseName).Row().Scan(&count) return count > 0 } func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool { - var count int - scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", c.databaseName(scope), tableName, columnName).Row().Scan(&count) + var ( + count int + databaseName string + ) + c.CurrentDatabase(scope, &databaseName) + scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName).Row().Scan(&count) return count > 0 } @@ -99,3 +97,7 @@ func (commonDialect) HasIndex(scope *Scope, tableName string, indexName string) func (commonDialect) RemoveIndex(scope *Scope, indexName string) { scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())) } + +func (commonDialect) CurrentDatabase(scope *Scope, name *string) { + scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(name)) +} diff --git a/dialect.go b/dialect.go index f3221075..5d17b545 100644 --- a/dialect.go +++ b/dialect.go @@ -17,6 +17,7 @@ type Dialect interface { HasColumn(scope *Scope, tableName string, columnName string) bool HasIndex(scope *Scope, tableName string, indexName string) bool RemoveIndex(scope *Scope, indexName string) + CurrentDatabase(scope *Scope, name *string) } func NewDialect(driver string) Dialect { diff --git a/foundation.go b/foundation.go index a9c8f500..8f68c720 100644 --- a/foundation.go +++ b/foundation.go @@ -76,3 +76,7 @@ func (foundation) HasIndex(scope *Scope, tableName string, indexName string) boo scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.indexes WHERE table_schema = current_schema AND table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count) return count > 0 } + +func (foundation) CurrentDatabase(scope *Scope, name *string) { + scope.Err(scope.NewDB().Raw("SELECT CURRENT_SCHEMA").Row().Scan(name)) +} diff --git a/main.go b/main.go index 30802205..7005f6b5 100644 --- a/main.go +++ b/main.go @@ -429,6 +429,12 @@ func (s *DB) RemoveIndex(indexName string) *DB { return scope.db } +func (s *DB) CurrentDatabase(name *string) *DB { + scope := s.clone().NewScope(s.Value) + s.dialect.CurrentDatabase(scope, name) + return scope.db +} + /* Add foreign key to the given scope diff --git a/mssql.go b/mssql.go index c44541c7..b2633292 100644 --- a/mssql.go +++ b/mssql.go @@ -3,7 +3,6 @@ package gorm import ( "fmt" "reflect" - "strings" "time" ) @@ -51,26 +50,23 @@ func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String())) } -func (mssql) databaseName(scope *Scope) string { - dbStr := strings.Split(scope.db.parent.source, ";") - for _, value := range dbStr { - s := strings.Split(value, "=") - if s[0] == "database" { - return s[1] - } - } - return "" -} - func (s mssql) HasTable(scope *Scope, tableName string) bool { - var count int - scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.databaseName(scope)).Row().Scan(&count) + var ( + count int + databaseName string + ) + s.CurrentDatabase(scope, &databaseName) + scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, databaseName).Row().Scan(&count) return count > 0 } func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool { - var count int - scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count) + var ( + count int + databaseName string + ) + s.CurrentDatabase(scope, &databaseName) + scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName).Row().Scan(&count) return count > 0 } @@ -79,3 +75,7 @@ func (mssql) HasIndex(scope *Scope, tableName string, indexName string) bool { scope.NewDB().Raw("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Row().Scan(&count) return count > 0 } + +func (mssql) CurrentDatabase(scope *Scope, name *string) { + scope.Err(scope.NewDB().Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(name)) +} diff --git a/mysql.go b/mysql.go index a5e4a459..025f87a5 100644 --- a/mysql.go +++ b/mysql.go @@ -63,3 +63,7 @@ func (mysql) Quote(key string) string { func (mysql) SelectFromDummyTable() string { return "FROM DUAL" } + +func (mysql) CurrentDatabase(scope *Scope, name *string) { + scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(name)) +} diff --git a/postgres.go b/postgres.go index 4218e1ba..7374af9c 100644 --- a/postgres.go +++ b/postgres.go @@ -85,6 +85,10 @@ func (postgres) HasIndex(scope *Scope, tableName string, indexName string) bool return count > 0 } +func (postgres) CurrentDatabase(scope *Scope, name *string) { + scope.Err(scope.NewDB().Raw("SELECT CURRENT_DATABASE()").Row().Scan(name)) +} + var hstoreType = reflect.TypeOf(Hstore{}) type Hstore map[string]*string diff --git a/query_test.go b/query_test.go index 580d06c4..bd13ec18 100644 --- a/query_test.go +++ b/query_test.go @@ -579,3 +579,15 @@ func TestSelectWithArrayInput(t *testing.T) { t.Errorf("Should have selected both age and name") } } + +func TestCurrentDatabase(t *testing.T) { + DB.LogMode(true) + var name string + if err := DB.CurrentDatabase(&name).Error; err != nil { + t.Errorf("Problem getting current db name: %s", err) + } + if name == "" { + t.Errorf("Current db name returned empty; this should never happen!") + } + t.Logf("Got current db name: %v", name) +} diff --git a/sqlite3.go b/sqlite3.go index afe70e3a..a73d2379 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -61,3 +61,20 @@ func (sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool { func (sqlite3) RemoveIndex(scope *Scope, indexName string) { scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)) } + +func (sqlite3) CurrentDatabase(scope *Scope, name *string) { + var ( + ifaces = make([]interface{}, 3) + pointers = make([]*string, 3) + i int + ) + for i = 0; i < 3; i++ { + ifaces[i] = &pointers[i] + } + if err := scope.NewDB().Raw("PRAGMA database_list").Row().Scan(ifaces...); scope.Err(err) != nil { + return + } + if pointers[1] != nil { + *name = *pointers[1] + } +} From beeb040c62ae25ca2f1b7698546f5d18c1bfbf85 Mon Sep 17 00:00:00 2001 From: Jay Taylor Date: Tue, 11 Aug 2015 08:59:59 -0700 Subject: [PATCH 2/2] Reworked CurrentDatabase API to return the name instead of `*gorm.DB'. --- common_dialect.go | 11 +++++------ dialect.go | 2 +- foundation.go | 5 +++-- main.go | 10 ++++++---- mssql.go | 11 +++++------ mysql.go | 5 +++-- postgres.go | 5 +++-- query_test.go | 8 ++++---- sqlite3.go | 5 +++-- 9 files changed, 33 insertions(+), 29 deletions(-) diff --git a/common_dialect.go b/common_dialect.go index d7546ede..5613f419 100644 --- a/common_dialect.go +++ b/common_dialect.go @@ -71,9 +71,8 @@ func (commonDialect) Quote(key string) string { func (c commonDialect) HasTable(scope *Scope, tableName string) bool { var ( count int - databaseName string + databaseName = c.CurrentDatabase(scope) ) - c.CurrentDatabase(scope, &databaseName) scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, databaseName).Row().Scan(&count) return count > 0 } @@ -81,9 +80,8 @@ func (c commonDialect) HasTable(scope *Scope, tableName string) bool { func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool { var ( count int - databaseName string + databaseName = c.CurrentDatabase(scope) ) - c.CurrentDatabase(scope, &databaseName) scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName).Row().Scan(&count) return count > 0 } @@ -98,6 +96,7 @@ func (commonDialect) RemoveIndex(scope *Scope, indexName string) { scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())) } -func (commonDialect) CurrentDatabase(scope *Scope, name *string) { - scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(name)) +func (commonDialect) CurrentDatabase(scope *Scope) (name string) { + scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(&name)) + return } diff --git a/dialect.go b/dialect.go index 5d17b545..926f8a11 100644 --- a/dialect.go +++ b/dialect.go @@ -17,7 +17,7 @@ type Dialect interface { HasColumn(scope *Scope, tableName string, columnName string) bool HasIndex(scope *Scope, tableName string, indexName string) bool RemoveIndex(scope *Scope, indexName string) - CurrentDatabase(scope *Scope, name *string) + CurrentDatabase(scope *Scope) string } func NewDialect(driver string) Dialect { diff --git a/foundation.go b/foundation.go index 8f68c720..427f8d19 100644 --- a/foundation.go +++ b/foundation.go @@ -77,6 +77,7 @@ func (foundation) HasIndex(scope *Scope, tableName string, indexName string) boo return count > 0 } -func (foundation) CurrentDatabase(scope *Scope, name *string) { - scope.Err(scope.NewDB().Raw("SELECT CURRENT_SCHEMA").Row().Scan(name)) +func (foundation) CurrentDatabase(scope *Scope) (name string) { + scope.Err(scope.NewDB().Raw("SELECT CURRENT_SCHEMA").Row().Scan(&name)) + return } diff --git a/main.go b/main.go index 7005f6b5..56727e35 100644 --- a/main.go +++ b/main.go @@ -429,10 +429,12 @@ func (s *DB) RemoveIndex(indexName string) *DB { return scope.db } -func (s *DB) CurrentDatabase(name *string) *DB { - scope := s.clone().NewScope(s.Value) - s.dialect.CurrentDatabase(scope, name) - return scope.db +func (s *DB) CurrentDatabase() string { + var ( + scope = s.clone().NewScope(s.Value) + name = s.dialect.CurrentDatabase(scope) + ) + return name } /* diff --git a/mssql.go b/mssql.go index b2633292..22b36ba2 100644 --- a/mssql.go +++ b/mssql.go @@ -53,9 +53,8 @@ func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { func (s mssql) HasTable(scope *Scope, tableName string) bool { var ( count int - databaseName string + databaseName = s.CurrentDatabase(scope) ) - s.CurrentDatabase(scope, &databaseName) scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, databaseName).Row().Scan(&count) return count > 0 } @@ -63,9 +62,8 @@ func (s mssql) HasTable(scope *Scope, tableName string) bool { func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool { var ( count int - databaseName string + databaseName = s.CurrentDatabase(scope) ) - s.CurrentDatabase(scope, &databaseName) scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName).Row().Scan(&count) return count > 0 } @@ -76,6 +74,7 @@ func (mssql) HasIndex(scope *Scope, tableName string, indexName string) bool { return count > 0 } -func (mssql) CurrentDatabase(scope *Scope, name *string) { - scope.Err(scope.NewDB().Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(name)) +func (mssql) CurrentDatabase(scope *Scope) (name string) { + scope.Err(scope.NewDB().Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(&name)) + return } diff --git a/mysql.go b/mysql.go index 025f87a5..66c4a3a0 100644 --- a/mysql.go +++ b/mysql.go @@ -64,6 +64,7 @@ func (mysql) SelectFromDummyTable() string { return "FROM DUAL" } -func (mysql) CurrentDatabase(scope *Scope, name *string) { - scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(name)) +func (mysql) CurrentDatabase(scope *Scope) (name string) { + scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(&name)) + return } diff --git a/postgres.go b/postgres.go index 7374af9c..00527dcb 100644 --- a/postgres.go +++ b/postgres.go @@ -85,8 +85,9 @@ func (postgres) HasIndex(scope *Scope, tableName string, indexName string) bool return count > 0 } -func (postgres) CurrentDatabase(scope *Scope, name *string) { - scope.Err(scope.NewDB().Raw("SELECT CURRENT_DATABASE()").Row().Scan(name)) +func (postgres) CurrentDatabase(scope *Scope) (name string) { + scope.Err(scope.NewDB().Raw("SELECT CURRENT_DATABASE()").Row().Scan(&name)) + return } var hstoreType = reflect.TypeOf(Hstore{}) diff --git a/query_test.go b/query_test.go index bd13ec18..ff022dbb 100644 --- a/query_test.go +++ b/query_test.go @@ -582,12 +582,12 @@ func TestSelectWithArrayInput(t *testing.T) { func TestCurrentDatabase(t *testing.T) { DB.LogMode(true) - var name string - if err := DB.CurrentDatabase(&name).Error; err != nil { + databaseName := DB.CurrentDatabase() + if err := DB.Error; err != nil { t.Errorf("Problem getting current db name: %s", err) } - if name == "" { + if databaseName == "" { t.Errorf("Current db name returned empty; this should never happen!") } - t.Logf("Got current db name: %v", name) + t.Logf("Got current db name: %v", databaseName) } diff --git a/sqlite3.go b/sqlite3.go index a73d2379..c2ebffb4 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -62,7 +62,7 @@ func (sqlite3) RemoveIndex(scope *Scope, indexName string) { scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)) } -func (sqlite3) CurrentDatabase(scope *Scope, name *string) { +func (sqlite3) CurrentDatabase(scope *Scope) (name string) { var ( ifaces = make([]interface{}, 3) pointers = make([]*string, 3) @@ -75,6 +75,7 @@ func (sqlite3) CurrentDatabase(scope *Scope, name *string) { return } if pointers[1] != nil { - *name = *pointers[1] + name = *pointers[1] } + return }