diff --git a/clause/on_conflict.go b/clause/on_conflict.go index 6001399f..47f69fc9 100644 --- a/clause/on_conflict.go +++ b/clause/on_conflict.go @@ -14,8 +14,14 @@ func (OnConflict) Name() string { // Build build onConflict clause func (onConflict OnConflict) Build(builder Builder) { if len(onConflict.Columns) > 0 { - builder.WriteQuoted(onConflict.Columns) // FIXME columns - builder.WriteByte(' ') + builder.WriteByte('(') + for idx, column := range onConflict.Columns { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(column) + } + builder.WriteString(`) `) } if len(onConflict.Where.Exprs) > 0 { diff --git a/clause/set.go b/clause/set.go index 4adfe68f..7704ca36 100644 --- a/clause/set.go +++ b/clause/set.go @@ -47,7 +47,7 @@ func Assignments(values map[string]interface{}) Set { for _, key := range keys { assignments = append(assignments, Assignment{ - Column: Column{Table: CurrentTable, Name: key}, + Column: Column{Name: key}, Value: values[key], }) } diff --git a/tests/upsert_test.go b/tests/upsert_test.go index f132a7da..311b7136 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -10,10 +10,14 @@ import ( func TestUpsert(t *testing.T) { lang := Language{Code: "upsert", Name: "Upsert"} - DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&lang) + if err := DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&lang).Error; err != nil { + t.Fatalf("failed to upsert, got %v", err) + } lang2 := Language{Code: "upsert", Name: "Upsert"} - DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&lang2) + if err := DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&lang2).Error; err != nil { + t.Fatalf("failed to upsert, got %v", err) + } var langs []Language if err := DB.Find(&langs, "code = ?", lang.Code).Error; err != nil { @@ -21,6 +25,22 @@ func TestUpsert(t *testing.T) { } else if len(langs) != 1 { t.Errorf("should only find only 1 languages, but got %+v", langs) } + + lang3 := Language{Code: "upsert", Name: "Upsert"} + if err := DB.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "code"}}, + DoUpdates: clause.Assignments(map[string]interface{}{"name": "upsert-new"}), + }).Create(&lang3).Error; err != nil { + t.Fatalf("failed to upsert, got %v", err) + } + + if err := DB.Find(&langs, "code = ?", lang.Code).Error; err != nil { + t.Errorf("no error should happen when find languages with code, but got %v", err) + } else if len(langs) != 1 { + t.Errorf("should only find only 1 languages, but got %+v", langs) + } else if langs[0].Name != "upsert-new" { + t.Errorf("should update name on conflict, but got name %+v", langs[0].Name) + } } func TestUpsertSlice(t *testing.T) {