Simplify dialect definitions

This commit is contained in:
Jinzhu 2015-03-17 10:40:42 +08:00
parent faae729b20
commit a0848909c2
5 changed files with 46 additions and 147 deletions

View File

@ -9,19 +9,19 @@ import (
type commonDialect struct{}
func (s *commonDialect) BinVar(i int) string {
return "?"
func (commonDialect) BinVar(i int) string {
return "$$" // ?
}
func (s *commonDialect) SupportLastInsertId() bool {
func (commonDialect) SupportLastInsertId() bool {
return true
}
func (s *commonDialect) HasTop() bool {
func (commonDialect) HasTop() bool {
return false
}
func (s *commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "BOOLEAN"
@ -57,19 +57,19 @@ func (s *commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool)
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String()))
}
func (s *commonDialect) ReturningStr(tableName, key string) string {
func (commonDialect) ReturningStr(tableName, key string) string {
return ""
}
func (s *commonDialect) SelectFromDummyTable() string {
func (commonDialect) SelectFromDummyTable() string {
return ""
}
func (s *commonDialect) Quote(key string) string {
return fmt.Sprintf("`%s`", key)
func (commonDialect) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
}
func (s *commonDialect) databaseName(scope *Scope) string {
func (commonDialect) databaseName(scope *Scope) string {
from := strings.Index(scope.db.parent.source, "/") + 1
to := strings.Index(scope.db.parent.source, "?")
if to == -1 {
@ -78,24 +78,24 @@ func (s *commonDialect) databaseName(scope *Scope) string {
return scope.db.parent.source[from:to]
}
func (s *commonDialect) HasTable(scope *Scope, tableName string) bool {
func (c commonDialect) HasTable(scope *Scope, tableName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, s.databaseName(scope)).Row().Scan(&count)
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, c.databaseName(scope)).Row().Scan(&count)
return count > 0
}
func (s *commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool {
func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count)
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", c.databaseName(scope), tableName, columnName).Row().Scan(&count)
return count > 0
}
func (s *commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {
func (commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS where table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count)
return count > 0
}
func (s *commonDialect) RemoveIndex(scope *Scope, indexName string) {
func (commonDialect) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName()))
}

View File

@ -7,21 +7,15 @@ import (
"time"
)
type mssql struct{}
func (s *mssql) BinVar(i int) string {
return "$$" // ?
type mssql struct {
commonDialect
}
func (s *mssql) SupportLastInsertId() bool {
func (mssql) HasTop() bool {
return true
}
func (s *mssql) HasTop() bool {
return true
}
func (s *mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "bit"
@ -57,19 +51,7 @@ func (s *mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String()))
}
func (s *mssql) ReturningStr(tableName, key string) string {
return ""
}
func (s *mssql) SelectFromDummyTable() string {
return ""
}
func (s *mssql) Quote(key string) string {
return fmt.Sprintf(" \"%s\"", key)
}
func (s *mssql) databaseName(scope *Scope) string {
func (mssql) databaseName(scope *Scope) string {
dbStr := strings.Split(scope.db.parent.source, ";")
for _, value := range dbStr {
s := strings.Split(value, "=")
@ -80,24 +62,20 @@ func (s *mssql) databaseName(scope *Scope) string {
return ""
}
func (s *mssql) HasTable(scope *Scope, tableName string) bool {
func (s mssql) HasTable(scope *Scope, tableName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.databaseName(scope)).Row().Scan(&count)
return count > 0
}
func (s *mssql) HasColumn(scope *Scope, tableName string, columnName string) bool {
func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count)
return count > 0
}
func (s *mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
func (mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Row().Scan(&count)
return count > 0
}
func (s *mssql) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName()))
}

View File

@ -3,25 +3,14 @@ package gorm
import (
"fmt"
"reflect"
"strings"
"time"
)
type mysql struct{}
func (s *mysql) BinVar(i int) string {
return "$$" // ?
type mysql struct {
commonDialect
}
func (s *mysql) SupportLastInsertId() bool {
return true
}
func (s *mysql) HasTop() bool {
return false
}
func (s *mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "boolean"
@ -57,45 +46,10 @@ func (s *mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
}
func (s *mysql) ReturningStr(tableName, key string) string {
return ""
}
func (s *mysql) SelectFromDummyTable() string {
return "FROM DUAL"
}
func (s *mysql) Quote(key string) string {
func (mysql) Quote(key string) string {
return fmt.Sprintf("`%s`", key)
}
func (s *mysql) databaseName(scope *Scope) string {
from := strings.Index(scope.db.parent.source, "/") + 1
to := strings.Index(scope.db.parent.source, "?")
if to == -1 {
to = len(scope.db.parent.source)
}
return scope.db.parent.source[from:to]
}
func (s *mysql) HasTable(scope *Scope, tableName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = ? AND table_schema = ?", tableName, s.databaseName(scope)).Row().Scan(&count)
return count > 0
}
func (s *mysql) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count)
return count > 0
}
func (s *mysql) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS where table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count)
return count > 0
}
func (s *mysql) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName()))
func (mysql) SelectFromDummyTable() string {
return "FROM DUAL"
}

View File

@ -11,21 +11,18 @@ import (
)
type postgres struct {
commonDialect
}
func (s *postgres) BinVar(i int) string {
func (postgres) BinVar(i int) string {
return fmt.Sprintf("$%v", i)
}
func (s *postgres) SupportLastInsertId() bool {
func (postgres) SupportLastInsertId() bool {
return false
}
func (s *postgres) HasTop() bool {
return false
}
func (s *postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
func (postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "boolean"
@ -62,35 +59,27 @@ func (s *postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) stri
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String()))
}
func (s *postgres) ReturningStr(tableName, key string) string {
func (s postgres) ReturningStr(tableName, key string) string {
return fmt.Sprintf("RETURNING %v.%v", s.Quote(tableName), key)
}
func (s *postgres) SelectFromDummyTable() string {
return ""
}
func (s *postgres) Quote(key string) string {
return fmt.Sprintf("\"%s\"", key)
}
func (s *postgres) HasTable(scope *Scope, tableName string) bool {
func (postgres) HasTable(scope *Scope, tableName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_type = 'BASE TABLE'", tableName).Row().Scan(&count)
return count > 0
}
func (s *postgres) HasColumn(scope *Scope, tableName string, columnName string) bool {
func (postgres) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = ? AND column_name = ?", tableName, columnName).Row().Scan(&count)
return count > 0
}
func (s *postgres) RemoveIndex(scope *Scope, indexName string) {
func (postgres) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName))
}
func (s *postgres) HasIndex(scope *Scope, tableName string, indexName string) bool {
func (postgres) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName).Row().Scan(&count)
return count > 0

View File

@ -6,21 +6,11 @@ import (
"time"
)
type sqlite3 struct{}
func (s *sqlite3) BinVar(i int) string {
return "$$" // ?
type sqlite3 struct {
commonDialect
}
func (s *sqlite3) SupportLastInsertId() bool {
return true
}
func (s *sqlite3) HasTop() bool {
return false
}
func (s *sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
func (sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "bool"
@ -50,36 +40,24 @@ func (s *sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) strin
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
}
func (s *sqlite3) ReturningStr(tableName, key string) string {
return ""
}
func (s *sqlite3) SelectFromDummyTable() string {
return ""
}
func (s *sqlite3) Quote(key string) string {
return fmt.Sprintf("\"%s\"", key)
}
func (s *sqlite3) HasTable(scope *Scope, tableName string) bool {
func (sqlite3) HasTable(scope *Scope, tableName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Row().Scan(&count)
return count > 0
}
func (s *sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool {
func (sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int
scope.NewDB().Raw(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%');\n", columnName, columnName, columnName, columnName), tableName).Row().Scan(&count)
return count > 0
}
func (s *sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
func (sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
scope.NewDB().Raw(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Row().Scan(&count)
return count > 0
}
func (s *sqlite3) RemoveIndex(scope *Scope, indexName string) {
func (sqlite3) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName))
}