diff --git a/association.go b/association.go index 017e739c..19fef6a0 100644 --- a/association.go +++ b/association.go @@ -81,60 +81,6 @@ func (association *Association) Append(values ...interface{}) *Association { return association } -func (association *Association) Delete(values ...interface{}) *Association { - scope := association.Scope - query := scope.NewDB() - relationship := association.Field.Relationship - - // many to many - if relationship.Kind == "many_to_many" { - // current value's foreign keys - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { - query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - // deleting value's foreign keys - primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...) - sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)) - query = query.Where(sql, toQueryValues(primaryKeys)...) - - if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil { - leftValues := reflect.Zero(association.Field.Field.Type()) - for i := 0; i < association.Field.Field.Len(); i++ { - reflectValue := association.Field.Field.Index(i) - primaryKey := association.getPrimaryKeys(relationship.ForeignFieldNames, reflectValue.Interface())[0] - var included = false - for _, pk := range primaryKeys { - if equalAsString(primaryKey, pk) { - included = true - } - } - if !included { - leftValues = reflect.Append(leftValues, reflectValue) - } - } - association.Field.Set(leftValues) - } - } else { - association.Field.Set(reflect.Zero(association.Field.Field.Type())) - - if relationship.Kind == "belongs_to" { - var foreignKeyMap = map[string]interface{}{} - for _, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - } - primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...) - sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)) - query.Model(scope.Value).Where(sql, toQueryValues(primaryKeys)...).Update(foreignKeyMap) - } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { - // query.Model(association.Field).UpdateColumn(foreignKeyMap) - } - } - return association -} - func (association *Association) Replace(values ...interface{}) *Association { relationship := association.Field.Relationship scope := association.Scope @@ -185,11 +131,66 @@ func (association *Association) Replace(values ...interface{}) *Association { return association } +func (association *Association) Delete(values ...interface{}) *Association { + scope := association.Scope + query := scope.NewDB() + relationship := association.Field.Relationship + + // many to many + if relationship.Kind == "many_to_many" { + // current value's foreign keys + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + } + } + + // deleting value's foreign keys + primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...) + sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)) + query = query.Where(sql, toQueryValues(primaryKeys)...) + + if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil { + leftValues := reflect.Zero(association.Field.Field.Type()) + for i := 0; i < association.Field.Field.Len(); i++ { + reflectValue := association.Field.Field.Index(i) + primaryKey := association.getPrimaryKeys(relationship.ForeignFieldNames, reflectValue.Interface())[0] + var included = false + for _, pk := range primaryKeys { + if equalAsString(primaryKey, pk) { + included = true + } + } + if !included { + leftValues = reflect.Append(leftValues, reflectValue) + } + } + association.Field.Set(leftValues) + } + } else { + association.Field.Set(reflect.Zero(association.Field.Field.Type())) + + if relationship.Kind == "belongs_to" { + var foreignKeyMap = map[string]interface{}{} + for _, foreignKey := range relationship.ForeignDBNames { + foreignKeyMap[foreignKey] = nil + } + primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...) + sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)) + association.setErr(query.Model(scope.Value).Where(sql, toQueryValues(primaryKeys)...).UpdateColumn(foreignKeyMap).Error) + } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { + // query.Model(association.Field).UpdateColumn(foreignKeyMap) + } + } + return association +} + func (association *Association) Clear() *Association { relationship := association.Field.Relationship scope := association.Scope + query := scope.NewDB() + if relationship.Kind == "many_to_many" { - query := scope.NewDB() for idx, foreignKey := range relationship.ForeignDBNames { if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) @@ -202,7 +203,16 @@ func (association *Association) Clear() *Association { association.setErr(err) } } else { - association.setErr(errors.New("clear only support many to many")) + association.Field.Set(reflect.Zero(association.Field.Field.Type())) + if relationship.Kind == "belongs_to" { + var foreignKeyMap = map[string]interface{}{} + for _, foreignKey := range relationship.ForeignDBNames { + foreignKeyMap[foreignKey] = nil + } + query.Model(scope.Value).Update(foreignKeyMap) + } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { + // query.Model(association.Field).UpdateColumn(foreignKeyMap) + } } return association } diff --git a/association_test.go b/association_test.go index 66844c34..e538bbe4 100644 --- a/association_test.go +++ b/association_test.go @@ -66,6 +66,12 @@ func TestBelongsTo(t *testing.T) { } // Delete + DB.Model(&post).Association("Category").Delete(&category2) + DB.First(&post, post.Id) + if DB.Model(&post).Related(&Category{}).RecordNotFound() { + t.Errorf("Should not delete any category when Delete a unrelated Category") + } + DB.Model(&post).Association("Category").Delete(&category3) var category41 Category @@ -75,6 +81,19 @@ func TestBelongsTo(t *testing.T) { } // Clear + DB.Model(&post).Association("Category").Append(&Category{ + Name: "Category 2", + }) + + if DB.Model(&post).Related(&Category{}).RecordNotFound() { + t.Errorf("Should find category after append") + } + + DB.Model(&post).Association("Category").Clear() + + if !DB.Model(&post).Related(&Category{}).RecordNotFound() { + t.Errorf("Should not find any category after Clear") + } } func TestHasOneAndHasManyAssociation(t *testing.T) {