Keep refactoring on Preload

This commit is contained in:
Jinzhu 2016-01-15 20:37:41 +08:00
parent 3326a4e69d
commit 551c1e0c20
2 changed files with 108 additions and 88 deletions

View File

@ -95,10 +95,10 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{})
for i := 0; i < resultsValue.Len(); i++ { for i := 0; i < resultsValue.Len(); i++ {
result := resultsValue.Index(i) result := resultsValue.Index(i)
if indirectScopeValue.Kind() == reflect.Slice { if indirectScopeValue.Kind() == reflect.Slice {
value := getValueFromFields(result, relation.ForeignFieldNames) foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
for j := 0; j < indirectScopeValue.Len(); j++ { for j := 0; j < indirectScopeValue.Len(); j++ {
if equalAsString(getValueFromFields(indirectScopeValue.Index(j), relation.AssociationForeignFieldNames), value) { if indirectValue := reflect.Indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) {
reflect.Indirect(indirectScopeValue.Index(j)).FieldByName(field.Name).Set(result) indirectValue.FieldByName(field.Name).Set(result)
break break
} }
} }
@ -110,58 +110,72 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{})
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
relation := field.Relationship relation := field.Relationship
// get relations's primary keys
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames) primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames)
if len(primaryKeys) == 0 { if len(primaryKeys) == 0 {
return return
} }
// find relations
results := makeSlice(field.Struct.Type) results := makeSlice(field.Struct.Type)
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
resultValues := reflect.Indirect(reflect.ValueOf(results))
if scope.IndirectValue().Kind() == reflect.Slice { // assign find results
for i := 0; i < resultValues.Len(); i++ { var (
result := resultValues.Index(i) resultsValue = reflect.Indirect(reflect.ValueOf(results))
value := getValueFromFields(result, relation.ForeignFieldNames) indirectScopeValue = scope.IndirectValue()
objects := scope.IndirectValue() )
for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j)) if indirectScopeValue.Kind() == reflect.Slice {
if equalAsString(getValueFromFields(object, relation.AssociationForeignFieldNames), value) { for i := 0; i < resultsValue.Len(); i++ {
f := object.FieldByName(field.Name) result := resultsValue.Index(i)
f.Set(reflect.Append(f, result)) foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
for j := 0; j < indirectScopeValue.Len(); j++ {
object := reflect.Indirect(indirectScopeValue.Index(j))
if equalAsString(getValueFromFields(object, relation.AssociationForeignFieldNames), foreignValues) {
objectField := object.FieldByName(field.Name)
objectField.Set(reflect.Append(objectField, result))
break break
} }
} }
} }
} else { } else {
scope.SetColumn(field, resultValues) scope.Err(field.Set(resultsValue))
} }
} }
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
relation := field.Relationship relation := field.Relationship
// get relations's primary keys
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames) primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames)
if len(primaryKeys) == 0 { if len(primaryKeys) == 0 {
return return
} }
// find relations
results := makeSlice(field.Struct.Type) results := makeSlice(field.Struct.Type)
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
resultValues := reflect.Indirect(reflect.ValueOf(results))
for i := 0; i < resultValues.Len(); i++ { // assign find results
result := resultValues.Index(i) var (
if scope.IndirectValue().Kind() == reflect.Slice { resultsValue = reflect.Indirect(reflect.ValueOf(results))
indirectScopeValue = scope.IndirectValue()
)
for i := 0; i < resultsValue.Len(); i++ {
result := resultsValue.Index(i)
if indirectScopeValue.Kind() == reflect.Slice {
value := getValueFromFields(result, relation.AssociationForeignFieldNames) value := getValueFromFields(result, relation.AssociationForeignFieldNames)
objects := scope.IndirectValue() for j := 0; j < indirectScopeValue.Len(); j++ {
for j := 0; j < objects.Len(); j++ { object := reflect.Indirect(indirectScopeValue.Index(j))
object := reflect.Indirect(objects.Index(j))
if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) { if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
object.FieldByName(field.Name).Set(result) object.FieldByName(field.Name).Set(result)
} }
} }
} else { } else {
scope.SetColumn(field, result) scope.Err(field.Set(result))
} }
} }
} }
@ -170,24 +184,25 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
var ( var (
relation = field.Relationship relation = field.Relationship
joinTableHandler = relation.JoinTableHandler joinTableHandler = relation.JoinTableHandler
destType = field.StructField.Struct.Type.Elem() fieldType = field.StructField.Struct.Type.Elem()
linkHash = make(map[string][]reflect.Value)
sourceKeys = []string{}
foreignKeyValue interface{} foreignKeyValue interface{}
foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type() foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type()
linkHash = map[string][]reflect.Value{}
isPtr bool isPtr bool
) )
if destType.Kind() == reflect.Ptr { if fieldType.Kind() == reflect.Ptr {
isPtr = true isPtr = true
destType = destType.Elem() fieldType = fieldType.Elem()
} }
var sourceKeys = []string{}
for _, key := range joinTableHandler.SourceForeignKeys() { for _, key := range joinTableHandler.SourceForeignKeys() {
sourceKeys = append(sourceKeys, key.DBName) sourceKeys = append(sourceKeys, key.DBName)
} }
preloadJoinDB := scope.NewDB().Table(scope.New(reflect.New(destType).Interface()).TableName()).Select("*") // generate query with join table
preloadJoinDB := scope.NewDB().Table(scope.New(reflect.New(fieldType).Interface()).TableName()).Select("*")
preloadJoinDB = joinTableHandler.JoinWith(joinTableHandler, preloadJoinDB, scope.Value) preloadJoinDB = joinTableHandler.JoinWith(joinTableHandler, preloadJoinDB, scope.Value)
// preload inline conditions // preload inline conditions
@ -205,7 +220,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
columns, _ := rows.Columns() columns, _ := rows.Columns()
for rows.Next() { for rows.Next() {
var ( var (
elem = reflect.New(destType).Elem() elem = reflect.New(fieldType).Elem()
fields = scope.New(elem.Addr().Interface()).Fields() fields = scope.New(elem.Addr().Interface()).Fields()
) )
@ -235,10 +250,11 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
indirectScopeValue = scope.IndirectValue() indirectScopeValue = scope.IndirectValue()
fieldsSourceMap = map[string]reflect.Value{} fieldsSourceMap = map[string]reflect.Value{}
foreignFieldNames = []string{} foreignFieldNames = []string{}
fields = scope.Fields()
) )
for _, dbName := range relation.ForeignFieldNames { for _, dbName := range relation.ForeignFieldNames {
if field, ok := scope.FieldByName(dbName); ok { if field, ok := fields[dbName]; ok {
foreignFieldNames = append(foreignFieldNames, field.Name) foreignFieldNames = append(foreignFieldNames, field.Name)
} }
} }
@ -256,60 +272,3 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
fieldsSourceMap[source].Set(reflect.Append(fieldsSourceMap[source], link...)) fieldsSourceMap[source].Set(reflect.Append(fieldsSourceMap[source], link...))
} }
} }
func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) {
values := scope.IndirectValue()
switch values.Kind() {
case reflect.Slice:
for i := 0; i < values.Len(); i++ {
var result []interface{}
for _, column := range columns {
result = append(result, reflect.Indirect(values.Index(i)).FieldByName(column).Interface())
}
results = append(results, result)
}
case reflect.Struct:
var result []interface{}
for _, column := range columns {
result = append(result, values.FieldByName(column).Interface())
}
return [][]interface{}{result}
}
return
}
func (scope *Scope) getColumnAsScope(column string) *Scope {
indirectScopeValue := scope.IndirectValue()
switch indirectScopeValue.Kind() {
case reflect.Slice:
if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok {
fieldType := fieldStruct.Type
if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr {
fieldType = fieldType.Elem()
}
results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem()
for i := 0; i < indirectScopeValue.Len(); i++ {
result := reflect.Indirect(reflect.Indirect(indirectScopeValue.Index(i)).FieldByName(column))
if result.Kind() == reflect.Slice {
for j := 0; j < result.Len(); j++ {
if elem := result.Index(j); elem.CanAddr() {
results = reflect.Append(results, elem.Addr())
}
}
} else if result.CanAddr() {
results = reflect.Append(results, result.Addr())
}
}
return scope.New(results.Interface())
}
case reflect.Struct:
if field := indirectScopeValue.FieldByName(column); field.CanAddr() {
return scope.New(field.Addr().Interface())
}
}
return nil
}

