diff --git a/common_dialect.go b/common_dialect.go index e54df6ff..9acedf78 100644 --- a/common_dialect.go +++ b/common_dialect.go @@ -16,6 +16,10 @@ func (s *commonDialect) SupportLastInsertId() bool { return true } +func (s *commonDialect) HasTop() bool { + return false +} + func (d *commonDialect) SqlTag(value reflect.Value, size int) string { switch value.Kind() { case reflect.Bool: diff --git a/dialect.go b/dialect.go index 771a7ac4..64ffc7f0 100644 --- a/dialect.go +++ b/dialect.go @@ -11,6 +11,7 @@ var timeType = reflect.TypeOf(time.Time{}) type Dialect interface { BinVar(i int) string SupportLastInsertId() bool + HasTop() bool SqlTag(value reflect.Value, size int) string PrimaryKeyTag(value reflect.Value, size int) string ReturningStr(key string) string @@ -30,6 +31,8 @@ func NewDialect(driver string) Dialect { d = &mysql{} case "sqlite3": d = &sqlite3{} + case "mssql": + d = &mssql{} default: fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", driver) d = &commonDialect{} diff --git a/main.go b/main.go index af2c26c4..3a6eb5c1 100644 --- a/main.go +++ b/main.go @@ -51,7 +51,9 @@ func Open(dialect string, drivesources ...string) (DB, error) { source = drivesources[1] } - db = DB{dialect: NewDialect(dialect), tagIdentifier: "sql", logger: defaultLogger, callback: DefaultCallback, source: source, values: map[string]interface{}{}} + db = DB{dialect: NewDialect(dialect), tagIdentifier: "sql", + logger: defaultLogger, callback: DefaultCallback, source: source, + values: map[string]interface{}{}} db.db, err = sql.Open(driver, source) db.parent = &db } diff --git a/main_test.go b/main_test.go index 65cacf9e..7be913ed 100644 --- a/main_test.go +++ b/main_test.go @@ -6,6 +6,7 @@ import ( "fmt" "strconv" + _ "github.com/denisenkom/go-mssqldb" testdb "github.com/erikstmartin/go-testdb" _ "github.com/go-sql-driver/mysql" "github.com/jinzhu/gorm" @@ -35,6 +36,9 @@ func init() { case "postgres": fmt.Println("testing postgres...") DB, err = gorm.Open("postgres", "user=gorm DB.ame=gorm sslmode=disable") + case "mssql": + fmt.Println("testing mssql...") + DB, err = gorm.Open("mssql", "server=SERVER_HERE;database=rogue;user id=USER_HERE;password=PW_HERE;port=1433") default: fmt.Println("testing sqlite3...") DB, err = gorm.Open("sqlite3", "/tmp/gorm.db") @@ -445,6 +449,11 @@ func TestTimeWithZone(t *testing.T) { for index, vtime := range times { name := "time_with_zone_" + strconv.Itoa(index) user := User{Name: name, Birthday: vtime} + + // TODO mssql does not support time zones + if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" { + user.Birthday = vtime.UTC() + } DB.Save(&user) if user.Birthday.UTC().Format(format) != "2013-02-18 17:51:49 +0000" { t.Errorf("User's birthday should not be changed after save") diff --git a/mssql.go b/mssql.go new file mode 100644 index 00000000..5fe45f0a --- /dev/null +++ b/mssql.go @@ -0,0 +1,114 @@ +package gorm + +import ( + "fmt" + "reflect" + "strings" +) + +type mssql struct{} + +func (s *mssql) BinVar(i int) string { + return "$$" // ? +} + +func (s *mssql) SupportLastInsertId() bool { + return true +} + +func (s *mssql) HasTop() bool { + return true +} + +func (d *mssql) SqlTag(value reflect.Value, size int) string { + switch value.Kind() { + case reflect.Bool: + return "bit" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + return "int" + case reflect.Int64, reflect.Uint64: + return "bigint" + case reflect.Float32, reflect.Float64: + return "float" + case reflect.String: + if size > 0 && size < 65532 { + return fmt.Sprintf("nvarchar(%d)", size) + } else { + return "text" + } + case reflect.Struct: + if value.Type() == timeType { + return "datetime2" + } + default: + if _, ok := value.Interface().([]byte); ok { + if size > 0 && size < 65532 { + return fmt.Sprintf("varchar(%d)", size) + } else { + return "text" + } + } + } + panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String())) +} + +func (s *mssql) PrimaryKeyTag(value reflect.Value, size int) string { + suffix_str := " IDENTITY(1,1) PRIMARY KEY" + switch value.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + return "int" + suffix_str + case reflect.Int64, reflect.Uint64: + return "bigint" + suffix_str + default: + panic("Invalid primary key type") + } +} + +func (s *mssql) ReturningStr(key string) string { + return "" +} + +func (s *mssql) SelectFromDummyTable() string { + return "" +} + +func (s *mssql) Quote(key string) string { + return fmt.Sprintf(" \"%s\"", key) +} + +func (s *mssql) databaseName(scope *Scope) string { + dbStr := strings.Split(scope.db.parent.source, ";") + for _, value := range dbStr { + s := strings.Split(value, "=") + if s[0] == "database" { + return s[1] + } + } + return "" +} + +func (s *mssql) HasTable(scope *Scope, tableName string) bool { + var count int + newScope := scope.New(nil) + newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v AND table_catalog = %v", + newScope.AddToVars(tableName), + newScope.AddToVars(s.databaseName(scope)))) + newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) + return count > 0 +} + +func (s *mssql) HasColumn(scope *Scope, tableName string, columnName string) bool { + var count int + newScope := scope.New(nil) + newScope.Raw(fmt.Sprintf("SELECT count(*) FROM information_schema.columns WHERE TABLE_CATALOG = %v AND table_name = %v AND column_name = %v", + newScope.AddToVars(s.databaseName(scope)), + newScope.AddToVars(tableName), + newScope.AddToVars(columnName), + )) + newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) + return count > 0 +} + +func (s *mssql) RemoveIndex(scope *Scope, indexName string) { + scope.Raw(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Exec() +} diff --git a/mysql.go b/mysql.go index 1332add7..9f4c525b 100644 --- a/mysql.go +++ b/mysql.go @@ -3,7 +3,6 @@ package gorm import ( "fmt" "strings" - "reflect" ) @@ -17,6 +16,10 @@ func (s *mysql) SupportLastInsertId() bool { return true } +func (s *mysql) HasTop() bool { + return false +} + func (d *mysql) SqlTag(value reflect.Value, size int) string { switch value.Kind() { case reflect.Bool: diff --git a/postgres.go b/postgres.go index aab5757a..d279b859 100644 --- a/postgres.go +++ b/postgres.go @@ -5,7 +5,6 @@ import ( "database/sql/driver" "fmt" "reflect" - "github.com/lib/pq/hstore" ) @@ -20,6 +19,10 @@ func (s *postgres) SupportLastInsertId() bool { return false } +func (s *postgres) HasTop() bool { + return false +} + func (d *postgres) SqlTag(value reflect.Value, size int) string { switch value.Kind() { case reflect.Bool: diff --git a/query_test.go b/query_test.go index b8a304e7..b7ab1f5d 100644 --- a/query_test.go +++ b/query_test.go @@ -2,9 +2,10 @@ package gorm_test import ( "fmt" - "github.com/jinzhu/now" "reflect" + "github.com/jinzhu/now" + "testing" "time" ) @@ -244,7 +245,7 @@ func TestOrderAndPluck(t *testing.T) { } var ages1, ages2 []int64 - scopedb.Order("age desc").Pluck("age", &ages1).Order("age").Pluck("age", &ages2) + scopedb.Order("age desc").Pluck("age", &ages1).Pluck("age", &ages2) if !reflect.DeepEqual(ages1, ages2) { t.Errorf("The first order is the primary order") } @@ -258,7 +259,11 @@ func TestOrderAndPluck(t *testing.T) { var names []string var ages5 []int64 scopedb.Model(User{}).Order("name").Order("age desc").Pluck("age", &ages5).Pluck("name", &names) - if !(names[0] == user1.Name && names[1] == user2.Name && names[2] == user3.Name && ages5[2] == 20) { + if names != nil && ages5 != nil { + if !(names[0] == user1.Name && names[1] == user2.Name && names[2] == user3.Name && ages5[2] == 20) { + t.Errorf("Order with multiple orders") + } + } else { t.Errorf("Order with multiple orders") } diff --git a/scope_private.go b/scope_private.go index 209db5bb..e97dfce0 100644 --- a/scope_private.go +++ b/scope_private.go @@ -196,10 +196,26 @@ func (s *Scope) orderSql() string { } func (s *Scope) limitSql() string { - if len(s.Search.Limit) == 0 { - return "" + if !s.Dialect().HasTop() { + if len(s.Search.Limit) == 0 { + return "" + } else { + return " LIMIT " + s.Search.Limit + } } else { - return " LIMIT " + s.Search.Limit + return "" + } +} + +func (s *Scope) topSql() string { + if s.Dialect().HasTop() && len(s.Search.Offset) == 0 { + if len(s.Search.Limit) == 0 { + return "" + } else { + return " TOP(" + s.Search.Limit + ")" + } + } else { + return "" } } @@ -207,7 +223,15 @@ func (s *Scope) offsetSql() string { if len(s.Search.Offset) == 0 { return "" } else { - return " OFFSET " + s.Search.Offset + if s.Dialect().HasTop() { + sql := " OFFSET " + s.Search.Offset + " ROW " + if len(s.Search.Limit) > 0 { + sql += "FETCH NEXT " + s.Search.Limit + " ROWS ONLY" + } + return sql + } else { + return " OFFSET " + s.Search.Offset + } } } @@ -235,7 +259,7 @@ func (scope *Scope) prepareQuerySql() { if scope.Search.Raw { scope.Raw(strings.TrimLeft(scope.CombinedConditionSql(), "WHERE ")) } else { - scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql())) + scope.Raw(fmt.Sprintf("SELECT %v %v FROM %v %v", scope.topSql(), scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql())) } return } diff --git a/sqlite3.go b/sqlite3.go index 088da6b1..e1e5d2d9 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -15,6 +15,10 @@ func (s *sqlite3) SupportLastInsertId() bool { return true } +func (s *sqlite3) HasTop() bool { + return false +} + func (s *sqlite3) SqlTag(value reflect.Value, size int) string { switch value.Kind() { case reflect.Bool: diff --git a/structs_test.go b/structs_test.go index 5ddc3b50..023adcd0 100644 --- a/structs_test.go +++ b/structs_test.go @@ -94,8 +94,13 @@ type Role struct { Name string } + func (role *Role) Scan(value interface{}) error { - role.Name = string(value.([]uint8)) + if b, ok := value.([]uint8); ok { + role.Name = string(b) + } else { + role.Name = value.(string) + } return nil }