Fix tests with mssql

This commit is contained in:
Jinzhu 2017-03-22 22:57:13 +08:00
parent 403487d5dd
commit 5730b92954
9 changed files with 73 additions and 34 deletions

View File

@ -58,12 +58,20 @@ func TestCreate(t *testing.T) {
} }
} }
type AutoIncrementUser struct {
User
Sequence uint `gorm:"AUTO_INCREMENT"`
}
func TestCreateWithAutoIncrement(t *testing.T) { func TestCreateWithAutoIncrement(t *testing.T) {
if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" { if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" {
t.Skip("Skipping this because only postgres properly support auto_increment on a non-primary_key column") t.Skip("Skipping this because only postgres properly support auto_increment on a non-primary_key column")
} }
user1 := User{}
user2 := User{} DB.AutoMigrate(&AutoIncrementUser{})
user1 := AutoIncrementUser{}
user2 := AutoIncrementUser{}
DB.Create(&user1) DB.Create(&user1)
DB.Create(&user2) DB.Create(&user2)
@ -126,7 +134,7 @@ func TestAnonymousScanner(t *testing.T) {
t.Errorf("Should be able to get anonymous scanner") t.Errorf("Should be able to get anonymous scanner")
} }
if !user2.IsAdmin() { if !user2.Role.IsAdmin() {
t.Errorf("Should be able to get anonymous scanner") t.Errorf("Should be able to get anonymous scanner")
} }
} }

View File

