Update updated_at when upserting with Create OnConflict

This commit is contained in:
Jinzhu 2021-09-17 15:29:49 +08:00
parent 12bbde89e6
commit da16a8aac6
3 changed files with 68 additions and 28 deletions

View File

@ -227,6 +227,8 @@ func AfterCreate(db *gorm.DB) {
// ConvertToCreateValues convert to create values
func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
curTime := stmt.DB.NowFunc()
switch value := stmt.Dest.(type) {
case map[string]interface{}:
values = ConvertMapToValuesForCreate(stmt, value)
@ -240,7 +242,6 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
var (
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
_, updateTrackTime = stmt.Get("gorm:update_track_time")
curTime = stmt.DB.NowFunc()
isZero bool
)
stmt.Settings.Delete("gorm:update_track_time")
@ -352,13 +353,27 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
if field := stmt.Schema.LookUpField(column.Name); field != nil {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 {
if field.AutoUpdateTime > 0 {
assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime}
switch field.AutoUpdateTime {
case schema.UnixNanosecond:
assignment.Value = curTime.UnixNano()
case schema.UnixMillisecond:
assignment.Value = curTime.UnixNano() / 1e6
case schema.UnixSecond:
assignment.Value = curTime.Unix()
}
onConflict.DoUpdates = append(onConflict.DoUpdates, assignment)
} else {
columns = append(columns, column.Name)
}
}
}
}
}
onConflict.DoUpdates = clause.AssignmentColumns(columns)
onConflict.DoUpdates = append(onConflict.DoUpdates, clause.AssignmentColumns(columns)...)
// use primary fields as default OnConflict columns
if len(onConflict.Columns) == 0 {

View File

@ -21,9 +21,10 @@ type TimeType int64
var TimeReflectType = reflect.TypeOf(time.Time{})
const (
UnixSecond TimeType = 1
UnixMillisecond TimeType = 2
UnixNanosecond TimeType = 3
UnixTime TimeType = 1
UnixSecond TimeType = 2
UnixMillisecond TimeType = 3
UnixNanosecond TimeType = 4
)
const (
@ -251,7 +252,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
}
if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
if strings.ToUpper(v) == "NANO" {
if field.DataType == Time {
field.AutoCreateTime = UnixTime
} else if strings.ToUpper(v) == "NANO" {
field.AutoCreateTime = UnixNanosecond
} else if strings.ToUpper(v) == "MILLI" {
field.AutoCreateTime = UnixMillisecond
@ -261,7 +264,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
}
if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
if strings.ToUpper(v) == "NANO" {
if field.DataType == Time {
field.AutoUpdateTime = UnixTime
} else if strings.ToUpper(v) == "NANO" {
field.AutoUpdateTime = UnixNanosecond
} else if strings.ToUpper(v) == "MILLI" {
field.AutoUpdateTime = UnixMillisecond

View File

@ -66,6 +66,26 @@ func TestUpsert(t *testing.T) {
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
}
}
var user = *GetUser("upsert_on_conflict", Config{})
user.Age = 20
if err := DB.Create(&user).Error; err != nil {
t.Errorf("failed to create user, got error %v", err)
}
var user2 User
DB.First(&user2, user.ID)
user2.Age = 30
time.Sleep(time.Second)
if err := DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&user2).Error; err != nil {
t.Fatalf("failed to onconflict create user, got error %v", err)
} else {
var user3 User
DB.First(&user3, user.ID)
if user3.UpdatedAt.UnixNano() == user2.UpdatedAt.UnixNano() {
t.Fatalf("failed to update user's updated_at, old: %v, new: %v", user2.UpdatedAt, user3.UpdatedAt)
}
}
}
func TestUpsertSlice(t *testing.T) {
@ -152,29 +172,29 @@ func TestUpsertWithSave(t *testing.T) {
}
}
// lang := Language{Code: "upsert-save-3", Name: "Upsert-save-3"}
// if err := DB.Save(&lang).Error; err != nil {
// t.Errorf("Failed to create, got error %v", err)
// }
lang := Language{Code: "upsert-save-3", Name: "Upsert-save-3"}
if err := DB.Save(&lang).Error; err != nil {
t.Errorf("Failed to create, got error %v", err)
}
// var result Language
// if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil {
// t.Errorf("Failed to query lang, got error %v", err)
// } else {
// AssertEqual(t, result, lang)
// }
var result Language
if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil {
t.Errorf("Failed to query lang, got error %v", err)
} else {
AssertEqual(t, result, lang)
}
// lang.Name += "_new"
// if err := DB.Save(&lang).Error; err != nil {
// t.Errorf("Failed to create, got error %v", err)
// }
lang.Name += "_new"
if err := DB.Save(&lang).Error; err != nil {
t.Errorf("Failed to create, got error %v", err)
}
// var result2 Language
// if err := DB.First(&result2, "code = ?", lang.Code).Error; err != nil {
// t.Errorf("Failed to query lang, got error %v", err)
// } else {
// AssertEqual(t, result2, lang)
// }
var result2 Language
if err := DB.First(&result2, "code = ?", lang.Code).Error; err != nil {
t.Errorf("Failed to query lang, got error %v", err)
} else {
AssertEqual(t, result2, lang)
}
}
func TestFindOrInitialize(t *testing.T) {