61
scope_utils.go Normal file
View File

@ -0,0 +1,61 @@
package gorm
import "reflect"
func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) {
indirectScopeValue := scope.IndirectValue()
switch indirectScopeValue.Kind() {
case reflect.Slice:
for i := 0; i < indirectScopeValue.Len(); i++ {
var result []interface{}
var object = reflect.Indirect(indirectScopeValue.Index(i))
for _, column := range columns {
result = append(result, object.FieldByName(column).Interface())
}
results = append(results, result)
}
case reflect.Struct:
var result []interface{}
for _, column := range columns {
result = append(result, indirectScopeValue.FieldByName(column).Interface())
}
return [][]interface{}{result}
}
return
}
func (scope *Scope) getColumnAsScope(column string) *Scope {
indirectScopeValue := scope.IndirectValue()
switch indirectScopeValue.Kind() {
case reflect.Slice:
if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok {
fieldType := fieldStruct.Type
if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr {
fieldType = fieldType.Elem()
}
results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem()
for i := 0; i < indirectScopeValue.Len(); i++ {
result := reflect.Indirect(reflect.Indirect(indirectScopeValue.Index(i)).FieldByName(column))
if result.Kind() == reflect.Slice {
for j := 0; j < result.Len(); j++ {
if elem := result.Index(j); elem.CanAddr() {
results = reflect.Append(results, elem.Addr())
}
}
} else if result.CanAddr() {
results = reflect.Append(results, result.Addr())
}
}
return scope.New(results.Interface())
}
case reflect.Struct:
if field := indirectScopeValue.FieldByName(column); field.CanAddr() {
return scope.New(field.Addr().Interface())
}
}
return nil
}