From 59586dcd313bd067c2b94c118a9d20663ab3c8d0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 29 Aug 2020 23:02:19 +0800 Subject: [PATCH] Fix unnecessary duplicated primary condition when using Save, close #3330 --- finisher_api.go | 9 ++------- tests/update_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 2cde3c31..824f2a2e 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -32,17 +32,12 @@ func (db *DB) Save(value interface{}) (tx *DB) { tx.callbacks.Create().Execute(tx) case reflect.Struct: if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { - where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))} - for idx, pf := range tx.Statement.Schema.PrimaryFields { - if pv, isZero := pf.ValueOf(reflectValue); isZero { + for _, pf := range tx.Statement.Schema.PrimaryFields { + if _, isZero := pf.ValueOf(reflectValue); isZero { tx.callbacks.Create().Execute(tx) return - } else { - where.Exprs[idx] = clause.Eq{Column: pf.DBName, Value: pv} } } - - tx.Statement.AddClause(where) } fallthrough diff --git a/tests/update_test.go b/tests/update_test.go index e52dc652..d566c04d 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -2,6 +2,7 @@ package tests_test import ( "errors" + "regexp" "sort" "strings" "testing" @@ -586,3 +587,26 @@ func TestUpdateFromSubQuery(t *testing.T) { t.Errorf("name should be %v, but got %v", user.Company.Name, result.Name) } } + +func TestSave(t *testing.T) { + user := *GetUser("save", Config{}) + DB.Create(&user) + + if err := DB.First(&User{}, "name = ?", "save").Error; err != nil { + t.Fatalf("failed to find created user") + } + + user.Name = "save2" + DB.Save(&user) + + var result User + if err := DB.First(&result, "name = ?", "save2").Error; err != nil || result.ID != user.ID { + t.Fatalf("failed to find updated user") + } + + dryDB := DB.Session(&gorm.Session{DryRun: true}) + stmt := dryDB.Save(&user).Statement + if !regexp.MustCompile("WHERE .id. = [^ ]+$").MatchString(stmt.SQL.String()) { + t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String()) + } +}