diff --git a/soft_delete.go b/soft_delete.go index cb56035d..bdbf03c2 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -104,7 +104,9 @@ func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) { func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { if stmt.SQL.String() == "" { - stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: stmt.DB.NowFunc()}}) + curTime := stmt.DB.NowFunc() + stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}}) + stmt.SetColumn(sd.Field.DBName, curTime, true) if stmt.Schema != nil { _, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields) diff --git a/statement.go b/statement.go index a0da0c6d..355a5f0b 100644 --- a/statement.go +++ b/statement.go @@ -447,9 +447,15 @@ func (stmt *Statement) clone() *Statement { // Helpers // SetColumn set column's value -func (stmt *Statement) SetColumn(name string, value interface{}) { +// stmt.SetColumn("Name", "jinzhu") // Hooks Method +// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method +func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) { if v, ok := stmt.Dest.(map[string]interface{}); ok { v[name] = value + } else if v, ok := stmt.Dest.([]map[string]interface{}); ok { + for _, m := range v { + m[name] = value + } } else if stmt.Schema != nil { if field := stmt.Schema.LookUpField(name); field != nil { destValue := reflect.ValueOf(stmt.Dest) @@ -475,7 +481,13 @@ func (stmt *Statement) SetColumn(name string, value interface{}) { switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value) + if len(fromCallbacks) > 0 { + for i := 0; i < stmt.ReflectValue.Len(); i++ { + field.Set(stmt.ReflectValue.Index(i), value) + } + } else { + field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value) + } case reflect.Struct: field.Set(stmt.ReflectValue, value) } diff --git a/tests/delete_test.go b/tests/delete_test.go index 954c7097..37e29fbe 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -45,7 +45,7 @@ func TestDelete(t *testing.T) { } } - if err := DB.Delete(users[0]).Error; err != nil { + if err := DB.Delete(&users[0]).Error; err != nil { t.Errorf("errors happened when delete: %v", err) } diff --git a/tests/soft_delete_test.go b/tests/soft_delete_test.go index f1ea8a51..0dfe24d5 100644 --- a/tests/soft_delete_test.go +++ b/tests/soft_delete_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "database/sql" "encoding/json" "errors" "regexp" @@ -29,6 +30,10 @@ func TestSoftDelete(t *testing.T) { t.Fatalf("No error should happen when soft delete user, but got %v", err) } + if sql.NullTime(user.DeletedAt).Time.IsZero() { + t.Fatalf("user's deleted at is zero") + } + sql := DB.Session(&gorm.Session{DryRun: true}).Delete(&user).Statement.SQL.String() if !regexp.MustCompile(`UPDATE .users. SET .deleted_at.=.* WHERE .users.\..id. = .* AND .users.\..deleted_at. IS NULL`).MatchString(sql) { t.Fatalf("invalid sql generated, got %v", sql)