From eb9a704fda14b74a49d9b9d4d965706c848415dd Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 26 Feb 2021 17:11:25 +0800 Subject: [PATCH] Fix update UpdatedAt when full saving associations, close #4115 --- callbacks/associations.go | 5 +++++ callbacks/create.go | 5 +++++ tests/update_has_one_test.go | 12 +++++++++++- 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index dc84e137..2deb8ede 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -361,6 +361,7 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, } tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{ + FullSaveAssociations: db.FullSaveAssociations, SkipHooks: db.Statement.SkipHooks, DisableNestedTransaction: true, }) @@ -370,6 +371,10 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, return true }) + if tx.Statement.FullSaveAssociations { + tx = tx.InstanceSet("gorm:update_track_time", true) + } + if len(selects) > 0 { tx = tx.Select(selects) } else if len(selectColumns) > 0 && len(omits) == 0 { diff --git a/callbacks/create.go b/callbacks/create.go index 5656b861..10da731f 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -320,6 +320,11 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { field.Set(stmt.ReflectValue, curTime) values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) } + } else if field.AutoUpdateTime > 0 { + if _, ok := stmt.DB.InstanceGet("gorm:update_track_time"); ok { + field.Set(stmt.ReflectValue, curTime) + values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) + } } } diff --git a/tests/update_has_one_test.go b/tests/update_has_one_test.go index 54568546..a61629f8 100644 --- a/tests/update_has_one_test.go +++ b/tests/update_has_one_test.go @@ -2,6 +2,7 @@ package tests_test import ( "testing" + "time" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" @@ -31,7 +32,10 @@ func TestUpdateHasOne(t *testing.T) { var user3 User DB.Preload("Account").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + var lastUpdatedAt = user2.Account.UpdatedAt + time.Sleep(time.Second) if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { t.Fatalf("errors happened when update: %v", err) @@ -39,7 +43,13 @@ func TestUpdateHasOne(t *testing.T) { var user4 User DB.Preload("Account").Find(&user4, "id = ?", user.ID) - CheckUser(t, user4, user) + + if lastUpdatedAt.Format(time.RFC3339) == user4.Account.UpdatedAt.Format(time.RFC3339) { + t.Fatalf("updated at should be updated, but not, old: %v, new %v", lastUpdatedAt.Format(time.RFC3339), user3.Account.UpdatedAt.Format(time.RFC3339)) + } else { + user.Account.UpdatedAt = user4.Account.UpdatedAt + CheckUser(t, user4, user) + } t.Run("Polymorphic", func(t *testing.T) { var pet = Pet{Name: "create"}