mirror of https://github.com/go-gorm/gorm.git
Update updated_at when upserting with Create OnConflict
This commit is contained in:
parent
12bbde89e6
commit
da16a8aac6
|
@ -227,6 +227,8 @@ func AfterCreate(db *gorm.DB) {
|
|||
|
||||
// ConvertToCreateValues convert to create values
|
||||
func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||
curTime := stmt.DB.NowFunc()
|
||||
|
||||
switch value := stmt.Dest.(type) {
|
||||
case map[string]interface{}:
|
||||
values = ConvertMapToValuesForCreate(stmt, value)
|
||||
|
@ -240,7 +242,6 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
|||
var (
|
||||
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
|
||||
_, updateTrackTime = stmt.Get("gorm:update_track_time")
|
||||
curTime = stmt.DB.NowFunc()
|
||||
isZero bool
|
||||
)
|
||||
stmt.Settings.Delete("gorm:update_track_time")
|
||||
|
@ -352,13 +353,27 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
|||
if field := stmt.Schema.LookUpField(column.Name); field != nil {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 {
|
||||
columns = append(columns, column.Name)
|
||||
if field.AutoUpdateTime > 0 {
|
||||
assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime}
|
||||
switch field.AutoUpdateTime {
|
||||
case schema.UnixNanosecond:
|
||||
assignment.Value = curTime.UnixNano()
|
||||
case schema.UnixMillisecond:
|
||||
assignment.Value = curTime.UnixNano() / 1e6
|
||||
case schema.UnixSecond:
|
||||
assignment.Value = curTime.Unix()
|
||||
}
|
||||
|
||||
onConflict.DoUpdates = append(onConflict.DoUpdates, assignment)
|
||||
} else {
|
||||
columns = append(columns, column.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
onConflict.DoUpdates = clause.AssignmentColumns(columns)
|
||||
onConflict.DoUpdates = append(onConflict.DoUpdates, clause.AssignmentColumns(columns)...)
|
||||
|
||||
// use primary fields as default OnConflict columns
|
||||
if len(onConflict.Columns) == 0 {
|
||||
|
|
|
@ -21,9 +21,10 @@ type TimeType int64
|
|||
var TimeReflectType = reflect.TypeOf(time.Time{})
|
||||
|
||||
const (
|
||||
UnixSecond TimeType = 1
|
||||
UnixMillisecond TimeType = 2
|
||||
UnixNanosecond TimeType = 3
|
||||
UnixTime TimeType = 1
|
||||
UnixSecond TimeType = 2
|
||||
UnixMillisecond TimeType = 3
|
||||
UnixNanosecond TimeType = 4
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -251,7 +252,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||
}
|
||||
|
||||
if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
|
||||
if strings.ToUpper(v) == "NANO" {
|
||||
if field.DataType == Time {
|
||||
field.AutoCreateTime = UnixTime
|
||||
} else if strings.ToUpper(v) == "NANO" {
|
||||
field.AutoCreateTime = UnixNanosecond
|
||||
} else if strings.ToUpper(v) == "MILLI" {
|
||||
field.AutoCreateTime = UnixMillisecond
|
||||
|
@ -261,7 +264,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||
}
|
||||
|
||||
if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
|
||||
if strings.ToUpper(v) == "NANO" {
|
||||
if field.DataType == Time {
|
||||
field.AutoUpdateTime = UnixTime
|
||||
} else if strings.ToUpper(v) == "NANO" {
|
||||
field.AutoUpdateTime = UnixNanosecond
|
||||
} else if strings.ToUpper(v) == "MILLI" {
|
||||
field.AutoUpdateTime = UnixMillisecond
|
||||
|
|
|
@ -66,6 +66,26 @@ func TestUpsert(t *testing.T) {
|
|||
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
|
||||
}
|
||||
}
|
||||
|
||||
var user = *GetUser("upsert_on_conflict", Config{})
|
||||
user.Age = 20
|
||||
if err := DB.Create(&user).Error; err != nil {
|
||||
t.Errorf("failed to create user, got error %v", err)
|
||||
}
|
||||
|
||||
var user2 User
|
||||
DB.First(&user2, user.ID)
|
||||
user2.Age = 30
|
||||
time.Sleep(time.Second)
|
||||
if err := DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&user2).Error; err != nil {
|
||||
t.Fatalf("failed to onconflict create user, got error %v", err)
|
||||
} else {
|
||||
var user3 User
|
||||
DB.First(&user3, user.ID)
|
||||
if user3.UpdatedAt.UnixNano() == user2.UpdatedAt.UnixNano() {
|
||||
t.Fatalf("failed to update user's updated_at, old: %v, new: %v", user2.UpdatedAt, user3.UpdatedAt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertSlice(t *testing.T) {
|
||||
|
@ -152,29 +172,29 @@ func TestUpsertWithSave(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// lang := Language{Code: "upsert-save-3", Name: "Upsert-save-3"}
|
||||
// if err := DB.Save(&lang).Error; err != nil {
|
||||
// t.Errorf("Failed to create, got error %v", err)
|
||||
// }
|
||||
lang := Language{Code: "upsert-save-3", Name: "Upsert-save-3"}
|
||||
if err := DB.Save(&lang).Error; err != nil {
|
||||
t.Errorf("Failed to create, got error %v", err)
|
||||
}
|
||||
|
||||
// var result Language
|
||||
// if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil {
|
||||
// t.Errorf("Failed to query lang, got error %v", err)
|
||||
// } else {
|
||||
// AssertEqual(t, result, lang)
|
||||
// }
|
||||
var result Language
|
||||
if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil {
|
||||
t.Errorf("Failed to query lang, got error %v", err)
|
||||
} else {
|
||||
AssertEqual(t, result, lang)
|
||||
}
|
||||
|
||||
// lang.Name += "_new"
|
||||
// if err := DB.Save(&lang).Error; err != nil {
|
||||
// t.Errorf("Failed to create, got error %v", err)
|
||||
// }
|
||||
lang.Name += "_new"
|
||||
if err := DB.Save(&lang).Error; err != nil {
|
||||
t.Errorf("Failed to create, got error %v", err)
|
||||
}
|
||||
|
||||
// var result2 Language
|
||||
// if err := DB.First(&result2, "code = ?", lang.Code).Error; err != nil {
|
||||
// t.Errorf("Failed to query lang, got error %v", err)
|
||||
// } else {
|
||||
// AssertEqual(t, result2, lang)
|
||||
// }
|
||||
var result2 Language
|
||||
if err := DB.First(&result2, "code = ?", lang.Code).Error; err != nil {
|
||||
t.Errorf("Failed to query lang, got error %v", err)
|
||||
} else {
|
||||
AssertEqual(t, result2, lang)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindOrInitialize(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue