diff --git a/common_dialect.go b/common_dialect.go new file mode 100644 index 00000000..7d5d76fa --- /dev/null +++ b/common_dialect.go @@ -0,0 +1,100 @@ +package gorm + +import ( + "fmt" + "reflect" + "strings" +) + +type commonDialect struct{} + +func (s *commonDialect) BinVar(i int) string { + return "?" +} + +func (s *commonDialect) SupportLastInsertId() bool { + return true +} + +func (d *commonDialect) SqlTag(value reflect.Value, size int) string { + switch value.Kind() { + case reflect.Bool: + return "BOOLEAN" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + return "INTEGER" + case reflect.Int64, reflect.Uint64: + return "BIGINT" + case reflect.Float32, reflect.Float64: + return "FLOAT" + case reflect.String: + if size > 0 && size < 65532 { + return fmt.Sprintf("VARCHAR(%d)", size) + } else { + return "VARCHAR(65532)" + } + case reflect.Struct: + if value.Type() == timeType { + return "TIMESTAMP" + } + default: + if _, ok := value.Interface().([]byte); ok { + if size > 0 && size < 65532 { + return fmt.Sprintf("BINARY(%d)", size) + } else { + return "BINARY(65532)" + } + } + } + panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String())) +} + +func (s *commonDialect) PrimaryKeyTag(value reflect.Value, size int) string { + suffix_str := " NOT NULL PRIMARY KEY" + switch value.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + return "INTEGER" + suffix_str + case reflect.Int64, reflect.Uint64: + return "BIGINT" + suffix_str + default: + panic("Invalid primary key type") + } +} + +func (s *commonDialect) ReturningStr(key string) string { + return "" +} + +func (s *commonDialect) Quote(key string) string { + return fmt.Sprintf("`%s`", key) +} + +func (s *commonDialect) databaseName(scope *Scope) string { + from := strings.Index(scope.db.parent.source, "/") + 1 + to := strings.Index(scope.db.parent.source, "?") + if to == -1 { + to = len(scope.db.parent.source) + } + return scope.db.parent.source[from:to] +} + +func (s *commonDialect) HasTable(scope *Scope, tableName string) bool { + var count int + newScope := scope.New(nil) + newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v AND table_schema = %v", + newScope.AddToVars(tableName), + newScope.AddToVars(s.databaseName(scope)))) + newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) + return count > 0 +} + +func (s *commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool { + var count int + newScope := scope.New(nil) + newScope.Raw(fmt.Sprintf("SELECT count(*) FROM information_schema.columns WHERE table_schema = %v AND table_name = %v AND column_name = %v", + newScope.AddToVars(s.databaseName(scope)), + newScope.AddToVars(tableName), + newScope.AddToVars(columnName), + )) + newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) + return count > 0 +} diff --git a/dialect.go b/dialect.go index d1176039..20ba32d4 100644 --- a/dialect.go +++ b/dialect.go @@ -1,6 +1,7 @@ package gorm import ( + "fmt" "reflect" "time" ) @@ -27,6 +28,9 @@ func NewDialect(driver string) Dialect { d = &mysql{} case "sqlite3": d = &sqlite3{} + default: + fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", driver) + d = &commonDialect{} } return d } diff --git a/main_test.go b/main_test.go index 39aac884..6616b992 100644 --- a/main_test.go +++ b/main_test.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" + testdb "github.com/erikstmartin/go-testdb" _ "github.com/go-sql-driver/mysql" "github.com/jinzhu/gorm" _ "github.com/lib/pq" @@ -2098,3 +2099,22 @@ func TestCreate(t *testing.T) { t.Error("Should not be able to create record with predefined duplicate Id") } } + +func TestCompatibilityMode(t *testing.T) { + db, _ := gorm.Open("testdb", "") + testdb.SetQueryFunc(func(query string) (driver.Rows, error) { + columns := []string{"id", "name", "age"} + result := ` + 1,Tim,20 + 2,Joe,25 + 3,Bob,30 + ` + return testdb.RowsFromCSVString(columns, result), nil + }) + + var users []User + db.Find(&users) + if (users[0].Name != "Tim") || len(users) != 3 { + t.Errorf("Unexcepted result returned") + } +}