Simplify dialect definitions

This commit is contained in:
Jinzhu 2015-03-17 10:40:42 +08:00
parent faae729b20
commit a0848909c2
5 changed files with 46 additions and 147 deletions

View File

@ -9,19 +9,19 @@ import (
type commonDialect struct{} type commonDialect struct{}
func (s *commonDialect) BinVar(i int) string { func (commonDialect) BinVar(i int) string {
return "?" return "$$" // ?
} }
func (s *commonDialect) SupportLastInsertId() bool { func (commonDialect) SupportLastInsertId() bool {
return true return true
} }
func (s *commonDialect) HasTop() bool { func (commonDialect) HasTop() bool {
return false return false
} }
func (s *commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string { func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() { switch value.Kind() {
case reflect.Bool: case reflect.Bool:
return "BOOLEAN" return "BOOLEAN"
@ -57,19 +57,19 @@ func (s *commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool)
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String())) panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String()))
} }
func (s *commonDialect) ReturningStr(tableName, key string) string { func (commonDialect) ReturningStr(tableName, key string) string {
return "" return ""
} }
func (s *commonDialect) SelectFromDummyTable() string { func (commonDialect) SelectFromDummyTable() string {
return "" return ""
} }
func (s *commonDialect) Quote(key string) string { func (commonDialect) Quote(key string) string {
return fmt.Sprintf("`%s`", key) return fmt.Sprintf(`"%s"`, key)
} }
func (s *commonDialect) databaseName(scope *Scope) string { func (commonDialect) databaseName(scope *Scope) string {
from := strings.Index(scope.db.parent.source, "/") + 1 from := strings.Index(scope.db.parent.source, "/") + 1
to := strings.Index(scope.db.parent.source, "?") to := strings.Index(scope.db.parent.source, "?")
if to == -1 { if to == -1 {
@ -78,24 +78,24 @@ func (s *commonDialect) databaseName(scope *Scope) string {
return scope.db.parent.source[from:to] return scope.db.parent.source[from:to]
} }
func (s *commonDialect) HasTable(scope *Scope, tableName string) bool { func (c commonDialect) HasTable(scope *Scope, tableName string) bool {
var count int var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, s.databaseName(scope)).Row().Scan(&count) scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, c.databaseName(scope)).Row().Scan(&count)
return count > 0 return count > 0
} }
func (s *commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool { func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count) scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", c.databaseName(scope), tableName, columnName).Row().Scan(&count)
return count > 0 return count > 0
} }
func (s *commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool { func (commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS where table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count) scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS where table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count)
return count > 0 return count > 0
} }
func (s *commonDialect) RemoveIndex(scope *Scope, indexName string) { func (commonDialect) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())) scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName()))
} }

View File

@ -7,21 +7,15 @@ import (
"time" "time"
) )
type mssql struct{} type mssql struct {
commonDialect
func (s *mssql) BinVar(i int) string {
return "$$" // ?
} }
func (s *mssql) SupportLastInsertId() bool { func (mssql) HasTop() bool {
return true return true
} }
func (s *mssql) HasTop() bool { func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
return true
}
func (s *mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() { switch value.Kind() {
case reflect.Bool: case reflect.Bool:
return "bit" return "bit"
@ -57,19 +51,7 @@ func (s *mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String())) panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String()))
} }
func (s *mssql) ReturningStr(tableName, key string) string { func (mssql) databaseName(scope *Scope) string {
return ""
}
func (s *mssql) SelectFromDummyTable() string {
return ""
}
func (s *mssql) Quote(key string) string {
return fmt.Sprintf(" \"%s\"", key)
}
func (s *mssql) databaseName(scope *Scope) string {
dbStr := strings.Split(scope.db.parent.source, ";") dbStr := strings.Split(scope.db.parent.source, ";")
for _, value := range dbStr { for _, value := range dbStr {
s := strings.Split(value, "=") s := strings.Split(value, "=")
@ -80,24 +62,20 @@ func (s *mssql) databaseName(scope *Scope) string {
return "" return ""
} }
func (s *mssql) HasTable(scope *Scope, tableName string) bool { func (s mssql) HasTable(scope *Scope, tableName string) bool {
var count int var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.databaseName(scope)).Row().Scan(&count) scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.databaseName(scope)).Row().Scan(&count)
return count > 0 return count > 0
} }
func (s *mssql) HasColumn(scope *Scope, tableName string, columnName string) bool { func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int var count int
scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count) scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count)
return count > 0 return count > 0
} }
func (s *mssql) HasIndex(scope *Scope, tableName string, indexName string) bool { func (mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int var count int
scope.NewDB().Raw("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Row().Scan(&count) scope.NewDB().Raw("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Row().Scan(&count)
return count > 0 return count > 0
} }
func (s *mssql) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName()))
}

View File

