diff --git a/association.go b/association.go index 9838f440..82e39950 100644 --- a/association.go +++ b/association.go @@ -124,46 +124,49 @@ func (association *Association) Delete(values ...interface{}) *Association { func (association *Association) Replace(values ...interface{}) *Association { relationship := association.Field.Relationship scope := association.Scope + field := association.Field.Field + + oldPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, field.Interface()) + association.Field.Set(reflect.Zero(association.Field.Field.Type())) + association.Append(values...) + newPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, field.Interface()) + + var addedPrimaryKeys = [][]interface{}{} + for _, newKey := range newPrimaryKeys { + hasEqual := false + for _, oldKey := range oldPrimaryKeys { + if equalAsString(newKey, oldKey) { + hasEqual = true + break + } + } + if !hasEqual { + addedPrimaryKeys = append(addedPrimaryKeys, newKey) + } + } + + for _, primaryKey := range association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...) { + addedPrimaryKeys = append(addedPrimaryKeys, primaryKey) + } + + query := scope.NewDB() + var foreignKeyMap = map[string]string{} + for idx, foreignKey := range relationship.ForeignDBNames { + foreignKeyMap[foreignKey] = "" + if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + } + } + + if len(addedPrimaryKeys) > 0 { + sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(addedPrimaryKeys)) + query = query.Where(sql, toQueryValues(addedPrimaryKeys)...) + } + if relationship.Kind == "many_to_many" { - field := association.Field.Field - - oldPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, field.Interface()) - association.Field.Set(reflect.Zero(association.Field.Field.Type())) - association.Append(values...) - newPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, field.Interface()) - - var addedPrimaryKeys = [][]interface{}{} - for _, newKey := range newPrimaryKeys { - hasEqual := false - for _, oldKey := range oldPrimaryKeys { - if equalAsString(newKey, oldKey) { - hasEqual = true - break - } - } - if !hasEqual { - addedPrimaryKeys = append(addedPrimaryKeys, newKey) - } - } - - for _, primaryKey := range association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...) { - addedPrimaryKeys = append(addedPrimaryKeys, primaryKey) - } - - query := scope.NewDB() - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { - query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - if len(addedPrimaryKeys) > 0 { - sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(addedPrimaryKeys)) - query = query.Where(sql, toQueryValues(addedPrimaryKeys)...) - } association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship)) - } else { - association.setErr(errors.New("replace only support many to many")) + } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { + query.Update(foreignKeyMap) } return association } diff --git a/association_test.go b/association_test.go index a5cc1860..32f1e6b2 100644 --- a/association_test.go +++ b/association_test.go @@ -51,15 +51,20 @@ func TestHasOne(t *testing.T) { } // Replace - // DB.Model(&post).Association("Category").Replace(&Category{ - // Name: "Category 3", - // }) + var category3 = Category{ + Name: "Category 3", + } + DB.Model(&post).Association("Category").Replace(&category3) - // var category3 Category - // DB.Model(&post).Related(&category3) - // if category3.Name != "Category 3" { - // t.Errorf("Category should be updated with Replace") - // } + if category3.Id == 0 { + t.Errorf("Category should has ID when created with Replace") + } + + var category31 Category + DB.Model(&post).Related(&category31) + if category31.Name != "Category 3" { + t.Errorf("Category should be updated with Replace") + } // Delete // Clear