From 822e895d4d0707402b04a90ff7185846ce53ebac Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 15 Jan 2016 22:53:09 +0800 Subject: [PATCH] Refactor getColumnAsArray --- association.go | 6 +++--- join_table_handler.go | 4 ++-- preload.go | 6 +++--- scope_utils.go | 32 +++++++++++++++++++------------- 4 files changed, 27 insertions(+), 21 deletions(-) diff --git a/association.go b/association.go index f828aea4..ecf6eb49 100644 --- a/association.go +++ b/association.go @@ -117,7 +117,7 @@ func (association *Association) Delete(values ...interface{}) *Association { } } - deletingPrimaryKeys := association.getPrimaryKeys(deletingResourcePrimaryFieldNames, values...) + deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...) if relationship.Kind == "many_to_many" { // source value's foreign keys @@ -141,7 +141,7 @@ func (association *Association) Delete(values ...interface{}) *Association { if relationship.Kind == "belongs_to" { // find with deleting relation's foreign keys - primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...) + primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...) newDB = newDB.Where( fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)..., @@ -158,7 +158,7 @@ func (association *Association) Delete(values ...interface{}) *Association { } } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { // find all relations - primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, scope.Value) + primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) newDB = newDB.Where( fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)..., diff --git a/join_table_handler.go b/join_table_handler.go index 6e7f9045..9e6c027a 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -154,7 +154,7 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so foreignFieldNames = append(foreignFieldNames, scope.Fields()[foreignKey.AssociationDBName].Name) } - foreignFieldValues := scope.getColumnAsArray(foreignFieldNames) + foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value) var condString string if len(foreignFieldValues) > 0 { @@ -165,7 +165,7 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues)) - keys := scope.getColumnAsArray(foreignFieldNames) + keys := scope.getColumnAsArray(foreignFieldNames, scope.Value) values = append(values, toQueryValues(keys)) } else { condString = fmt.Sprintf("1 <> 1") diff --git a/preload.go b/preload.go index 2333cade..20c9aacf 100644 --- a/preload.go +++ b/preload.go @@ -77,7 +77,7 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) relation := field.Relationship // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames) + primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) if len(primaryKeys) == 0 { return } @@ -112,7 +112,7 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) relation := field.Relationship // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames) + primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) if len(primaryKeys) == 0 { return } @@ -149,7 +149,7 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ relation := field.Relationship // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames) + primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value) if len(primaryKeys) == 0 { return } diff --git a/scope_utils.go b/scope_utils.go index 99957adc..ffaa99b4 100644 --- a/scope_utils.go +++ b/scope_utils.go @@ -2,24 +2,30 @@ package gorm import "reflect" -func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) { - indirectScopeValue := scope.IndirectValue() - switch indirectScopeValue.Kind() { - case reflect.Slice: - for i := 0; i < indirectScopeValue.Len(); i++ { +func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) { + for _, value := range values { + indirectValue := reflect.ValueOf(value) + for indirectValue.Kind() == reflect.Ptr { + indirectValue = indirectValue.Elem() + } + + switch indirectValue.Kind() { + case reflect.Slice: + for i := 0; i < indirectValue.Len(); i++ { + var result []interface{} + var object = reflect.Indirect(indirectValue.Index(i)) + for _, column := range columns { + result = append(result, object.FieldByName(column).Interface()) + } + results = append(results, result) + } + case reflect.Struct: var result []interface{} - var object = reflect.Indirect(indirectScopeValue.Index(i)) for _, column := range columns { - result = append(result, object.FieldByName(column).Interface()) + result = append(result, indirectValue.FieldByName(column).Interface()) } results = append(results, result) } - case reflect.Struct: - var result []interface{} - for _, column := range columns { - result = append(result, indirectScopeValue.FieldByName(column).Interface()) - } - return [][]interface{}{result} } return }