gorm/preload.go

275 lines
8.2 KiB
Go
Raw Normal View History

package gorm
import (
"errors"
"fmt"
"reflect"
2015-04-21 10:00:36 +03:00
"strings"
)
2016-01-15 10:53:53 +03:00
// Preload preload relations callback
func Preload(scope *Scope) {
2016-01-03 09:21:21 +03:00
if scope.Search.preload == nil || scope.HasError() {
2015-04-21 11:51:52 +03:00
return
}
2016-01-15 10:53:53 +03:00
var (
preloadedMap = map[string]bool{}
fields = scope.Fields()
)
2015-04-21 11:51:52 +03:00
for _, preload := range scope.Search.preload {
2016-01-15 10:53:53 +03:00
var (
preloadFields = strings.Split(preload.schema, ".")
currentScope = scope
currentFields = fields
)
2015-04-21 11:51:52 +03:00
2016-01-15 10:53:53 +03:00
for idx, preloadField := range preloadFields {
var currentPreloadConditions []interface{}
2015-04-21 11:51:52 +03:00
2016-01-15 10:53:53 +03:00
// if not preloaded
if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] {
// assign search conditions to last preload
if idx == len(preloadFields)-1 {
currentPreloadConditions = preload.conditions
2015-04-21 10:00:36 +03:00
}
2016-01-15 10:53:53 +03:00
for _, field := range currentFields {
if field.Name != preloadField || field.Relationship == nil {
continue
}
switch field.Relationship.Kind {
case "has_one":
currentScope.handleHasOnePreload(field, currentPreloadConditions)
case "has_many":
currentScope.handleHasManyPreload(field, currentPreloadConditions)
case "belongs_to":
currentScope.handleBelongsToPreload(field, currentPreloadConditions)
case "many_to_many":
currentScope.handleManyToManyPreload(field, currentPreloadConditions)
default:
scope.Err(errors.New("unsupported relation"))
}
preloadedMap[preloadKey] = true
break
}
2015-04-21 11:51:52 +03:00
2016-01-15 10:53:53 +03:00
if !preloadedMap[preloadKey] {
scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType))
return
2015-04-21 11:51:52 +03:00
}
}
2016-01-15 10:53:53 +03:00
// preload next level
if idx < len(preloadFields)-1 {
currentScope = currentScope.getColumnAsScope(preloadField)
2015-04-21 11:51:52 +03:00
currentFields = currentScope.Fields()
}
}
}
}
2015-04-21 11:51:52 +03:00
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
2015-07-30 13:19:49 +03:00
relation := field.Relationship
2015-07-30 17:36:04 +03:00
2016-01-15 10:53:53 +03:00
// get relations's primary keys
2015-07-30 13:19:49 +03:00
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames)
2015-04-21 11:51:52 +03:00
if len(primaryKeys) == 0 {
return
}
2016-01-15 10:53:53 +03:00
// find relations
2015-04-21 11:51:52 +03:00
results := makeSlice(field.Struct.Type)
2015-07-30 13:19:49 +03:00
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
2015-04-21 11:51:52 +03:00
2016-01-15 10:53:53 +03:00
// assign find results
var (
resultsValue = reflect.Indirect(reflect.ValueOf(results))
indirectScopeValue = scope.IndirectValue()
)
for i := 0; i < resultsValue.Len(); i++ {
result := resultsValue.Index(i)
if indirectScopeValue.Kind() == reflect.Slice {
2016-01-15 15:37:41 +03:00
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
2016-01-15 10:53:53 +03:00
for j := 0; j < indirectScopeValue.Len(); j++ {
2016-01-15 15:37:41 +03:00
if indirectValue := reflect.Indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) {
indirectValue.FieldByName(field.Name).Set(result)
2015-04-21 11:51:52 +03:00
break
}
}
} else {
2016-01-15 10:53:53 +03:00
scope.Err(field.Set(result))
2015-04-21 11:51:52 +03:00
}
}
}
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
2015-07-30 13:19:49 +03:00
relation := field.Relationship
2016-01-15 15:37:41 +03:00
// get relations's primary keys
2015-07-30 13:19:49 +03:00
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames)
2015-04-21 11:51:52 +03:00
if len(primaryKeys) == 0 {
return
}
2016-01-15 15:37:41 +03:00
// find relations
2015-04-21 11:51:52 +03:00
results := makeSlice(field.Struct.Type)
2015-07-30 13:19:49 +03:00
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
2016-01-15 15:37:41 +03:00
// assign find results
var (
resultsValue = reflect.Indirect(reflect.ValueOf(results))
indirectScopeValue = scope.IndirectValue()
)
if indirectScopeValue.Kind() == reflect.Slice {
for i := 0; i < resultsValue.Len(); i++ {
result := resultsValue.Index(i)
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
for j := 0; j < indirectScopeValue.Len(); j++ {
object := reflect.Indirect(indirectScopeValue.Index(j))
if equalAsString(getValueFromFields(object, relation.AssociationForeignFieldNames), foreignValues) {
objectField := object.FieldByName(field.Name)
objectField.Set(reflect.Append(objectField, result))
2015-04-21 11:51:52 +03:00
break
}
}
}
} else {
2016-01-15 15:37:41 +03:00
scope.Err(field.Set(resultsValue))
2015-04-21 11:51:52 +03:00
}
}
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
relation := field.Relationship
2016-01-15 15:37:41 +03:00
// get relations's primary keys
2015-07-30 13:19:49 +03:00
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames)
2015-04-21 11:51:52 +03:00
if len(primaryKeys) == 0 {
return
}
2016-01-15 15:37:41 +03:00
// find relations
2015-04-21 11:51:52 +03:00
results := makeSlice(field.Struct.Type)
2015-07-30 17:36:04 +03:00
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
2015-04-21 11:51:52 +03:00
2016-01-15 15:37:41 +03:00
// assign find results
var (
resultsValue = reflect.Indirect(reflect.ValueOf(results))
indirectScopeValue = scope.IndirectValue()
)
for i := 0; i < resultsValue.Len(); i++ {
result := resultsValue.Index(i)
if indirectScopeValue.Kind() == reflect.Slice {
2016-01-15 10:53:53 +03:00
value := getValueFromFields(result, relation.AssociationForeignFieldNames)
2016-01-15 15:37:41 +03:00
for j := 0; j < indirectScopeValue.Len(); j++ {
object := reflect.Indirect(indirectScopeValue.Index(j))
2016-01-15 10:53:53 +03:00
if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
2015-04-21 11:51:52 +03:00
object.FieldByName(field.Name).Set(result)
}
}
} else {
2016-01-15 15:37:41 +03:00
scope.Err(field.Set(result))
2015-04-21 11:51:52 +03:00
}
}
}
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
2016-01-13 05:11:31 +03:00
var (
relation = field.Relationship
joinTableHandler = relation.JoinTableHandler
2016-01-15 15:37:41 +03:00
fieldType = field.StructField.Struct.Type.Elem()
2016-01-15 05:08:22 +03:00
foreignKeyValue interface{}
foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type()
2016-01-15 15:37:41 +03:00
linkHash = map[string][]reflect.Value{}
2016-01-13 05:11:31 +03:00
isPtr bool
)
2016-01-15 15:37:41 +03:00
if fieldType.Kind() == reflect.Ptr {
isPtr = true
2016-01-15 15:37:41 +03:00
fieldType = fieldType.Elem()
}
2015-08-16 10:10:11 +03:00
2016-01-15 15:37:41 +03:00
var sourceKeys = []string{}
2015-08-16 10:10:11 +03:00
for _, key := range joinTableHandler.SourceForeignKeys() {
sourceKeys = append(sourceKeys, key.DBName)
}
2016-01-15 15:37:41 +03:00
// generate query with join table
preloadJoinDB := scope.NewDB().Table(scope.New(reflect.New(fieldType).Interface()).TableName()).Select("*")
preloadJoinDB = joinTableHandler.JoinWith(joinTableHandler, preloadJoinDB, scope.Value)
2016-01-13 05:11:31 +03:00
// preload inline conditions
if len(conditions) > 0 {
preloadJoinDB = preloadJoinDB.Where(conditions[0], conditions[1:]...)
}
2016-01-13 05:11:31 +03:00
rows, err := preloadJoinDB.Rows()
2015-08-16 10:10:11 +03:00
if scope.Err(err) != nil {
return
}
defer rows.Close()
columns, _ := rows.Columns()
for rows.Next() {
2016-01-15 05:08:22 +03:00
var (
2016-01-15 15:37:41 +03:00
elem = reflect.New(fieldType).Elem()
2016-01-15 05:08:22 +03:00
fields = scope.New(elem.Addr().Interface()).Fields()
)
// register foreign keys in join tables
for _, sourceKey := range sourceKeys {
fields[sourceKey] = &Field{Field: reflect.New(foreignKeyType).Elem()}
2015-08-16 10:10:11 +03:00
}
2016-01-15 05:08:22 +03:00
scope.scan(rows, columns, fields)
2015-08-16 10:10:11 +03:00
2016-01-15 05:08:22 +03:00
// generate hashed forkey keys in join table
var foreignKeys = make([]interface{}, len(sourceKeys))
for idx, sourceKey := range sourceKeys {
foreignKeys[idx] = fields[sourceKey].Field.Elem().Interface()
2015-08-16 10:10:11 +03:00
}
2016-01-15 05:08:22 +03:00
hashedSourceKeys := toString(foreignKeys)
2015-08-16 10:10:11 +03:00
2016-01-15 05:08:22 +03:00
if isPtr {
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr())
} else {
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem)
}
2015-08-16 10:10:11 +03:00
}
// assign find results
var (
indirectScopeValue = scope.IndirectValue()
fieldsSourceMap = map[string]reflect.Value{}
foreignFieldNames = []string{}
2016-01-15 15:37:41 +03:00
fields = scope.Fields()
)
for _, dbName := range relation.ForeignFieldNames {
2016-01-15 15:37:41 +03:00
if field, ok := fields[dbName]; ok {
foreignFieldNames = append(foreignFieldNames, field.Name)
}
}
if indirectScopeValue.Kind() == reflect.Slice {
for j := 0; j < indirectScopeValue.Len(); j++ {
object := reflect.Indirect(indirectScopeValue.Index(j))
2016-01-15 10:53:53 +03:00
fieldsSourceMap[toString(getValueFromFields(object, foreignFieldNames))] = object.FieldByName(field.Name)
2015-08-16 12:25:25 +03:00
}
} else if indirectScopeValue.IsValid() {
2016-01-15 10:53:53 +03:00
fieldsSourceMap[toString(getValueFromFields(indirectScopeValue, foreignFieldNames))] = indirectScopeValue.FieldByName(field.Name)
}
for source, link := range linkHash {
fieldsSourceMap[source].Set(reflect.Append(fieldsSourceMap[source], link...))
2015-08-16 10:10:11 +03:00
}
}