diff --git a/callbacks/create.go b/callbacks/create.go index 091f1774..22adca24 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -185,19 +185,19 @@ func AfterCreate(db *gorm.DB) { } // ConvertToCreateValues convert to create values -func ConvertToCreateValues(stmt *gorm.Statement) clause.Values { +func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { switch value := stmt.Dest.(type) { case map[string]interface{}: - return ConvertMapToValuesForCreate(stmt, value) + values = ConvertMapToValuesForCreate(stmt, value) case []map[string]interface{}: - return ConvertSliceOfMapToValuesForCreate(stmt, value) + values = ConvertSliceOfMapToValuesForCreate(stmt, value) default: var ( - values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))} selectColumns, restricted = SelectAndOmitColumns(stmt, true, false) curTime = stmt.DB.NowFunc() isZero bool ) + values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))} for _, db := range stmt.Schema.DBNames { if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil { @@ -274,7 +274,30 @@ func ConvertToCreateValues(stmt *gorm.Statement) clause.Values { } } } - - return values } + + if stmt.UpdatingColumn { + if stmt.Schema != nil { + columns := make([]string, 0, len(stmt.Schema.DBNames)-1) + for _, name := range stmt.Schema.DBNames { + if field := stmt.Schema.LookUpField(name); field != nil { + if !field.PrimaryKey && !field.HasDefaultValue && field.AutoCreateTime == 0 { + columns = append(columns, name) + } + } + } + + onConflict := clause.OnConflict{ + Columns: make([]clause.Column, len(stmt.Schema.PrimaryFieldDBNames)), + DoUpdates: clause.AssignmentColumns(columns), + } + + for idx, field := range stmt.Schema.PrimaryFields { + onConflict.Columns[idx] = clause.Column{Name: field.DBName} + } + stmt.AddClause(onConflict) + } + } + + return values } diff --git a/finisher_api.go b/finisher_api.go index d45c6c4f..afefd9fd 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -22,13 +22,14 @@ func (db *DB) Save(value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = value - if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { - where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))} - reflectValue := reflect.Indirect(reflect.ValueOf(value)) - switch reflectValue.Kind() { - case reflect.Slice, reflect.Array: - tx.AddError(ErrPtrStructSupported) - case reflect.Struct: + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + tx.Statement.UpdatingColumn = true + tx.callbacks.Create().Execute(tx) + case reflect.Struct: + if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { + where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))} for idx, pf := range tx.Statement.Schema.PrimaryFields { if pv, isZero := pf.ValueOf(reflectValue); isZero { tx.callbacks.Create().Execute(tx) @@ -40,12 +41,16 @@ func (db *DB) Save(value interface{}) (tx *DB) { tx.Statement.AddClause(where) } + + fallthrough + default: + if len(tx.Statement.Selects) == 0 { + tx.Statement.Selects = append(tx.Statement.Selects, "*") + } + + tx.callbacks.Update().Execute(tx) } - if len(tx.Statement.Selects) == 0 { - tx.Statement.Selects = append(tx.Statement.Selects, "*") - } - tx.callbacks.Update().Execute(tx) return } diff --git a/tests/upsert_test.go b/tests/upsert_test.go index a1307e32..5826b4fc 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -90,6 +90,23 @@ func TestUpsertSlice(t *testing.T) { } } +func TestUpsertWithSave(t *testing.T) { + langs := []Language{ + {Code: "upsert-save-1", Name: "Upsert-save-1"}, + {Code: "upsert-save-2", Name: "Upsert-save-2"}, + } + if err := DB.Save(&langs).Error; err != nil { + t.Errorf("Failed to create, got error %v", err) + } + + for _, lang := range langs { + var result Language + if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil { + t.Errorf("Failed to query lang, got error %v", err) + } + } +} + func TestFindOrInitialize(t *testing.T) { var user1, user2, user3, user4, user5, user6 User if err := DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1).Error; err != nil {