@ -3,25 +3,14 @@ package gorm
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"strings"
"time" "time"
) )
type mysql struct{} type mysql struct {
commonDialect
func (s *mysql) BinVar(i int) string {
return "$$" // ?
} }
func (s *mysql) SupportLastInsertId() bool { func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
return true
}
func (s *mysql) HasTop() bool {
return false
}
func (s *mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() { switch value.Kind() {
case reflect.Bool: case reflect.Bool:
return "boolean" return "boolean"
@ -57,45 +46,10 @@ func (s *mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String())) panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
} }
func (s *mysql) ReturningStr(tableName, key string) string { func (mysql) Quote(key string) string {
return ""
}
func (s *mysql) SelectFromDummyTable() string {
return "FROM DUAL"
}
func (s *mysql) Quote(key string) string {
return fmt.Sprintf("`%s`", key) return fmt.Sprintf("`%s`", key)
} }
func (s *mysql) databaseName(scope *Scope) string { func (mysql) SelectFromDummyTable() string {
from := strings.Index(scope.db.parent.source, "/") + 1 return "FROM DUAL"
to := strings.Index(scope.db.parent.source, "?")
if to == -1 {
to = len(scope.db.parent.source)
}
return scope.db.parent.source[from:to]
}
func (s *mysql) HasTable(scope *Scope, tableName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = ? AND table_schema = ?", tableName, s.databaseName(scope)).Row().Scan(&count)
return count > 0
}
func (s *mysql) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count)
return count > 0
}
func (s *mysql) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS where table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count)
return count > 0
}
func (s *mysql) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName()))
} }

View File

@ -11,21 +11,18 @@ import (
) )
type postgres struct { type postgres struct {
commonDialect
} }
func (s *postgres) BinVar(i int) string { func (postgres) BinVar(i int) string {
return fmt.Sprintf("$%v", i) return fmt.Sprintf("$%v", i)
} }
func (s *postgres) SupportLastInsertId() bool { func (postgres) SupportLastInsertId() bool {
return false return false
} }
func (s *postgres) HasTop() bool { func (postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
return false
}
func (s *postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() { switch value.Kind() {
case reflect.Bool: case reflect.Bool:
return "boolean" return "boolean"
@ -62,35 +59,27 @@ func (s *postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) stri
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String())) panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String()))
} }
func (s *postgres) ReturningStr(tableName, key string) string { func (s postgres) ReturningStr(tableName, key string) string {
return fmt.Sprintf("RETURNING %v.%v", s.Quote(tableName), key) return fmt.Sprintf("RETURNING %v.%v", s.Quote(tableName), key)
} }
func (s *postgres) SelectFromDummyTable() string { func (postgres) HasTable(scope *Scope, tableName string) bool {
return ""
}
func (s *postgres) Quote(key string) string {
return fmt.Sprintf("\"%s\"", key)
}
func (s *postgres) HasTable(scope *Scope, tableName string) bool {
var count int var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_type = 'BASE TABLE'", tableName).Row().Scan(&count) scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_type = 'BASE TABLE'", tableName).Row().Scan(&count)
return count > 0 return count > 0
} }
func (s *postgres) HasColumn(scope *Scope, tableName string, columnName string) bool { func (postgres) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = ? AND column_name = ?", tableName, columnName).Row().Scan(&count) scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = ? AND column_name = ?", tableName, columnName).Row().Scan(&count)
return count > 0 return count > 0
} }
func (s *postgres) RemoveIndex(scope *Scope, indexName string) { func (postgres) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)) scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName))
} }
func (s *postgres) HasIndex(scope *Scope, tableName string, indexName string) bool { func (postgres) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int var count int
scope.NewDB().Raw("SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName).Row().Scan(&count) scope.NewDB().Raw("SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName).Row().Scan(&count)
return count > 0 return count > 0

View File

@ -6,21 +6,11 @@ import (
"time" "time"
) )
type sqlite3 struct{} type sqlite3 struct {
commonDialect
func (s *sqlite3) BinVar(i int) string {
return "$$" // ?
} }
func (s *sqlite3) SupportLastInsertId() bool { func (sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
return true
}
func (s *sqlite3) HasTop() bool {
return false
}
func (s *sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() { switch value.Kind() {
case reflect.Bool: case reflect.Bool:
return "bool" return "bool"
@ -50,36 +40,24 @@ func (s *sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) strin
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String())) panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
} }
func (s *sqlite3) ReturningStr(tableName, key string) string { func (sqlite3) HasTable(scope *Scope, tableName string) bool {
return ""
}
func (s *sqlite3) SelectFromDummyTable() string {
return ""
}
func (s *sqlite3) Quote(key string) string {
return fmt.Sprintf("\"%s\"", key)
}
func (s *sqlite3) HasTable(scope *Scope, tableName string) bool {
var count int var count int
scope.NewDB().Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Row().Scan(&count) scope.NewDB().Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Row().Scan(&count)
return count > 0 return count > 0
} }
func (s *sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool { func (sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int var count int
scope.NewDB().Raw(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%');\n", columnName, columnName, columnName, columnName), tableName).Row().Scan(&count) scope.NewDB().Raw(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%');\n", columnName, columnName, columnName, columnName), tableName).Row().Scan(&count)
return count > 0 return count > 0
} }
func (s *sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool { func (sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int var count int
scope.NewDB().Raw(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Row().Scan(&count) scope.NewDB().Raw(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Row().Scan(&count)
return count > 0 return count > 0
} }
func (s *sqlite3) RemoveIndex(scope *Scope, indexName string) { func (sqlite3) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)) scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName))
} }