diff --git a/clause/set.go b/clause/set.go index 1c2a9ef2..6a885711 100644 --- a/clause/set.go +++ b/clause/set.go @@ -50,3 +50,11 @@ func Assignments(values map[string]interface{}) Set { } return assignments } + +func AssignmentColumns(values []string) Set { + assignments := make([]Assignment, len(values)) + for idx, value := range values { + assignments[idx] = Assignment{Column: Column{Name: value}, Value: Column{Table: "excluded", Name: value}} + } + return assignments +} diff --git a/tests/go.mod b/tests/go.mod index 3c2dfc6c..c184732c 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -4,10 +4,10 @@ go 1.14 require ( github.com/jinzhu/now v1.1.1 - gorm.io/driver/mysql v0.0.0-20200602015408-0407d0c21cf0 + gorm.io/driver/mysql v0.0.0-20200609004954-b8310c61c3f2 gorm.io/driver/postgres v0.0.0-20200602015520-15fcc29eb286 gorm.io/driver/sqlite v1.0.0 - gorm.io/driver/sqlserver v0.0.0-20200605135528-04ae0f7a15bf + gorm.io/driver/sqlserver v0.0.0-20200609005334-d550a0be1cfb gorm.io/gorm v0.0.0-00010101000000-000000000000 ) diff --git a/tests/upsert_test.go b/tests/upsert_test.go index e9ba54e3..a1307e32 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -65,6 +65,29 @@ func TestUpsertSlice(t *testing.T) { } else if len(langs3) != 3 { t.Errorf("should only find only 3 languages, but got %+v", langs3) } + + for idx, lang := range langs { + lang.Name = lang.Name + "_new" + langs[idx] = lang + } + + if err := DB.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "code"}}, + DoUpdates: clause.AssignmentColumns([]string{"name"}), + }).Create(&langs).Error; err != nil { + t.Fatalf("failed to upsert, got %v", err) + } + + for _, lang := range langs { + var results []Language + if err := DB.Find(&results, "code = ?", lang.Code).Error; err != nil { + t.Errorf("no error should happen when find languages with code, but got %v", err) + } else if len(results) != 1 { + t.Errorf("should only find only 1 languages, but got %+v", langs) + } else if results[0].Name != lang.Name { + t.Errorf("should update name on conflict, but got name %+v", results[0].Name) + } + } } func TestFindOrInitialize(t *testing.T) {