From 0d58d5a3a7b7b73cf6b3533ef5da6b74ed602051 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 10 Jun 2020 10:48:48 +0800 Subject: [PATCH] Upsert selected columns --- callbacks/create.go | 8 ++++---- tests/upsert_test.go | 45 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index 22adca24..684d5530 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -278,11 +278,11 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { if stmt.UpdatingColumn { if stmt.Schema != nil { - columns := make([]string, 0, len(stmt.Schema.DBNames)-1) - for _, name := range stmt.Schema.DBNames { - if field := stmt.Schema.LookUpField(name); field != nil { + columns := make([]string, 0, len(values.Columns)-1) + for _, column := range values.Columns { + if field := stmt.Schema.LookUpField(column.Name); field != nil { if !field.PrimaryKey && !field.HasDefaultValue && field.AutoCreateTime == 0 { - columns = append(columns, name) + columns = append(columns, column.Name) } } } diff --git a/tests/upsert_test.go b/tests/upsert_test.go index 5826b4fc..ba7c1a9d 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -95,6 +95,7 @@ func TestUpsertWithSave(t *testing.T) { {Code: "upsert-save-1", Name: "Upsert-save-1"}, {Code: "upsert-save-2", Name: "Upsert-save-2"}, } + if err := DB.Save(&langs).Error; err != nil { t.Errorf("Failed to create, got error %v", err) } @@ -103,8 +104,52 @@ func TestUpsertWithSave(t *testing.T) { var result Language if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil { t.Errorf("Failed to query lang, got error %v", err) + } else { + AssertEqual(t, result, lang) } } + + for idx, lang := range langs { + lang.Name += "_new" + langs[idx] = lang + } + + if err := DB.Save(&langs).Error; err != nil { + t.Errorf("Failed to upsert, got error %v", err) + } + + for _, lang := range langs { + var result Language + if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil { + t.Errorf("Failed to query lang, got error %v", err) + } else { + AssertEqual(t, result, lang) + } + } + + // lang := Language{Code: "upsert-save-3", Name: "Upsert-save-3"} + // if err := DB.Save(&lang).Error; err != nil { + // t.Errorf("Failed to create, got error %v", err) + // } + + // var result Language + // if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil { + // t.Errorf("Failed to query lang, got error %v", err) + // } else { + // AssertEqual(t, result, lang) + // } + + // lang.Name += "_new" + // if err := DB.Save(&lang).Error; err != nil { + // t.Errorf("Failed to create, got error %v", err) + // } + + // var result2 Language + // if err := DB.First(&result2, "code = ?", lang.Code).Error; err != nil { + // t.Errorf("Failed to query lang, got error %v", err) + // } else { + // AssertEqual(t, result2, lang) + // } } func TestFindOrInitialize(t *testing.T) {