@ -11,7 +11,28 @@ import (
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
) )
func setIdentityInsert(scope *gorm.Scope) {
if scope.Dialect().GetName() == "mssql" {
for _, field := range scope.PrimaryFields() {
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsBlank {
scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName()))
scope.InstanceSet("mssql:identity_insert_on", true)
}
}
}
}
func turnOffIdentityInsert(scope *gorm.Scope) {
if scope.Dialect().GetName() == "mssql" {
if _, ok := scope.InstanceGet("mssql:identity_insert_on"); ok {
scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v OFF", scope.TableName()))
}
}
}
func init() { func init() {
gorm.DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert)
gorm.DefaultCallback.Create().Before("gorm:commit_or_rollback_transaction").Register("mssql:turn_off_identity_insert", turnOffIdentityInsert)
gorm.RegisterDialect("mssql", &mssql{}) gorm.RegisterDialect("mssql", &mssql{})
} }
@ -45,12 +66,14 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string {
sqlType = "bit" sqlType = "bit"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "int IDENTITY(1,1)" sqlType = "int IDENTITY(1,1)"
} else { } else {
sqlType = "int" sqlType = "int"
} }
case reflect.Int64, reflect.Uint64: case reflect.Int64, reflect.Uint64:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "bigint IDENTITY(1,1)" sqlType = "bigint IDENTITY(1,1)"
} else { } else {
sqlType = "bigint" sqlType = "bigint"

View File

@ -2,6 +2,7 @@ package gorm_test
import ( import (
"fmt" "fmt"
"strconv"
"testing" "testing"
"time" "time"
@ -23,14 +24,23 @@ type PersonAddress struct {
} }
func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, foreignValue interface{}, associationValue interface{}) error { func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, foreignValue interface{}, associationValue interface{}) error {
return db.Where(map[string]interface{}{ foreignPrimaryKey, _ := strconv.Atoi(fmt.Sprint(db.NewScope(foreignValue).PrimaryKeyValue()))
"person_id": db.NewScope(foreignValue).PrimaryKeyValue(), associationPrimaryKey, _ := strconv.Atoi(fmt.Sprint(db.NewScope(associationValue).PrimaryKeyValue()))
"address_id": db.NewScope(associationValue).PrimaryKeyValue(), if result := db.Unscoped().Model(&PersonAddress{}).Where(map[string]interface{}{
}).Assign(map[string]interface{}{ "person_id": foreignPrimaryKey,
"person_id": foreignValue, "address_id": associationPrimaryKey,
"address_id": associationValue, }).Update(map[string]interface{}{
"person_id": foreignPrimaryKey,
"address_id": associationPrimaryKey,
"deleted_at": gorm.Expr("NULL"), "deleted_at": gorm.Expr("NULL"),
}).FirstOrCreate(&PersonAddress{}).Error }).RowsAffected; result == 0 {
return db.Create(&PersonAddress{
PersonID: foreignPrimaryKey,
AddressID: associationPrimaryKey,
}).Error
}
return nil
} }
func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB, sources ...interface{}) error { func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB, sources ...interface{}) error {

View File

@ -821,11 +821,11 @@ func BenchmarkGorm(b *testing.B) {
for x := 0; x < b.N; x++ { for x := 0; x < b.N; x++ {
e := strconv.Itoa(x) + "benchmark@example.org" e := strconv.Itoa(x) + "benchmark@example.org"
now := time.Now() now := time.Now()
email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: &now} email := EmailWithIdx{Email: e, UserAgent: "pc", RegisteredAt: &now}
// Insert // Insert
DB.Save(&email) DB.Save(&email)
// Query // Query
DB.First(&BigEmail{}, "email = ?", e) DB.First(&EmailWithIdx{}, "email = ?", e)
// Update // Update
DB.Model(&email).UpdateColumn("email", "new-"+e) DB.Model(&email).UpdateColumn("email", "new-"+e)
// Delete // Delete
@ -846,7 +846,7 @@ func BenchmarkRawSql(b *testing.B) {
var id int64 var id int64
e := strconv.Itoa(x) + "benchmark@example.org" e := strconv.Itoa(x) + "benchmark@example.org"
now := time.Now() now := time.Now()
email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: &now} email := EmailWithIdx{Email: e, UserAgent: "pc", RegisteredAt: &now}
// Insert // Insert
DB.QueryRow(insertSql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id) DB.QueryRow(insertSql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id)
// Query // Query
@ -860,6 +860,6 @@ func BenchmarkRawSql(b *testing.B) {
} }
func parseTime(str string) *time.Time { func parseTime(str string) *time.Time {
t := now.MustParse(str) t := now.New(time.Now().UTC()).MustParse(str)
return &t return &t
} }

View File

@ -31,9 +31,8 @@ type User struct {
Languages []Language `gorm:"many2many:user_languages;"` Languages []Language `gorm:"many2many:user_languages;"`
CompanyID *int CompanyID *int
Company Company Company Company
Role Role Role
PasswordHash []byte PasswordHash []byte
Sequence uint `gorm:"AUTO_INCREMENT"`
IgnoreMe int64 `sql:"-"` IgnoreMe int64 `sql:"-"`
IgnoreStringSlice []string `sql:"-"` IgnoreStringSlice []string `sql:"-"`
Ignored struct{ Name string } `sql:"-"` Ignored struct{ Name string } `sql:"-"`
@ -333,7 +332,7 @@ func TestIndexes(t *testing.T) {
} }
} }
type BigEmail struct { type EmailWithIdx struct {
Id int64 Id int64
UserId int64 UserId int64
Email string `sql:"index:idx_email_agent"` Email string `sql:"index:idx_email_agent"`
@ -343,29 +342,26 @@ type BigEmail struct {
UpdatedAt time.Time UpdatedAt time.Time
} }
func (b BigEmail) TableName() string {
return "emails"
}
func TestAutoMigration(t *testing.T) { func TestAutoMigration(t *testing.T) {
DB.AutoMigrate(&Address{}) DB.AutoMigrate(&Address{})
if err := DB.Table("emails").AutoMigrate(&BigEmail{}).Error; err != nil { DB.DropTable(&EmailWithIdx{})
if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil {
t.Errorf("Auto Migrate should not raise any error") t.Errorf("Auto Migrate should not raise any error")
} }
now := time.Now() now := time.Now()
DB.Save(&BigEmail{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now}) DB.Save(&EmailWithIdx{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now})
scope := DB.NewScope(&BigEmail{}) scope := DB.NewScope(&EmailWithIdx{})
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") { if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") {
t.Errorf("Failed to create index") t.Errorf("Failed to create index")
} }
if !scope.Dialect().HasIndex(scope.TableName(), "uix_emails_registered_at") { if !scope.Dialect().HasIndex(scope.TableName(), "uix_email_with_idxes_registered_at") {
t.Errorf("Failed to create index") t.Errorf("Failed to create index")
} }
var bigemail BigEmail var bigemail EmailWithIdx
DB.First(&bigemail, "user_agent = ?", "pc") DB.First(&bigemail, "user_agent = ?", "pc")
if bigemail.Email != "jinzhu@example.org" || bigemail.UserAgent != "pc" || bigemail.RegisteredAt.IsZero() { if bigemail.Email != "jinzhu@example.org" || bigemail.UserAgent != "pc" || bigemail.RegisteredAt.IsZero() {
t.Error("Big Emails should be saved and fetched correctly") t.Error("Big Emails should be saved and fetched correctly")
@ -386,7 +382,7 @@ func TestMultipleIndexes(t *testing.T) {
} }
DB.AutoMigrate(&MultipleIndexes{}) DB.AutoMigrate(&MultipleIndexes{})
if err := DB.AutoMigrate(&BigEmail{}).Error; err != nil { if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil {
t.Errorf("Auto Migrate should not raise any error") t.Errorf("Auto Migrate should not raise any error")
} }

View File

@ -35,7 +35,7 @@ func compareTags(tags []Tag, contents []string) bool {
} }
func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { func TestManyToManyWithMultiPrimaryKeys(t *testing.T) {
if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" { if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" && dialect != "mssql" {
DB.DropTable(&Blog{}, &Tag{}) DB.DropTable(&Blog{}, &Tag{})
DB.DropTable("blog_tags") DB.DropTable("blog_tags")
DB.CreateTable(&Blog{}, &Tag{}) DB.CreateTable(&Blog{}, &Tag{})
@ -119,7 +119,7 @@ func TestManyToManyWithMultiPrimaryKeys(t *testing.T) {
} }
func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { func TestManyToManyWithCustomizedForeignKeys(t *testing.T) {
if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" { if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" && dialect != "mssql" {
DB.DropTable(&Blog{}, &Tag{}) DB.DropTable(&Blog{}, &Tag{})
DB.DropTable("shared_blog_tags") DB.DropTable("shared_blog_tags")
DB.CreateTable(&Blog{}, &Tag{}) DB.CreateTable(&Blog{}, &Tag{})
@ -236,7 +236,7 @@ func TestManyToManyWithCustomizedForeignKeys(t *testing.T) {
} }
func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) {
if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" { if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" && dialect != "mssql" {
DB.DropTable(&Blog{}, &Tag{}) DB.DropTable(&Blog{}, &Tag{})
DB.DropTable("locale_blog_tags") DB.DropTable("locale_blog_tags")
DB.CreateTable(&Blog{}, &Tag{}) DB.CreateTable(&Blog{}, &Tag{})

View File

@ -798,7 +798,7 @@ func TestNestedPreload12(t *testing.T) {
} }
func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" { if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" || dialect == "mssql" {
return return
} }

View File

@ -326,7 +326,7 @@ func TestOrderAndPluck(t *testing.T) {
scopedb := DB.Model(&User{}).Where("name like ?", "%OrderPluckUser%") scopedb := DB.Model(&User{}).Where("name like ?", "%OrderPluckUser%")
var user User var user User
scopedb.Order(gorm.Expr("name = ? DESC", "OrderPluckUser2")).First(&user) scopedb.Order(gorm.Expr("case when name = ? then 0 else 1 end", "OrderPluckUser2")).First(&user)
if user.Name != "OrderPluckUser2" { if user.Name != "OrderPluckUser2" {
t.Errorf("Order with sql expression") t.Errorf("Order with sql expression")
} }

View File

@ -50,7 +50,8 @@ type RecordWithSlice struct {
type ExampleStringSlice []string type ExampleStringSlice []string
func (l ExampleStringSlice) Value() (driver.Value, error) { func (l ExampleStringSlice) Value() (driver.Value, error) {
return json.Marshal(l) bytes, err := json.Marshal(l)
return string(bytes), err
} }
func (l *ExampleStringSlice) Scan(input interface{}) error { func (l *ExampleStringSlice) Scan(input interface{}) error {
@ -72,7 +73,8 @@ type ExampleStruct struct {
type ExampleStructSlice []ExampleStruct type ExampleStructSlice []ExampleStruct
func (l ExampleStructSlice) Value() (driver.Value, error) { func (l ExampleStructSlice) Value() (driver.Value, error) {
return json.Marshal(l) bytes, err := json.Marshal(l)
return string(bytes), err
} }
func (l *ExampleStructSlice) Scan(input interface{}) error { func (l *ExampleStructSlice) Scan(input interface{}) error {