From 05794298bd3d87dc8e98de8cde451b19093e2a4a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 6 Sep 2020 12:22:05 +0800 Subject: [PATCH] Fix Save with specified table, close #3396 --- finisher_api.go | 3 ++- tests/update_test.go | 22 ++++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/finisher_api.go b/finisher_api.go index 1d5ef5fc..6ece0f79 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -51,7 +51,8 @@ func (db *DB) Save(value interface{}) (tx *DB) { tx.callbacks.Update().Execute(tx) if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate { - if err := tx.Session(&Session{}).First(value).Error; errors.Is(err, ErrRecordNotFound) { + result := reflect.New(tx.Statement.Schema.ModelType).Interface() + if err := tx.Session(&Session{WithConditions: true}).First(result).Error; errors.Is(err, ErrRecordNotFound) { return tx.Create(value) } } diff --git a/tests/update_test.go b/tests/update_test.go index 1944ed3f..a660647c 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -629,4 +629,26 @@ func TestSaveWithPrimaryValue(t *testing.T) { var result2 Language DB.First(&result2, "code = ?", "save") AssertEqual(t, result2, lang) + + DB.Table("langs").Migrator().DropTable(&Language{}) + DB.Table("langs").AutoMigrate(&Language{}) + + if err := DB.Table("langs").Save(&lang).Error; err != nil { + t.Errorf("no error should happen when creating data, but got %v", err) + } + + var result3 Language + if err := DB.Table("langs").First(&result3, "code = ?", lang.Code).Error; err != nil || result3.Name != lang.Name { + t.Errorf("failed to find created record, got error: %v, result: %+v", err, result3) + } + + lang.Name += "name2" + if err := DB.Table("langs").Save(&lang).Error; err != nil { + t.Errorf("no error should happen when creating data, but got %v", err) + } + + var result4 Language + if err := DB.Table("langs").First(&result4, "code = ?", lang.Code).Error; err != nil || result4.Name != lang.Name { + t.Errorf("failed to find created record, got error: %v, result: %+v", err, result4) + } }