Add OnConflict UpdateAll support

This commit is contained in:
Jinzhu 2020-11-16 20:22:08 +08:00
parent a4c0c6b400
commit 62be27d3ca
4 changed files with 32 additions and 16 deletions

View File

@ -329,26 +329,29 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
}
}
if stmt.UpdatingColumn {
if stmt.Schema != nil && len(values.Columns) > 1 {
columns := make([]string, 0, len(values.Columns)-1)
for _, column := range values.Columns {
if field := stmt.Schema.LookUpField(column.Name); field != nil {
if !field.PrimaryKey && !field.HasDefaultValue && field.AutoCreateTime == 0 {
columns = append(columns, column.Name)
if c, ok := stmt.Clauses["ON CONFLICT"]; ok {
if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll {
if stmt.Schema != nil && len(values.Columns) > 1 {
columns := make([]string, 0, len(values.Columns)-1)
for _, column := range values.Columns {
if field := stmt.Schema.LookUpField(column.Name); field != nil {
if !field.PrimaryKey && !field.HasDefaultValue && field.AutoCreateTime == 0 {
columns = append(columns, column.Name)
}
}
}
}
onConflict := clause.OnConflict{
Columns: make([]clause.Column, len(stmt.Schema.PrimaryFieldDBNames)),
DoUpdates: clause.AssignmentColumns(columns),
}
onConflict := clause.OnConflict{
Columns: make([]clause.Column, len(stmt.Schema.PrimaryFieldDBNames)),
DoUpdates: clause.AssignmentColumns(columns),
}
for idx, field := range stmt.Schema.PrimaryFields {
onConflict.Columns[idx] = clause.Column{Name: field.DBName}
for idx, field := range stmt.Schema.PrimaryFields {
onConflict.Columns[idx] = clause.Column{Name: field.DBName}
}
stmt.AddClause(onConflict)
}
stmt.AddClause(onConflict)
}
}

View File

@ -5,6 +5,7 @@ type OnConflict struct {
Where Where
DoNothing bool
DoUpdates Set
UpdateAll bool
}
func (OnConflict) Name() string {

View File

@ -29,7 +29,9 @@ func (db *DB) Save(value interface{}) (tx *DB) {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
tx.Statement.UpdatingColumn = true
if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
tx = tx.Clauses(clause.OnConflict{UpdateAll: true})
}
tx.callbacks.Create().Execute(tx)
case reflect.Struct:
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {

View File

@ -41,6 +41,16 @@ func TestUpsert(t *testing.T) {
} else if langs[0].Name != "upsert-new" {
t.Errorf("should update name on conflict, but got name %+v", langs[0].Name)
}
lang = Language{Code: "upsert", Name: "Upsert-Newname"}
if err := DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&lang).Error; err != nil {
t.Fatalf("failed to upsert, got %v", err)
}
var result Language
if err := DB.Find(&result, "code = ?", lang.Code).Error; err != nil || result.Name != lang.Name {
t.Fatalf("failed to upsert, got name %v", result.Name)
}
}
func TestUpsertSlice(t *testing.T) {