Refactor dialect

This commit is contained in:
Jinzhu 2016-02-15 14:09:24 +08:00
parent 6546ec3b5e
commit 4e8370e18b
11 changed files with 94 additions and 104 deletions

View File

@ -26,7 +26,7 @@ func TestCustomizeColumn(t *testing.T) {
DB.AutoMigrate(&CustomizeColumn{})
scope := DB.NewScope(&CustomizeColumn{})
if !scope.Dialect().HasColumn(scope, scope.TableName(), col) {
if !scope.Dialect().HasColumn(scope.TableName(), col) {
t.Errorf("CustomizeColumn should have column %s", col)
}

View File

@ -17,8 +17,7 @@ func TestDdlErrors(t *testing.T) {
}
}()
DB.HasTable("foobarbaz")
if DB.Error == nil {
if err := DB.Find(&User{}).Error; err == nil {
t.Errorf("Expected operation on closed db to produce an error, but err was nil")
}
}

View File

@ -10,6 +10,9 @@ import (
// Dialect interface contains behaviors that differ across SQL database
type Dialect interface {
// SetDB set db for dialect
SetDB(db *sql.DB)
// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
BindVar(i int) string
// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
@ -18,13 +21,13 @@ type Dialect interface {
DataTypeOf(field *StructField) string
// HasIndex check has index or not
HasIndex(scope *Scope, tableName string, indexName string) bool
HasIndex(tableName string, indexName string) bool
// RemoveIndex remove index
RemoveIndex(scope *Scope, indexName string)
RemoveIndex(tableName string, indexName string) error
// HasTable check has table or not
HasTable(scope *Scope, tableName string) bool
HasTable(tableName string) bool
// HasColumn check has column or not
HasColumn(scope *Scope, tableName string, columnName string) bool
HasColumn(tableName string, columnName string) bool
// LimitAndOffsetSQL return generate SQL with limit and offset, as mssql has special case
LimitAndOffsetSQL(limit, offset int) string
@ -36,12 +39,17 @@ type Dialect interface {
var dialectsMap = map[string]Dialect{}
func newDialect(name string) Dialect {
if dialect, ok := dialectsMap[name]; ok {
func newDialect(name string, db *sql.DB) Dialect {
if value, ok := dialectsMap[name]; ok {
dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect)
dialect.SetDB(db)
return dialect
}
fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name)
return &commonDialect{}
commontDialect := &commonDialect{}
commontDialect.SetDB(db)
return commontDialect
}
// RegisterDialect register new dialect

View File

@ -1,18 +1,25 @@
package gorm
import (
"database/sql"
"fmt"
"reflect"
"strings"
"time"
)
type commonDialect struct{}
type commonDialect struct {
db *sql.DB
}
func init() {
RegisterDialect("common", &commonDialect{})
}
func (s *commonDialect) SetDB(db *sql.DB) {
s.db = db
}
func (commonDialect) BindVar(i int) string {
return "$$" // ?
}
@ -73,51 +80,31 @@ func (commonDialect) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {
var (
count int
databaseName = c.currentDatabase(scope)
)
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", databaseName, tableName, indexName)
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.currentDatabase(), tableName, indexName).Scan(&count)
return count > 0
}
func (commonDialect) RemoveIndex(scope *Scope, indexName string) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Error)
func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName))
return err
}
func (c commonDialect) HasTable(scope *Scope, tableName string) bool {
var (
count int
databaseName = c.currentDatabase(scope)
)
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", databaseName, tableName)
func (s commonDialect) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.currentDatabase(), tableName).Scan(&count)
return count > 0
}
func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool {
var (
count int
databaseName = c.currentDatabase(scope)
)
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName)
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.currentDatabase(), tableName, columnName).Scan(&count)
return count > 0
}
// RawScanInt scans the first column of the first row into the `scan' int pointer.
// This function captures raw query errors and propagates them to the original scope.
func (commonDialect) RawScanInt(scope *Scope, scanPtr *int, query string, args ...interface{}) {
scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr))
}
// RawScanString scans the first column of the first row into the `scan' string pointer.
// This function captures raw query errors and propagates them to the original scope.
func (commonDialect) RawScanString(scope *Scope, scanPtr *string, query string, args ...interface{}) {
scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr))
}
func (commonDialect) currentDatabase(scope *Scope) (name string) {
scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(&name))
func (s commonDialect) currentDatabase() (name string) {
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
return
}

