Register dialects

This commit is contained in:
Jinzhu 2016-02-14 18:06:42 +08:00
parent 421979cfc2
commit f4456e139e
7 changed files with 34 additions and 19 deletions

View File

@ -34,22 +34,19 @@ type Dialect interface {
LastInsertIdReturningSuffix(tableName, columnName string) string LastInsertIdReturningSuffix(tableName, columnName string) string
} }
func NewDialect(driver string) Dialect { var dialectsMap = map[string]Dialect{}
var d Dialect
switch driver { func newDialect(name string) Dialect {
case "postgres": if dialect, ok := dialectsMap[name]; ok {
d = &postgres{} return dialect
case "mysql":
d = &mysql{}
case "sqlite3":
d = &sqlite3{}
case "mssql":
d = &mssql{}
default:
fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", driver)
d = &commonDialect{}
} }
return d fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name)
return &commonDialect{}
}
// RegisterDialect register new dialect
func RegisterDialect(name string, dialect Dialect) {
dialectsMap[name] = dialect
} }
// ParseFieldStructForDialect parse field struct for dialect // ParseFieldStructForDialect parse field struct for dialect

View File

@ -9,6 +9,10 @@ import (
type commonDialect struct{} type commonDialect struct{}
func init() {
RegisterDialect("common", &commonDialect{})
}
func (commonDialect) BindVar(i int) string { func (commonDialect) BindVar(i int) string {
return "$$" // ? return "$$" // ?
} }

View File

@ -11,6 +11,10 @@ type mssql struct {
commonDialect commonDialect
} }
func init() {
RegisterDialect("mssql", &mssql{})
}
func (mssql) DataTypeOf(field *StructField) string { func (mssql) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)

View File

@ -11,6 +11,10 @@ type mysql struct {
commonDialect commonDialect
} }
func init() {
RegisterDialect("mysql", &mysql{})
}
func (mysql) Quote(key string) string { func (mysql) Quote(key string) string {
return fmt.Sprintf("`%s`", key) return fmt.Sprintf("`%s`", key)
} }

View File

@ -15,6 +15,10 @@ type postgres struct {
commonDialect commonDialect
} }
func init() {
RegisterDialect("postgres", &postgres{})
}
func (postgres) BindVar(i int) string { func (postgres) BindVar(i int) string {
return fmt.Sprintf("$%v", i) return fmt.Sprintf("$%v", i)
} }

View File

@ -11,6 +11,11 @@ type sqlite3 struct {
commonDialect commonDialect
} }
func init() {
RegisterDialect("sqlite", &sqlite3{})
RegisterDialect("sqlite3", &sqlite3{})
}
// Get Data Type for Sqlite Dialect // Get Data Type for Sqlite Dialect
func (sqlite3) DataTypeOf(field *StructField) string { func (sqlite3) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)

View File

@ -55,9 +55,6 @@ func Open(dialect string, args ...interface{}) (*DB, error) {
driver = value driver = value
source = args[1].(string) source = args[1].(string)
} }
if driver == "foundation" {
driver = "postgres" // FoundationDB speaks a postgres-compatible protocol.
}
dbSql, err = sql.Open(driver, source) dbSql, err = sql.Open(driver, source)
case sqlCommon: case sqlCommon:
source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String() source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
@ -65,7 +62,7 @@ func Open(dialect string, args ...interface{}) (*DB, error) {
} }
db = DB{ db = DB{
dialect: NewDialect(dialect), dialect: newDialect(dialect),
logger: defaultLogger, logger: defaultLogger,
callbacks: defaultCallback, callbacks: defaultCallback,
source: source, source: source,