diff --git a/create_test.go b/create_test.go index d67d34fc..7aa181ce 100644 --- a/create_test.go +++ b/create_test.go @@ -58,12 +58,20 @@ func TestCreate(t *testing.T) { } } +type AutoIncrementUser struct { + User + Sequence uint `gorm:"AUTO_INCREMENT"` +} + func TestCreateWithAutoIncrement(t *testing.T) { if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" { t.Skip("Skipping this because only postgres properly support auto_increment on a non-primary_key column") } - user1 := User{} - user2 := User{} + + DB.AutoMigrate(&AutoIncrementUser{}) + + user1 := AutoIncrementUser{} + user2 := AutoIncrementUser{} DB.Create(&user1) DB.Create(&user2) @@ -126,7 +134,7 @@ func TestAnonymousScanner(t *testing.T) { t.Errorf("Should be able to get anonymous scanner") } - if !user2.IsAdmin() { + if !user2.Role.IsAdmin() { t.Errorf("Should be able to get anonymous scanner") } } diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 7541b222..f9087495 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -11,7 +11,28 @@ import ( "github.com/jinzhu/gorm" ) +func setIdentityInsert(scope *gorm.Scope) { + if scope.Dialect().GetName() == "mssql" { + for _, field := range scope.PrimaryFields() { + if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsBlank { + scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName())) + scope.InstanceSet("mssql:identity_insert_on", true) + } + } + } +} + +func turnOffIdentityInsert(scope *gorm.Scope) { + if scope.Dialect().GetName() == "mssql" { + if _, ok := scope.InstanceGet("mssql:identity_insert_on"); ok { + scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v OFF", scope.TableName())) + } + } +} + func init() { + gorm.DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert) + gorm.DefaultCallback.Create().Before("gorm:commit_or_rollback_transaction").Register("mssql:turn_off_identity_insert", turnOffIdentityInsert) gorm.RegisterDialect("mssql", &mssql{}) } @@ -45,12 +66,14 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string { sqlType = "bit" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "int IDENTITY(1,1)" } else { sqlType = "int" } case reflect.Int64, reflect.Uint64: if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "bigint IDENTITY(1,1)" } else { sqlType = "bigint" diff --git a/join_table_test.go b/join_table_test.go index 1a83a9c8..f083ab02 100644 --- a/join_table_test.go +++ b/join_table_test.go @@ -2,6 +2,7 @@ package gorm_test import ( "fmt" + "strconv" "testing" "time" @@ -23,14 +24,23 @@ type PersonAddress struct { } func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, foreignValue interface{}, associationValue interface{}) error { - return db.Where(map[string]interface{}{ - "person_id": db.NewScope(foreignValue).PrimaryKeyValue(), - "address_id": db.NewScope(associationValue).PrimaryKeyValue(), - }).Assign(map[string]interface{}{ - "person_id": foreignValue, - "address_id": associationValue, + foreignPrimaryKey, _ := strconv.Atoi(fmt.Sprint(db.NewScope(foreignValue).PrimaryKeyValue())) + associationPrimaryKey, _ := strconv.Atoi(fmt.Sprint(db.NewScope(associationValue).PrimaryKeyValue())) + if result := db.Unscoped().Model(&PersonAddress{}).Where(map[string]interface{}{ + "person_id": foreignPrimaryKey, + "address_id": associationPrimaryKey, + }).Update(map[string]interface{}{ + "person_id": foreignPrimaryKey, + "address_id": associationPrimaryKey, "deleted_at": gorm.Expr("NULL"), - }).FirstOrCreate(&PersonAddress{}).Error + }).RowsAffected; result == 0 { + return db.Create(&PersonAddress{ + PersonID: foreignPrimaryKey, + AddressID: associationPrimaryKey, + }).Error + } + + return nil } func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB, sources ...interface{}) error { diff --git a/main_test.go b/main_test.go index 32e8c0c9..3b1433cf 100644 --- a/main_test.go +++ b/main_test.go @@ -821,11 +821,11 @@ func BenchmarkGorm(b *testing.B) { for x := 0; x < b.N; x++ { e := strconv.Itoa(x) + "benchmark@example.org" now := time.Now() - email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: &now} + email := EmailWithIdx{Email: e, UserAgent: "pc", RegisteredAt: &now} // Insert DB.Save(&email) // Query - DB.First(&BigEmail{}, "email = ?", e) + DB.First(&EmailWithIdx{}, "email = ?", e) // Update DB.Model(&email).UpdateColumn("email", "new-"+e) // Delete @@ -846,7 +846,7 @@ func BenchmarkRawSql(b *testing.B) { var id int64 e := strconv.Itoa(x) + "benchmark@example.org" now := time.Now() - email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: &now} + email := EmailWithIdx{Email: e, UserAgent: "pc", RegisteredAt: &now} // Insert DB.QueryRow(insertSql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id) // Query @@ -860,6 +860,6 @@ func BenchmarkRawSql(b *testing.B) { } func parseTime(str string) *time.Time { - t := now.MustParse(str) + t := now.New(time.Now().UTC()).MustParse(str) return &t } diff --git a/migration_test.go b/migration_test.go index 95c2c571..9fc14fa0 100644 --- a/migration_test.go +++ b/migration_test.go @@ -31,9 +31,8 @@ type User struct { Languages []Language `gorm:"many2many:user_languages;"` CompanyID *int Company Company - Role + Role Role PasswordHash []byte - Sequence uint `gorm:"AUTO_INCREMENT"` IgnoreMe int64 `sql:"-"` IgnoreStringSlice []string `sql:"-"` Ignored struct{ Name string } `sql:"-"` @@ -333,7 +332,7 @@ func TestIndexes(t *testing.T) { } } -type BigEmail struct { +type EmailWithIdx struct { Id int64 UserId int64 Email string `sql:"index:idx_email_agent"` @@ -343,29 +342,26 @@ type BigEmail struct { UpdatedAt time.Time } -func (b BigEmail) TableName() string { - return "emails" -} - func TestAutoMigration(t *testing.T) { DB.AutoMigrate(&Address{}) - if err := DB.Table("emails").AutoMigrate(&BigEmail{}).Error; err != nil { + DB.DropTable(&EmailWithIdx{}) + if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil { t.Errorf("Auto Migrate should not raise any error") } now := time.Now() - DB.Save(&BigEmail{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now}) + DB.Save(&EmailWithIdx{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now}) - scope := DB.NewScope(&BigEmail{}) + scope := DB.NewScope(&EmailWithIdx{}) if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") { t.Errorf("Failed to create index") } - if !scope.Dialect().HasIndex(scope.TableName(), "uix_emails_registered_at") { + if !scope.Dialect().HasIndex(scope.TableName(), "uix_email_with_idxes_registered_at") { t.Errorf("Failed to create index") } - var bigemail BigEmail + var bigemail EmailWithIdx DB.First(&bigemail, "user_agent = ?", "pc") if bigemail.Email != "jinzhu@example.org" || bigemail.UserAgent != "pc" || bigemail.RegisteredAt.IsZero() { t.Error("Big Emails should be saved and fetched correctly") @@ -386,7 +382,7 @@ func TestMultipleIndexes(t *testing.T) { } DB.AutoMigrate(&MultipleIndexes{}) - if err := DB.AutoMigrate(&BigEmail{}).Error; err != nil { + if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil { t.Errorf("Auto Migrate should not raise any error") } diff --git a/multi_primary_keys_test.go b/multi_primary_keys_test.go index 8b275d18..32a14772 100644 --- a/multi_primary_keys_test.go +++ b/multi_primary_keys_test.go @@ -35,7 +35,7 @@ func compareTags(tags []Tag, contents []string) bool { } func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" { + if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" && dialect != "mssql" { DB.DropTable(&Blog{}, &Tag{}) DB.DropTable("blog_tags") DB.CreateTable(&Blog{}, &Tag{}) @@ -119,7 +119,7 @@ func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { } func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" { + if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" && dialect != "mssql" { DB.DropTable(&Blog{}, &Tag{}) DB.DropTable("shared_blog_tags") DB.CreateTable(&Blog{}, &Tag{}) @@ -236,7 +236,7 @@ func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { } func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" { + if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" && dialect != "mssql" { DB.DropTable(&Blog{}, &Tag{}) DB.DropTable("locale_blog_tags") DB.CreateTable(&Blog{}, &Tag{}) diff --git a/preload_test.go b/preload_test.go index 8b8b39b8..c830025c 100644 --- a/preload_test.go +++ b/preload_test.go @@ -798,7 +798,7 @@ func TestNestedPreload12(t *testing.T) { } func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" { + if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" || dialect == "mssql" { return } diff --git a/query_test.go b/query_test.go index d6b23ddf..866d81d2 100644 --- a/query_test.go +++ b/query_test.go @@ -326,7 +326,7 @@ func TestOrderAndPluck(t *testing.T) { scopedb := DB.Model(&User{}).Where("name like ?", "%OrderPluckUser%") var user User - scopedb.Order(gorm.Expr("name = ? DESC", "OrderPluckUser2")).First(&user) + scopedb.Order(gorm.Expr("case when name = ? then 0 else 1 end", "OrderPluckUser2")).First(&user) if user.Name != "OrderPluckUser2" { t.Errorf("Order with sql expression") } diff --git a/scaner_test.go b/scaner_test.go index fae9d3e1..9e251dd6 100644 --- a/scaner_test.go +++ b/scaner_test.go @@ -50,7 +50,8 @@ type RecordWithSlice struct { type ExampleStringSlice []string func (l ExampleStringSlice) Value() (driver.Value, error) { - return json.Marshal(l) + bytes, err := json.Marshal(l) + return string(bytes), err } func (l *ExampleStringSlice) Scan(input interface{}) error { @@ -72,7 +73,8 @@ type ExampleStruct struct { type ExampleStructSlice []ExampleStruct func (l ExampleStructSlice) Value() (driver.Value, error) { - return json.Marshal(l) + bytes, err := json.Marshal(l) + return string(bytes), err } func (l *ExampleStructSlice) Scan(input interface{}) error {