Merge pull request #600 from jaytaylor/jay/current_database

`CurrentDatabase' implementation.
This commit is contained in:
Jinzhu 2015-08-12 22:03:49 +08:00
commit 7fcb3e889f
9 changed files with 84 additions and 30 deletions

View File

@ -3,7 +3,6 @@ package gorm
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"strings"
"time" "time"
) )
@ -69,24 +68,21 @@ func (commonDialect) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key) return fmt.Sprintf(`"%s"`, key)
} }
func (commonDialect) databaseName(scope *Scope) string {
from := strings.LastIndex(scope.db.parent.source, "/") + 1
to := strings.LastIndex(scope.db.parent.source, "?")
if to == -1 {
to = len(scope.db.parent.source)
}
return scope.db.parent.source[from:to]
}
func (c commonDialect) HasTable(scope *Scope, tableName string) bool { func (c commonDialect) HasTable(scope *Scope, tableName string) bool {
var count int var (
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, c.databaseName(scope)).Row().Scan(&count) count int
databaseName = c.CurrentDatabase(scope)
)
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, databaseName).Row().Scan(&count)
return count > 0 return count > 0
} }
func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool { func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int var (
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", c.databaseName(scope), tableName, columnName).Row().Scan(&count) count int
databaseName = c.CurrentDatabase(scope)
)
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName).Row().Scan(&count)
return count > 0 return count > 0
} }
@ -99,3 +95,8 @@ func (commonDialect) HasIndex(scope *Scope, tableName string, indexName string)
func (commonDialect) RemoveIndex(scope *Scope, indexName string) { func (commonDialect) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())) scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName()))
} }
func (commonDialect) CurrentDatabase(scope *Scope) (name string) {
scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(&name))
return
}

View File

@ -17,6 +17,7 @@ type Dialect interface {
HasColumn(scope *Scope, tableName string, columnName string) bool HasColumn(scope *Scope, tableName string, columnName string) bool
HasIndex(scope *Scope, tableName string, indexName string) bool HasIndex(scope *Scope, tableName string, indexName string) bool
RemoveIndex(scope *Scope, indexName string) RemoveIndex(scope *Scope, indexName string)
CurrentDatabase(scope *Scope) string
} }
func NewDialect(driver string) Dialect { func NewDialect(driver string) Dialect {

View File

@ -76,3 +76,8 @@ func (foundation) HasIndex(scope *Scope, tableName string, indexName string) boo
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.indexes WHERE table_schema = current_schema AND table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count) scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.indexes WHERE table_schema = current_schema AND table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count)
return count > 0 return count > 0
} }
func (foundation) CurrentDatabase(scope *Scope) (name string) {
scope.Err(scope.NewDB().Raw("SELECT CURRENT_SCHEMA").Row().Scan(&name))
return
}

View File

@ -429,6 +429,14 @@ func (s *DB) RemoveIndex(indexName string) *DB {
return scope.db return scope.db
} }
func (s *DB) CurrentDatabase() string {
var (
scope = s.clone().NewScope(s.Value)
name = s.dialect.CurrentDatabase(scope)
)
return name
}
/* /*
Add foreign key to the given scope Add foreign key to the given scope

View File

@ -3,7 +3,6 @@ package gorm
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"strings"
"time" "time"
) )
@ -51,26 +50,21 @@ func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String())) panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String()))
} }
func (mssql) databaseName(scope *Scope) string {
dbStr := strings.Split(scope.db.parent.source, ";")
for _, value := range dbStr {
s := strings.Split(value, "=")
if s[0] == "database" {
return s[1]
}
}
return ""
}
func (s mssql) HasTable(scope *Scope, tableName string) bool { func (s mssql) HasTable(scope *Scope, tableName string) bool {
var count int var (
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.databaseName(scope)).Row().Scan(&count) count int
databaseName = s.CurrentDatabase(scope)
)
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, databaseName).Row().Scan(&count)
return count > 0 return count > 0
} }
func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool { func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int var (
scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count) count int
databaseName = s.CurrentDatabase(scope)
)
scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName).Row().Scan(&count)
return count > 0 return count > 0
} }
@ -79,3 +73,8 @@ func (mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
scope.NewDB().Raw("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Row().Scan(&count) scope.NewDB().Raw("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Row().Scan(&count)
return count > 0 return count > 0
} }
func (mssql) CurrentDatabase(scope *Scope) (name string) {
scope.Err(scope.NewDB().Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(&name))
return
}

View File

@ -63,3 +63,8 @@ func (mysql) Quote(key string) string {
func (mysql) SelectFromDummyTable() string { func (mysql) SelectFromDummyTable() string {
return "FROM DUAL" return "FROM DUAL"
} }
func (mysql) CurrentDatabase(scope *Scope) (name string) {
scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(&name))
return
}

View File

@ -85,6 +85,11 @@ func (postgres) HasIndex(scope *Scope, tableName string, indexName string) bool
return count > 0 return count > 0
} }
func (postgres) CurrentDatabase(scope *Scope) (name string) {
scope.Err(scope.NewDB().Raw("SELECT CURRENT_DATABASE()").Row().Scan(&name))
return
}
var hstoreType = reflect.TypeOf(Hstore{}) var hstoreType = reflect.TypeOf(Hstore{})
type Hstore map[string]*string type Hstore map[string]*string

View File

@ -579,3 +579,15 @@ func TestSelectWithArrayInput(t *testing.T) {
t.Errorf("Should have selected both age and name") t.Errorf("Should have selected both age and name")
} }
} }
func TestCurrentDatabase(t *testing.T) {
DB.LogMode(true)
databaseName := DB.CurrentDatabase()
if err := DB.Error; err != nil {
t.Errorf("Problem getting current db name: %s", err)
}
if databaseName == "" {
t.Errorf("Current db name returned empty; this should never happen!")
}
t.Logf("Got current db name: %v", databaseName)
}

View File

@ -61,3 +61,21 @@ func (sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
func (sqlite3) RemoveIndex(scope *Scope, indexName string) { func (sqlite3) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)) scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName))
} }
func (sqlite3) CurrentDatabase(scope *Scope) (name string) {
var (
ifaces = make([]interface{}, 3)
pointers = make([]*string, 3)
i int
)
for i = 0; i < 3; i++ {
ifaces[i] = &pointers[i]
}
if err := scope.NewDB().Raw("PRAGMA database_list").Row().Scan(ifaces...); scope.Err(err) != nil {
return
}
if pointers[1] != nil {
name = *pointers[1]
}
return
}