mirror of https://github.com/go-gorm/gorm.git
Update updated_at when upserting with Create OnConflict
This commit is contained in:
parent
12bbde89e6
commit
da16a8aac6
|
@ -227,6 +227,8 @@ func AfterCreate(db *gorm.DB) {
|
||||||
|
|
||||||
// ConvertToCreateValues convert to create values
|
// ConvertToCreateValues convert to create values
|
||||||
func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||||
|
curTime := stmt.DB.NowFunc()
|
||||||
|
|
||||||
switch value := stmt.Dest.(type) {
|
switch value := stmt.Dest.(type) {
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
values = ConvertMapToValuesForCreate(stmt, value)
|
values = ConvertMapToValuesForCreate(stmt, value)
|
||||||
|
@ -240,7 +242,6 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||||
var (
|
var (
|
||||||
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
|
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
|
||||||
_, updateTrackTime = stmt.Get("gorm:update_track_time")
|
_, updateTrackTime = stmt.Get("gorm:update_track_time")
|
||||||
curTime = stmt.DB.NowFunc()
|
|
||||||
isZero bool
|
isZero bool
|
||||||
)
|
)
|
||||||
stmt.Settings.Delete("gorm:update_track_time")
|
stmt.Settings.Delete("gorm:update_track_time")
|
||||||
|
@ -352,13 +353,27 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||||
if field := stmt.Schema.LookUpField(column.Name); field != nil {
|
if field := stmt.Schema.LookUpField(column.Name); field != nil {
|
||||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||||
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 {
|
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 {
|
||||||
|
if field.AutoUpdateTime > 0 {
|
||||||
|
assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime}
|
||||||
|
switch field.AutoUpdateTime {
|
||||||
|
case schema.UnixNanosecond:
|
||||||
|
assignment.Value = curTime.UnixNano()
|
||||||
|
case schema.UnixMillisecond:
|
||||||
|
assignment.Value = curTime.UnixNano() / 1e6
|
||||||
|
case schema.UnixSecond:
|
||||||
|
assignment.Value = curTime.Unix()
|
||||||
|
}
|
||||||
|
|
||||||
|
onConflict.DoUpdates = append(onConflict.DoUpdates, assignment)
|
||||||
|
} else {
|
||||||
columns = append(columns, column.Name)
|
columns = append(columns, column.Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
onConflict.DoUpdates = clause.AssignmentColumns(columns)
|
onConflict.DoUpdates = append(onConflict.DoUpdates, clause.AssignmentColumns(columns)...)
|
||||||
|
|
||||||
// use primary fields as default OnConflict columns
|
// use primary fields as default OnConflict columns
|
||||||
if len(onConflict.Columns) == 0 {
|
if len(onConflict.Columns) == 0 {
|
||||||
|
|
|
@ -21,9 +21,10 @@ type TimeType int64
|
||||||
var TimeReflectType = reflect.TypeOf(time.Time{})
|
var TimeReflectType = reflect.TypeOf(time.Time{})
|
||||||
|
|
||||||
const (
|
const (
|
||||||
UnixSecond TimeType = 1
|
UnixTime TimeType = 1
|
||||||
UnixMillisecond TimeType = 2
|
UnixSecond TimeType = 2
|
||||||
UnixNanosecond TimeType = 3
|
UnixMillisecond TimeType = 3
|
||||||
|
UnixNanosecond TimeType = 4
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -251,7 +252,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||||
}
|
}
|
||||||
|
|
||||||
if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
|
if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
|
||||||
if strings.ToUpper(v) == "NANO" {
|
if field.DataType == Time {
|
||||||
|
field.AutoCreateTime = UnixTime
|
||||||
|
} else if strings.ToUpper(v) == "NANO" {
|
||||||
field.AutoCreateTime = UnixNanosecond
|
field.AutoCreateTime = UnixNanosecond
|
||||||
} else if strings.ToUpper(v) == "MILLI" {
|
} else if strings.ToUpper(v) == "MILLI" {
|
||||||
field.AutoCreateTime = UnixMillisecond
|
field.AutoCreateTime = UnixMillisecond
|
||||||
|
@ -261,7 +264,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||||
}
|
}
|
||||||
|
|
||||||
if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
|
if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
|
||||||
if strings.ToUpper(v) == "NANO" {
|
if field.DataType == Time {
|
||||||
|
field.AutoUpdateTime = UnixTime
|
||||||
|
} else if strings.ToUpper(v) == "NANO" {
|
||||||
field.AutoUpdateTime = UnixNanosecond
|
field.AutoUpdateTime = UnixNanosecond
|
||||||
} else if strings.ToUpper(v) == "MILLI" {
|
} else if strings.ToUpper(v) == "MILLI" {
|
||||||
field.AutoUpdateTime = UnixMillisecond
|
field.AutoUpdateTime = UnixMillisecond
|
||||||
|
|
|
@ -66,6 +66,26 @@ func TestUpsert(t *testing.T) {
|
||||||
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
|
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var user = *GetUser("upsert_on_conflict", Config{})
|
||||||
|
user.Age = 20
|
||||||
|
if err := DB.Create(&user).Error; err != nil {
|
||||||
|
t.Errorf("failed to create user, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var user2 User
|
||||||
|
DB.First(&user2, user.ID)
|
||||||
|
user2.Age = 30
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
if err := DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&user2).Error; err != nil {
|
||||||
|
t.Fatalf("failed to onconflict create user, got error %v", err)
|
||||||
|
} else {
|
||||||
|
var user3 User
|
||||||
|
DB.First(&user3, user.ID)
|
||||||
|
if user3.UpdatedAt.UnixNano() == user2.UpdatedAt.UnixNano() {
|
||||||
|
t.Fatalf("failed to update user's updated_at, old: %v, new: %v", user2.UpdatedAt, user3.UpdatedAt)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpsertSlice(t *testing.T) {
|
func TestUpsertSlice(t *testing.T) {
|
||||||
|
@ -152,29 +172,29 @@ func TestUpsertWithSave(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// lang := Language{Code: "upsert-save-3", Name: "Upsert-save-3"}
|
lang := Language{Code: "upsert-save-3", Name: "Upsert-save-3"}
|
||||||
// if err := DB.Save(&lang).Error; err != nil {
|
if err := DB.Save(&lang).Error; err != nil {
|
||||||
// t.Errorf("Failed to create, got error %v", err)
|
t.Errorf("Failed to create, got error %v", err)
|
||||||
// }
|
}
|
||||||
|
|
||||||
// var result Language
|
var result Language
|
||||||
// if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil {
|
if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil {
|
||||||
// t.Errorf("Failed to query lang, got error %v", err)
|
t.Errorf("Failed to query lang, got error %v", err)
|
||||||
// } else {
|
} else {
|
||||||
// AssertEqual(t, result, lang)
|
AssertEqual(t, result, lang)
|
||||||
// }
|
}
|
||||||
|
|
||||||
// lang.Name += "_new"
|
lang.Name += "_new"
|
||||||
// if err := DB.Save(&lang).Error; err != nil {
|
if err := DB.Save(&lang).Error; err != nil {
|
||||||
// t.Errorf("Failed to create, got error %v", err)
|
t.Errorf("Failed to create, got error %v", err)
|
||||||
// }
|
}
|
||||||
|
|
||||||
// var result2 Language
|
var result2 Language
|
||||||
// if err := DB.First(&result2, "code = ?", lang.Code).Error; err != nil {
|
if err := DB.First(&result2, "code = ?", lang.Code).Error; err != nil {
|
||||||
// t.Errorf("Failed to query lang, got error %v", err)
|
t.Errorf("Failed to query lang, got error %v", err)
|
||||||
// } else {
|
} else {
|
||||||
// AssertEqual(t, result2, lang)
|
AssertEqual(t, result2, lang)
|
||||||
// }
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFindOrInitialize(t *testing.T) {
|
func TestFindOrInitialize(t *testing.T) {
|
||||||
|
|
Loading…
Reference in New Issue