diff --git a/association.go b/association.go index 82a2274e..027f327e 100644 --- a/association.go +++ b/association.go @@ -2,9 +2,11 @@ package gorm import ( "fmt" + "reflect" "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/schema" + "github.com/jinzhu/gorm/utils" ) // Association Mode contains some helper methods to handle relationship things easily. @@ -46,6 +48,90 @@ func (association *Association) Replace(values ...interface{}) error { } func (association *Association) Delete(values ...interface{}) error { + if association.Error == nil { + var ( + tx = association.DB + rel = association.Relationship + reflectValue = tx.Statement.ReflectValue + conds = rel.ToQueryConditions(reflectValue) + relFields []*schema.Field + foreignKeys []string + updateAttrs = map[string]interface{}{} + ) + + for _, ref := range rel.References { + if ref.PrimaryValue == "" { + if rel.JoinTable == nil || !ref.OwnPrimaryKey { + if ref.OwnPrimaryKey { + relFields = append(relFields, ref.ForeignKey) + } else { + relFields = append(relFields, ref.PrimaryKey) + } + + foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) + updateAttrs[ref.ForeignKey.DBName] = nil + } + } + } + + relValuesMap, relQueryValues := schema.GetIdentityFieldValuesMapFromValues(values, relFields) + column, values := schema.ToQueryValues(foreignKeys, relQueryValues) + tx.Where(clause.IN{Column: column, Values: values}) + + switch association.Relationship.Type { + case schema.HasOne, schema.HasMany: + modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() + tx.Model(modelValue).Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs) + case schema.BelongsTo: + tx.Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs) + case schema.Many2Many: + modelValue := reflect.New(rel.JoinTable.ModelType).Interface() + tx.Clauses(clause.Where{Exprs: conds}).Delete(modelValue) + } + + if tx.Error == nil { + cleanUpDeletedRelations := func(data reflect.Value) { + if _, zero := rel.Field.ValueOf(data); !zero { + fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data)) + + fieldValues := make([]reflect.Value, len(relFields)) + switch fieldValue.Kind() { + case reflect.Slice, reflect.Array: + validFieldValues := reflect.Zero(rel.Field.FieldType) + for i := 0; i < fieldValue.Len(); i++ { + for idx, field := range relFields { + fieldValues[idx] = field.ReflectValueOf(fieldValue.Index(i)) + } + + if _, ok := relValuesMap[utils.ToStringKey(fieldValues...)]; !ok { + validFieldValues = reflect.Append(validFieldValues, fieldValue.Index(i)) + } + } + + rel.Field.Set(data, validFieldValues) + case reflect.Struct: + for idx, field := range relFields { + fieldValues[idx] = field.ReflectValueOf(data) + } + if _, ok := relValuesMap[utils.ToStringKey(fieldValues...)]; ok { + rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType)) + } + } + } + } + + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + cleanUpDeletedRelations(reflect.Indirect(reflectValue.Index(i))) + } + case reflect.Struct: + cleanUpDeletedRelations(reflectValue) + } + } else { + association.Error = tx.Error + } + } return association.Error } @@ -61,6 +147,10 @@ func (association *Association) Count() (count int) { ) if association.Relationship.JoinTable != nil { + for _, queryClause := range association.Relationship.JoinTable.QueryClauses { + tx.Clauses(queryClause) + } + tx.Clauses(clause.From{Joins: []clause.Join{{ Table: clause.Table{Name: association.Relationship.JoinTable.Table}, ON: clause.Where{Exprs: conds}, diff --git a/schema/utils.go b/schema/utils.go index 7a26332d..72bd149c 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -128,6 +128,21 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map return dataResults, results } +// GetIdentityFieldValuesMapFromValues get identity map from fields +func GetIdentityFieldValuesMapFromValues(values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { + resultsMap := map[string][]reflect.Value{} + results := [][]interface{}{} + + for _, v := range values { + rm, rs := GetIdentityFieldValuesMap(reflect.Indirect(reflect.ValueOf(v)), fields) + for k, v := range rm { + resultsMap[k] = append(resultsMap[k], v...) + } + results = append(results, rs...) + } + return resultsMap, results +} + // ToQueryValues to query values func ToQueryValues(foreignKeys []string, foreignValues [][]interface{}) (interface{}, []interface{}) { queryValues := make([]interface{}, len(foreignValues))