diff --git a/association.go b/association.go index 19fef6a0..48b051f8 100644 --- a/association.go +++ b/association.go @@ -136,6 +136,10 @@ func (association *Association) Delete(values ...interface{}) *Association { query := scope.NewDB() relationship := association.Field.Relationship + if len(values) == 0 { + return association + } + // many to many if relationship.Kind == "many_to_many" { // current value's foreign keys @@ -178,8 +182,27 @@ func (association *Association) Delete(values ...interface{}) *Association { primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...) sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)) association.setErr(query.Model(scope.Value).Where(sql, toQueryValues(primaryKeys)...).UpdateColumn(foreignKeyMap).Error) + } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { - // query.Model(association.Field).UpdateColumn(foreignKeyMap) + var foreignKeyMap = map[string]interface{}{} + for _, foreignKey := range relationship.ForeignDBNames { + foreignKeyMap[foreignKey] = nil + } + + primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, scope.Value) + sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)) + + var primaryFieldNames, primaryFieldDBNames []string + for _, field := range scope.New(values[0]).Fields() { + if field.IsPrimaryKey { + primaryFieldNames = append(primaryFieldNames, field.Name) + primaryFieldDBNames = append(primaryFieldDBNames, field.DBName) + } + } + relationsPrimaryKeys := association.getPrimaryKeys(primaryFieldNames, values...) + sql += fmt.Sprintf(" AND %v IN (%v)", toQueryCondition(scope, primaryFieldDBNames), toQueryMarks(relationsPrimaryKeys)) + + query.Model(association.Field.Field.Interface()).Where(sql, append(toQueryValues(primaryKeys), toQueryValues(relationsPrimaryKeys)...)...).UpdateColumn(foreignKeyMap) } } return association diff --git a/association_test.go b/association_test.go index 3a755230..6cd76634 100644 --- a/association_test.go +++ b/association_test.go @@ -2,6 +2,8 @@ package gorm_test import ( "fmt" + "reflect" + "sort" "testing" ) @@ -125,9 +127,58 @@ func TestHasMany(t *testing.T) { t.Errorf("Got errors when save post", err.Error()) } + for _, comment := range post.Comments { + if comment.PostId == 0 { + t.Errorf("comment's PostID should be updated") + } + } + + var compareComments = func(comments []Comment, contents []string) bool { + var commentContents []string + for _, comment := range comments { + commentContents = append(commentContents, comment.Content) + } + sort.Strings(commentContents) + sort.Strings(contents) + return reflect.DeepEqual(commentContents, contents) + } + // Query + if DB.First(&Comment{}, "content = ?", "Comment 1").Error != nil { + t.Errorf("Comment 1 should be saved") + } + + var comments1 []Comment + DB.Model(&post).Association("Comments").Find(&comments1) + if !compareComments(comments1, []string{"Comment 1", "Comment 2"}) { + t.Errorf("Query has many relations with Association") + } + + var comments11 []Comment + DB.Model(&post).Related(&comments11) + if !compareComments(comments11, []string{"Comment 1", "Comment 2"}) { + t.Errorf("Query has many relations with Related") + } + // Append + DB.Model(&post).Association("Comments").Append(&Comment{Content: "Comment 3"}) + + var comments2 []Comment + DB.Model(&post).Related(&comments2) + if !compareComments(comments2, []string{"Comment 1", "Comment 2", "Comment 3"}) { + t.Errorf("Append new record to has many relations") + } + // Delete + DB.Model(&post).Association("Comments").Delete(comments11) + + var comments3 []Comment + DB.Model(&post).Related(&comments3) + if !compareComments(comments3, []string{"Comment 3"}) { + fmt.Println(comments3) + t.Errorf("Delete an existing resource for has many relations") + } + // Replace // Clear }