Merge branch 'thecaddy-master'

This commit is contained in:
Jinzhu 2014-09-19 21:51:06 +08:00
commit 81874feef9
11 changed files with 188 additions and 12 deletions

View File

@ -16,6 +16,10 @@ func (s *commonDialect) SupportLastInsertId() bool {
return true return true
} }
func (s *commonDialect) HasTop() bool {
return false
}
func (d *commonDialect) SqlTag(value reflect.Value, size int) string { func (d *commonDialect) SqlTag(value reflect.Value, size int) string {
switch value.Kind() { switch value.Kind() {
case reflect.Bool: case reflect.Bool:

View File

@ -11,6 +11,7 @@ var timeType = reflect.TypeOf(time.Time{})
type Dialect interface { type Dialect interface {
BinVar(i int) string BinVar(i int) string
SupportLastInsertId() bool SupportLastInsertId() bool
HasTop() bool
SqlTag(value reflect.Value, size int) string SqlTag(value reflect.Value, size int) string
PrimaryKeyTag(value reflect.Value, size int) string PrimaryKeyTag(value reflect.Value, size int) string
ReturningStr(key string) string ReturningStr(key string) string
@ -30,6 +31,8 @@ func NewDialect(driver string) Dialect {
d = &mysql{} d = &mysql{}
case "sqlite3": case "sqlite3":
d = &sqlite3{} d = &sqlite3{}
case "mssql":
d = &mssql{}
default: default:
fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", driver) fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", driver)
d = &commonDialect{} d = &commonDialect{}

View File

@ -51,7 +51,9 @@ func Open(dialect string, drivesources ...string) (DB, error) {
source = drivesources[1] source = drivesources[1]
} }
db = DB{dialect: NewDialect(dialect), tagIdentifier: "sql", logger: defaultLogger, callback: DefaultCallback, source: source, values: map[string]interface{}{}} db = DB{dialect: NewDialect(dialect), tagIdentifier: "sql",
logger: defaultLogger, callback: DefaultCallback, source: source,
values: map[string]interface{}{}}
db.db, err = sql.Open(driver, source) db.db, err = sql.Open(driver, source)
db.parent = &db db.parent = &db
} }

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"strconv" "strconv"
_ "github.com/denisenkom/go-mssqldb"
testdb "github.com/erikstmartin/go-testdb" testdb "github.com/erikstmartin/go-testdb"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
@ -35,6 +36,9 @@ func init() {
case "postgres": case "postgres":
fmt.Println("testing postgres...") fmt.Println("testing postgres...")
DB, err = gorm.Open("postgres", "user=gorm DB.ame=gorm sslmode=disable") DB, err = gorm.Open("postgres", "user=gorm DB.ame=gorm sslmode=disable")
case "mssql":
fmt.Println("testing mssql...")
DB, err = gorm.Open("mssql", "server=SERVER_HERE;database=rogue;user id=USER_HERE;password=PW_HERE;port=1433")
default: default:
fmt.Println("testing sqlite3...") fmt.Println("testing sqlite3...")
DB, err = gorm.Open("sqlite3", "/tmp/gorm.db") DB, err = gorm.Open("sqlite3", "/tmp/gorm.db")
@ -445,6 +449,11 @@ func TestTimeWithZone(t *testing.T) {
for index, vtime := range times { for index, vtime := range times {
name := "time_with_zone_" + strconv.Itoa(index) name := "time_with_zone_" + strconv.Itoa(index)
user := User{Name: name, Birthday: vtime} user := User{Name: name, Birthday: vtime}
// TODO mssql does not support time zones
if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" {
user.Birthday = vtime.UTC()
}
DB.Save(&user) DB.Save(&user)
if user.Birthday.UTC().Format(format) != "2013-02-18 17:51:49 +0000" { if user.Birthday.UTC().Format(format) != "2013-02-18 17:51:49 +0000" {
t.Errorf("User's birthday should not be changed after save") t.Errorf("User's birthday should not be changed after save")

114
mssql.go Normal file
View File

@ -0,0 +1,114 @@
package gorm
import (
"fmt"
"reflect"
"strings"
)
type mssql struct{}
func (s *mssql) BinVar(i int) string {
return "$$" // ?
}
func (s *mssql) SupportLastInsertId() bool {
return true
}
func (s *mssql) HasTop() bool {
return true
}
func (d *mssql) SqlTag(value reflect.Value, size int) string {
switch value.Kind() {
case reflect.Bool:
return "bit"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "int"
case reflect.Int64, reflect.Uint64:
return "bigint"
case reflect.Float32, reflect.Float64:
return "float"
case reflect.String:
if size > 0 && size < 65532 {
return fmt.Sprintf("nvarchar(%d)", size)
} else {
return "text"
}
case reflect.Struct:
if value.Type() == timeType {
return "datetime2"
}
default:
if _, ok := value.Interface().([]byte); ok {
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
} else {
return "text"
}
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String()))
}
func (s *mssql) PrimaryKeyTag(value reflect.Value, size int) string {
suffix_str := " IDENTITY(1,1) PRIMARY KEY"
switch value.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "int" + suffix_str
case reflect.Int64, reflect.Uint64:
return "bigint" + suffix_str
default:
panic("Invalid primary key type")
}
}
func (s *mssql) ReturningStr(key string) string {
return ""
}
func (s *mssql) SelectFromDummyTable() string {
return ""
}
func (s *mssql) Quote(key string) string {
return fmt.Sprintf(" \"%s\"", key)
}
func (s *mssql) databaseName(scope *Scope) string {
dbStr := strings.Split(scope.db.parent.source, ";")
for _, value := range dbStr {
s := strings.Split(value, "=")
if s[0] == "database" {
return s[1]
}
}
return ""
}
func (s *mssql) HasTable(scope *Scope, tableName string) bool {
var count int
newScope := scope.New(nil)
newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v AND table_catalog = %v",
newScope.AddToVars(tableName),
newScope.AddToVars(s.databaseName(scope))))
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
return count > 0
}
func (s *mssql) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int
newScope := scope.New(nil)
newScope.Raw(fmt.Sprintf("SELECT count(*) FROM information_schema.columns WHERE TABLE_CATALOG = %v AND table_name = %v AND column_name = %v",
newScope.AddToVars(s.databaseName(scope)),
newScope.AddToVars(tableName),
newScope.AddToVars(columnName),
))
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
return count > 0
}
func (s *mssql) RemoveIndex(scope *Scope, indexName string) {
scope.Raw(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Exec()
}

View File

@ -3,7 +3,6 @@ package gorm
import ( import (
"fmt" "fmt"
"strings" "strings"
"reflect" "reflect"
) )
@ -17,6 +16,10 @@ func (s *mysql) SupportLastInsertId() bool {
return true return true
} }
func (s *mysql) HasTop() bool {
return false
}
func (d *mysql) SqlTag(value reflect.Value, size int) string { func (d *mysql) SqlTag(value reflect.Value, size int) string {
switch value.Kind() { switch value.Kind() {
case reflect.Bool: case reflect.Bool:

View File

@ -5,7 +5,6 @@ import (
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"reflect" "reflect"
"github.com/lib/pq/hstore" "github.com/lib/pq/hstore"
) )
@ -20,6 +19,10 @@ func (s *postgres) SupportLastInsertId() bool {
return false return false
} }
func (s *postgres) HasTop() bool {
return false
}
func (d *postgres) SqlTag(value reflect.Value, size int) string { func (d *postgres) SqlTag(value reflect.Value, size int) string {
switch value.Kind() { switch value.Kind() {
case reflect.Bool: case reflect.Bool:

View File

@ -2,9 +2,10 @@ package gorm_test
import ( import (
"fmt" "fmt"
"github.com/jinzhu/now"
"reflect" "reflect"
"github.com/jinzhu/now"
"testing" "testing"
"time" "time"
) )
@ -244,7 +245,7 @@ func TestOrderAndPluck(t *testing.T) {
} }
var ages1, ages2 []int64 var ages1, ages2 []int64
scopedb.Order("age desc").Pluck("age", &ages1).Order("age").Pluck("age", &ages2) scopedb.Order("age desc").Pluck("age", &ages1).Pluck("age", &ages2)
if !reflect.DeepEqual(ages1, ages2) { if !reflect.DeepEqual(ages1, ages2) {
t.Errorf("The first order is the primary order") t.Errorf("The first order is the primary order")
} }
@ -258,7 +259,11 @@ func TestOrderAndPluck(t *testing.T) {
var names []string var names []string
var ages5 []int64 var ages5 []int64
scopedb.Model(User{}).Order("name").Order("age desc").Pluck("age", &ages5).Pluck("name", &names) scopedb.Model(User{}).Order("name").Order("age desc").Pluck("age", &ages5).Pluck("name", &names)
if !(names[0] == user1.Name && names[1] == user2.Name && names[2] == user3.Name && ages5[2] == 20) { if names != nil && ages5 != nil {
if !(names[0] == user1.Name && names[1] == user2.Name && names[2] == user3.Name && ages5[2] == 20) {
t.Errorf("Order with multiple orders")
}
} else {
t.Errorf("Order with multiple orders") t.Errorf("Order with multiple orders")
} }

View File

@ -196,10 +196,26 @@ func (s *Scope) orderSql() string {
} }
func (s *Scope) limitSql() string { func (s *Scope) limitSql() string {
if len(s.Search.Limit) == 0 { if !s.Dialect().HasTop() {
return "" if len(s.Search.Limit) == 0 {
return ""
} else {
return " LIMIT " + s.Search.Limit
}
} else { } else {
return " LIMIT " + s.Search.Limit return ""
}
}
func (s *Scope) topSql() string {
if s.Dialect().HasTop() && len(s.Search.Offset) == 0 {
if len(s.Search.Limit) == 0 {
return ""
} else {
return " TOP(" + s.Search.Limit + ")"
}
} else {
return ""
} }
} }
@ -207,7 +223,15 @@ func (s *Scope) offsetSql() string {
if len(s.Search.Offset) == 0 { if len(s.Search.Offset) == 0 {
return "" return ""
} else { } else {
return " OFFSET " + s.Search.Offset if s.Dialect().HasTop() {
sql := " OFFSET " + s.Search.Offset + " ROW "
if len(s.Search.Limit) > 0 {
sql += "FETCH NEXT " + s.Search.Limit + " ROWS ONLY"
}
return sql
} else {
return " OFFSET " + s.Search.Offset
}
} }
} }
@ -235,7 +259,7 @@ func (scope *Scope) prepareQuerySql() {
if scope.Search.Raw { if scope.Search.Raw {
scope.Raw(strings.TrimLeft(scope.CombinedConditionSql(), "WHERE ")) scope.Raw(strings.TrimLeft(scope.CombinedConditionSql(), "WHERE "))
} else { } else {
scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql())) scope.Raw(fmt.Sprintf("SELECT %v %v FROM %v %v", scope.topSql(), scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql()))
} }
return return
} }

View File

@ -15,6 +15,10 @@ func (s *sqlite3) SupportLastInsertId() bool {
return true return true
} }
func (s *sqlite3) HasTop() bool {
return false
}
func (s *sqlite3) SqlTag(value reflect.Value, size int) string { func (s *sqlite3) SqlTag(value reflect.Value, size int) string {
switch value.Kind() { switch value.Kind() {
case reflect.Bool: case reflect.Bool:

View File

@ -94,8 +94,13 @@ type Role struct {
Name string Name string
} }
func (role *Role) Scan(value interface{}) error { func (role *Role) Scan(value interface{}) error {
role.Name = string(value.([]uint8)) if b, ok := value.([]uint8); ok {
role.Name = string(b)
} else {
role.Name = value.(string)
}
return nil return nil
} }