Add SoftDelete support

This commit is contained in:
Jinzhu 2020-05-29 07:35:45 +08:00
parent dca5244387
commit 55074213bc
13 changed files with 225 additions and 33 deletions

View File

@ -44,11 +44,7 @@ func (association *Association) Find(out interface{}, conds ...interface{}) erro
tx = association.DB.Model(out) tx = association.DB.Model(out)
) )
if association.Relationship.JoinTable != nil { if association.Relationship.JoinTable != nil && !tx.Statement.Unscoped {
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
tx.Clauses(queryClause)
}
tx.Clauses(clause.From{Joins: []clause.Join{{ tx.Clauses(clause.From{Joins: []clause.Join{{
Table: clause.Table{Name: association.Relationship.JoinTable.Table}, Table: clause.Table{Name: association.Relationship.JoinTable.Table},
ON: clause.Where{Exprs: queryConds}, ON: clause.Where{Exprs: queryConds},
@ -317,8 +313,10 @@ func (association *Association) Count() (count int64) {
) )
if association.Relationship.JoinTable != nil { if association.Relationship.JoinTable != nil {
for _, queryClause := range association.Relationship.JoinTable.QueryClauses { if !tx.Statement.Unscoped {
tx.Clauses(queryClause) for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
tx.Clauses(queryClause)
}
} }
tx.Clauses(clause.From{Joins: []clause.Join{{ tx.Clauses(clause.From{Joins: []clause.Join{{

View File

@ -46,12 +46,21 @@ func Create(config *Config) func(db *gorm.DB) {
return CreateWithReturning return CreateWithReturning
} else { } else {
return func(db *gorm.DB) { return func(db *gorm.DB) {
db.Statement.AddClauseIfNotExists(clause.Insert{ if db.Statement.Schema != nil && !db.Statement.Unscoped {
Table: clause.Table{Name: db.Statement.Table}, for _, c := range db.Statement.Schema.CreateClauses {
}) db.Statement.AddClause(c)
db.Statement.AddClause(ConvertToCreateValues(db.Statement)) }
}
if db.Statement.SQL.String() == "" {
db.Statement.AddClauseIfNotExists(clause.Insert{
Table: clause.Table{Name: db.Statement.Table},
})
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
db.Statement.Build("INSERT", "VALUES", "ON CONFLICT")
}
db.Statement.Build("INSERT", "VALUES", "ON CONFLICT")
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil { if err == nil {
@ -88,12 +97,20 @@ func Create(config *Config) func(db *gorm.DB) {
} }
func CreateWithReturning(db *gorm.DB) { func CreateWithReturning(db *gorm.DB) {
db.Statement.AddClauseIfNotExists(clause.Insert{ if db.Statement.Schema != nil && !db.Statement.Unscoped {
Table: clause.Table{Name: db.Statement.Table}, for _, c := range db.Statement.Schema.CreateClauses {
}) db.Statement.AddClause(c)
db.Statement.AddClause(ConvertToCreateValues(db.Statement)) }
}
db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") if db.Statement.SQL.String() == "" {
db.Statement.AddClauseIfNotExists(clause.Insert{
Table: clause.Table{Name: db.Statement.Table},
})
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
db.Statement.Build("INSERT", "VALUES", "ON CONFLICT")
}
if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 { if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 {
db.Statement.WriteString(" RETURNING ") db.Statement.WriteString(" RETURNING ")

View File

@ -1,6 +1,7 @@
package callbacks package callbacks
import ( import (
"fmt"
"reflect" "reflect"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
@ -34,26 +35,30 @@ func BeforeDelete(db *gorm.DB) {
} }
func Delete(db *gorm.DB) { func Delete(db *gorm.DB) {
if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.DeleteClauses {
db.Statement.AddClause(c)
fmt.Println(db.Statement.SQL.String())
}
}
if db.Statement.SQL.String() == "" { if db.Statement.SQL.String() == "" {
db.Statement.AddClauseIfNotExists(clause.Delete{}) db.Statement.AddClauseIfNotExists(clause.Delete{})
values := []reflect.Value{db.Statement.ReflectValue}
if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
values = append(values, reflect.ValueOf(db.Statement.Model))
}
if db.Statement.Schema != nil { if db.Statement.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
column, values := schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) column, values := schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 { if len(values) > 0 {
db.Where(clause.IN{Column: column, Values: values}) db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
} else if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { }
if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
column, values = schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) column, values = schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 { if len(values) > 0 {
db.Where(clause.IN{Column: column, Values: values}) db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
} }
} }
} }

View File

@ -12,6 +12,12 @@ import (
) )
func Query(db *gorm.DB) { func Query(db *gorm.DB) {
if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.QueryClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.String() == "" { if db.Statement.SQL.String() == "" {
clauseSelect := clause.Select{} clauseSelect := clause.Select{}

View File

@ -44,13 +44,21 @@ func BeforeUpdate(db *gorm.DB) {
} }
func Update(db *gorm.DB) { func Update(db *gorm.DB) {
db.Statement.AddClauseIfNotExists(clause.Update{}) if db.Statement.Schema != nil && !db.Statement.Unscoped {
if set := ConvertToAssignments(db.Statement); len(set) != 0 { for _, c := range db.Statement.Schema.UpdateClauses {
db.Statement.AddClause(set) db.Statement.AddClause(c)
} else { }
return }
if db.Statement.SQL.String() == "" {
db.Statement.AddClauseIfNotExists(clause.Update{})
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
db.Statement.AddClause(set)
} else {
return
}
db.Statement.Build("UPDATE", "SET", "WHERE")
} }
db.Statement.Build("UPDATE", "SET", "WHERE")
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)

View File

@ -238,6 +238,7 @@ func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
func (db *DB) Unscoped() (tx *DB) { func (db *DB) Unscoped() (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Unscoped = true
return return
} }

View File

@ -11,5 +11,5 @@ type Model struct {
ID uint `gorm:"primarykey"` ID uint `gorm:"primarykey"`
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
DeletedAt *time.Time `gorm:"index"` DeletedAt DeletedAt `gorm:"index"`
} }

View File

@ -86,6 +86,23 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
fieldValue := reflect.New(field.IndirectFieldType) fieldValue := reflect.New(field.IndirectFieldType)
if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok {
field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses()...)
}
if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok {
field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses()...)
}
if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok {
field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses()...)
}
if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok {
field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses()...)
}
// if field is valuer, used its value or first fields as data type // if field is valuer, used its value or first fields as data type
if valuer, isValueOf := fieldValue.Interface().(driver.Valuer); isValueOf { if valuer, isValueOf := fieldValue.Interface().(driver.Valuer); isValueOf {
var overrideFieldValue bool var overrideFieldValue bool
@ -283,6 +300,11 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
ef.TagSettings[k] = v ef.TagSettings[k] = v
} }
} }
field.Schema.CreateClauses = append(field.Schema.CreateClauses, field.EmbeddedSchema.CreateClauses...)
field.Schema.QueryClauses = append(field.Schema.QueryClauses, field.EmbeddedSchema.QueryClauses...)
field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, field.EmbeddedSchema.UpdateClauses...)
field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, field.EmbeddedSchema.DeleteClauses...)
} }
return field return field

View File

@ -42,6 +42,22 @@ type Schema struct {
cacheStore *sync.Map cacheStore *sync.Map
} }
type CreateClausesInterface interface {
CreateClauses() []clause.Interface
}
type QueryClausesInterface interface {
QueryClauses() []clause.Interface
}
type UpdateClausesInterface interface {
UpdateClauses() []clause.Interface
}
type DeleteClausesInterface interface {
DeleteClauses() []clause.Interface
}
func (schema Schema) String() string { func (schema Schema) String() string {
if schema.ModelType.Name() == "" { if schema.ModelType.Name() == "" {
return fmt.Sprintf("%v(%v)", schema.Name, schema.Table) return fmt.Sprintf("%v(%v)", schema.Name, schema.Table)

86
soft_delete.go Normal file
View File

@ -0,0 +1,86 @@
package gorm
import (
"database/sql"
"database/sql/driver"
"reflect"
"time"
"github.com/jinzhu/gorm/clause"
"github.com/jinzhu/gorm/schema"
)
type DeletedAt sql.NullTime
// Scan implements the Scanner interface.
func (n *DeletedAt) Scan(value interface{}) error {
return (*sql.NullTime)(n).Scan(value)
}
// Value implements the driver Valuer interface.
func (n DeletedAt) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Time, nil
}
func (DeletedAt) QueryClauses() []clause.Interface {
return []clause.Interface{
clause.Where{Exprs: []clause.Expression{
clause.Eq{
Column: clause.Column{Table: clause.CurrentTable, Name: "deleted_at"},
Value: nil,
},
}},
}
}
func (DeletedAt) DeleteClauses() []clause.Interface {
return []clause.Interface{SoftDeleteClause{}}
}
type SoftDeleteClause struct {
}
func (SoftDeleteClause) Name() string {
return ""
}
func (SoftDeleteClause) Build(clause.Builder) {
}
func (SoftDeleteClause) MergeClause(*clause.Clause) {
}
func (SoftDeleteClause) ModifyStatement(stmt *Statement) {
if stmt.SQL.String() == "" {
stmt.AddClause(clause.Set{{Column: clause.Column{Name: "deleted_at"}, Value: time.Now()}})
if stmt.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields)
column, values := schema.ToQueryValues(stmt.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
}
if stmt.Dest != stmt.Model && stmt.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields)
column, values = schema.ToQueryValues(stmt.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
}
}
}
if _, ok := stmt.Clauses["WHERE"]; !ok {
stmt.DB.AddError(ErrMissingWhereClause)
return
}
stmt.AddClauseIfNotExists(clause.Update{})
stmt.Build("UPDATE", "SET", "WHERE")
}
}

View File

@ -19,6 +19,7 @@ type Statement struct {
*DB *DB
Table string Table string
Model interface{} Model interface{}
Unscoped bool
Dest interface{} Dest interface{}
ReflectValue reflect.Value ReflectValue reflect.Value
Clauses map[string]clause.Clause Clauses map[string]clause.Clause

28
tests/soft_delete_test.go Normal file
View File

@ -0,0 +1,28 @@
package tests_test
import (
"testing"
. "github.com/jinzhu/gorm/tests"
)
func TestSoftDelete(t *testing.T) {
user := *GetUser("SoftDelete", Config{})
DB.Save(&user)
if err := DB.Delete(&user).Error; err != nil {
t.Fatalf("No error should happen when soft delete user, but got %v", err)
}
if DB.First(&User{}, "name = ?", user.Name).Error == nil {
t.Errorf("Can't find a soft deleted record")
}
if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; err != nil {
t.Errorf("Should find soft deleted record with Unscoped, but got err %s", err)
}
DB.Unscoped().Delete(&user)
if !DB.Unscoped().First(&User{}, "name = ?", user.Name).RecordNotFound() {
t.Errorf("Can't find permanently deleted record")
}
}

View File

@ -12,6 +12,7 @@ func TestFindOrInitialize(t *testing.T) {
if err := DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1).Error; err != nil { if err := DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1).Error; err != nil {
t.Errorf("no error should happen when FirstOrInit, but got %v", err) t.Errorf("no error should happen when FirstOrInit, but got %v", err)
} }
if user1.Name != "find or init" || user1.ID != 0 || user1.Age != 33 { if user1.Name != "find or init" || user1.ID != 0 || user1.Age != 33 {
t.Errorf("user should be initialized with search value") t.Errorf("user should be initialized with search value")
} }
@ -55,7 +56,10 @@ func TestFindOrInitialize(t *testing.T) {
func TestFindOrCreate(t *testing.T) { func TestFindOrCreate(t *testing.T) {
var user1, user2, user3, user4, user5, user6, user7, user8 User var user1, user2, user3, user4, user5, user6, user7, user8 User
DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user1) if err := DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user1).Error; err != nil {
t.Errorf("no error should happen when FirstOrInit, but got %v", err)
}
if user1.Name != "find or create" || user1.ID == 0 || user1.Age != 33 { if user1.Name != "find or create" || user1.ID == 0 || user1.Age != 33 {
t.Errorf("user should be created with search value") t.Errorf("user should be created with search value")
} }