Update updated_at when upserting with Create OnConflict

This commit is contained in:
Jinzhu 2021-09-17 15:29:49 +08:00
parent 12bbde89e6
commit da16a8aac6
3 changed files with 68 additions and 28 deletions

View File

@ -227,6 +227,8 @@ func AfterCreate(db *gorm.DB) {
// ConvertToCreateValues convert to create values // ConvertToCreateValues convert to create values
func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
curTime := stmt.DB.NowFunc()
switch value := stmt.Dest.(type) { switch value := stmt.Dest.(type) {
case map[string]interface{}: case map[string]interface{}:
values = ConvertMapToValuesForCreate(stmt, value) values = ConvertMapToValuesForCreate(stmt, value)
@ -240,7 +242,6 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
var ( var (
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
_, updateTrackTime = stmt.Get("gorm:update_track_time") _, updateTrackTime = stmt.Get("gorm:update_track_time")
curTime = stmt.DB.NowFunc()
isZero bool isZero bool
) )
stmt.Settings.Delete("gorm:update_track_time") stmt.Settings.Delete("gorm:update_track_time")
@ -352,13 +353,27 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
if field := stmt.Schema.LookUpField(column.Name); field != nil { if field := stmt.Schema.LookUpField(column.Name); field != nil {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 { if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 {
if field.AutoUpdateTime > 0 {
assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime}
switch field.AutoUpdateTime {
case schema.UnixNanosecond:
assignment.Value = curTime.UnixNano()
case schema.UnixMillisecond:
assignment.Value = curTime.UnixNano() / 1e6
case schema.UnixSecond:
assignment.Value = curTime.Unix()
}
onConflict.DoUpdates = append(onConflict.DoUpdates, assignment)
} else {
columns = append(columns, column.Name) columns = append(columns, column.Name)
} }
} }
} }
} }
}
onConflict.DoUpdates = clause.AssignmentColumns(columns) onConflict.DoUpdates = append(onConflict.DoUpdates, clause.AssignmentColumns(columns)...)
// use primary fields as default OnConflict columns // use primary fields as default OnConflict columns
if len(onConflict.Columns) == 0 { if len(onConflict.Columns) == 0 {

View File

@ -21,9 +21,10 @@ type TimeType int64
var TimeReflectType = reflect.TypeOf(time.Time{}) var TimeReflectType = reflect.TypeOf(time.Time{})
const ( const (
UnixSecond TimeType = 1 UnixTime TimeType = 1
UnixMillisecond TimeType = 2 UnixSecond TimeType = 2
UnixNanosecond TimeType = 3 UnixMillisecond TimeType = 3
UnixNanosecond TimeType = 4
) )
const ( const (
@ -251,7 +252,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
if strings.ToUpper(v) == "NANO" { if field.DataType == Time {
field.AutoCreateTime = UnixTime
} else if strings.ToUpper(v) == "NANO" {
field.AutoCreateTime = UnixNanosecond field.AutoCreateTime = UnixNanosecond
} else if strings.ToUpper(v) == "MILLI" { } else if strings.ToUpper(v) == "MILLI" {
field.AutoCreateTime = UnixMillisecond field.AutoCreateTime = UnixMillisecond
@ -261,7 +264,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
if strings.ToUpper(v) == "NANO" { if field.DataType == Time {
field.AutoUpdateTime = UnixTime
} else if strings.ToUpper(v) == "NANO" {
field.AutoUpdateTime = UnixNanosecond field.AutoUpdateTime = UnixNanosecond
} else if strings.ToUpper(v) == "MILLI" { } else if strings.ToUpper(v) == "MILLI" {
field.AutoUpdateTime = UnixMillisecond field.AutoUpdateTime = UnixMillisecond

View File

@ -66,6 +66,26 @@ func TestUpsert(t *testing.T) {
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
} }
} }
var user = *GetUser("upsert_on_conflict", Config{})
user.Age = 20
if err := DB.Create(&user).Error; err != nil {
t.Errorf("failed to create user, got error %v", err)
}
var user2 User
DB.First(&user2, user.ID)
user2.Age = 30
time.Sleep(time.Second)
if err := DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&user2).Error; err != nil {
t.Fatalf("failed to onconflict create user, got error %v", err)
} else {
var user3 User
DB.First(&user3, user.ID)
if user3.UpdatedAt.UnixNano() == user2.UpdatedAt.UnixNano() {
t.Fatalf("failed to update user's updated_at, old: %v, new: %v", user2.UpdatedAt, user3.UpdatedAt)
}
}
} }
func TestUpsertSlice(t *testing.T) { func TestUpsertSlice(t *testing.T) {
@ -152,29 +172,29 @@ func TestUpsertWithSave(t *testing.T) {
} }
} }
// lang := Language{Code: "upsert-save-3", Name: "Upsert-save-3"} lang := Language{Code: "upsert-save-3", Name: "Upsert-save-3"}
// if err := DB.Save(&lang).Error; err != nil { if err := DB.Save(&lang).Error; err != nil {
// t.Errorf("Failed to create, got error %v", err) t.Errorf("Failed to create, got error %v", err)
// } }
// var result Language var result Language
// if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil { if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil {
// t.Errorf("Failed to query lang, got error %v", err) t.Errorf("Failed to query lang, got error %v", err)
// } else { } else {
// AssertEqual(t, result, lang) AssertEqual(t, result, lang)
// } }
// lang.Name += "_new" lang.Name += "_new"
// if err := DB.Save(&lang).Error; err != nil { if err := DB.Save(&lang).Error; err != nil {
// t.Errorf("Failed to create, got error %v", err) t.Errorf("Failed to create, got error %v", err)
// } }
// var result2 Language var result2 Language
// if err := DB.First(&result2, "code = ?", lang.Code).Error; err != nil { if err := DB.First(&result2, "code = ?", lang.Code).Error; err != nil {
// t.Errorf("Failed to query lang, got error %v", err) t.Errorf("Failed to query lang, got error %v", err)
// } else { } else {
// AssertEqual(t, result2, lang) AssertEqual(t, result2, lang)
// } }
} }
func TestFindOrInitialize(t *testing.T) { func TestFindOrInitialize(t *testing.T) {