diff --git a/association.go b/association.go index 0ce2904f..a172e52d 100644 --- a/association.go +++ b/association.go @@ -57,9 +57,10 @@ func (association *Association) Append(values ...interface{}) *Association { return association.err(scope.db.Error) } -func (association *Association) Delete(values ...interface{}) *Association { +func (association *Association) getPrimaryKeys(values ...interface{}) []interface{} { primaryKeys := []interface{}{} scope := association.Scope + for _, value := range values { reflectValue := reflect.ValueOf(value) if reflectValue.Kind() == reflect.Ptr { @@ -81,6 +82,11 @@ func (association *Association) Delete(values ...interface{}) *Association { } } } + return primaryKeys +} + +func (association *Association) Delete(values ...interface{}) *Association { + primaryKeys := association.getPrimaryKeys(values...) if len(primaryKeys) == 0 { association.err(errors.New("no primary key found")) @@ -88,8 +94,8 @@ func (association *Association) Delete(values ...interface{}) *Association { relationship := association.Field.Relationship // many to many if relationship.kind == "many_to_many" { - whereSql := fmt.Sprintf("%v.%v IN (?)", relationship.joinTable, scope.Quote(ToSnake(relationship.associationForeignKey))) - scope.db.Table(relationship.joinTable).Where(whereSql, primaryKeys).Delete("") + whereSql := fmt.Sprintf("%v.%v IN (?)", relationship.joinTable, association.Scope.Quote(ToSnake(relationship.associationForeignKey))) + association.Scope.db.Model("").Table(relationship.joinTable).Where(whereSql, primaryKeys).Delete("") } else { association.err(errors.New("delete only support many to many")) } @@ -97,7 +103,38 @@ func (association *Association) Delete(values ...interface{}) *Association { return association } -func (association *Association) Replace(values interface{}) *Association { +func (association *Association) Replace(values ...interface{}) *Association { + relationship := association.Field.Relationship + scope := association.Scope + if relationship.kind == "many_to_many" { + field := scope.IndirectValue().FieldByName(association.Column) + + oldPrimaryKeys := association.getPrimaryKeys(field.Interface()) + association.Append(values...) + newPrimaryKeys := association.getPrimaryKeys(field.Interface()) + + var addedPrimaryKeys = []interface{}{} + for _, new := range newPrimaryKeys { + hasEqual := false + for _, old := range oldPrimaryKeys { + if reflect.DeepEqual(new, old) { + hasEqual = true + break + } + } + if !hasEqual { + addedPrimaryKeys = append(addedPrimaryKeys, new) + } + } + for _, primaryKey := range association.getPrimaryKeys(values...) { + addedPrimaryKeys = append(addedPrimaryKeys, primaryKey) + } + + whereSql := fmt.Sprintf("%v.%v NOT IN (?)", relationship.joinTable, scope.Quote(ToSnake(relationship.associationForeignKey))) + scope.db.Model("").Table(relationship.joinTable).Where(whereSql, addedPrimaryKeys).Delete("") + } else { + association.err(errors.New("replace only support many to many")) + } return association } @@ -122,17 +159,17 @@ func (association *Association) Count() int { relationship.joinTable, relationship.joinTable, scope.Quote(ToSnake(relationship.foreignKey))) - scope.db.Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey).NewScope("").count(&count) + scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey).Count(&count) } else if relationship.kind == "has_many" { whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.foreignKey))) - scope.db.Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey).NewScope("").count(&count) + scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey).Count(&count) } else if relationship.kind == "has_one" { whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), relationship.foreignKey) - scope.db.Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey).NewScope("").count(&count) + scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey).Count(&count) } else if relationship.kind == "belongs_to" { if v, ok := scope.FieldByName(association.Column); ok { whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), relationship.foreignKey) - scope.db.Table(newScope.QuotedTableName()).Where(whereSql, v).NewScope("").count(&count) + scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, v).Count(&count) } } diff --git a/association_test.go b/association_test.go index 561725a5..ff96a47b 100644 --- a/association_test.go +++ b/association_test.go @@ -187,6 +187,19 @@ func TestManyToMany(t *testing.T) { t.Errorf("Relations should be deleted with Delete") } + // Replace + var languageB Language + db.Where("name = ?", "BB").First(&languageB) + db.Debug().Model(&user).Association("Languages").Replace(languageB) + if db.Model(&user).Association("Languages").Count() != 1 { + t.Errorf("Relations should be deleted with Delete") + } + + db.Model(&user).Association("Languages").Replace(&[]Language{{Name: "FF"}, {Name: "JJ"}}) + if db.Model(&user).Association("Languages").Count() != len([]string{"FF", "JJ"}) { + t.Errorf("Relations should be deleted with Delete") + } + // db.Model(&User{}).Many2Many("Languages").Replace(&[]Language{}) // db.Model(&User{}).Related(&[]Language{}, "Languages") // SELECT `languages`.* FROM `languages` INNER JOIN `user_languages` ON `languages`.`id` = `user_languages`.`language_id` WHERE `user_languages`.`user_id` = 111