mirror of https://github.com/go-gorm/gorm.git
`CurrentDatabase' determines current dbname by querying the database.
Preserves the gorm-style query API.
This commit is contained in:
parent
d21eed4b66
commit
70725f9d77
|
@ -3,7 +3,6 @@ package gorm
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -69,24 +68,23 @@ func (commonDialect) Quote(key string) string {
|
||||||
return fmt.Sprintf(`"%s"`, key)
|
return fmt.Sprintf(`"%s"`, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (commonDialect) databaseName(scope *Scope) string {
|
|
||||||
from := strings.LastIndex(scope.db.parent.source, "/") + 1
|
|
||||||
to := strings.LastIndex(scope.db.parent.source, "?")
|
|
||||||
if to == -1 {
|
|
||||||
to = len(scope.db.parent.source)
|
|
||||||
}
|
|
||||||
return scope.db.parent.source[from:to]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c commonDialect) HasTable(scope *Scope, tableName string) bool {
|
func (c commonDialect) HasTable(scope *Scope, tableName string) bool {
|
||||||
var count int
|
var (
|
||||||
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, c.databaseName(scope)).Row().Scan(&count)
|
count int
|
||||||
|
databaseName string
|
||||||
|
)
|
||||||
|
c.CurrentDatabase(scope, &databaseName)
|
||||||
|
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, databaseName).Row().Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
||||||
var count int
|
var (
|
||||||
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", c.databaseName(scope), tableName, columnName).Row().Scan(&count)
|
count int
|
||||||
|
databaseName string
|
||||||
|
)
|
||||||
|
c.CurrentDatabase(scope, &databaseName)
|
||||||
|
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName).Row().Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -99,3 +97,7 @@ func (commonDialect) HasIndex(scope *Scope, tableName string, indexName string)
|
||||||
func (commonDialect) RemoveIndex(scope *Scope, indexName string) {
|
func (commonDialect) RemoveIndex(scope *Scope, indexName string) {
|
||||||
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName()))
|
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (commonDialect) CurrentDatabase(scope *Scope, name *string) {
|
||||||
|
scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(name))
|
||||||
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@ type Dialect interface {
|
||||||
HasColumn(scope *Scope, tableName string, columnName string) bool
|
HasColumn(scope *Scope, tableName string, columnName string) bool
|
||||||
HasIndex(scope *Scope, tableName string, indexName string) bool
|
HasIndex(scope *Scope, tableName string, indexName string) bool
|
||||||
RemoveIndex(scope *Scope, indexName string)
|
RemoveIndex(scope *Scope, indexName string)
|
||||||
|
CurrentDatabase(scope *Scope, name *string)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDialect(driver string) Dialect {
|
func NewDialect(driver string) Dialect {
|
||||||
|
|
|
@ -76,3 +76,7 @@ func (foundation) HasIndex(scope *Scope, tableName string, indexName string) boo
|
||||||
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.indexes WHERE table_schema = current_schema AND table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count)
|
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.indexes WHERE table_schema = current_schema AND table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (foundation) CurrentDatabase(scope *Scope, name *string) {
|
||||||
|
scope.Err(scope.NewDB().Raw("SELECT CURRENT_SCHEMA").Row().Scan(name))
|
||||||
|
}
|
||||||
|
|
6
main.go
6
main.go
|
@ -429,6 +429,12 @@ func (s *DB) RemoveIndex(indexName string) *DB {
|
||||||
return scope.db
|
return scope.db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DB) CurrentDatabase(name *string) *DB {
|
||||||
|
scope := s.clone().NewScope(s.Value)
|
||||||
|
s.dialect.CurrentDatabase(scope, name)
|
||||||
|
return scope.db
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
Add foreign key to the given scope
|
Add foreign key to the given scope
|
||||||
|
|
||||||
|
|
32
mssql.go
32
mssql.go
|
@ -3,7 +3,6 @@ package gorm
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -51,26 +50,23 @@ func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||||
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String()))
|
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mssql) databaseName(scope *Scope) string {
|
|
||||||
dbStr := strings.Split(scope.db.parent.source, ";")
|
|
||||||
for _, value := range dbStr {
|
|
||||||
s := strings.Split(value, "=")
|
|
||||||
if s[0] == "database" {
|
|
||||||
return s[1]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s mssql) HasTable(scope *Scope, tableName string) bool {
|
func (s mssql) HasTable(scope *Scope, tableName string) bool {
|
||||||
var count int
|
var (
|
||||||
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.databaseName(scope)).Row().Scan(&count)
|
count int
|
||||||
|
databaseName string
|
||||||
|
)
|
||||||
|
s.CurrentDatabase(scope, &databaseName)
|
||||||
|
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, databaseName).Row().Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
||||||
var count int
|
var (
|
||||||
scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count)
|
count int
|
||||||
|
databaseName string
|
||||||
|
)
|
||||||
|
s.CurrentDatabase(scope, &databaseName)
|
||||||
|
scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName).Row().Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -79,3 +75,7 @@ func (mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||||
scope.NewDB().Raw("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Row().Scan(&count)
|
scope.NewDB().Raw("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Row().Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (mssql) CurrentDatabase(scope *Scope, name *string) {
|
||||||
|
scope.Err(scope.NewDB().Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(name))
|
||||||
|
}
|
||||||
|
|
4
mysql.go
4
mysql.go
|
@ -63,3 +63,7 @@ func (mysql) Quote(key string) string {
|
||||||
func (mysql) SelectFromDummyTable() string {
|
func (mysql) SelectFromDummyTable() string {
|
||||||
return "FROM DUAL"
|
return "FROM DUAL"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (mysql) CurrentDatabase(scope *Scope, name *string) {
|
||||||
|
scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(name))
|
||||||
|
}
|
||||||
|
|
|
@ -85,6 +85,10 @@ func (postgres) HasIndex(scope *Scope, tableName string, indexName string) bool
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (postgres) CurrentDatabase(scope *Scope, name *string) {
|
||||||
|
scope.Err(scope.NewDB().Raw("SELECT CURRENT_DATABASE()").Row().Scan(name))
|
||||||
|
}
|
||||||
|
|
||||||
var hstoreType = reflect.TypeOf(Hstore{})
|
var hstoreType = reflect.TypeOf(Hstore{})
|
||||||
|
|
||||||
type Hstore map[string]*string
|
type Hstore map[string]*string
|
||||||
|
|
|
@ -579,3 +579,15 @@ func TestSelectWithArrayInput(t *testing.T) {
|
||||||
t.Errorf("Should have selected both age and name")
|
t.Errorf("Should have selected both age and name")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCurrentDatabase(t *testing.T) {
|
||||||
|
DB.LogMode(true)
|
||||||
|
var name string
|
||||||
|
if err := DB.CurrentDatabase(&name).Error; err != nil {
|
||||||
|
t.Errorf("Problem getting current db name: %s", err)
|
||||||
|
}
|
||||||
|
if name == "" {
|
||||||
|
t.Errorf("Current db name returned empty; this should never happen!")
|
||||||
|
}
|
||||||
|
t.Logf("Got current db name: %v", name)
|
||||||
|
}
|
||||||
|
|
17
sqlite3.go
17
sqlite3.go
|
@ -61,3 +61,20 @@ func (sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||||
func (sqlite3) RemoveIndex(scope *Scope, indexName string) {
|
func (sqlite3) RemoveIndex(scope *Scope, indexName string) {
|
||||||
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName))
|
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (sqlite3) CurrentDatabase(scope *Scope, name *string) {
|
||||||
|
var (
|
||||||
|
ifaces = make([]interface{}, 3)
|
||||||
|
pointers = make([]*string, 3)
|
||||||
|
i int
|
||||||
|
)
|
||||||
|
for i = 0; i < 3; i++ {
|
||||||
|
ifaces[i] = &pointers[i]
|
||||||
|
}
|
||||||
|
if err := scope.NewDB().Raw("PRAGMA database_list").Row().Scan(ifaces...); scope.Err(err) != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if pointers[1] != nil {
|
||||||
|
*name = *pointers[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue