gorm/callback_query_preload.go

322 lines
9.9 KiB
Go
Raw Normal View History

package gorm
import (
"errors"
"fmt"
"reflect"
2015-04-21 10:00:36 +03:00
"strings"
)
2016-01-17 15:51:11 +03:00
// preloadCallback used to preload associations
2016-01-17 10:30:42 +03:00
func preloadCallback(scope *Scope) {
2016-01-03 09:21:21 +03:00
if scope.Search.preload == nil || scope.HasError() {
2015-04-21 11:51:52 +03:00
return
}
2016-01-15 10:53:53 +03:00
var (
preloadedMap = map[string]bool{}
fields = scope.Fields()
)
2015-04-21 11:51:52 +03:00
for _, preload := range scope.Search.preload {
2016-01-15 10:53:53 +03:00
var (
preloadFields = strings.Split(preload.schema, ".")
currentScope = scope
currentFields = fields
)
2015-04-21 11:51:52 +03:00
2016-01-15 10:53:53 +03:00
for idx, preloadField := range preloadFields {
var currentPreloadConditions []interface{}
2015-04-21 11:51:52 +03:00
2016-01-15 10:53:53 +03:00
// if not preloaded
if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] {
// assign search conditions to last preload
if idx == len(preloadFields)-1 {
currentPreloadConditions = preload.conditions
2015-04-21 10:00:36 +03:00
}
2016-01-15 10:53:53 +03:00
for _, field := range currentFields {
if field.Name != preloadField || field.Relationship == nil {
continue
}
switch field.Relationship.Kind {
case "has_one":
currentScope.handleHasOnePreload(field, currentPreloadConditions)
case "has_many":
currentScope.handleHasManyPreload(field, currentPreloadConditions)
case "belongs_to":
currentScope.handleBelongsToPreload(field, currentPreloadConditions)
case "many_to_many":
currentScope.handleManyToManyPreload(field, currentPreloadConditions)
default:
scope.Err(errors.New("unsupported relation"))
}
preloadedMap[preloadKey] = true
break
}
2015-04-21 11:51:52 +03:00
2016-01-15 10:53:53 +03:00
if !preloadedMap[preloadKey] {
scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType))
return
2015-04-21 11:51:52 +03:00
}
}
2016-01-15 10:53:53 +03:00
// preload next level
if idx < len(preloadFields)-1 {
currentScope = currentScope.getColumnAsScope(preloadField)
2015-04-21 11:51:52 +03:00
currentFields = currentScope.Fields()
}
}
}
}
func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) {
var (
preloadDB = scope.NewDB()
preloadConditions []interface{}
)
for _, condition := range conditions {
if scopes, ok := condition.(func(*DB) *DB); ok {
preloadDB = scopes(preloadDB)
} else {
preloadConditions = append(preloadConditions, condition)
}
}
return preloadDB, preloadConditions
}
2016-01-17 15:51:11 +03:00
// handleHasOnePreload used to preload has one associations
2015-04-21 11:51:52 +03:00
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
2015-07-30 13:19:49 +03:00
relation := field.Relationship
2015-07-30 17:36:04 +03:00
2016-01-15 10:53:53 +03:00
// get relations's primary keys
2016-01-15 17:53:09 +03:00
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
2015-04-21 11:51:52 +03:00
if len(primaryKeys) == 0 {
return
}
// preload conditions
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
2016-01-15 10:53:53 +03:00
// find relations
2015-04-21 11:51:52 +03:00
results := makeSlice(field.Struct.Type)
scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
2015-04-21 11:51:52 +03:00
2016-01-15 10:53:53 +03:00
// assign find results
var (
2016-01-18 07:20:27 +03:00
resultsValue = indirect(reflect.ValueOf(results))
2016-01-15 10:53:53 +03:00
indirectScopeValue = scope.IndirectValue()
)
for i := 0; i < resultsValue.Len(); i++ {
result := resultsValue.Index(i)
if indirectScopeValue.Kind() == reflect.Slice {
2016-01-15 15:37:41 +03:00
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
2016-01-15 10:53:53 +03:00
for j := 0; j < indirectScopeValue.Len(); j++ {
2016-01-18 07:20:27 +03:00
if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) {
2016-01-15 15:37:41 +03:00
indirectValue.FieldByName(field.Name).Set(result)
2015-04-21 11:51:52 +03:00
break
}
}
} else {
2016-01-15 10:53:53 +03:00
scope.Err(field.Set(result))
2015-04-21 11:51:52 +03:00
}
}
}
2016-01-17 15:51:11 +03:00
// handleHasManyPreload used to preload has many associations
2015-04-21 11:51:52 +03:00
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
2015-07-30 13:19:49 +03:00
relation := field.Relationship
2016-01-15 15:37:41 +03:00
// get relations's primary keys
2016-01-15 17:53:09 +03:00
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
2015-04-21 11:51:52 +03:00
if len(primaryKeys) == 0 {
return
}
// preload conditions
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
2016-01-15 15:37:41 +03:00
// find relations
2015-04-21 11:51:52 +03:00
results := makeSlice(field.Struct.Type)
scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
2016-01-15 15:37:41 +03:00
// assign find results
var (
2016-01-18 07:20:27 +03:00
resultsValue = indirect(reflect.ValueOf(results))
2016-01-15 15:37:41 +03:00
indirectScopeValue = scope.IndirectValue()
)
if indirectScopeValue.Kind() == reflect.Slice {
2016-05-09 15:15:35 +03:00
preloadMap := make(map[string][]reflect.Value)
2016-01-15 15:37:41 +03:00
for i := 0; i < resultsValue.Len(); i++ {
result := resultsValue.Index(i)
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
2016-05-09 15:15:35 +03:00
preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result)
}
2016-05-09 17:42:07 +03:00
for j := 0; j < indirectScopeValue.Len(); j++ {
object := indirect(indirectScopeValue.Index(j))
2016-05-09 15:15:35 +03:00
objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames)
2016-05-09 17:42:07 +03:00
if results, ok := preloadMap[toString(objectRealValue)]; ok {
2016-05-09 15:15:35 +03:00
f := object.FieldByName(field.Name)
f.Set(reflect.Append(f, results...))
2015-04-21 11:51:52 +03:00
}
}
} else {
2016-01-15 15:37:41 +03:00
scope.Err(field.Set(resultsValue))
2015-04-21 11:51:52 +03:00
}
}
2016-01-17 15:51:11 +03:00
// handleBelongsToPreload used to preload belongs to associations
2015-04-21 11:51:52 +03:00
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
relation := field.Relationship
2016-01-15 15:37:41 +03:00
// preload conditions
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
2016-01-15 15:37:41 +03:00
// get relations's primary keys
2016-01-15 17:53:09 +03:00
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value)
2015-04-21 11:51:52 +03:00
if len(primaryKeys) == 0 {
return
}
2016-01-15 15:37:41 +03:00
// find relations
2015-04-21 11:51:52 +03:00
results := makeSlice(field.Struct.Type)
scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
2015-04-21 11:51:52 +03:00
2016-01-15 15:37:41 +03:00
// assign find results
var (
2016-01-18 07:20:27 +03:00
resultsValue = indirect(reflect.ValueOf(results))
2016-01-15 15:37:41 +03:00
indirectScopeValue = scope.IndirectValue()
)
for i := 0; i < resultsValue.Len(); i++ {
result := resultsValue.Index(i)
if indirectScopeValue.Kind() == reflect.Slice {
2016-01-15 10:53:53 +03:00
value := getValueFromFields(result, relation.AssociationForeignFieldNames)
2016-01-15 15:37:41 +03:00
for j := 0; j < indirectScopeValue.Len(); j++ {
2016-01-18 07:20:27 +03:00
object := indirect(indirectScopeValue.Index(j))
2016-01-15 10:53:53 +03:00
if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
2015-04-21 11:51:52 +03:00
object.FieldByName(field.Name).Set(result)
}
}
} else {
2016-01-15 15:37:41 +03:00
scope.Err(field.Set(result))
2015-04-21 11:51:52 +03:00
}
}
}
2016-01-17 15:51:11 +03:00
// handleManyToManyPreload used to preload many to many associations
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
2016-01-13 05:11:31 +03:00
var (
relation = field.Relationship
joinTableHandler = relation.JoinTableHandler
fieldType = field.Struct.Type.Elem()
2016-01-15 05:08:22 +03:00
foreignKeyValue interface{}
foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type()
2016-01-15 15:37:41 +03:00
linkHash = map[string][]reflect.Value{}
2016-01-13 05:11:31 +03:00
isPtr bool
)
2016-01-15 15:37:41 +03:00
if fieldType.Kind() == reflect.Ptr {
isPtr = true
2016-01-15 15:37:41 +03:00
fieldType = fieldType.Elem()
}
2015-08-16 10:10:11 +03:00
2016-01-15 15:37:41 +03:00
var sourceKeys = []string{}
2015-08-16 10:10:11 +03:00
for _, key := range joinTableHandler.SourceForeignKeys() {
sourceKeys = append(sourceKeys, key.DBName)
}
// preload conditions
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
2016-01-15 15:37:41 +03:00
// generate query with join table
newScope := scope.New(reflect.New(fieldType).Interface())
preloadDB = preloadDB.Table(newScope.TableName()).Select("*")
preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value)
2016-01-13 05:11:31 +03:00
// preload inline conditions
if len(preloadConditions) > 0 {
preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...)
}
2016-01-13 05:11:31 +03:00
rows, err := preloadDB.Rows()
2015-08-16 10:10:11 +03:00
if scope.Err(err) != nil {
return
}
defer rows.Close()
columns, _ := rows.Columns()
for rows.Next() {
2016-01-15 05:08:22 +03:00
var (
2016-01-15 15:37:41 +03:00
elem = reflect.New(fieldType).Elem()
2016-03-10 12:13:48 +03:00
fields = scope.New(elem.Addr().Interface()).Fields()
2016-01-15 05:08:22 +03:00
)
// register foreign keys in join tables
2016-03-10 12:13:48 +03:00
var joinTableFields []*Field
2016-01-15 05:08:22 +03:00
for _, sourceKey := range sourceKeys {
2016-03-10 12:13:48 +03:00
joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()})
2015-08-16 10:10:11 +03:00
}
2016-03-10 12:13:48 +03:00
scope.scan(rows, columns, append(fields, joinTableFields...))
2015-08-16 10:10:11 +03:00
2016-01-15 05:08:22 +03:00
var foreignKeys = make([]interface{}, len(sourceKeys))
2016-03-10 12:13:48 +03:00
// generate hashed forkey keys in join table
for idx, joinTableField := range joinTableFields {
if !joinTableField.Field.IsNil() {
foreignKeys[idx] = joinTableField.Field.Elem().Interface()
}
2015-08-16 10:10:11 +03:00
}
2016-01-15 05:08:22 +03:00
hashedSourceKeys := toString(foreignKeys)
2015-08-16 10:10:11 +03:00
2016-01-15 05:08:22 +03:00
if isPtr {
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr())
} else {
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem)
}
2015-08-16 10:10:11 +03:00
}
// assign find results
var (
indirectScopeValue = scope.IndirectValue()
2016-05-10 09:43:50 +03:00
fieldsSourceMap = map[string][]reflect.Value{}
foreignFieldNames = []string{}
)
for _, dbName := range relation.ForeignFieldNames {
2016-03-10 12:13:48 +03:00
if field, ok := scope.FieldByName(dbName); ok {
foreignFieldNames = append(foreignFieldNames, field.Name)
}
}
if indirectScopeValue.Kind() == reflect.Slice {
for j := 0; j < indirectScopeValue.Len(); j++ {
2016-01-18 07:20:27 +03:00
object := indirect(indirectScopeValue.Index(j))
2016-05-10 09:43:50 +03:00
key := toString(getValueFromFields(object, foreignFieldNames))
fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name))
2015-08-16 12:25:25 +03:00
}
} else if indirectScopeValue.IsValid() {
2016-05-10 09:43:50 +03:00
key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames))
fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name))
}
for source, link := range linkHash {
2016-05-10 09:43:50 +03:00
for i, field := range fieldsSourceMap[source] {
//If not 0 this means Value is a pointer and we already added preloaded models to it
if fieldsSourceMap[source][i].Len() != 0 {
continue
}
field.Set(reflect.Append(fieldsSourceMap[source][i], link...))
}
2015-08-16 10:10:11 +03:00
}
}