mirror of https://github.com/go-gorm/gorm.git
Test Upsert
This commit is contained in:
parent
cc07ee0444
commit
05e1af3bfb
|
@ -6,6 +6,7 @@ import (
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
"github.com/jinzhu/gorm/callbacks"
|
"github.com/jinzhu/gorm/callbacks"
|
||||||
"github.com/jinzhu/gorm/clause"
|
"github.com/jinzhu/gorm/clause"
|
||||||
|
"github.com/jinzhu/gorm/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Create(db *gorm.DB) {
|
func Create(db *gorm.DB) {
|
||||||
|
@ -85,6 +86,7 @@ func Create(db *gorm.DB) {
|
||||||
values := db.Statement.Schema.PrioritizedPrimaryField.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
values := db.Statement.Schema.PrioritizedPrimaryField.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
||||||
|
|
||||||
if rows.Next() {
|
if rows.Next() {
|
||||||
|
db.RowsAffected++
|
||||||
err = rows.Scan(values)
|
err = rows.Scan(values)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -95,6 +97,16 @@ func Create(db *gorm.DB) {
|
||||||
|
|
||||||
func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) {
|
func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) {
|
||||||
values := callbacks.ConvertToCreateValues(db.Statement)
|
values := callbacks.ConvertToCreateValues(db.Statement)
|
||||||
|
setIdentityInsert := false
|
||||||
|
|
||||||
|
if field := db.Statement.Schema.PrioritizedPrimaryField; field != nil {
|
||||||
|
if field.DataType == schema.Int || field.DataType == schema.Uint {
|
||||||
|
setIdentityInsert = true
|
||||||
|
db.Statement.WriteString("SET IDENTITY_INSERT ")
|
||||||
|
db.Statement.WriteQuoted(db.Statement.Table)
|
||||||
|
db.Statement.WriteString("ON;")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
db.Statement.WriteString("MERGE INTO ")
|
db.Statement.WriteString("MERGE INTO ")
|
||||||
db.Statement.WriteQuoted(db.Statement.Table)
|
db.Statement.WriteQuoted(db.Statement.Table)
|
||||||
|
@ -156,6 +168,12 @@ func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) {
|
||||||
db.Statement.WriteString(")")
|
db.Statement.WriteString(")")
|
||||||
outputInserted(db)
|
outputInserted(db)
|
||||||
db.Statement.WriteString(";")
|
db.Statement.WriteString(";")
|
||||||
|
|
||||||
|
if setIdentityInsert {
|
||||||
|
db.Statement.WriteString("SET IDENTITY_INSERT ")
|
||||||
|
db.Statement.WriteQuoted(db.Statement.Table)
|
||||||
|
db.Statement.WriteString("OFF;")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func outputInserted(db *gorm.DB) {
|
func outputInserted(db *gorm.DB) {
|
||||||
|
|
|
@ -4,9 +4,49 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm/clause"
|
||||||
. "github.com/jinzhu/gorm/tests"
|
. "github.com/jinzhu/gorm/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestUpsert(t *testing.T) {
|
||||||
|
lang := Language{Code: "upsert", Name: "Upsert"}
|
||||||
|
DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&lang)
|
||||||
|
|
||||||
|
lang2 := Language{Code: "upsert", Name: "Upsert"}
|
||||||
|
DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&lang2)
|
||||||
|
|
||||||
|
var langs []Language
|
||||||
|
if err := DB.Find(&langs, "code = ?", lang.Code).Error; err != nil {
|
||||||
|
t.Errorf("no error should happen when find languages with code, but got %v", err)
|
||||||
|
} else if len(langs) != 1 {
|
||||||
|
t.Errorf("should only find only 1 languages, but got %+v", langs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpsertSlice(t *testing.T) {
|
||||||
|
langs := []Language{
|
||||||
|
{Code: "upsert-slice1", Name: "Upsert-slice1"},
|
||||||
|
{Code: "upsert-slice2", Name: "Upsert-slice2"},
|
||||||
|
{Code: "upsert-slice3", Name: "Upsert-slice3"},
|
||||||
|
}
|
||||||
|
DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&langs)
|
||||||
|
|
||||||
|
var langs2 []Language
|
||||||
|
if err := DB.Find(&langs2, "code LIKE ?", "upsert-slice%").Error; err != nil {
|
||||||
|
t.Errorf("no error should happen when find languages with code, but got %v", err)
|
||||||
|
} else if len(langs2) != 3 {
|
||||||
|
t.Errorf("should only find only 3 languages, but got %+v", langs2)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&langs)
|
||||||
|
var langs3 []Language
|
||||||
|
if err := DB.Find(&langs3, "code LIKE ?", "upsert-slice%").Error; err != nil {
|
||||||
|
t.Errorf("no error should happen when find languages with code, but got %v", err)
|
||||||
|
} else if len(langs3) != 3 {
|
||||||
|
t.Errorf("should only find only 3 languages, but got %+v", langs3)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestFindOrInitialize(t *testing.T) {
|
func TestFindOrInitialize(t *testing.T) {
|
||||||
var user1, user2, user3, user4, user5, user6 User
|
var user1, user2, user3, user4, user5, user6 User
|
||||||
if err := DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1).Error; err != nil {
|
if err := DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1).Error; err != nil {
|
||||||
|
|
Loading…
Reference in New Issue