mirror of https://github.com/go-gorm/gorm.git
Refactor Preload
This commit is contained in:
parent
41620f3d6c
commit
3326a4e69d
198
preload.go
198
preload.go
|
@ -1,139 +1,109 @@
|
||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql/driver"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func getRealValue(value reflect.Value, columns []string) (results []interface{}) {
|
// Preload preload relations callback
|
||||||
// If value is a nil pointer, Indirect returns a zero Value!
|
|
||||||
// Therefor we need to check for a zero value,
|
|
||||||
// as FieldByName could panic
|
|
||||||
if pointedValue := reflect.Indirect(value); pointedValue.IsValid() {
|
|
||||||
for _, column := range columns {
|
|
||||||
if pointedValue.FieldByName(column).IsValid() {
|
|
||||||
result := pointedValue.FieldByName(column).Interface()
|
|
||||||
if r, ok := result.(driver.Valuer); ok {
|
|
||||||
result, _ = r.Value()
|
|
||||||
}
|
|
||||||
results = append(results, result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func equalAsString(a interface{}, b interface{}) bool {
|
|
||||||
return fmt.Sprintf("%v", a) == fmt.Sprintf("%v", b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Preload(scope *Scope) {
|
func Preload(scope *Scope) {
|
||||||
if scope.Search.preload == nil || scope.HasError() {
|
if scope.Search.preload == nil || scope.HasError() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
preloadMap := map[string]bool{}
|
var (
|
||||||
fields := scope.Fields()
|
preloadedMap = map[string]bool{}
|
||||||
for _, preload := range scope.Search.preload {
|
fields = scope.Fields()
|
||||||
schema, conditions := preload.schema, preload.conditions
|
)
|
||||||
keys := strings.Split(schema, ".")
|
|
||||||
currentScope := scope
|
|
||||||
currentFields := fields
|
|
||||||
originalConditions := conditions
|
|
||||||
conditions = []interface{}{}
|
|
||||||
for i, key := range keys {
|
|
||||||
var found bool
|
|
||||||
if preloadMap[strings.Join(keys[:i+1], ".")] {
|
|
||||||
goto nextLoop
|
|
||||||
}
|
|
||||||
|
|
||||||
if i == len(keys)-1 {
|
for _, preload := range scope.Search.preload {
|
||||||
conditions = originalConditions
|
var (
|
||||||
|
preloadFields = strings.Split(preload.schema, ".")
|
||||||
|
currentScope = scope
|
||||||
|
currentFields = fields
|
||||||
|
)
|
||||||
|
|
||||||
|
for idx, preloadField := range preloadFields {
|
||||||
|
var currentPreloadConditions []interface{}
|
||||||
|
|
||||||
|
// if not preloaded
|
||||||
|
if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] {
|
||||||
|
|
||||||
|
// assign search conditions to last preload
|
||||||
|
if idx == len(preloadFields)-1 {
|
||||||
|
currentPreloadConditions = preload.conditions
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, field := range currentFields {
|
for _, field := range currentFields {
|
||||||
if field.Name != key || field.Relationship == nil {
|
if field.Name != preloadField || field.Relationship == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
found = true
|
|
||||||
switch field.Relationship.Kind {
|
switch field.Relationship.Kind {
|
||||||
case "has_one":
|
case "has_one":
|
||||||
currentScope.handleHasOnePreload(field, conditions)
|
currentScope.handleHasOnePreload(field, currentPreloadConditions)
|
||||||
case "has_many":
|
case "has_many":
|
||||||
currentScope.handleHasManyPreload(field, conditions)
|
currentScope.handleHasManyPreload(field, currentPreloadConditions)
|
||||||
case "belongs_to":
|
case "belongs_to":
|
||||||
currentScope.handleBelongsToPreload(field, conditions)
|
currentScope.handleBelongsToPreload(field, currentPreloadConditions)
|
||||||
case "many_to_many":
|
case "many_to_many":
|
||||||
currentScope.handleManyToManyPreload(field, conditions)
|
currentScope.handleManyToManyPreload(field, currentPreloadConditions)
|
||||||
default:
|
default:
|
||||||
currentScope.Err(errors.New("not supported relation"))
|
scope.Err(errors.New("unsupported relation"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
preloadedMap[preloadKey] = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if !found {
|
if !preloadedMap[preloadKey] {
|
||||||
value := reflect.ValueOf(currentScope.Value)
|
scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType))
|
||||||
if value.Kind() == reflect.Slice && value.Type().Elem().Kind() == reflect.Interface {
|
|
||||||
value = value.Index(0).Elem()
|
|
||||||
}
|
|
||||||
scope.Err(fmt.Errorf("can't find field %s in %s", key, value.Type()))
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
preloadMap[strings.Join(keys[:i+1], ".")] = true
|
// preload next level
|
||||||
|
if idx < len(preloadFields)-1 {
|
||||||
nextLoop:
|
currentScope = currentScope.getColumnAsScope(preloadField)
|
||||||
if i < len(keys)-1 {
|
|
||||||
currentScope = currentScope.getColumnsAsScope(key)
|
|
||||||
currentFields = currentScope.Fields()
|
currentFields = currentScope.Fields()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeSlice(typ reflect.Type) interface{} {
|
|
||||||
if typ.Kind() == reflect.Slice {
|
|
||||||
typ = typ.Elem()
|
|
||||||
}
|
|
||||||
sliceType := reflect.SliceOf(typ)
|
|
||||||
slice := reflect.New(sliceType)
|
|
||||||
slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0))
|
|
||||||
return slice.Interface()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
|
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
|
||||||
relation := field.Relationship
|
relation := field.Relationship
|
||||||
|
|
||||||
|
// get relations's primary keys
|
||||||
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames)
|
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames)
|
||||||
if len(primaryKeys) == 0 {
|
if len(primaryKeys) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// find relations
|
||||||
results := makeSlice(field.Struct.Type)
|
results := makeSlice(field.Struct.Type)
|
||||||
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
|
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
|
||||||
resultValues := reflect.Indirect(reflect.ValueOf(results))
|
|
||||||
|
|
||||||
for i := 0; i < resultValues.Len(); i++ {
|
// assign find results
|
||||||
result := resultValues.Index(i)
|
var (
|
||||||
if scope.IndirectValue().Kind() == reflect.Slice {
|
resultsValue = reflect.Indirect(reflect.ValueOf(results))
|
||||||
value := getRealValue(result, relation.ForeignFieldNames)
|
indirectScopeValue = scope.IndirectValue()
|
||||||
objects := scope.IndirectValue()
|
)
|
||||||
for j := 0; j < objects.Len(); j++ {
|
|
||||||
if equalAsString(getRealValue(objects.Index(j), relation.AssociationForeignFieldNames), value) {
|
for i := 0; i < resultsValue.Len(); i++ {
|
||||||
reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result)
|
result := resultsValue.Index(i)
|
||||||
|
if indirectScopeValue.Kind() == reflect.Slice {
|
||||||
|
value := getValueFromFields(result, relation.ForeignFieldNames)
|
||||||
|
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||||
|
if equalAsString(getValueFromFields(indirectScopeValue.Index(j), relation.AssociationForeignFieldNames), value) {
|
||||||
|
reflect.Indirect(indirectScopeValue.Index(j)).FieldByName(field.Name).Set(result)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := scope.SetColumn(field, result); err != nil {
|
scope.Err(field.Set(result))
|
||||||
scope.Err(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -152,11 +122,11 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{})
|
||||||
if scope.IndirectValue().Kind() == reflect.Slice {
|
if scope.IndirectValue().Kind() == reflect.Slice {
|
||||||
for i := 0; i < resultValues.Len(); i++ {
|
for i := 0; i < resultValues.Len(); i++ {
|
||||||
result := resultValues.Index(i)
|
result := resultValues.Index(i)
|
||||||
value := getRealValue(result, relation.ForeignFieldNames)
|
value := getValueFromFields(result, relation.ForeignFieldNames)
|
||||||
objects := scope.IndirectValue()
|
objects := scope.IndirectValue()
|
||||||
for j := 0; j < objects.Len(); j++ {
|
for j := 0; j < objects.Len(); j++ {
|
||||||
object := reflect.Indirect(objects.Index(j))
|
object := reflect.Indirect(objects.Index(j))
|
||||||
if equalAsString(getRealValue(object, relation.AssociationForeignFieldNames), value) {
|
if equalAsString(getValueFromFields(object, relation.AssociationForeignFieldNames), value) {
|
||||||
f := object.FieldByName(field.Name)
|
f := object.FieldByName(field.Name)
|
||||||
f.Set(reflect.Append(f, result))
|
f.Set(reflect.Append(f, result))
|
||||||
break
|
break
|
||||||
|
@ -182,11 +152,11 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
|
||||||
for i := 0; i < resultValues.Len(); i++ {
|
for i := 0; i < resultValues.Len(); i++ {
|
||||||
result := resultValues.Index(i)
|
result := resultValues.Index(i)
|
||||||
if scope.IndirectValue().Kind() == reflect.Slice {
|
if scope.IndirectValue().Kind() == reflect.Slice {
|
||||||
value := getRealValue(result, relation.AssociationForeignFieldNames)
|
value := getValueFromFields(result, relation.AssociationForeignFieldNames)
|
||||||
objects := scope.IndirectValue()
|
objects := scope.IndirectValue()
|
||||||
for j := 0; j < objects.Len(); j++ {
|
for j := 0; j < objects.Len(); j++ {
|
||||||
object := reflect.Indirect(objects.Index(j))
|
object := reflect.Indirect(objects.Index(j))
|
||||||
if equalAsString(getRealValue(object, relation.ForeignFieldNames), value) {
|
if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
|
||||||
object.FieldByName(field.Name).Set(result)
|
object.FieldByName(field.Name).Set(result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -276,10 +246,10 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
|
||||||
if indirectScopeValue.Kind() == reflect.Slice {
|
if indirectScopeValue.Kind() == reflect.Slice {
|
||||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||||
object := reflect.Indirect(indirectScopeValue.Index(j))
|
object := reflect.Indirect(indirectScopeValue.Index(j))
|
||||||
fieldsSourceMap[toString(getRealValue(object, foreignFieldNames))] = object.FieldByName(field.Name)
|
fieldsSourceMap[toString(getValueFromFields(object, foreignFieldNames))] = object.FieldByName(field.Name)
|
||||||
}
|
}
|
||||||
} else if indirectScopeValue.IsValid() {
|
} else if indirectScopeValue.IsValid() {
|
||||||
fieldsSourceMap[toString(getRealValue(indirectScopeValue, foreignFieldNames))] = indirectScopeValue.FieldByName(field.Name)
|
fieldsSourceMap[toString(getValueFromFields(indirectScopeValue, foreignFieldNames))] = indirectScopeValue.FieldByName(field.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
for source, link := range linkHash {
|
for source, link := range linkHash {
|
||||||
|
@ -308,46 +278,38 @@ func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) getColumnsAsScope(column string) *Scope {
|
func (scope *Scope) getColumnAsScope(column string) *Scope {
|
||||||
values := scope.IndirectValue()
|
indirectScopeValue := scope.IndirectValue()
|
||||||
switch values.Kind() {
|
|
||||||
|
switch indirectScopeValue.Kind() {
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
modelType := values.Type().Elem()
|
if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok {
|
||||||
if modelType.Kind() == reflect.Ptr {
|
fieldType := fieldStruct.Type
|
||||||
modelType = modelType.Elem()
|
if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr {
|
||||||
|
fieldType = fieldType.Elem()
|
||||||
}
|
}
|
||||||
fieldStruct, _ := modelType.FieldByName(column)
|
|
||||||
var columns reflect.Value
|
results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem()
|
||||||
if fieldStruct.Type.Kind() == reflect.Slice || fieldStruct.Type.Kind() == reflect.Ptr {
|
|
||||||
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type.Elem()))).Elem()
|
for i := 0; i < indirectScopeValue.Len(); i++ {
|
||||||
} else {
|
result := reflect.Indirect(reflect.Indirect(indirectScopeValue.Index(i)).FieldByName(column))
|
||||||
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type))).Elem()
|
|
||||||
}
|
if result.Kind() == reflect.Slice {
|
||||||
for i := 0; i < values.Len(); i++ {
|
for j := 0; j < result.Len(); j++ {
|
||||||
column := reflect.Indirect(values.Index(i)).FieldByName(column)
|
if elem := result.Index(j); elem.CanAddr() {
|
||||||
if column.Kind() == reflect.Ptr {
|
results = reflect.Append(results, elem.Addr())
|
||||||
column = column.Elem()
|
|
||||||
}
|
|
||||||
if column.Kind() == reflect.Slice {
|
|
||||||
for i := 0; i < column.Len(); i++ {
|
|
||||||
elem := column.Index(i)
|
|
||||||
if elem.CanAddr() {
|
|
||||||
columns = reflect.Append(columns, elem.Addr())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if result.CanAddr() {
|
||||||
if column.CanAddr() {
|
results = reflect.Append(results, result.Addr())
|
||||||
columns = reflect.Append(columns, column.Addr())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return scope.New(results.Interface())
|
||||||
}
|
}
|
||||||
return scope.New(columns.Interface())
|
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
field := values.FieldByName(column)
|
if field := indirectScopeValue.FieldByName(column); field.CanAddr() {
|
||||||
if !field.CanAddr() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return scope.New(field.Addr().Interface())
|
return scope.New(field.Addr().Interface())
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
@ -73,6 +74,10 @@ func convertInterfaceToMap(values interface{}) map[string]interface{} {
|
||||||
return attrs
|
return attrs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func equalAsString(a interface{}, b interface{}) bool {
|
||||||
|
return toString(a) == toString(b)
|
||||||
|
}
|
||||||
|
|
||||||
func toString(str interface{}) string {
|
func toString(str interface{}) string {
|
||||||
if values, ok := str.([]interface{}); ok {
|
if values, ok := str.([]interface{}); ok {
|
||||||
var results []string
|
var results []string
|
||||||
|
@ -87,6 +92,16 @@ func toString(str interface{}) string {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func makeSlice(elemType reflect.Type) interface{} {
|
||||||
|
if elemType.Kind() == reflect.Slice {
|
||||||
|
elemType = elemType.Elem()
|
||||||
|
}
|
||||||
|
sliceType := reflect.SliceOf(elemType)
|
||||||
|
slice := reflect.New(sliceType)
|
||||||
|
slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0))
|
||||||
|
return slice.Interface()
|
||||||
|
}
|
||||||
|
|
||||||
func strInSlice(a string, list []string) bool {
|
func strInSlice(a string, list []string) bool {
|
||||||
for _, b := range list {
|
for _, b := range list {
|
||||||
if b == a {
|
if b == a {
|
||||||
|
@ -95,3 +110,22 @@ func strInSlice(a string, list []string) bool {
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getValueFromFields return given fields's value
|
||||||
|
func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) {
|
||||||
|
// If value is a nil pointer, Indirect returns a zero Value!
|
||||||
|
// Therefor we need to check for a zero value,
|
||||||
|
// as FieldByName could panic
|
||||||
|
if indirectValue := reflect.Indirect(value); indirectValue.IsValid() {
|
||||||
|
for _, fieldName := range fieldNames {
|
||||||
|
if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() {
|
||||||
|
result := fieldValue.Interface()
|
||||||
|
if r, ok := result.(driver.Valuer); ok {
|
||||||
|
result, _ = r.Value()
|
||||||
|
}
|
||||||
|
results = append(results, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue