diff --git a/dialects/mssql/create.go b/dialects/mssql/create.go index 9183ba76..b17a2227 100644 --- a/dialects/mssql/create.go +++ b/dialects/mssql/create.go @@ -6,6 +6,7 @@ import ( "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" "github.com/jinzhu/gorm/clause" + "github.com/jinzhu/gorm/schema" ) func Create(db *gorm.DB) { @@ -85,6 +86,7 @@ func Create(db *gorm.DB) { values := db.Statement.Schema.PrioritizedPrimaryField.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() if rows.Next() { + db.RowsAffected++ err = rows.Scan(values) } } @@ -95,6 +97,16 @@ func Create(db *gorm.DB) { func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) { values := callbacks.ConvertToCreateValues(db.Statement) + setIdentityInsert := false + + if field := db.Statement.Schema.PrioritizedPrimaryField; field != nil { + if field.DataType == schema.Int || field.DataType == schema.Uint { + setIdentityInsert = true + db.Statement.WriteString("SET IDENTITY_INSERT ") + db.Statement.WriteQuoted(db.Statement.Table) + db.Statement.WriteString("ON;") + } + } db.Statement.WriteString("MERGE INTO ") db.Statement.WriteQuoted(db.Statement.Table) @@ -156,6 +168,12 @@ func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) { db.Statement.WriteString(")") outputInserted(db) db.Statement.WriteString(";") + + if setIdentityInsert { + db.Statement.WriteString("SET IDENTITY_INSERT ") + db.Statement.WriteQuoted(db.Statement.Table) + db.Statement.WriteString("OFF;") + } } func outputInserted(db *gorm.DB) { diff --git a/tests/upsert_test.go b/tests/upsert_test.go index 615ead95..6f67f603 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -4,9 +4,49 @@ import ( "testing" "time" + "github.com/jinzhu/gorm/clause" . "github.com/jinzhu/gorm/tests" ) +func TestUpsert(t *testing.T) { + lang := Language{Code: "upsert", Name: "Upsert"} + DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&lang) + + lang2 := Language{Code: "upsert", Name: "Upsert"} + DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&lang2) + + var langs []Language + if err := DB.Find(&langs, "code = ?", lang.Code).Error; err != nil { + t.Errorf("no error should happen when find languages with code, but got %v", err) + } else if len(langs) != 1 { + t.Errorf("should only find only 1 languages, but got %+v", langs) + } +} + +func TestUpsertSlice(t *testing.T) { + langs := []Language{ + {Code: "upsert-slice1", Name: "Upsert-slice1"}, + {Code: "upsert-slice2", Name: "Upsert-slice2"}, + {Code: "upsert-slice3", Name: "Upsert-slice3"}, + } + DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&langs) + + var langs2 []Language + if err := DB.Find(&langs2, "code LIKE ?", "upsert-slice%").Error; err != nil { + t.Errorf("no error should happen when find languages with code, but got %v", err) + } else if len(langs2) != 3 { + t.Errorf("should only find only 3 languages, but got %+v", langs2) + } + + DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&langs) + var langs3 []Language + if err := DB.Find(&langs3, "code LIKE ?", "upsert-slice%").Error; err != nil { + t.Errorf("no error should happen when find languages with code, but got %v", err) + } else if len(langs3) != 3 { + t.Errorf("should only find only 3 languages, but got %+v", langs3) + } +} + func TestFindOrInitialize(t *testing.T) { var user1, user2, user3, user4, user5, user6 User if err := DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1).Error; err != nil {