View File

@ -67,32 +67,31 @@ func (mssql) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
func (s mssql) HasIndex(tableName string, indexName string) bool {
var count int
s.RawScanInt(scope, &count, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName)
s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
return count > 0
}
func (s mssql) HasTable(scope *Scope, tableName string) bool {
var (
count int
databaseName = s.currentDatabase(scope)
)
s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, databaseName)
func (s mssql) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
return err
}
func (s mssql) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.currentDatabase()).Scan(&count)
return count > 0
}
func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool {
var (
count int
databaseName = s.currentDatabase(scope)
)
s.RawScanInt(scope, &count, "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName)
func (s mssql) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.currentDatabase(), tableName, columnName).Scan(&count)
return count > 0
}
func (s mssql) currentDatabase(scope *Scope) (name string) {
s.RawScanString(scope, &name, "SELECT DB_NAME() AS [Current Database]")
func (s mssql) currentDatabase() (name string) {
s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name)
return
}

View File

@ -88,8 +88,13 @@ func (mysql) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s mysql) currentDatabase(scope *Scope) (name string) {
s.RawScanString(scope, &name, "SELECT DATABASE()")
func (s mysql) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
return err
}
func (s mysql) currentDatabase() (name string) {
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
return
}

View File

@ -77,30 +77,26 @@ func (postgres) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool {
func (s postgres) HasIndex(tableName string, indexName string) bool {
var count int
s.RawScanInt(scope, &count, "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName)
s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2", tableName, indexName).Scan(&count)
return count > 0
}
func (postgres) RemoveIndex(scope *Scope, indexName string) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
}
func (s postgres) HasTable(scope *Scope, tableName string) bool {
func (s postgres) HasTable(tableName string) bool {
var count int
s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_type = 'BASE TABLE'", tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE'", tableName).Scan(&count)
return count > 0
}
func (s postgres) HasColumn(scope *Scope, tableName string, columnName string) bool {
func (s postgres) HasColumn(tableName string, columnName string) bool {
var count int
s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = ? AND column_name = ?", tableName, columnName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2", tableName, columnName).Scan(&count)
return count > 0
}
func (s postgres) currentDatabase(scope *Scope) (name string) {
s.RawScanString(scope, &name, "SELECT CURRENT_DATABASE()")
func (s postgres) currentDatabase() (name string) {
s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name)
return
}

View File

@ -65,29 +65,25 @@ func (sqlite3) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
func (s sqlite3) HasIndex(tableName string, indexName string) bool {
var count int
s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName)
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count)
return count > 0
}
func (sqlite3) RemoveIndex(scope *Scope, indexName string) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
}
func (s sqlite3) HasTable(scope *Scope, tableName string) bool {
func (s sqlite3) HasTable(tableName string) bool {
var count int
s.RawScanInt(scope, &count, "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName)
s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count)
return count > 0
}
func (s sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool {
func (s sqlite3) HasColumn(tableName string, columnName string) bool {
var count int
s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%, \"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%' OR sql LIKE '%%,%v %%');\n", columnName, columnName, columnName, columnName, columnName, columnName), tableName)
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%, \"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%' OR sql LIKE '%%,%v %%');\n", columnName, columnName, columnName, columnName, columnName, columnName), tableName).Scan(&count)
return count > 0
}
func (sqlite3) currentDatabase(scope *Scope) (name string) {
func (s sqlite3) currentDatabase() (name string) {
var (
ifaces = make([]interface{}, 3)
pointers = make([]*string, 3)
@ -96,7 +92,7 @@ func (sqlite3) currentDatabase(scope *Scope) (name string) {
for i = 0; i < 3; i++ {
ifaces[i] = &pointers[i]
}
if err := scope.NewDB().Raw("PRAGMA database_list").Row().Scan(ifaces...); scope.Err(err) != nil {
if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil {
return
}
if pointers[1] != nil {

View File

@ -62,7 +62,7 @@ func Open(dialect string, args ...interface{}) (*DB, error) {
}
db = DB{
dialect: newDialect(dialect),
dialect: newDialect(dialect, dbSql.(*sql.DB)),
logger: defaultLogger,
callbacks: defaultCallback,
source: source,
@ -430,7 +430,7 @@ func (s *DB) HasTable(value interface{}) bool {
tableName = scope.TableName()
}
has := scope.Dialect().HasTable(scope, tableName)
has := scope.Dialect().HasTable(tableName)
s.AddError(scope.db.Error)
return has
}
@ -531,7 +531,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
handler.Setup(field.Relationship, many2many, source, destination)
field.Relationship.JoinTableHandler = handler
if table := handler.Table(s); scope.Dialect().HasTable(scope, table) {
if table := handler.Table(s); scope.Dialect().HasTable(table) {
s.Table(table).AutoMigrate(handler)
}
}

View File

@ -31,7 +31,7 @@ func TestIndexes(t *testing.T) {
}
scope := DB.NewScope(&Email{})
if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") {
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
t.Errorf("Email should have index idx_email_email")
}
@ -39,7 +39,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Got error when tried to remove index: %+v", err)
}
if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") {
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
t.Errorf("Email's index idx_email_email should be deleted")
}
@ -47,7 +47,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Got error when tried to create index: %+v", err)
}
if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email should have index idx_email_email_and_user_id")
}
@ -55,7 +55,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Got error when tried to remove index: %+v", err)
}
if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
}
@ -63,7 +63,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Got error when tried to create index: %+v", err)
}
if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email should have index idx_email_email_and_user_id")
}
@ -85,7 +85,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Got error when tried to remove index: %+v", err)
}
if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
}
@ -117,11 +117,11 @@ func TestAutoMigration(t *testing.T) {
DB.Save(&BigEmail{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: time.Now()})
scope := DB.NewScope(&BigEmail{})
if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_agent") {
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") {
t.Errorf("Failed to create index")
}
if !scope.Dialect().HasIndex(scope, scope.TableName(), "uix_emails_registered_at") {
if !scope.Dialect().HasIndex(scope.TableName(), "uix_emails_registered_at") {
t.Errorf("Failed to create index")
}

View File

@ -515,7 +515,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
joinTableHandler := relationship.JoinTableHandler
joinTable := joinTableHandler.Table(scope.db)
if !scope.Dialect().HasTable(scope, joinTable) {
if !scope.Dialect().HasTable(joinTable) {
toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
var sqlTypes, primaryKeys []string
@ -586,7 +586,7 @@ func (scope *Scope) dropTable() *Scope {
}
func (scope *Scope) dropTableIfExists() *Scope {
if scope.Dialect().HasTable(scope, scope.TableName()) {
if scope.Dialect().HasTable(scope.TableName()) {
scope.dropTable()
}
return scope
@ -601,7 +601,7 @@ func (scope *Scope) dropColumn(column string) {
}
func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
if scope.Dialect().HasIndex(scope, scope.TableName(), indexName) {
if scope.Dialect().HasIndex(scope.TableName(), indexName) {
return
}
@ -626,18 +626,18 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on
}
func (scope *Scope) removeIndex(indexName string) {
scope.Dialect().RemoveIndex(scope, indexName)
scope.Dialect().RemoveIndex(scope.TableName(), indexName)
}
func (scope *Scope) autoMigrate() *Scope {
tableName := scope.TableName()
quotedTableName := scope.QuotedTableName()
if !scope.Dialect().HasTable(scope, tableName) {
if !scope.Dialect().HasTable(tableName) {
scope.createTable()
} else {
for _, field := range scope.GetModelStruct().StructFields {
if !scope.Dialect().HasColumn(scope, tableName, field.DBName) {
if !scope.Dialect().HasColumn(tableName, field.DBName) {
if field.IsNormal {
sqlTag := scope.Dialect().DataTypeOf(field)
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()