diff --git a/association.go b/association.go index f62e712b..b088c1dd 100644 --- a/association.go +++ b/association.go @@ -58,25 +58,24 @@ func (association *Association) Delete(values ...interface{}) *Association { } } - primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignDBNames, values...) + primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...) sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)) query = query.Where(sql, toQueryValues(primaryKeys)...) if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil { leftValues := reflect.Zero(association.Field.Field.Type()) for i := 0; i < association.Field.Field.Len(); i++ { - value := association.Field.Field.Index(i) - if primaryField := association.Scope.New(value.Interface()).PrimaryField(); primaryField != nil { - var included = false - for _, primaryKey := range primaryKeys { - if equalAsString(primaryKey, primaryField.Field.Interface()) { - included = true - } - } - if !included { - leftValues = reflect.Append(leftValues, value) + reflectValue := association.Field.Field.Index(i) + primaryKey := association.getPrimaryKeys(relationship.ForeignFieldNames, reflectValue.Interface())[0] + var included = false + for _, pk := range primaryKeys { + if equalAsString(primaryKey, pk) { + included = true } } + if !included { + leftValues = reflect.Append(leftValues, reflectValue) + } } association.Field.Set(leftValues) } @@ -92,16 +91,16 @@ func (association *Association) Replace(values ...interface{}) *Association { if relationship.Kind == "many_to_many" { field := association.Field.Field - oldPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignDBNames, field.Interface()) + oldPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, field.Interface()) association.Field.Set(reflect.Zero(association.Field.Field.Type())) association.Append(values...) - newPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignDBNames, field.Interface()) + newPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, field.Interface()) var addedPrimaryKeys = [][]interface{}{} for _, newKey := range newPrimaryKeys { hasEqual := false for _, oldKey := range oldPrimaryKeys { - if reflect.DeepEqual(newKey, oldKey) { + if equalAsString(newKey, oldKey) { hasEqual = true break } @@ -111,7 +110,7 @@ func (association *Association) Replace(values ...interface{}) *Association { } } - for _, primaryKey := range association.getPrimaryKeys(relationship.AssociationForeignDBNames, values...) { + for _, primaryKey := range association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...) { addedPrimaryKeys = append(addedPrimaryKeys, primaryKey) } @@ -123,7 +122,7 @@ func (association *Association) Replace(values ...interface{}) *Association { } } - sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(addedPrimaryKeys)) + sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(addedPrimaryKeys)) query = query.Where(sql, toQueryValues(addedPrimaryKeys)...) association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship)) } @@ -195,11 +194,10 @@ func (association *Association) getPrimaryKeys(columns []string, values ...inter scope := association.Scope for _, value := range values { - primaryKeys := []interface{}{} - reflectValue := reflect.Indirect(reflect.ValueOf(value)) if reflectValue.Kind() == reflect.Slice { for i := 0; i < reflectValue.Len(); i++ { + primaryKeys := []interface{}{} newScope := scope.New(reflectValue.Index(i).Interface()) for _, column := range columns { if field, ok := newScope.FieldByName(column); ok { @@ -208,9 +206,11 @@ func (association *Association) getPrimaryKeys(columns []string, values ...inter primaryKeys = append(primaryKeys, "") } } + results = append(results, primaryKeys) } } else if reflectValue.Kind() == reflect.Struct { newScope := scope.New(value) + var primaryKeys []interface{} for _, column := range columns { if field, ok := newScope.FieldByName(column); ok { primaryKeys = append(primaryKeys, field.Field.Interface()) @@ -218,9 +218,9 @@ func (association *Association) getPrimaryKeys(columns []string, values ...inter primaryKeys = append(primaryKeys, "") } } - } - results = append(results, primaryKeys) + results = append(results, primaryKeys) + } } return results } diff --git a/association_test.go b/association_test.go index 205a929e..dfda46a5 100644 --- a/association_test.go +++ b/association_test.go @@ -186,6 +186,7 @@ func TestManyToMany(t *testing.T) { var language Language DB.Where("name = ?", "EE").First(&language) DB.Model(&user).Association("Languages").Delete(language, &language) + if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-1 || len(user.Languages) != len(totalLanguages)-1 { t.Errorf("Relations should be deleted with Delete") }