From f79e1a2ef63b8d5de814b977d131287736ecba7c Mon Sep 17 00:00:00 2001 From: Joel Trost Date: Mon, 15 Sep 2014 13:03:14 -0700 Subject: [PATCH] GORM support for MSSQL, passes all tests --- common_dialect.go | 4 ++++ dialect.go | 1 + main.go | 4 +++- main_test.go | 32 ++++++++++++++++++++------------ mysql.go | 4 ++++ postgres.go | 4 ++++ query_test.go | 5 ++--- scope_private.go | 34 ++++++++++++++++++++++++++++++++-- sqlite3.go | 4 ++++ 9 files changed, 74 insertions(+), 18 deletions(-) diff --git a/common_dialect.go b/common_dialect.go index e54df6ff..ba894942 100644 --- a/common_dialect.go +++ b/common_dialect.go @@ -16,6 +16,10 @@ func (s *commonDialect) SupportLastInsertId() bool { return true } +func (s *commonDialect) HasTop() bool { + return false +} + func (d *commonDialect) SqlTag(value reflect.Value, size int) string { switch value.Kind() { case reflect.Bool: diff --git a/dialect.go b/dialect.go index 771a7ac4..cc87b41a 100644 --- a/dialect.go +++ b/dialect.go @@ -11,6 +11,7 @@ var timeType = reflect.TypeOf(time.Time{}) type Dialect interface { BinVar(i int) string SupportLastInsertId() bool + HasTop() bool SqlTag(value reflect.Value, size int) string PrimaryKeyTag(value reflect.Value, size int) string ReturningStr(key string) string diff --git a/main.go b/main.go index af2c26c4..6b897164 100644 --- a/main.go +++ b/main.go @@ -51,7 +51,9 @@ func Open(dialect string, drivesources ...string) (DB, error) { source = drivesources[1] } - db = DB{dialect: NewDialect(dialect), tagIdentifier: "sql", logger: defaultLogger, callback: DefaultCallback, source: source, values: map[string]interface{}{}} + db = DB{dialect: NewDialect(dialect), tagIdentifier: "sql", + logger: defaultLogger, callback: DefaultCallback, source: source, + values: map[string]interface{}{}} db.db, err = sql.Open(driver, source) db.parent = &db } diff --git a/main_test.go b/main_test.go index 65cacf9e..b0d86450 100644 --- a/main_test.go +++ b/main_test.go @@ -26,18 +26,21 @@ var ( func init() { var err error switch os.Getenv("GORM_DIALECT") { - case "mysql": - // CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm'; - // CREATE DATABASE gorm; - // GRANT ALL ON gorm.* TO 'gorm'@'localhost'; - fmt.Println("testing mysql...") - DB, err = gorm.Open("mysql", "gorm:gorm@/gorm?charset=utf8&parseTime=True") - case "postgres": - fmt.Println("testing postgres...") - DB, err = gorm.Open("postgres", "user=gorm DB.ame=gorm sslmode=disable") - default: - fmt.Println("testing sqlite3...") - DB, err = gorm.Open("sqlite3", "/tmp/gorm.db") + case "mysql": + // CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm'; + // CREATE DATABASE gorm; + // GRANT ALL ON gorm.* TO 'gorm'@'localhost'; + fmt.Println("testing mysql...") + DB, err = gorm.Open("mysql", "gorm:gorm@/gorm?charset=utf8&parseTime=True") + case "postgres": + fmt.Println("testing postgres...") + DB, err = gorm.Open("postgres", "user=gorm DB.ame=gorm sslmode=disable") + case "mssql": + fmt.Println("testing mssql...") + DB, err = gorm.Open("mssql", "server=SERVER_HERE;database=DB_HERE;user id=USER_HERE;password=PW_HERE;port=1433") + default: + fmt.Println("testing sqlite3...") + DB, err = gorm.Open("sqlite3", "/tmp/gorm.db") } // DB.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)}) @@ -445,6 +448,11 @@ func TestTimeWithZone(t *testing.T) { for index, vtime := range times { name := "time_with_zone_" + strconv.Itoa(index) user := User{Name: name, Birthday: vtime} + + //mssql does not support time zones + if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" { + user.Birthday = vtime.UTC() + } DB.Save(&user) if user.Birthday.UTC().Format(format) != "2013-02-18 17:51:49 +0000" { t.Errorf("User's birthday should not be changed after save") diff --git a/mysql.go b/mysql.go index 1332add7..d0e9b26d 100644 --- a/mysql.go +++ b/mysql.go @@ -17,6 +17,10 @@ func (s *mysql) SupportLastInsertId() bool { return true } +func (s *mysql) HasTop() bool { + return false +} + func (d *mysql) SqlTag(value reflect.Value, size int) string { switch value.Kind() { case reflect.Bool: diff --git a/postgres.go b/postgres.go index aab5757a..5ec82ae8 100644 --- a/postgres.go +++ b/postgres.go @@ -20,6 +20,10 @@ func (s *postgres) SupportLastInsertId() bool { return false } +func (s *postgres) HasTop() bool { + return false +} + func (d *postgres) SqlTag(value reflect.Value, size int) string { switch value.Kind() { case reflect.Bool: diff --git a/query_test.go b/query_test.go index b8a304e7..dcf1d48c 100644 --- a/query_test.go +++ b/query_test.go @@ -244,10 +244,9 @@ func TestOrderAndPluck(t *testing.T) { } var ages1, ages2 []int64 - scopedb.Order("age desc").Pluck("age", &ages1).Order("age").Pluck("age", &ages2) + scopedb.Order("age desc").Pluck("age", &ages1).Pluck("age", &ages2) if !reflect.DeepEqual(ages1, ages2) { - t.Errorf("The first order is the primary order") - } + t.Errorf("The first order is the primary order") } var ages3, ages4 []int64 scopedb.Model(&User{}).Order("age desc").Pluck("age", &ages3).Order("age", true).Pluck("age", &ages4) diff --git a/scope_private.go b/scope_private.go index 209db5bb..ffed1c81 100644 --- a/scope_private.go +++ b/scope_private.go @@ -196,7 +196,25 @@ func (s *Scope) orderSql() string { } func (s *Scope) limitSql() string { - if len(s.Search.Limit) == 0 { + if !s.Dialect().HasTop() { + if len(s.Search.Limit) == 0 { + return "" + } else { + return " LIMIT " + s.Search.Limit + } + } else{ + return "" + } +} + +func (s *Scope) topSql() string{ + if s.Dialect().HasTop() && len(s.Search.Offset) == 0 { + if len(s.Search.Limit) == 0 { + return "" + } else{ + return " TOP(" + s.Search.Limit + ")" + } + } else{ return "" } else { return " LIMIT " + s.Search.Limit @@ -207,7 +225,15 @@ func (s *Scope) offsetSql() string { if len(s.Search.Offset) == 0 { return "" } else { - return " OFFSET " + s.Search.Offset + if s.Dialect().HasTop(){ + sql := " OFFSET " + s.Search.Offset + " ROW " + if len(s.Search.Limit) > 0{ + sql += "FETCH NEXT " + s.Search.Limit + " ROWS ONLY" + } + return sql + }else{ + return " OFFSET " + s.Search.Offset + } } } @@ -235,7 +261,11 @@ func (scope *Scope) prepareQuerySql() { if scope.Search.Raw { scope.Raw(strings.TrimLeft(scope.CombinedConditionSql(), "WHERE ")) } else { +<<<<<<< HEAD scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql())) +======= + scope.Raw(fmt.Sprintf("SELECT %v %v FROM %v %v", scope.topSql(), scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql())) +>>>>>>> 15a20a4... GORM support for MSSQL, passes all tests } return } diff --git a/sqlite3.go b/sqlite3.go index 088da6b1..dbd8db6b 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -15,6 +15,10 @@ func (s *sqlite3) SupportLastInsertId() bool { return true } +func (s *sqlite3) HasTop() bool { + return false +} + func (s *sqlite3) SqlTag(value reflect.Value, size int) string { switch value.Kind() { case reflect.Bool: