Test UpdateColumn

This commit is contained in:
Jinzhu 2020-05-30 17:34:22 +08:00
parent 028c9d6e17
commit 9dd516a7e8
4 changed files with 67 additions and 11 deletions

View File

@ -141,9 +141,6 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
for _, k := range keys { for _, k := range keys {
if field := stmt.Schema.LookUpField(k); field != nil { if field := stmt.Schema.LookUpField(k); field != nil {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if field.AutoUpdateTime > 0 {
value[k] = time.Now()
}
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]})
assignValue(field, value[k]) assignValue(field, value[k])
} }
@ -152,11 +149,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
} }
} }
for _, field := range stmt.Schema.FieldsByDBName { if !stmt.DisableUpdateTime {
if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { for _, field := range stmt.Schema.FieldsByDBName {
now := time.Now() if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) now := time.Now()
assignValue(field, now) set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
assignValue(field, now)
}
} }
} }
default: default:
@ -167,9 +166,11 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
if !field.PrimaryKey || (!stmt.ReflectValue.CanAddr() || stmt.Dest != stmt.Model) { if !field.PrimaryKey || (!stmt.ReflectValue.CanAddr() || stmt.Dest != stmt.Model) {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
value, isZero := field.ValueOf(stmt.ReflectValue) value, isZero := field.ValueOf(stmt.ReflectValue)
if field.AutoUpdateTime > 0 { if !stmt.DisableUpdateTime {
value = time.Now() if field.AutoUpdateTime > 0 {
isZero = false value = time.Now()
isZero = false
}
} }
if ok || !isZero { if ok || !isZero {

View File

@ -207,6 +207,7 @@ func (db *DB) Updates(values interface{}) (tx *DB) {
func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = map[string]interface{}{column: value} tx.Statement.Dest = map[string]interface{}{column: value}
tx.Statement.DisableUpdateTime = true
tx.callbacks.Update().Execute(tx) tx.callbacks.Update().Execute(tx)
return return
} }
@ -214,6 +215,7 @@ func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) {
func (db *DB) UpdateColumns(values interface{}) (tx *DB) { func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = values tx.Statement.Dest = values
tx.Statement.DisableUpdateTime = true
tx.callbacks.Update().Execute(tx) tx.callbacks.Update().Execute(tx)
return return
} }

View File

@ -32,6 +32,7 @@ type Statement struct {
Schema *schema.Schema Schema *schema.Schema
Context context.Context Context context.Context
RaiseErrorOnNotFound bool RaiseErrorOnNotFound bool
DisableUpdateTime bool
SQL strings.Builder SQL strings.Builder
Vars []interface{} Vars []interface{}
NamedVars []sql.NamedArg NamedVars []sql.NamedArg

View File

@ -159,3 +159,55 @@ func TestUpdates(t *testing.T) {
user3.Age += 100 user3.Age += 100
AssertEqual(t, user4.UpdatedAt, user3.UpdatedAt) AssertEqual(t, user4.UpdatedAt, user3.UpdatedAt)
} }
func TestUpdateColumn(t *testing.T) {
var users = []*User{
GetUser("update_column_01", Config{}),
GetUser("update_column_02", Config{}),
}
DB.Create(&users)
lastUpdatedAt := users[1].UpdatedAt
// update with map
DB.Model(users[1]).UpdateColumns(map[string]interface{}{"name": "update_column_02_newname", "age": 100})
if users[1].Name != "update_column_02_newname" || users[1].Age != 100 {
t.Errorf("user 2 should be updated with update column")
}
AssertEqual(t, lastUpdatedAt.UnixNano(), users[1].UpdatedAt.UnixNano())
// user2 should not be updated
var user1, user2 User
DB.First(&user1, users[0].ID)
DB.First(&user2, users[1].ID)
CheckUser(t, user1, *users[0])
CheckUser(t, user2, *users[1])
DB.Model(users[1]).UpdateColumn("name", "update_column_02_newnew")
AssertEqual(t, lastUpdatedAt.UnixNano(), users[1].UpdatedAt.UnixNano())
if users[1].Name != "update_column_02_newnew" {
t.Errorf("user 2's name should be updated, but got %v", users[1].Name)
}
DB.Model(users[1]).UpdateColumn("age", gorm.Expr("age + 100 - 50"))
var user3 User
DB.First(&user3, users[1].ID)
users[1].Age += 50
CheckUser(t, user3, *users[1])
// update with struct
DB.Model(users[1]).UpdateColumns(User{Name: "update_column_02_newnew2", Age: 200})
if users[1].Name != "update_column_02_newnew2" || users[1].Age != 200 {
t.Errorf("user 2 should be updated with update column")
}
AssertEqual(t, lastUpdatedAt.UnixNano(), users[1].UpdatedAt.UnixNano())
// user2 should not be updated
var user5, user6 User
DB.First(&user5, users[0].ID)
DB.First(&user6, users[1].ID)
CheckUser(t, user5, *users[0])
CheckUser(t, user6, *users[1])
}