From 67874f923242e53af66e16619357ea7d0bf63415 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 16 Jan 2016 06:00:18 +0800 Subject: [PATCH] Keep Refactoring Association Mode --- association.go | 104 ++++++++++++++++++++++++++----------------- association_utils.go | 36 --------------- 2 files changed, 64 insertions(+), 76 deletions(-) diff --git a/association.go b/association.go index ecf6eb49..862276d5 100644 --- a/association.go +++ b/association.go @@ -57,16 +57,23 @@ func (association *Association) Replace(values ...interface{}) *Association { newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) } - // Relations except new created + // Delete Relations except new created if len(values) > 0 { var associationForeignFieldNames []string if relationship.Kind == "many_to_many" { - associationForeignFieldNames = relationship.AssociationForeignFieldNames + // if many to many relations, get association fields name from association foreign keys + associationFields := scope.New(reflect.New(field.Type()).Interface()).Fields() + for _, dbName := range relationship.AssociationForeignFieldNames { + associationForeignFieldNames = append(associationForeignFieldNames, associationFields[dbName].Name) + } } else { - associationForeignFieldNames = relationship.AssociationForeignDBNames + // If other relations, use primary keys + for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { + associationForeignFieldNames = append(associationForeignFieldNames, field.Name) + } } - newPrimaryKeys := association.getPrimaryKeys(associationForeignFieldNames, field.Interface()) + newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface()) if len(newPrimaryKeys) > 0 { sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(newPrimaryKeys)) @@ -75,12 +82,25 @@ func (association *Association) Replace(values ...interface{}) *Association { } if relationship.Kind == "many_to_many" { - if sourcePrimaryKeys := association.getPrimaryKeys(relationship.ForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { + // if many to many relations, delete related relations from join table + + // get source fields name from source foreign keys + var ( + sourceFields = scope.Fields() + sourceForeignFieldNames []string + ) + + for _, dbName := range relationship.ForeignFieldNames { + sourceForeignFieldNames = append(sourceForeignFieldNames, sourceFields[dbName].Name) + } + + if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...) association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) } } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { + // has_one or has_many relations, set foreign key to be nil (TODO or delete them?) var foreignKeyMap = map[string]interface{}{} for idx, foreignKey := range relationship.ForeignDBNames { foreignKeyMap[foreignKey] = nil @@ -110,11 +130,9 @@ func (association *Association) Delete(values ...interface{}) *Association { } var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string - for _, field := range scope.New(reflect.New(field.Type()).Interface()).Fields() { - if field.IsPrimaryKey { - deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name) - deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName) - } + for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { + deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name) + deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName) } deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...) @@ -127,8 +145,15 @@ func (association *Association) Delete(values ...interface{}) *Association { } } + // get association's foreign fields name + var associationFields = scope.New(reflect.New(field.Type()).Interface()).Fields() + var associationForeignFieldNames []string + for _, associationDBName := range relationship.AssociationForeignFieldNames { + associationForeignFieldNames = append(associationForeignFieldNames, associationFields[associationDBName].Name) + } + // association value's foreign keys - deletingPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...) + deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...) sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys)) newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...) @@ -147,7 +172,7 @@ func (association *Association) Delete(values ...interface{}) *Association { toQueryValues(primaryKeys)..., ) - // set foreign key to be null + // set foreign key to be null if there are some records affected modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface() if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil { if results.RowsAffected > 0 { @@ -176,28 +201,29 @@ func (association *Association) Delete(values ...interface{}) *Association { } } - // Remove deleted records from field + // Remove deleted records from source's field if association.Error == nil { if association.Field.Field.Kind() == reflect.Slice { leftValues := reflect.Zero(association.Field.Field.Type()) for i := 0; i < association.Field.Field.Len(); i++ { reflectValue := association.Field.Field.Index(i) - primaryKey := association.getPrimaryKeys(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0] - var included = false + primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0] + var isDeleted = false for _, pk := range deletingPrimaryKeys { if equalAsString(primaryKey, pk) { - included = true + isDeleted = true + break } } - if !included { + if !isDeleted { leftValues = reflect.Append(leftValues, reflectValue) } } association.Field.Set(leftValues) } else if association.Field.Field.Kind() == reflect.Struct { - primaryKey := association.getPrimaryKeys(deletingResourcePrimaryFieldNames, association.Field.Field.Interface())[0] + primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, association.Field.Field.Interface())[0] for _, pk := range deletingPrimaryKeys { if equalAsString(primaryKey, pk) { association.Field.Set(reflect.Zero(association.Field.Field.Type())) @@ -222,34 +248,32 @@ func (association *Association) Count() int { relationship = association.Field.Relationship scope = association.Scope fieldValue = association.Field.Field.Interface() - newScope = scope.New(fieldValue) + query = scope.DB() ) if relationship.Kind == "many_to_many" { - relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.DB(), association.Scope.Value).Model(fieldValue).Count(&count) + query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.DB(), association.Scope.Value) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { - query := scope.DB() - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { - query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)), - field.Field.Interface()) - } - } - - if relationship.PolymorphicType != "" { - query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName()) - } - query.Model(fieldValue).Count(&count) + primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) + query = query.Where( + fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), + toQueryValues(primaryKeys)..., + ) } else if relationship.Kind == "belongs_to" { - query := scope.DB() - for idx, primaryKey := range relationship.AssociationForeignDBNames { - if field, ok := scope.FieldByName(relationship.ForeignDBNames[idx]); ok { - query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(primaryKey)), - field.Field.Interface()) - } - } - query.Model(fieldValue).Count(&count) + primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value) + query = query.Where( + fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)), + toQueryValues(primaryKeys)..., + ) } + if relationship.PolymorphicType != "" { + query = query.Where( + fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)), + scope.TableName(), + ) + } + + query.Model(fieldValue).Count(&count) return count } diff --git a/association_utils.go b/association_utils.go index 7ec2ab7f..912c9ca4 100644 --- a/association_utils.go +++ b/association_utils.go @@ -82,42 +82,6 @@ func (association *Association) saveAssociations(values ...interface{}) *Associa return association } -func (association *Association) getPrimaryKeys(columns []string, values ...interface{}) (results [][]interface{}) { - scope := association.Scope - - for _, value := range values { - reflectValue := reflect.Indirect(reflect.ValueOf(value)) - if reflectValue.Kind() == reflect.Slice { - for i := 0; i < reflectValue.Len(); i++ { - primaryKeys := []interface{}{} - newScope := scope.New(reflectValue.Index(i).Interface()) - for _, column := range columns { - if field, ok := newScope.FieldByName(column); ok { - primaryKeys = append(primaryKeys, field.Field.Interface()) - } else { - primaryKeys = append(primaryKeys, "") - } - } - results = append(results, primaryKeys) - } - } else if reflectValue.Kind() == reflect.Struct { - newScope := scope.New(value) - var primaryKeys []interface{} - for _, column := range columns { - if field, ok := newScope.FieldByName(column); ok { - primaryKeys = append(primaryKeys, field.Field.Interface()) - } else { - primaryKeys = append(primaryKeys, "") - } - } - - results = append(results, primaryKeys) - } - } - - return -} - func toQueryMarks(primaryValues [][]interface{}) string { var results []string