diff --git a/association.go b/association.go index 5b777465..bed89837 100644 --- a/association.go +++ b/association.go @@ -44,11 +44,7 @@ func (association *Association) Find(out interface{}, conds ...interface{}) erro tx = association.DB.Model(out) ) - if association.Relationship.JoinTable != nil { - for _, queryClause := range association.Relationship.JoinTable.QueryClauses { - tx.Clauses(queryClause) - } - + if association.Relationship.JoinTable != nil && !tx.Statement.Unscoped { tx.Clauses(clause.From{Joins: []clause.Join{{ Table: clause.Table{Name: association.Relationship.JoinTable.Table}, ON: clause.Where{Exprs: queryConds}, @@ -317,8 +313,10 @@ func (association *Association) Count() (count int64) { ) if association.Relationship.JoinTable != nil { - for _, queryClause := range association.Relationship.JoinTable.QueryClauses { - tx.Clauses(queryClause) + if !tx.Statement.Unscoped { + for _, queryClause := range association.Relationship.JoinTable.QueryClauses { + tx.Clauses(queryClause) + } } tx.Clauses(clause.From{Joins: []clause.Join{{ diff --git a/callbacks/create.go b/callbacks/create.go index 0b30775a..18f25c9a 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -46,12 +46,21 @@ func Create(config *Config) func(db *gorm.DB) { return CreateWithReturning } else { return func(db *gorm.DB) { - db.Statement.AddClauseIfNotExists(clause.Insert{ - Table: clause.Table{Name: db.Statement.Table}, - }) - db.Statement.AddClause(ConvertToCreateValues(db.Statement)) + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.CreateClauses { + db.Statement.AddClause(c) + } + } + + if db.Statement.SQL.String() == "" { + db.Statement.AddClauseIfNotExists(clause.Insert{ + Table: clause.Table{Name: db.Statement.Table}, + }) + db.Statement.AddClause(ConvertToCreateValues(db.Statement)) + + db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") + } - db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err == nil { @@ -88,12 +97,20 @@ func Create(config *Config) func(db *gorm.DB) { } func CreateWithReturning(db *gorm.DB) { - db.Statement.AddClauseIfNotExists(clause.Insert{ - Table: clause.Table{Name: db.Statement.Table}, - }) - db.Statement.AddClause(ConvertToCreateValues(db.Statement)) + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.CreateClauses { + db.Statement.AddClause(c) + } + } - db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") + if db.Statement.SQL.String() == "" { + db.Statement.AddClauseIfNotExists(clause.Insert{ + Table: clause.Table{Name: db.Statement.Table}, + }) + db.Statement.AddClause(ConvertToCreateValues(db.Statement)) + + db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") + } if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 { db.Statement.WriteString(" RETURNING ") diff --git a/callbacks/delete.go b/callbacks/delete.go index a88edcf8..1c59afbe 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -1,6 +1,7 @@ package callbacks import ( + "fmt" "reflect" "github.com/jinzhu/gorm" @@ -34,26 +35,30 @@ func BeforeDelete(db *gorm.DB) { } func Delete(db *gorm.DB) { + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.DeleteClauses { + db.Statement.AddClause(c) + fmt.Println(db.Statement.SQL.String()) + } + } + if db.Statement.SQL.String() == "" { db.Statement.AddClauseIfNotExists(clause.Delete{}) - values := []reflect.Value{db.Statement.ReflectValue} - if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { - values = append(values, reflect.ValueOf(db.Statement.Model)) - } - if db.Statement.Schema != nil { _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) column, values := schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { - db.Where(clause.IN{Column: column, Values: values}) - } else if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { + db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + } + + if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) column, values = schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { - db.Where(clause.IN{Column: column, Values: values}) + db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) } } } diff --git a/callbacks/query.go b/callbacks/query.go index 84b9ed98..ee3f5c8d 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -12,6 +12,12 @@ import ( ) func Query(db *gorm.DB) { + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.QueryClauses { + db.Statement.AddClause(c) + } + } + if db.Statement.SQL.String() == "" { clauseSelect := clause.Select{} diff --git a/callbacks/update.go b/callbacks/update.go index f9b20981..f56aa22c 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -44,13 +44,21 @@ func BeforeUpdate(db *gorm.DB) { } func Update(db *gorm.DB) { - db.Statement.AddClauseIfNotExists(clause.Update{}) - if set := ConvertToAssignments(db.Statement); len(set) != 0 { - db.Statement.AddClause(set) - } else { - return + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.UpdateClauses { + db.Statement.AddClause(c) + } + } + + if db.Statement.SQL.String() == "" { + db.Statement.AddClauseIfNotExists(clause.Update{}) + if set := ConvertToAssignments(db.Statement); len(set) != 0 { + db.Statement.AddClause(set) + } else { + return + } + db.Statement.Build("UPDATE", "SET", "WHERE") } - db.Statement.Build("UPDATE", "SET", "WHERE") result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) diff --git a/chainable_api.go b/chainable_api.go index 8336b787..afcdccd2 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -238,6 +238,7 @@ func (db *DB) Assign(attrs ...interface{}) (tx *DB) { func (db *DB) Unscoped() (tx *DB) { tx = db.getInstance() + tx.Statement.Unscoped = true return } diff --git a/model.go b/model.go index fdee99dc..dcc3cdc2 100644 --- a/model.go +++ b/model.go @@ -11,5 +11,5 @@ type Model struct { ID uint `gorm:"primarykey"` CreatedAt time.Time UpdatedAt time.Time - DeletedAt *time.Time `gorm:"index"` + DeletedAt DeletedAt `gorm:"index"` } diff --git a/schema/field.go b/schema/field.go index 8b8b190d..75ff71f6 100644 --- a/schema/field.go +++ b/schema/field.go @@ -86,6 +86,23 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } fieldValue := reflect.New(field.IndirectFieldType) + + if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok { + field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses()...) + } + + if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok { + field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses()...) + } + + if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok { + field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses()...) + } + + if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok { + field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses()...) + } + // if field is valuer, used its value or first fields as data type if valuer, isValueOf := fieldValue.Interface().(driver.Valuer); isValueOf { var overrideFieldValue bool @@ -283,6 +300,11 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { ef.TagSettings[k] = v } } + + field.Schema.CreateClauses = append(field.Schema.CreateClauses, field.EmbeddedSchema.CreateClauses...) + field.Schema.QueryClauses = append(field.Schema.QueryClauses, field.EmbeddedSchema.QueryClauses...) + field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, field.EmbeddedSchema.UpdateClauses...) + field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, field.EmbeddedSchema.DeleteClauses...) } return field diff --git a/schema/schema.go b/schema/schema.go index e66084a3..77b9832c 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -42,6 +42,22 @@ type Schema struct { cacheStore *sync.Map } +type CreateClausesInterface interface { + CreateClauses() []clause.Interface +} + +type QueryClausesInterface interface { + QueryClauses() []clause.Interface +} + +type UpdateClausesInterface interface { + UpdateClauses() []clause.Interface +} + +type DeleteClausesInterface interface { + DeleteClauses() []clause.Interface +} + func (schema Schema) String() string { if schema.ModelType.Name() == "" { return fmt.Sprintf("%v(%v)", schema.Name, schema.Table) diff --git a/soft_delete.go b/soft_delete.go new file mode 100644 index 00000000..138c9c63 --- /dev/null +++ b/soft_delete.go @@ -0,0 +1,86 @@ +package gorm + +import ( + "database/sql" + "database/sql/driver" + "reflect" + "time" + + "github.com/jinzhu/gorm/clause" + "github.com/jinzhu/gorm/schema" +) + +type DeletedAt sql.NullTime + +// Scan implements the Scanner interface. +func (n *DeletedAt) Scan(value interface{}) error { + return (*sql.NullTime)(n).Scan(value) +} + +// Value implements the driver Valuer interface. +func (n DeletedAt) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Time, nil +} + +func (DeletedAt) QueryClauses() []clause.Interface { + return []clause.Interface{ + clause.Where{Exprs: []clause.Expression{ + clause.Eq{ + Column: clause.Column{Table: clause.CurrentTable, Name: "deleted_at"}, + Value: nil, + }, + }}, + } +} + +func (DeletedAt) DeleteClauses() []clause.Interface { + return []clause.Interface{SoftDeleteClause{}} +} + +type SoftDeleteClause struct { +} + +func (SoftDeleteClause) Name() string { + return "" +} + +func (SoftDeleteClause) Build(clause.Builder) { +} + +func (SoftDeleteClause) MergeClause(*clause.Clause) { +} + +func (SoftDeleteClause) ModifyStatement(stmt *Statement) { + if stmt.SQL.String() == "" { + stmt.AddClause(clause.Set{{Column: clause.Column{Name: "deleted_at"}, Value: time.Now()}}) + + if stmt.Schema != nil { + _, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields) + column, values := schema.ToQueryValues(stmt.Schema.PrimaryFieldDBNames, queryValues) + + if len(values) > 0 { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + } + + if stmt.Dest != stmt.Model && stmt.Model != nil { + _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) + column, values = schema.ToQueryValues(stmt.Schema.PrimaryFieldDBNames, queryValues) + + if len(values) > 0 { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + } + } + } + + if _, ok := stmt.Clauses["WHERE"]; !ok { + stmt.DB.AddError(ErrMissingWhereClause) + return + } + + stmt.AddClauseIfNotExists(clause.Update{}) + stmt.Build("UPDATE", "SET", "WHERE") + } +} diff --git a/statement.go b/statement.go index b110ac1b..626ca689 100644 --- a/statement.go +++ b/statement.go @@ -19,6 +19,7 @@ type Statement struct { *DB Table string Model interface{} + Unscoped bool Dest interface{} ReflectValue reflect.Value Clauses map[string]clause.Clause diff --git a/tests/soft_delete_test.go b/tests/soft_delete_test.go new file mode 100644 index 00000000..f91052c1 --- /dev/null +++ b/tests/soft_delete_test.go @@ -0,0 +1,28 @@ +package tests_test + +import ( + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +func TestSoftDelete(t *testing.T) { + user := *GetUser("SoftDelete", Config{}) + DB.Save(&user) + if err := DB.Delete(&user).Error; err != nil { + t.Fatalf("No error should happen when soft delete user, but got %v", err) + } + + if DB.First(&User{}, "name = ?", user.Name).Error == nil { + t.Errorf("Can't find a soft deleted record") + } + + if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Errorf("Should find soft deleted record with Unscoped, but got err %s", err) + } + + DB.Unscoped().Delete(&user) + if !DB.Unscoped().First(&User{}, "name = ?", user.Name).RecordNotFound() { + t.Errorf("Can't find permanently deleted record") + } +} diff --git a/tests/upsert_test.go b/tests/upsert_test.go index bd540620..615ead95 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -12,6 +12,7 @@ func TestFindOrInitialize(t *testing.T) { if err := DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1).Error; err != nil { t.Errorf("no error should happen when FirstOrInit, but got %v", err) } + if user1.Name != "find or init" || user1.ID != 0 || user1.Age != 33 { t.Errorf("user should be initialized with search value") } @@ -55,7 +56,10 @@ func TestFindOrInitialize(t *testing.T) { func TestFindOrCreate(t *testing.T) { var user1, user2, user3, user4, user5, user6, user7, user8 User - DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user1) + if err := DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user1).Error; err != nil { + t.Errorf("no error should happen when FirstOrInit, but got %v", err) + } + if user1.Name != "find or create" || user1.ID == 0 || user1.Age != 33 { t.Errorf("user should be created with search value") }