forked from mirror/gorm
Test Upsert
This commit is contained in:
parent
cc07ee0444
commit
05e1af3bfb
|
@ -6,6 +6,7 @@ import (
|
|||
"github.com/jinzhu/gorm"
|
||||
"github.com/jinzhu/gorm/callbacks"
|
||||
"github.com/jinzhu/gorm/clause"
|
||||
"github.com/jinzhu/gorm/schema"
|
||||
)
|
||||
|
||||
func Create(db *gorm.DB) {
|
||||
|
@ -85,6 +86,7 @@ func Create(db *gorm.DB) {
|
|||
values := db.Statement.Schema.PrioritizedPrimaryField.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
||||
|
||||
if rows.Next() {
|
||||
db.RowsAffected++
|
||||
err = rows.Scan(values)
|
||||
}
|
||||
}
|
||||
|
@ -95,6 +97,16 @@ func Create(db *gorm.DB) {
|
|||
|
||||
func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) {
|
||||
values := callbacks.ConvertToCreateValues(db.Statement)
|
||||
setIdentityInsert := false
|
||||
|
||||
if field := db.Statement.Schema.PrioritizedPrimaryField; field != nil {
|
||||
if field.DataType == schema.Int || field.DataType == schema.Uint {
|
||||
setIdentityInsert = true
|
||||
db.Statement.WriteString("SET IDENTITY_INSERT ")
|
||||
db.Statement.WriteQuoted(db.Statement.Table)
|
||||
db.Statement.WriteString("ON;")
|
||||
}
|
||||
}
|
||||
|
||||
db.Statement.WriteString("MERGE INTO ")
|
||||
db.Statement.WriteQuoted(db.Statement.Table)
|
||||
|
@ -156,6 +168,12 @@ func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) {
|
|||
db.Statement.WriteString(")")
|
||||
outputInserted(db)
|
||||
db.Statement.WriteString(";")
|
||||
|
||||
if setIdentityInsert {
|
||||
db.Statement.WriteString("SET IDENTITY_INSERT ")
|
||||
db.Statement.WriteQuoted(db.Statement.Table)
|
||||
db.Statement.WriteString("OFF;")
|
||||
}
|
||||
}
|
||||
|
||||
func outputInserted(db *gorm.DB) {
|
||||
|
|
|
@ -4,9 +4,49 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jinzhu/gorm/clause"
|
||||
. "github.com/jinzhu/gorm/tests"
|
||||
)
|
||||
|
||||
func TestUpsert(t *testing.T) {
|
||||
lang := Language{Code: "upsert", Name: "Upsert"}
|
||||
DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&lang)
|
||||
|
||||
lang2 := Language{Code: "upsert", Name: "Upsert"}
|
||||
DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&lang2)
|
||||
|
||||
var langs []Language
|
||||
if err := DB.Find(&langs, "code = ?", lang.Code).Error; err != nil {
|
||||
t.Errorf("no error should happen when find languages with code, but got %v", err)
|
||||
} else if len(langs) != 1 {
|
||||
t.Errorf("should only find only 1 languages, but got %+v", langs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertSlice(t *testing.T) {
|
||||
langs := []Language{
|
||||
{Code: "upsert-slice1", Name: "Upsert-slice1"},
|
||||
{Code: "upsert-slice2", Name: "Upsert-slice2"},
|
||||
{Code: "upsert-slice3", Name: "Upsert-slice3"},
|
||||
}
|
||||
DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&langs)
|
||||
|
||||
var langs2 []Language
|
||||
if err := DB.Find(&langs2, "code LIKE ?", "upsert-slice%").Error; err != nil {
|
||||
t.Errorf("no error should happen when find languages with code, but got %v", err)
|
||||
} else if len(langs2) != 3 {
|
||||
t.Errorf("should only find only 3 languages, but got %+v", langs2)
|
||||
}
|
||||
|
||||
DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&langs)
|
||||
var langs3 []Language
|
||||
if err := DB.Find(&langs3, "code LIKE ?", "upsert-slice%").Error; err != nil {
|
||||
t.Errorf("no error should happen when find languages with code, but got %v", err)
|
||||
} else if len(langs3) != 3 {
|
||||
t.Errorf("should only find only 3 languages, but got %+v", langs3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindOrInitialize(t *testing.T) {
|
||||
var user1, user2, user3, user4, user5, user6 User
|
||||
if err := DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1).Error; err != nil {
|
||||
|
|
Loading…
Reference in New Issue