From da16a8aac6c3620532f5ad6d1fedf20fca2c1cf6 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 17 Sep 2021 15:29:49 +0800 Subject: [PATCH] Update updated_at when upserting with Create OnConflict --- callbacks/create.go | 21 +++++++++++++--- schema/field.go | 15 +++++++---- tests/upsert_test.go | 60 +++++++++++++++++++++++++++++--------------- 3 files changed, 68 insertions(+), 28 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index 8a3c593c..a2944319 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -227,6 +227,8 @@ func AfterCreate(db *gorm.DB) { // ConvertToCreateValues convert to create values func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { + curTime := stmt.DB.NowFunc() + switch value := stmt.Dest.(type) { case map[string]interface{}: values = ConvertMapToValuesForCreate(stmt, value) @@ -240,7 +242,6 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { var ( selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) _, updateTrackTime = stmt.Get("gorm:update_track_time") - curTime = stmt.DB.NowFunc() isZero bool ) stmt.Settings.Delete("gorm:update_track_time") @@ -352,13 +353,27 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { if field := stmt.Schema.LookUpField(column.Name); field != nil { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 { - columns = append(columns, column.Name) + if field.AutoUpdateTime > 0 { + assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime} + switch field.AutoUpdateTime { + case schema.UnixNanosecond: + assignment.Value = curTime.UnixNano() + case schema.UnixMillisecond: + assignment.Value = curTime.UnixNano() / 1e6 + case schema.UnixSecond: + assignment.Value = curTime.Unix() + } + + onConflict.DoUpdates = append(onConflict.DoUpdates, assignment) + } else { + columns = append(columns, column.Name) + } } } } } - onConflict.DoUpdates = clause.AssignmentColumns(columns) + onConflict.DoUpdates = append(onConflict.DoUpdates, clause.AssignmentColumns(columns)...) // use primary fields as default OnConflict columns if len(onConflict.Columns) == 0 { diff --git a/schema/field.go b/schema/field.go index ce0e3c13..f3189c7a 100644 --- a/schema/field.go +++ b/schema/field.go @@ -21,9 +21,10 @@ type TimeType int64 var TimeReflectType = reflect.TypeOf(time.Time{}) const ( - UnixSecond TimeType = 1 - UnixMillisecond TimeType = 2 - UnixNanosecond TimeType = 3 + UnixTime TimeType = 1 + UnixSecond TimeType = 2 + UnixMillisecond TimeType = 3 + UnixNanosecond TimeType = 4 ) const ( @@ -251,7 +252,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { - if strings.ToUpper(v) == "NANO" { + if field.DataType == Time { + field.AutoCreateTime = UnixTime + } else if strings.ToUpper(v) == "NANO" { field.AutoCreateTime = UnixNanosecond } else if strings.ToUpper(v) == "MILLI" { field.AutoCreateTime = UnixMillisecond @@ -261,7 +264,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { - if strings.ToUpper(v) == "NANO" { + if field.DataType == Time { + field.AutoUpdateTime = UnixTime + } else if strings.ToUpper(v) == "NANO" { field.AutoUpdateTime = UnixNanosecond } else if strings.ToUpper(v) == "MILLI" { field.AutoUpdateTime = UnixMillisecond diff --git a/tests/upsert_test.go b/tests/upsert_test.go index 867110d8..0e247caa 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -66,6 +66,26 @@ func TestUpsert(t *testing.T) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } } + + var user = *GetUser("upsert_on_conflict", Config{}) + user.Age = 20 + if err := DB.Create(&user).Error; err != nil { + t.Errorf("failed to create user, got error %v", err) + } + + var user2 User + DB.First(&user2, user.ID) + user2.Age = 30 + time.Sleep(time.Second) + if err := DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&user2).Error; err != nil { + t.Fatalf("failed to onconflict create user, got error %v", err) + } else { + var user3 User + DB.First(&user3, user.ID) + if user3.UpdatedAt.UnixNano() == user2.UpdatedAt.UnixNano() { + t.Fatalf("failed to update user's updated_at, old: %v, new: %v", user2.UpdatedAt, user3.UpdatedAt) + } + } } func TestUpsertSlice(t *testing.T) { @@ -152,29 +172,29 @@ func TestUpsertWithSave(t *testing.T) { } } - // lang := Language{Code: "upsert-save-3", Name: "Upsert-save-3"} - // if err := DB.Save(&lang).Error; err != nil { - // t.Errorf("Failed to create, got error %v", err) - // } + lang := Language{Code: "upsert-save-3", Name: "Upsert-save-3"} + if err := DB.Save(&lang).Error; err != nil { + t.Errorf("Failed to create, got error %v", err) + } - // var result Language - // if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil { - // t.Errorf("Failed to query lang, got error %v", err) - // } else { - // AssertEqual(t, result, lang) - // } + var result Language + if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil { + t.Errorf("Failed to query lang, got error %v", err) + } else { + AssertEqual(t, result, lang) + } - // lang.Name += "_new" - // if err := DB.Save(&lang).Error; err != nil { - // t.Errorf("Failed to create, got error %v", err) - // } + lang.Name += "_new" + if err := DB.Save(&lang).Error; err != nil { + t.Errorf("Failed to create, got error %v", err) + } - // var result2 Language - // if err := DB.First(&result2, "code = ?", lang.Code).Error; err != nil { - // t.Errorf("Failed to query lang, got error %v", err) - // } else { - // AssertEqual(t, result2, lang) - // } + var result2 Language + if err := DB.First(&result2, "code = ?", lang.Code).Error; err != nil { + t.Errorf("Failed to query lang, got error %v", err) + } else { + AssertEqual(t, result2, lang) + } } func TestFindOrInitialize(t *testing.T) {