From f4456e139e5cdd7288ae60fac8b1dd903630b88b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 14 Feb 2016 18:06:42 +0800 Subject: [PATCH] Register dialects --- dialect.go | 27 ++++++++++++--------------- dialect_common.go | 4 ++++ dialect_mssql.go | 4 ++++ dialect_mysql.go | 4 ++++ dialect_postgres.go | 4 ++++ dialect_sqlite3.go | 5 +++++ main.go | 5 +---- 7 files changed, 34 insertions(+), 19 deletions(-) diff --git a/dialect.go b/dialect.go index 61220a42..1923e66e 100644 --- a/dialect.go +++ b/dialect.go @@ -34,22 +34,19 @@ type Dialect interface { LastInsertIdReturningSuffix(tableName, columnName string) string } -func NewDialect(driver string) Dialect { - var d Dialect - switch driver { - case "postgres": - d = &postgres{} - case "mysql": - d = &mysql{} - case "sqlite3": - d = &sqlite3{} - case "mssql": - d = &mssql{} - default: - fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", driver) - d = &commonDialect{} +var dialectsMap = map[string]Dialect{} + +func newDialect(name string) Dialect { + if dialect, ok := dialectsMap[name]; ok { + return dialect } - return d + fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name) + return &commonDialect{} +} + +// RegisterDialect register new dialect +func RegisterDialect(name string, dialect Dialect) { + dialectsMap[name] = dialect } // ParseFieldStructForDialect parse field struct for dialect diff --git a/dialect_common.go b/dialect_common.go index 9f10a287..d5a81ad6 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -9,6 +9,10 @@ import ( type commonDialect struct{} +func init() { + RegisterDialect("common", &commonDialect{}) +} + func (commonDialect) BindVar(i int) string { return "$$" // ? } diff --git a/dialect_mssql.go b/dialect_mssql.go index 63971e34..a2af49ad 100644 --- a/dialect_mssql.go +++ b/dialect_mssql.go @@ -11,6 +11,10 @@ type mssql struct { commonDialect } +func init() { + RegisterDialect("mssql", &mssql{}) +} + func (mssql) DataTypeOf(field *StructField) string { var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) diff --git a/dialect_mysql.go b/dialect_mysql.go index 22f8b88d..51b1926a 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -11,6 +11,10 @@ type mysql struct { commonDialect } +func init() { + RegisterDialect("mysql", &mysql{}) +} + func (mysql) Quote(key string) string { return fmt.Sprintf("`%s`", key) } diff --git a/dialect_postgres.go b/dialect_postgres.go index b35b918a..e726d233 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -15,6 +15,10 @@ type postgres struct { commonDialect } +func init() { + RegisterDialect("postgres", &postgres{}) +} + func (postgres) BindVar(i int) string { return fmt.Sprintf("$%v", i) } diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index d5ffb78d..3abdb92e 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -11,6 +11,11 @@ type sqlite3 struct { commonDialect } +func init() { + RegisterDialect("sqlite", &sqlite3{}) + RegisterDialect("sqlite3", &sqlite3{}) +} + // Get Data Type for Sqlite Dialect func (sqlite3) DataTypeOf(field *StructField) string { var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) diff --git a/main.go b/main.go index 4cdbca0b..5f4d9dbd 100644 --- a/main.go +++ b/main.go @@ -55,9 +55,6 @@ func Open(dialect string, args ...interface{}) (*DB, error) { driver = value source = args[1].(string) } - if driver == "foundation" { - driver = "postgres" // FoundationDB speaks a postgres-compatible protocol. - } dbSql, err = sql.Open(driver, source) case sqlCommon: source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String() @@ -65,7 +62,7 @@ func Open(dialect string, args ...interface{}) (*DB, error) { } db = DB{ - dialect: NewDialect(dialect), + dialect: newDialect(dialect), logger: defaultLogger, callbacks: defaultCallback, source: source,