Refactor getColumnAsArray

This commit is contained in:
Jinzhu 2016-01-15 22:53:09 +08:00
parent 41870191b0
commit 822e895d4d
4 changed files with 27 additions and 21 deletions

View File

@ -117,7 +117,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
} }
} }
deletingPrimaryKeys := association.getPrimaryKeys(deletingResourcePrimaryFieldNames, values...) deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...)
if relationship.Kind == "many_to_many" { if relationship.Kind == "many_to_many" {
// source value's foreign keys // source value's foreign keys
@ -141,7 +141,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
if relationship.Kind == "belongs_to" { if relationship.Kind == "belongs_to" {
// find with deleting relation's foreign keys // find with deleting relation's foreign keys
primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...) primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...)
newDB = newDB.Where( newDB = newDB.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
toQueryValues(primaryKeys)..., toQueryValues(primaryKeys)...,
@ -158,7 +158,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
} }
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
// find all relations // find all relations
primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, scope.Value) primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
newDB = newDB.Where( newDB = newDB.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
toQueryValues(primaryKeys)..., toQueryValues(primaryKeys)...,

View File

@ -154,7 +154,7 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
foreignFieldNames = append(foreignFieldNames, scope.Fields()[foreignKey.AssociationDBName].Name) foreignFieldNames = append(foreignFieldNames, scope.Fields()[foreignKey.AssociationDBName].Name)
} }
foreignFieldValues := scope.getColumnAsArray(foreignFieldNames) foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value)
var condString string var condString string
if len(foreignFieldValues) > 0 { if len(foreignFieldValues) > 0 {
@ -165,7 +165,7 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues)) condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues))
keys := scope.getColumnAsArray(foreignFieldNames) keys := scope.getColumnAsArray(foreignFieldNames, scope.Value)
values = append(values, toQueryValues(keys)) values = append(values, toQueryValues(keys))
} else { } else {
condString = fmt.Sprintf("1 <> 1") condString = fmt.Sprintf("1 <> 1")

View File

@ -77,7 +77,7 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{})
relation := field.Relationship relation := field.Relationship
// get relations's primary keys // get relations's primary keys
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames) primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
if len(primaryKeys) == 0 { if len(primaryKeys) == 0 {
return return
} }
@ -112,7 +112,7 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{})
relation := field.Relationship relation := field.Relationship
// get relations's primary keys // get relations's primary keys
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames) primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
if len(primaryKeys) == 0 { if len(primaryKeys) == 0 {
return return
} }
@ -149,7 +149,7 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
relation := field.Relationship relation := field.Relationship
// get relations's primary keys // get relations's primary keys
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames) primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value)
if len(primaryKeys) == 0 { if len(primaryKeys) == 0 {
return return
} }

View File

@ -2,13 +2,18 @@ package gorm
import "reflect" import "reflect"
func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) { func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) {
indirectScopeValue := scope.IndirectValue() for _, value := range values {
switch indirectScopeValue.Kind() { indirectValue := reflect.ValueOf(value)
for indirectValue.Kind() == reflect.Ptr {
indirectValue = indirectValue.Elem()
}
switch indirectValue.Kind() {
case reflect.Slice: case reflect.Slice:
for i := 0; i < indirectScopeValue.Len(); i++ { for i := 0; i < indirectValue.Len(); i++ {
var result []interface{} var result []interface{}
var object = reflect.Indirect(indirectScopeValue.Index(i)) var object = reflect.Indirect(indirectValue.Index(i))
for _, column := range columns { for _, column := range columns {
result = append(result, object.FieldByName(column).Interface()) result = append(result, object.FieldByName(column).Interface())
} }
@ -17,9 +22,10 @@ func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{})
case reflect.Struct: case reflect.Struct:
var result []interface{} var result []interface{}
for _, column := range columns { for _, column := range columns {
result = append(result, indirectScopeValue.FieldByName(column).Interface()) result = append(result, indirectValue.FieldByName(column).Interface())
}
results = append(results, result)
} }
return [][]interface{}{result}
} }
return return
} }