From b5725940e95cc886403b12e01cba4c941881a7be Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 6 Jul 2020 11:20:43 +0800 Subject: [PATCH] Test Select with Update Struct --- callbacks/update.go | 18 ++++++++++-------- tests/update_test.go | 26 ++++++++++++++++++++++++-- utils/tests/utils.go | 7 ++++++- 3 files changed, 40 insertions(+), 11 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index f84e933c..97a0e893 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -196,15 +196,17 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if !stmt.UpdatingColumn && stmt.Schema != nil { for _, field := range stmt.Schema.FieldsByDBName { if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { - now := stmt.DB.NowFunc() - assignValue(field, now) + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + now := stmt.DB.NowFunc() + assignValue(field, now) - if field.AutoUpdateTime == schema.UnixNanosecond { - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) - } else if field.DataType == schema.Time { - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) - } else { - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()}) + if field.AutoUpdateTime == schema.UnixNanosecond { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) + } else if field.DataType == schema.Time { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) + } else { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()}) + } } } } diff --git a/tests/update_test.go b/tests/update_test.go index d56e3f76..2ff150dd 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -8,6 +8,7 @@ import ( "time" "gorm.io/gorm" + "gorm.io/gorm/utils" . "gorm.io/gorm/utils/tests" ) @@ -267,6 +268,22 @@ func TestSelectWithUpdate(t *testing.T) { }) AssertObjEqual(t, result2, result, "Name", "Account", "Toys", "Manager", "ManagerID", "Languages") + + DB.Model(&result).Select("Name", "Age").Updates(User{Name: "update_with_select"}) + if result.Age != 0 || result.Name != "update_with_select" { + t.Fatalf("Failed to update struct with select, got %+v", result) + } + AssertObjEqual(t, result, user, "UpdatedAt") + + var result3 User + DB.First(&result3, result.ID) + AssertObjEqual(t, result, result3, "Name", "Age", "UpdatedAt") + + DB.Model(&result).Select("Name", "Age", "UpdatedAt").Updates(User{Name: "update_with_select"}) + + if utils.AssertEqual(result.UpdatedAt, user.UpdatedAt) { + t.Fatalf("Update struct should update UpdatedAt, was %+v, got %+v", result.UpdatedAt, user.UpdatedAt) + } } func TestSelectWithUpdateWithMap(t *testing.T) { @@ -290,7 +307,7 @@ func TestSelectWithUpdateWithMap(t *testing.T) { "Friends": user2.Friends, } - DB.Model(&result).Select("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Updates(updateValues) + DB.Model(&result).Omit("name", "updated_at").Updates(updateValues) var result2 User DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID) @@ -427,11 +444,16 @@ func TestSelectWithUpdateColumn(t *testing.T) { var result User DB.First(&result, user.ID) - DB.Model(&result).Select("Name").UpdateColumns(updateValues) + + time.Sleep(time.Second) + lastUpdatedAt := result.UpdatedAt + DB.Model(&result).Select("Name").Updates(updateValues) var result2 User DB.First(&result2, user.ID) + AssertEqual(t, lastUpdatedAt, result2.UpdatedAt) + if result2.Name == user.Name || result2.Age != user.Age { t.Errorf("Should only update users with name column") } diff --git a/utils/tests/utils.go b/utils/tests/utils.go index 5248e620..a44eb548 100644 --- a/utils/tests/utils.go +++ b/utils/tests/utils.go @@ -84,15 +84,20 @@ func AssertEqual(t *testing.T, got, expect interface{}) { if reflect.ValueOf(got).Kind() == reflect.Struct { if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() { + exported := false for i := 0; i < reflect.ValueOf(got).NumField(); i++ { if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) { + exported = true field := reflect.ValueOf(got).Field(i) t.Run(fieldStruct.Name, func(t *testing.T) { AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface()) }) } } - return + + if exported { + return + } } }