From dcc06e22f7b54f4c38e347b25ba827845f92ea08 Mon Sep 17 00:00:00 2001 From: Jay Taylor Date: Tue, 24 Mar 2015 10:33:51 -0700 Subject: [PATCH] FoundationDB dialect layer and compatibility updates. --- README.md | 1 + dialect.go | 2 ++ foundation.go | 78 +++++++++++++++++++++++++++++++++++++++++++ join_table_handler.go | 2 +- main.go | 3 ++ main_test.go | 31 ++++++++++++----- test_all.sh | 2 +- 7 files changed, 109 insertions(+), 10 deletions(-) create mode 100644 foundation.go diff --git a/README.md b/README.md index a6dd1865..50c4c5fe 100644 --- a/README.md +++ b/README.md @@ -96,6 +96,7 @@ import ( ) db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable") +// db, err := gorm.Open("foundation", "dbname=gorm") // FoundationDB. // db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local") // db, err := gorm.Open("sqlite3", "/tmp/gorm.db") diff --git a/dialect.go b/dialect.go index 2e64cca5..f3221075 100644 --- a/dialect.go +++ b/dialect.go @@ -24,6 +24,8 @@ func NewDialect(driver string) Dialect { switch driver { case "postgres": d = &postgres{} + case "foundation": + d = &foundation{} case "mysql": d = &mysql{} case "sqlite3": diff --git a/foundation.go b/foundation.go new file mode 100644 index 00000000..a9c8f500 --- /dev/null +++ b/foundation.go @@ -0,0 +1,78 @@ +package gorm + +import ( + "fmt" + "reflect" + "time" +) + +type foundation struct { + commonDialect +} + +func (foundation) BinVar(i int) string { + return fmt.Sprintf("$%v", i) +} + +func (foundation) SupportLastInsertId() bool { + return false +} + +func (foundation) SqlTag(value reflect.Value, size int, autoIncrease bool) string { + switch value.Kind() { + case reflect.Bool: + return "boolean" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + if autoIncrease { + return "serial" + } + return "int" + case reflect.Int64, reflect.Uint64: + if autoIncrease { + return "bigserial" + } + return "bigint" + case reflect.Float32, reflect.Float64: + return "double" + case reflect.String: + if size > 0 && size < 65532 { + return fmt.Sprintf("varchar(%d)", size) + } + return "clob" + case reflect.Struct: + if _, ok := value.Interface().(time.Time); ok { + return "datetime" + } + default: + if _, ok := value.Interface().([]byte); ok { + return "blob" + } + } + panic(fmt.Sprintf("invalid sql type %s (%s) for foundation", value.Type().Name(), value.Kind().String())) +} + +func (f foundation) ReturningStr(tableName, key string) string { + return fmt.Sprintf("RETURNING %v.%v", f.Quote(tableName), key) +} + +func (foundation) HasTable(scope *Scope, tableName string) bool { + var count int + scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_schema = current_schema AND table_type = 'TABLE' AND table_name = ?", tableName).Row().Scan(&count) + return count > 0 +} + +func (foundation) HasColumn(scope *Scope, tableName string, columnName string) bool { + var count int + scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = current_schema AND table_name = ? AND column_name = ?", tableName, columnName).Row().Scan(&count) + return count > 0 +} + +func (f foundation) RemoveIndex(scope *Scope, indexName string) { + scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", f.Quote(indexName))) +} + +func (foundation) HasIndex(scope *Scope, tableName string, indexName string) bool { + var count int + scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.indexes WHERE table_schema = current_schema AND table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count) + return count > 0 +} diff --git a/join_table_handler.go b/join_table_handler.go index b4299f5a..9f705564 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -104,7 +104,7 @@ func (s JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) quotedTable := s.Table(db) sql := fmt.Sprintf( - "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v);", + "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)", quotedTable, strings.Join(assignColumns, ","), strings.Join(binVars, ","), diff --git a/main.go b/main.go index e197a99c..82567971 100644 --- a/main.go +++ b/main.go @@ -55,6 +55,9 @@ func Open(dialect string, args ...interface{}) (DB, error) { driver = value source = args[1].(string) } + if driver == "foundation" { + driver = "postgres" // FoundationDB speaks a postgres-compatible protocol. + } dbSql, err = sql.Open(driver, source) case sqlCommon: source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String() diff --git a/main_test.go b/main_test.go index 3864fcbd..b547534c 100644 --- a/main_test.go +++ b/main_test.go @@ -36,6 +36,9 @@ func init() { case "postgres": fmt.Println("testing postgres...") DB, err = gorm.Open("postgres", "user=gorm DB.name=gorm sslmode=disable") + case "foundation": + fmt.Println("testing foundation...") + DB, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable") case "mssql": fmt.Println("testing mssql...") DB, err = gorm.Open("mssql", "server=SERVER_HERE;database=rogue;user id=USER_HERE;password=PW_HERE;port=1433") @@ -445,6 +448,14 @@ func TestHaving(t *testing.T) { } } +func DialectHasTzSupport() bool { + // NB: mssql and FoundationDB do not support time zones. + if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" || dialect == "foundation" { + return false + } + return true +} + func TestTimeWithZone(t *testing.T) { var format = "2006-01-02 15:04:05 -0700" var times []time.Time @@ -456,26 +467,30 @@ func TestTimeWithZone(t *testing.T) { name := "time_with_zone_" + strconv.Itoa(index) user := User{Name: name, Birthday: vtime} - // TODO mssql does not support time zones - if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" { + if !DialectHasTzSupport() { + // If our driver dialect doesn't support TZ's, just use UTC for everything here. user.Birthday = vtime.UTC() } + DB.Save(&user) - if user.Birthday.UTC().Format(format) != "2013-02-18 17:51:49 +0000" { - t.Errorf("User's birthday should not be changed after save") + expectedBirthday := "2013-02-18 17:51:49 +0000" + foundBirthday := user.Birthday.UTC().Format(format) + if foundBirthday != expectedBirthday { + t.Errorf("User's birthday should not be changed after save for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday) } var findUser, findUser2, findUser3 User DB.First(&findUser, "name = ?", name) - if findUser.Birthday.UTC().Format(format) != "2013-02-18 17:51:49 +0000" { - t.Errorf("User's birthday should not be changed after find") + foundBirthday = findUser.Birthday.UTC().Format(format) + if foundBirthday != expectedBirthday { + t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v or %+v", name, expectedBirthday, foundBirthday) } - if DB.Where("id = ? AND birthday >= ?", findUser.Id, vtime.Add(-time.Minute)).First(&findUser2).RecordNotFound() { + if DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(-time.Minute)).First(&findUser2).RecordNotFound() { t.Errorf("User should be found") } - if !DB.Where("id = ? AND birthday >= ?", findUser.Id, vtime.Add(time.Minute)).First(&findUser3).RecordNotFound() { + if !DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(time.Minute)).First(&findUser3).RecordNotFound() { t.Errorf("User should not be found") } } diff --git a/test_all.sh b/test_all.sh index 6c5593b3..bd28294d 100755 --- a/test_all.sh +++ b/test_all.sh @@ -1,4 +1,4 @@ -dialects=("postgres" "mysql" "sqlite") +dialects=("postgres" "foundation" "mysql" "sqlite") for dialect in "${dialects[@]}" ; do GORM_DIALECT=${dialect} go test