From f13dcd8bc0faa2dac5ac94cf55ea15f1dc83ee9c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 25 Dec 2015 20:43:51 +0800 Subject: [PATCH] Refactor association Replace --- association.go | 74 ++++++++++++++++++++++++++++++++------------- association_test.go | 4 +-- 2 files changed, 55 insertions(+), 23 deletions(-) diff --git a/association.go b/association.go index 7f56c2d8..4bc7269f 100644 --- a/association.go +++ b/association.go @@ -86,10 +86,18 @@ func (association *Association) Replace(values ...interface{}) *Association { scope := association.Scope field := association.Field.Field + // get old primary keys oldPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, field.Interface()) + + // append new values association.Field.Set(reflect.Zero(association.Field.Field.Type())) association.Append(values...) - newPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, field.Interface()) + + // get new primary keys + var newPrimaryKeys [][]interface{} + if len(values) > 0 { + newPrimaryKeys = association.getPrimaryKeys(relationship.AssociationForeignFieldNames, field.Interface()) + } var addedPrimaryKeys = [][]interface{}{} for _, newKey := range newPrimaryKeys { @@ -111,25 +119,49 @@ func (association *Association) Replace(values ...interface{}) *Association { query := scope.NewDB() var foreignKeyMap = map[string]interface{}{} - for idx, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok { - query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + + if relationship.Kind == "belongs_to" { + for idx, foreignKey := range relationship.AssociationForeignDBNames { + if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok { + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + } } - } - if len(addedPrimaryKeys) > 0 { - sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(addedPrimaryKeys)) - query = query.Where(sql, toQueryValues(addedPrimaryKeys)...) - } + for _, foreignKey := range relationship.ForeignDBNames { + foreignKeyMap[foreignKey] = nil + } - if relationship.Kind == "many_to_many" { - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship)) - } else if relationship.Kind == "belongs_to" { - association.setErr(query.Model(scope.Value).UpdateColumn(foreignKeyMap).Error) - } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { - fieldValue := reflect.New(association.Field.Field.Type()).Interface() - association.setErr(query.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) + if len(addedPrimaryKeys) > 0 { + sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(addedPrimaryKeys)) + query = query.Where(sql, toQueryValues(addedPrimaryKeys)...) + } + + modelValue := scope.Value + // if replacing with a new value, don't reset current foreign key + // if clearing foreign value, then reset the foreign key to null + if len(values) > 0 { + modelValue = reflect.New(scope.GetModelStruct().ModelType).Interface() + } + association.setErr(query.Model(modelValue).UpdateColumn(foreignKeyMap).Error) + } else { + for idx, foreignKey := range relationship.ForeignDBNames { + foreignKeyMap[foreignKey] = nil + if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok { + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + } + } + + if len(addedPrimaryKeys) > 0 { + sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(addedPrimaryKeys)) + query = query.Where(sql, toQueryValues(addedPrimaryKeys)...) + } + + if relationship.Kind == "many_to_many" { + association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship)) + } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { + fieldValue := reflect.New(association.Field.Field.Type()).Interface() + association.setErr(query.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) + } } return association } @@ -184,8 +216,8 @@ func (association *Association) Delete(values ...interface{}) *Association { } primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...) sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)) - association.setErr(query.Model(scope.Value).Where(sql, toQueryValues(primaryKeys)...).UpdateColumn(foreignKeyMap).Error) + association.setErr(query.Model(scope.Value).Where(sql, toQueryValues(primaryKeys)...).UpdateColumn(foreignKeyMap).Error) } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { var foreignKeyMap = map[string]interface{}{} for _, foreignKey := range relationship.ForeignDBNames { @@ -250,8 +282,7 @@ func (association *Association) Count() int { return count } -func (association *Association) getPrimaryKeys(columns []string, values ...interface{}) [][]interface{} { - results := [][]interface{}{} +func (association *Association) getPrimaryKeys(columns []string, values ...interface{}) (results [][]interface{}) { scope := association.Scope for _, value := range values { @@ -283,7 +314,8 @@ func (association *Association) getPrimaryKeys(columns []string, values ...inter results = append(results, primaryKeys) } } - return results + + return } func toQueryMarks(primaryValues [][]interface{}) string { diff --git a/association_test.go b/association_test.go index c094b81e..92854d89 100644 --- a/association_test.go +++ b/association_test.go @@ -106,7 +106,7 @@ func TestBelongsTo(t *testing.T) { t.Errorf("Should find category after append") } - DB.Model(&post).Debug().Association("Category").Clear() + DB.Model(&post).Association("Category").Clear() if !DB.Model(&post).Related(&Category{}).RecordNotFound() { t.Errorf("Should not find any category after Clear") @@ -179,7 +179,7 @@ func TestHasMany(t *testing.T) { } // Replace - DB.Model(&Post{Id: 999}).Debug().Association("Comments").Replace() + DB.Model(&Post{Id: 999}).Association("Comments").Replace() var comments4 []Comment DB.Model(&post).Related(&comments4)