diff --git a/common_dialect.go b/common_dialect.go index ba894942..9acedf78 100644 --- a/common_dialect.go +++ b/common_dialect.go @@ -17,7 +17,7 @@ func (s *commonDialect) SupportLastInsertId() bool { } func (s *commonDialect) HasTop() bool { - return false + return false } func (d *commonDialect) SqlTag(value reflect.Value, size int) string { diff --git a/main.go b/main.go index 6b897164..bc934c9d 100644 --- a/main.go +++ b/main.go @@ -51,8 +51,8 @@ func Open(dialect string, drivesources ...string) (DB, error) { source = drivesources[1] } - db = DB{dialect: NewDialect(dialect), tagIdentifier: "sql", - logger: defaultLogger, callback: DefaultCallback, source: source, + db = DB{dialect: NewDialect(dialect), tagIdentifier: "sql", + logger: defaultLogger, callback: DefaultCallback, source: source, values: map[string]interface{}{}} db.db, err = sql.Open(driver, source) db.parent = &db @@ -119,6 +119,7 @@ func (s *DB) Limit(value interface{}) *DB { func (s *DB) Offset(value interface{}) *DB { return s.clone().search.offset(value).db + return s.clone().search.offset(value).db } func (s *DB) Order(value string, reorder ...bool) *DB { diff --git a/main_test.go b/main_test.go index b0d86450..4346452b 100644 --- a/main_test.go +++ b/main_test.go @@ -6,6 +6,7 @@ import ( "fmt" "strconv" + _ "github.com/denisenkom/go-mssqldb" testdb "github.com/erikstmartin/go-testdb" _ "github.com/go-sql-driver/mysql" "github.com/jinzhu/gorm" @@ -26,21 +27,21 @@ var ( func init() { var err error switch os.Getenv("GORM_DIALECT") { - case "mysql": - // CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm'; - // CREATE DATABASE gorm; - // GRANT ALL ON gorm.* TO 'gorm'@'localhost'; - fmt.Println("testing mysql...") - DB, err = gorm.Open("mysql", "gorm:gorm@/gorm?charset=utf8&parseTime=True") - case "postgres": - fmt.Println("testing postgres...") - DB, err = gorm.Open("postgres", "user=gorm DB.ame=gorm sslmode=disable") - case "mssql": - fmt.Println("testing mssql...") - DB, err = gorm.Open("mssql", "server=SERVER_HERE;database=DB_HERE;user id=USER_HERE;password=PW_HERE;port=1433") - default: - fmt.Println("testing sqlite3...") - DB, err = gorm.Open("sqlite3", "/tmp/gorm.db") + case "mysql": + // CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm'; + // CREATE DATABASE gorm; + // GRANT ALL ON gorm.* TO 'gorm'@'localhost'; + fmt.Println("testing mysql...") + DB, err = gorm.Open("mysql", "gorm:gorm@/gorm?charset=utf8&parseTime=True") + case "postgres": + fmt.Println("testing postgres...") + DB, err = gorm.Open("postgres", "user=gorm DB.ame=gorm sslmode=disable") + case "mssql": + fmt.Println("testing mssql...") + DB, err = gorm.Open("mssql", "server=SERVER_HERE;database=DB_HERE;user id=USER_HERE;password=PW_HERE;port=1433") + default: + fmt.Println("testing sqlite3...") + DB, err = gorm.Open("sqlite3", "/tmp/gorm.db") } // DB.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)}) @@ -310,11 +311,13 @@ func TestRows(t *testing.T) { } count := 0 - for rows.Next() { - var name string - var age int64 - rows.Scan(&name, &age) - count++ + if rows != nil { + for rows.Next() { + var name string + var age int64 + rows.Scan(&name, &age) + count++ + } } if count != 2 { t.Errorf("Should found two records with name 3") diff --git a/mssql.go b/mssql.go new file mode 100644 index 00000000..5fe45f0a --- /dev/null +++ b/mssql.go @@ -0,0 +1,114 @@ +package gorm + +import ( + "fmt" + "reflect" + "strings" +) + +type mssql struct{} + +func (s *mssql) BinVar(i int) string { + return "$$" // ? +} + +func (s *mssql) SupportLastInsertId() bool { + return true +} + +func (s *mssql) HasTop() bool { + return true +} + +func (d *mssql) SqlTag(value reflect.Value, size int) string { + switch value.Kind() { + case reflect.Bool: + return "bit" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + return "int" + case reflect.Int64, reflect.Uint64: + return "bigint" + case reflect.Float32, reflect.Float64: + return "float" + case reflect.String: + if size > 0 && size < 65532 { + return fmt.Sprintf("nvarchar(%d)", size) + } else { + return "text" + } + case reflect.Struct: + if value.Type() == timeType { + return "datetime2" + } + default: + if _, ok := value.Interface().([]byte); ok { + if size > 0 && size < 65532 { + return fmt.Sprintf("varchar(%d)", size) + } else { + return "text" + } + } + } + panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String())) +} + +func (s *mssql) PrimaryKeyTag(value reflect.Value, size int) string { + suffix_str := " IDENTITY(1,1) PRIMARY KEY" + switch value.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + return "int" + suffix_str + case reflect.Int64, reflect.Uint64: + return "bigint" + suffix_str + default: + panic("Invalid primary key type") + } +} + +func (s *mssql) ReturningStr(key string) string { + return "" +} + +func (s *mssql) SelectFromDummyTable() string { + return "" +} + +func (s *mssql) Quote(key string) string { + return fmt.Sprintf(" \"%s\"", key) +} + +func (s *mssql) databaseName(scope *Scope) string { + dbStr := strings.Split(scope.db.parent.source, ";") + for _, value := range dbStr { + s := strings.Split(value, "=") + if s[0] == "database" { + return s[1] + } + } + return "" +} + +func (s *mssql) HasTable(scope *Scope, tableName string) bool { + var count int + newScope := scope.New(nil) + newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v AND table_catalog = %v", + newScope.AddToVars(tableName), + newScope.AddToVars(s.databaseName(scope)))) + newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) + return count > 0 +} + +func (s *mssql) HasColumn(scope *Scope, tableName string, columnName string) bool { + var count int + newScope := scope.New(nil) + newScope.Raw(fmt.Sprintf("SELECT count(*) FROM information_schema.columns WHERE TABLE_CATALOG = %v AND table_name = %v AND column_name = %v", + newScope.AddToVars(s.databaseName(scope)), + newScope.AddToVars(tableName), + newScope.AddToVars(columnName), + )) + newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) + return count > 0 +} + +func (s *mssql) RemoveIndex(scope *Scope, indexName string) { + scope.Raw(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Exec() +} diff --git a/mysql.go b/mysql.go index d0e9b26d..3263180a 100644 --- a/mysql.go +++ b/mysql.go @@ -3,8 +3,8 @@ package gorm import ( "fmt" "strings" - "reflect" + "strings" ) type mysql struct{} @@ -18,7 +18,7 @@ func (s *mysql) SupportLastInsertId() bool { } func (s *mysql) HasTop() bool { - return false + return false } func (d *mysql) SqlTag(value reflect.Value, size int) string { diff --git a/postgres.go b/postgres.go index 5ec82ae8..78161da5 100644 --- a/postgres.go +++ b/postgres.go @@ -5,8 +5,8 @@ import ( "database/sql/driver" "fmt" "reflect" - "github.com/lib/pq/hstore" + "reflect" ) type postgres struct { @@ -21,7 +21,7 @@ func (s *postgres) SupportLastInsertId() bool { } func (s *postgres) HasTop() bool { - return false + return false } func (d *postgres) SqlTag(value reflect.Value, size int) string { diff --git a/query_test.go b/query_test.go index dcf1d48c..5d54ccac 100644 --- a/query_test.go +++ b/query_test.go @@ -239,14 +239,19 @@ func TestOrderAndPluck(t *testing.T) { var ages []int64 scopedb.Order("age desc").Pluck("age", &ages) - if ages[0] != 20 { + if ages != nil { + if ages[0] != 20 { + t.Errorf("The first age should be 20 when order with age desc") + } + } else { t.Errorf("The first age should be 20 when order with age desc") } var ages1, ages2 []int64 scopedb.Order("age desc").Pluck("age", &ages1).Pluck("age", &ages2) if !reflect.DeepEqual(ages1, ages2) { - t.Errorf("The first order is the primary order") } + t.Errorf("The first order is the primary order") + } var ages3, ages4 []int64 scopedb.Model(&User{}).Order("age desc").Pluck("age", &ages3).Order("age", true).Pluck("age", &ages4) @@ -257,7 +262,11 @@ func TestOrderAndPluck(t *testing.T) { var names []string var ages5 []int64 scopedb.Model(User{}).Order("name").Order("age desc").Pluck("age", &ages5).Pluck("name", &names) - if !(names[0] == user1.Name && names[1] == user2.Name && names[2] == user3.Name && ages5[2] == 20) { + if names != nil && ages5 != nil { + if !(names[0] == user1.Name && names[1] == user2.Name && names[2] == user3.Name && ages5[2] == 20) { + t.Errorf("Order with multiple orders") + } + } else { t.Errorf("Order with multiple orders") } diff --git a/scope_private.go b/scope_private.go index ffed1c81..adea5f26 100644 --- a/scope_private.go +++ b/scope_private.go @@ -202,19 +202,19 @@ func (s *Scope) limitSql() string { } else { return " LIMIT " + s.Search.Limit } - } else{ + } else { return "" } } -func (s *Scope) topSql() string{ +func (s *Scope) topSql() string { if s.Dialect().HasTop() && len(s.Search.Offset) == 0 { if len(s.Search.Limit) == 0 { return "" - } else{ + } else { return " TOP(" + s.Search.Limit + ")" } - } else{ + } else { return "" } else { return " LIMIT " + s.Search.Limit @@ -225,13 +225,13 @@ func (s *Scope) offsetSql() string { if len(s.Search.Offset) == 0 { return "" } else { - if s.Dialect().HasTop(){ - sql := " OFFSET " + s.Search.Offset + " ROW " - if len(s.Search.Limit) > 0{ + if s.Dialect().HasTop() { + sql := " OFFSET " + s.Search.Offset + " ROW " + if len(s.Search.Limit) > 0 { sql += "FETCH NEXT " + s.Search.Limit + " ROWS ONLY" } return sql - }else{ + } else { return " OFFSET " + s.Search.Offset } } diff --git a/sqlite3.go b/sqlite3.go index dbd8db6b..e1e5d2d9 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -16,7 +16,7 @@ func (s *sqlite3) SupportLastInsertId() bool { } func (s *sqlite3) HasTop() bool { - return false + return false } func (s *sqlite3) SqlTag(value reflect.Value, size int) string { diff --git a/structs_test.go b/structs_test.go index 5ddc3b50..023adcd0 100644 --- a/structs_test.go +++ b/structs_test.go @@ -94,8 +94,13 @@ type Role struct { Name string } + func (role *Role) Scan(value interface{}) error { - role.Name = string(value.([]uint8)) + if b, ok := value.([]uint8); ok { + role.Name = string(b) + } else { + role.Name = value.(string) + } return nil }