From 9dd516a7e8aaccad326778abac631782f24689e1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 30 May 2020 17:34:22 +0800 Subject: [PATCH] Test UpdateColumn --- callbacks/update.go | 23 ++++++++++---------- finisher_api.go | 2 ++ statement.go | 1 + tests/update_test.go | 52 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 67 insertions(+), 11 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index 7e8c0f3e..623d64fe 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -141,9 +141,6 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { for _, k := range keys { if field := stmt.Schema.LookUpField(k); field != nil { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - if field.AutoUpdateTime > 0 { - value[k] = time.Now() - } set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) assignValue(field, value[k]) } @@ -152,11 +149,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } - for _, field := range stmt.Schema.FieldsByDBName { - if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { - now := time.Now() - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) - assignValue(field, now) + if !stmt.DisableUpdateTime { + for _, field := range stmt.Schema.FieldsByDBName { + if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { + now := time.Now() + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) + assignValue(field, now) + } } } default: @@ -167,9 +166,11 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if !field.PrimaryKey || (!stmt.ReflectValue.CanAddr() || stmt.Dest != stmt.Model) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { value, isZero := field.ValueOf(stmt.ReflectValue) - if field.AutoUpdateTime > 0 { - value = time.Now() - isZero = false + if !stmt.DisableUpdateTime { + if field.AutoUpdateTime > 0 { + value = time.Now() + isZero = false + } } if ok || !isZero { diff --git a/finisher_api.go b/finisher_api.go index c47e12af..f14bcfbe 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -207,6 +207,7 @@ func (db *DB) Updates(values interface{}) (tx *DB) { func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = map[string]interface{}{column: value} + tx.Statement.DisableUpdateTime = true tx.callbacks.Update().Execute(tx) return } @@ -214,6 +215,7 @@ func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { func (db *DB) UpdateColumns(values interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = values + tx.Statement.DisableUpdateTime = true tx.callbacks.Update().Execute(tx) return } diff --git a/statement.go b/statement.go index f81ae0e5..42df148a 100644 --- a/statement.go +++ b/statement.go @@ -32,6 +32,7 @@ type Statement struct { Schema *schema.Schema Context context.Context RaiseErrorOnNotFound bool + DisableUpdateTime bool SQL strings.Builder Vars []interface{} NamedVars []sql.NamedArg diff --git a/tests/update_test.go b/tests/update_test.go index cb61b40e..371a9f78 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -159,3 +159,55 @@ func TestUpdates(t *testing.T) { user3.Age += 100 AssertEqual(t, user4.UpdatedAt, user3.UpdatedAt) } + +func TestUpdateColumn(t *testing.T) { + var users = []*User{ + GetUser("update_column_01", Config{}), + GetUser("update_column_02", Config{}), + } + + DB.Create(&users) + lastUpdatedAt := users[1].UpdatedAt + + // update with map + DB.Model(users[1]).UpdateColumns(map[string]interface{}{"name": "update_column_02_newname", "age": 100}) + if users[1].Name != "update_column_02_newname" || users[1].Age != 100 { + t.Errorf("user 2 should be updated with update column") + } + AssertEqual(t, lastUpdatedAt.UnixNano(), users[1].UpdatedAt.UnixNano()) + + // user2 should not be updated + var user1, user2 User + DB.First(&user1, users[0].ID) + DB.First(&user2, users[1].ID) + CheckUser(t, user1, *users[0]) + CheckUser(t, user2, *users[1]) + + DB.Model(users[1]).UpdateColumn("name", "update_column_02_newnew") + AssertEqual(t, lastUpdatedAt.UnixNano(), users[1].UpdatedAt.UnixNano()) + + if users[1].Name != "update_column_02_newnew" { + t.Errorf("user 2's name should be updated, but got %v", users[1].Name) + } + + DB.Model(users[1]).UpdateColumn("age", gorm.Expr("age + 100 - 50")) + var user3 User + DB.First(&user3, users[1].ID) + + users[1].Age += 50 + CheckUser(t, user3, *users[1]) + + // update with struct + DB.Model(users[1]).UpdateColumns(User{Name: "update_column_02_newnew2", Age: 200}) + if users[1].Name != "update_column_02_newnew2" || users[1].Age != 200 { + t.Errorf("user 2 should be updated with update column") + } + AssertEqual(t, lastUpdatedAt.UnixNano(), users[1].UpdatedAt.UnixNano()) + + // user2 should not be updated + var user5, user6 User + DB.First(&user5, users[0].ID) + DB.First(&user6, users[1].ID) + CheckUser(t, user5, *users[0]) + CheckUser(t, user6, *users[1]) +}