Test Upsert

This commit is contained in:
Jinzhu 2020-05-30 13:46:33 +08:00
parent cc07ee0444
commit 05e1af3bfb
2 changed files with 58 additions and 0 deletions

View File

@ -6,6 +6,7 @@ import (
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/callbacks"
"github.com/jinzhu/gorm/clause"
"github.com/jinzhu/gorm/schema"
)
func Create(db *gorm.DB) {
@ -85,6 +86,7 @@ func Create(db *gorm.DB) {
values := db.Statement.Schema.PrioritizedPrimaryField.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
if rows.Next() {
db.RowsAffected++
err = rows.Scan(values)
}
}
@ -95,6 +97,16 @@ func Create(db *gorm.DB) {
func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) {
values := callbacks.ConvertToCreateValues(db.Statement)
setIdentityInsert := false
if field := db.Statement.Schema.PrioritizedPrimaryField; field != nil {
if field.DataType == schema.Int || field.DataType == schema.Uint {
setIdentityInsert = true
db.Statement.WriteString("SET IDENTITY_INSERT ")
db.Statement.WriteQuoted(db.Statement.Table)
db.Statement.WriteString("ON;")
}
}
db.Statement.WriteString("MERGE INTO ")
db.Statement.WriteQuoted(db.Statement.Table)
@ -156,6 +168,12 @@ func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) {
db.Statement.WriteString(")")
outputInserted(db)
db.Statement.WriteString(";")
if setIdentityInsert {
db.Statement.WriteString("SET IDENTITY_INSERT ")
db.Statement.WriteQuoted(db.Statement.Table)
db.Statement.WriteString("OFF;")
}
}
func outputInserted(db *gorm.DB) {

View File

@ -4,9 +4,49 @@ import (
"testing"
"time"
"github.com/jinzhu/gorm/clause"
. "github.com/jinzhu/gorm/tests"
)
func TestUpsert(t *testing.T) {
lang := Language{Code: "upsert", Name: "Upsert"}
DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&lang)
lang2 := Language{Code: "upsert", Name: "Upsert"}
DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&lang2)
var langs []Language
if err := DB.Find(&langs, "code = ?", lang.Code).Error; err != nil {
t.Errorf("no error should happen when find languages with code, but got %v", err)
} else if len(langs) != 1 {
t.Errorf("should only find only 1 languages, but got %+v", langs)
}
}
func TestUpsertSlice(t *testing.T) {
langs := []Language{
{Code: "upsert-slice1", Name: "Upsert-slice1"},
{Code: "upsert-slice2", Name: "Upsert-slice2"},
{Code: "upsert-slice3", Name: "Upsert-slice3"},
}
DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&langs)
var langs2 []Language
if err := DB.Find(&langs2, "code LIKE ?", "upsert-slice%").Error; err != nil {
t.Errorf("no error should happen when find languages with code, but got %v", err)
} else if len(langs2) != 3 {
t.Errorf("should only find only 3 languages, but got %+v", langs2)
}
DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&langs)
var langs3 []Language
if err := DB.Find(&langs3, "code LIKE ?", "upsert-slice%").Error; err != nil {
t.Errorf("no error should happen when find languages with code, but got %v", err)
} else if len(langs3) != 3 {
t.Errorf("should only find only 3 languages, but got %+v", langs3)
}
}
func TestFindOrInitialize(t *testing.T) {
var user1, user2, user3, user4, user5, user6 User
if err := DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1).Error; err != nil {