mirror of https://github.com/go-gorm/gorm.git
Add SoftDelete support
This commit is contained in:
parent
dca5244387
commit
55074213bc
|
@ -44,11 +44,7 @@ func (association *Association) Find(out interface{}, conds ...interface{}) erro
|
||||||
tx = association.DB.Model(out)
|
tx = association.DB.Model(out)
|
||||||
)
|
)
|
||||||
|
|
||||||
if association.Relationship.JoinTable != nil {
|
if association.Relationship.JoinTable != nil && !tx.Statement.Unscoped {
|
||||||
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
|
|
||||||
tx.Clauses(queryClause)
|
|
||||||
}
|
|
||||||
|
|
||||||
tx.Clauses(clause.From{Joins: []clause.Join{{
|
tx.Clauses(clause.From{Joins: []clause.Join{{
|
||||||
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
|
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
|
||||||
ON: clause.Where{Exprs: queryConds},
|
ON: clause.Where{Exprs: queryConds},
|
||||||
|
@ -317,8 +313,10 @@ func (association *Association) Count() (count int64) {
|
||||||
)
|
)
|
||||||
|
|
||||||
if association.Relationship.JoinTable != nil {
|
if association.Relationship.JoinTable != nil {
|
||||||
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
|
if !tx.Statement.Unscoped {
|
||||||
tx.Clauses(queryClause)
|
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
|
||||||
|
tx.Clauses(queryClause)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tx.Clauses(clause.From{Joins: []clause.Join{{
|
tx.Clauses(clause.From{Joins: []clause.Join{{
|
||||||
|
|
|
@ -46,12 +46,21 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||||
return CreateWithReturning
|
return CreateWithReturning
|
||||||
} else {
|
} else {
|
||||||
return func(db *gorm.DB) {
|
return func(db *gorm.DB) {
|
||||||
db.Statement.AddClauseIfNotExists(clause.Insert{
|
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||||
Table: clause.Table{Name: db.Statement.Table},
|
for _, c := range db.Statement.Schema.CreateClauses {
|
||||||
})
|
db.Statement.AddClause(c)
|
||||||
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.Statement.SQL.String() == "" {
|
||||||
|
db.Statement.AddClauseIfNotExists(clause.Insert{
|
||||||
|
Table: clause.Table{Name: db.Statement.Table},
|
||||||
|
})
|
||||||
|
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
|
||||||
|
|
||||||
|
db.Statement.Build("INSERT", "VALUES", "ON CONFLICT")
|
||||||
|
}
|
||||||
|
|
||||||
db.Statement.Build("INSERT", "VALUES", "ON CONFLICT")
|
|
||||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -88,12 +97,20 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateWithReturning(db *gorm.DB) {
|
func CreateWithReturning(db *gorm.DB) {
|
||||||
db.Statement.AddClauseIfNotExists(clause.Insert{
|
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||||
Table: clause.Table{Name: db.Statement.Table},
|
for _, c := range db.Statement.Schema.CreateClauses {
|
||||||
})
|
db.Statement.AddClause(c)
|
||||||
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
|
}
|
||||||
|
}
|
||||||
|
|
||||||
db.Statement.Build("INSERT", "VALUES", "ON CONFLICT")
|
if db.Statement.SQL.String() == "" {
|
||||||
|
db.Statement.AddClauseIfNotExists(clause.Insert{
|
||||||
|
Table: clause.Table{Name: db.Statement.Table},
|
||||||
|
})
|
||||||
|
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
|
||||||
|
|
||||||
|
db.Statement.Build("INSERT", "VALUES", "ON CONFLICT")
|
||||||
|
}
|
||||||
|
|
||||||
if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 {
|
if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 {
|
||||||
db.Statement.WriteString(" RETURNING ")
|
db.Statement.WriteString(" RETURNING ")
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package callbacks
|
package callbacks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
|
@ -34,26 +35,30 @@ func BeforeDelete(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Delete(db *gorm.DB) {
|
func Delete(db *gorm.DB) {
|
||||||
|
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||||
|
for _, c := range db.Statement.Schema.DeleteClauses {
|
||||||
|
db.Statement.AddClause(c)
|
||||||
|
fmt.Println(db.Statement.SQL.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if db.Statement.SQL.String() == "" {
|
if db.Statement.SQL.String() == "" {
|
||||||
db.Statement.AddClauseIfNotExists(clause.Delete{})
|
db.Statement.AddClauseIfNotExists(clause.Delete{})
|
||||||
|
|
||||||
values := []reflect.Value{db.Statement.ReflectValue}
|
|
||||||
if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
|
|
||||||
values = append(values, reflect.ValueOf(db.Statement.Model))
|
|
||||||
}
|
|
||||||
|
|
||||||
if db.Statement.Schema != nil {
|
if db.Statement.Schema != nil {
|
||||||
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
|
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
|
||||||
column, values := schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues)
|
column, values := schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues)
|
||||||
|
|
||||||
if len(values) > 0 {
|
if len(values) > 0 {
|
||||||
db.Where(clause.IN{Column: column, Values: values})
|
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||||
} else if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
|
}
|
||||||
|
|
||||||
|
if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
|
||||||
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
|
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
|
||||||
column, values = schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues)
|
column, values = schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues)
|
||||||
|
|
||||||
if len(values) > 0 {
|
if len(values) > 0 {
|
||||||
db.Where(clause.IN{Column: column, Values: values})
|
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,12 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func Query(db *gorm.DB) {
|
func Query(db *gorm.DB) {
|
||||||
|
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||||
|
for _, c := range db.Statement.Schema.QueryClauses {
|
||||||
|
db.Statement.AddClause(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if db.Statement.SQL.String() == "" {
|
if db.Statement.SQL.String() == "" {
|
||||||
clauseSelect := clause.Select{}
|
clauseSelect := clause.Select{}
|
||||||
|
|
||||||
|
|
|
@ -44,13 +44,21 @@ func BeforeUpdate(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Update(db *gorm.DB) {
|
func Update(db *gorm.DB) {
|
||||||
db.Statement.AddClauseIfNotExists(clause.Update{})
|
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||||
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
|
for _, c := range db.Statement.Schema.UpdateClauses {
|
||||||
db.Statement.AddClause(set)
|
db.Statement.AddClause(c)
|
||||||
} else {
|
}
|
||||||
return
|
}
|
||||||
|
|
||||||
|
if db.Statement.SQL.String() == "" {
|
||||||
|
db.Statement.AddClauseIfNotExists(clause.Update{})
|
||||||
|
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
|
||||||
|
db.Statement.AddClause(set)
|
||||||
|
} else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
db.Statement.Build("UPDATE", "SET", "WHERE")
|
||||||
}
|
}
|
||||||
db.Statement.Build("UPDATE", "SET", "WHERE")
|
|
||||||
|
|
||||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
|
|
||||||
|
|
|
@ -238,6 +238,7 @@ func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
|
||||||
|
|
||||||
func (db *DB) Unscoped() (tx *DB) {
|
func (db *DB) Unscoped() (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
|
tx.Statement.Unscoped = true
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
2
model.go
2
model.go
|
@ -11,5 +11,5 @@ type Model struct {
|
||||||
ID uint `gorm:"primarykey"`
|
ID uint `gorm:"primarykey"`
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
DeletedAt *time.Time `gorm:"index"`
|
DeletedAt DeletedAt `gorm:"index"`
|
||||||
}
|
}
|
||||||
|
|
|
@ -86,6 +86,23 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||||
}
|
}
|
||||||
|
|
||||||
fieldValue := reflect.New(field.IndirectFieldType)
|
fieldValue := reflect.New(field.IndirectFieldType)
|
||||||
|
|
||||||
|
if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok {
|
||||||
|
field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses()...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok {
|
||||||
|
field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses()...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok {
|
||||||
|
field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses()...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok {
|
||||||
|
field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses()...)
|
||||||
|
}
|
||||||
|
|
||||||
// if field is valuer, used its value or first fields as data type
|
// if field is valuer, used its value or first fields as data type
|
||||||
if valuer, isValueOf := fieldValue.Interface().(driver.Valuer); isValueOf {
|
if valuer, isValueOf := fieldValue.Interface().(driver.Valuer); isValueOf {
|
||||||
var overrideFieldValue bool
|
var overrideFieldValue bool
|
||||||
|
@ -283,6 +300,11 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||||
ef.TagSettings[k] = v
|
ef.TagSettings[k] = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
field.Schema.CreateClauses = append(field.Schema.CreateClauses, field.EmbeddedSchema.CreateClauses...)
|
||||||
|
field.Schema.QueryClauses = append(field.Schema.QueryClauses, field.EmbeddedSchema.QueryClauses...)
|
||||||
|
field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, field.EmbeddedSchema.UpdateClauses...)
|
||||||
|
field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, field.EmbeddedSchema.DeleteClauses...)
|
||||||
}
|
}
|
||||||
|
|
||||||
return field
|
return field
|
||||||
|
|
|
@ -42,6 +42,22 @@ type Schema struct {
|
||||||
cacheStore *sync.Map
|
cacheStore *sync.Map
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CreateClausesInterface interface {
|
||||||
|
CreateClauses() []clause.Interface
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryClausesInterface interface {
|
||||||
|
QueryClauses() []clause.Interface
|
||||||
|
}
|
||||||
|
|
||||||
|
type UpdateClausesInterface interface {
|
||||||
|
UpdateClauses() []clause.Interface
|
||||||
|
}
|
||||||
|
|
||||||
|
type DeleteClausesInterface interface {
|
||||||
|
DeleteClauses() []clause.Interface
|
||||||
|
}
|
||||||
|
|
||||||
func (schema Schema) String() string {
|
func (schema Schema) String() string {
|
||||||
if schema.ModelType.Name() == "" {
|
if schema.ModelType.Name() == "" {
|
||||||
return fmt.Sprintf("%v(%v)", schema.Name, schema.Table)
|
return fmt.Sprintf("%v(%v)", schema.Name, schema.Table)
|
||||||
|
|
|
@ -0,0 +1,86 @@
|
||||||
|
package gorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
"reflect"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm/clause"
|
||||||
|
"github.com/jinzhu/gorm/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DeletedAt sql.NullTime
|
||||||
|
|
||||||
|
// Scan implements the Scanner interface.
|
||||||
|
func (n *DeletedAt) Scan(value interface{}) error {
|
||||||
|
return (*sql.NullTime)(n).Scan(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value implements the driver Valuer interface.
|
||||||
|
func (n DeletedAt) Value() (driver.Value, error) {
|
||||||
|
if !n.Valid {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return n.Time, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (DeletedAt) QueryClauses() []clause.Interface {
|
||||||
|
return []clause.Interface{
|
||||||
|
clause.Where{Exprs: []clause.Expression{
|
||||||
|
clause.Eq{
|
||||||
|
Column: clause.Column{Table: clause.CurrentTable, Name: "deleted_at"},
|
||||||
|
Value: nil,
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (DeletedAt) DeleteClauses() []clause.Interface {
|
||||||
|
return []clause.Interface{SoftDeleteClause{}}
|
||||||
|
}
|
||||||
|
|
||||||
|
type SoftDeleteClause struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (SoftDeleteClause) Name() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (SoftDeleteClause) Build(clause.Builder) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (SoftDeleteClause) MergeClause(*clause.Clause) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (SoftDeleteClause) ModifyStatement(stmt *Statement) {
|
||||||
|
if stmt.SQL.String() == "" {
|
||||||
|
stmt.AddClause(clause.Set{{Column: clause.Column{Name: "deleted_at"}, Value: time.Now()}})
|
||||||
|
|
||||||
|
if stmt.Schema != nil {
|
||||||
|
_, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields)
|
||||||
|
column, values := schema.ToQueryValues(stmt.Schema.PrimaryFieldDBNames, queryValues)
|
||||||
|
|
||||||
|
if len(values) > 0 {
|
||||||
|
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||||
|
}
|
||||||
|
|
||||||
|
if stmt.Dest != stmt.Model && stmt.Model != nil {
|
||||||
|
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields)
|
||||||
|
column, values = schema.ToQueryValues(stmt.Schema.PrimaryFieldDBNames, queryValues)
|
||||||
|
|
||||||
|
if len(values) > 0 {
|
||||||
|
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := stmt.Clauses["WHERE"]; !ok {
|
||||||
|
stmt.DB.AddError(ErrMissingWhereClause)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt.AddClauseIfNotExists(clause.Update{})
|
||||||
|
stmt.Build("UPDATE", "SET", "WHERE")
|
||||||
|
}
|
||||||
|
}
|
|
@ -19,6 +19,7 @@ type Statement struct {
|
||||||
*DB
|
*DB
|
||||||
Table string
|
Table string
|
||||||
Model interface{}
|
Model interface{}
|
||||||
|
Unscoped bool
|
||||||
Dest interface{}
|
Dest interface{}
|
||||||
ReflectValue reflect.Value
|
ReflectValue reflect.Value
|
||||||
Clauses map[string]clause.Clause
|
Clauses map[string]clause.Clause
|
||||||
|
|
|
@ -0,0 +1,28 @@
|
||||||
|
package tests_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/jinzhu/gorm/tests"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSoftDelete(t *testing.T) {
|
||||||
|
user := *GetUser("SoftDelete", Config{})
|
||||||
|
DB.Save(&user)
|
||||||
|
if err := DB.Delete(&user).Error; err != nil {
|
||||||
|
t.Fatalf("No error should happen when soft delete user, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.First(&User{}, "name = ?", user.Name).Error == nil {
|
||||||
|
t.Errorf("Can't find a soft deleted record")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; err != nil {
|
||||||
|
t.Errorf("Should find soft deleted record with Unscoped, but got err %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Unscoped().Delete(&user)
|
||||||
|
if !DB.Unscoped().First(&User{}, "name = ?", user.Name).RecordNotFound() {
|
||||||
|
t.Errorf("Can't find permanently deleted record")
|
||||||
|
}
|
||||||
|
}
|
|
@ -12,6 +12,7 @@ func TestFindOrInitialize(t *testing.T) {
|
||||||
if err := DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1).Error; err != nil {
|
if err := DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1).Error; err != nil {
|
||||||
t.Errorf("no error should happen when FirstOrInit, but got %v", err)
|
t.Errorf("no error should happen when FirstOrInit, but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if user1.Name != "find or init" || user1.ID != 0 || user1.Age != 33 {
|
if user1.Name != "find or init" || user1.ID != 0 || user1.Age != 33 {
|
||||||
t.Errorf("user should be initialized with search value")
|
t.Errorf("user should be initialized with search value")
|
||||||
}
|
}
|
||||||
|
@ -55,7 +56,10 @@ func TestFindOrInitialize(t *testing.T) {
|
||||||
|
|
||||||
func TestFindOrCreate(t *testing.T) {
|
func TestFindOrCreate(t *testing.T) {
|
||||||
var user1, user2, user3, user4, user5, user6, user7, user8 User
|
var user1, user2, user3, user4, user5, user6, user7, user8 User
|
||||||
DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user1)
|
if err := DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user1).Error; err != nil {
|
||||||
|
t.Errorf("no error should happen when FirstOrInit, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if user1.Name != "find or create" || user1.ID == 0 || user1.Age != 33 {
|
if user1.Name != "find or create" || user1.ID == 0 || user1.Age != 33 {
|
||||||
t.Errorf("user should be created with search value")
|
t.Errorf("user should be created with search value")
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue