diff --git a/finisher_api.go b/finisher_api.go index e6fe4666..d647cf64 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -106,7 +106,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { updateTx := tx.callbacks.Update().Execute(tx.Session(&Session{Initialized: true})) if updateTx.Error == nil && updateTx.RowsAffected == 0 && !updateTx.DryRun && !selectedUpdate { - return tx.Create(value) + return tx.Clauses(clause.OnConflict{UpdateAll: true}).Create(value) } return updateTx diff --git a/tests/update_test.go b/tests/update_test.go index b2da11c6..36ffa6a0 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -610,6 +610,25 @@ func TestUpdateFromSubQuery(t *testing.T) { } } +func TestIdempotentSave(t *testing.T) { + create := Company{ + Name: "company_idempotent", + } + DB.Create(&create) + + var company Company + if err := DB.Find(&company, "id = ?", create.ID).Error; err != nil { + t.Fatalf("failed to find created company, got err: %v", err) + } + + if err := DB.Save(&company).Error; err != nil || company.ID != create.ID { + t.Errorf("failed to save company, got err: %v", err) + } + if err := DB.Save(&company).Error; err != nil || company.ID != create.ID { + t.Errorf("failed to save company, got err: %v", err) + } +} + func TestSave(t *testing.T) { user := *GetUser("save", Config{}) DB.Create(&user)