mirror of https://github.com/go-gorm/gorm.git
388 lines
11 KiB
Go
388 lines
11 KiB
Go
package gorm
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
)
|
|
|
|
// preloadCallback used to preload associations
|
|
func preloadCallback(scope *Scope) {
|
|
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
|
|
return
|
|
}
|
|
|
|
if _, ok := scope.Get("gorm:auto_preload"); ok {
|
|
autoPreload(scope)
|
|
}
|
|
|
|
if scope.Search.preload == nil || scope.HasError() {
|
|
return
|
|
}
|
|
|
|
var (
|
|
preloadedMap = map[string]bool{}
|
|
fields = scope.Fields()
|
|
)
|
|
|
|
for _, preload := range scope.Search.preload {
|
|
var (
|
|
preloadFields = strings.Split(preload.schema, ".")
|
|
currentScope = scope
|
|
currentFields = fields
|
|
)
|
|
|
|
for idx, preloadField := range preloadFields {
|
|
var currentPreloadConditions []interface{}
|
|
|
|
if currentScope == nil {
|
|
continue
|
|
}
|
|
|
|
// if not preloaded
|
|
if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] {
|
|
|
|
// assign search conditions to last preload
|
|
if idx == len(preloadFields)-1 {
|
|
currentPreloadConditions = preload.conditions
|
|
}
|
|
|
|
for _, field := range currentFields {
|
|
if field.Name != preloadField || field.Relationship == nil {
|
|
continue
|
|
}
|
|
|
|
switch field.Relationship.Kind {
|
|
case "has_one":
|
|
currentScope.handleHasOnePreload(field, currentPreloadConditions)
|
|
case "has_many":
|
|
currentScope.handleHasManyPreload(field, currentPreloadConditions)
|
|
case "belongs_to":
|
|
currentScope.handleBelongsToPreload(field, currentPreloadConditions)
|
|
case "many_to_many":
|
|
currentScope.handleManyToManyPreload(field, currentPreloadConditions)
|
|
default:
|
|
scope.Err(errors.New("unsupported relation"))
|
|
}
|
|
|
|
preloadedMap[preloadKey] = true
|
|
break
|
|
}
|
|
|
|
if !preloadedMap[preloadKey] {
|
|
scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType))
|
|
return
|
|
}
|
|
}
|
|
|
|
// preload next level
|
|
if idx < len(preloadFields)-1 {
|
|
currentScope = currentScope.getColumnAsScope(preloadField)
|
|
if currentScope != nil {
|
|
currentFields = currentScope.Fields()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func autoPreload(scope *Scope) {
|
|
for _, field := range scope.Fields() {
|
|
if field.Relationship == nil {
|
|
continue
|
|
}
|
|
|
|
if val, ok := field.TagSettings["PRELOAD"]; ok {
|
|
if preload, err := strconv.ParseBool(val); err != nil {
|
|
scope.Err(errors.New("invalid preload option"))
|
|
return
|
|
} else if !preload {
|
|
continue
|
|
}
|
|
}
|
|
|
|
scope.Search.Preload(field.Name)
|
|
}
|
|
}
|
|
|
|
func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) {
|
|
var (
|
|
preloadDB = scope.NewDB()
|
|
preloadConditions []interface{}
|
|
)
|
|
|
|
for _, condition := range conditions {
|
|
if scopes, ok := condition.(func(*DB) *DB); ok {
|
|
preloadDB = scopes(preloadDB)
|
|
} else {
|
|
preloadConditions = append(preloadConditions, condition)
|
|
}
|
|
}
|
|
|
|
return preloadDB, preloadConditions
|
|
}
|
|
|
|
// handleHasOnePreload used to preload has one associations
|
|
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
|
|
relation := field.Relationship
|
|
|
|
// get relations's primary keys
|
|
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
|
|
if len(primaryKeys) == 0 {
|
|
return
|
|
}
|
|
|
|
// preload conditions
|
|
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
|
|
|
// find relations
|
|
query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
|
|
values := toQueryValues(primaryKeys)
|
|
if relation.PolymorphicType != "" {
|
|
query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
|
|
values = append(values, relation.PolymorphicValue)
|
|
}
|
|
|
|
results := makeSlice(field.Struct.Type)
|
|
scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
|
|
|
|
// assign find results
|
|
var (
|
|
resultsValue = indirect(reflect.ValueOf(results))
|
|
indirectScopeValue = scope.IndirectValue()
|
|
)
|
|
|
|
if indirectScopeValue.Kind() == reflect.Slice {
|
|
for j := 0; j < indirectScopeValue.Len(); j++ {
|
|
for i := 0; i < resultsValue.Len(); i++ {
|
|
result := resultsValue.Index(i)
|
|
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
|
|
if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) {
|
|
indirectValue.FieldByName(field.Name).Set(result)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
for i := 0; i < resultsValue.Len(); i++ {
|
|
result := resultsValue.Index(i)
|
|
scope.Err(field.Set(result))
|
|
}
|
|
}
|
|
}
|
|
|
|
// handleHasManyPreload used to preload has many associations
|
|
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
|
|
relation := field.Relationship
|
|
|
|
// get relations's primary keys
|
|
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
|
|
if len(primaryKeys) == 0 {
|
|
return
|
|
}
|
|
|
|
// preload conditions
|
|
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
|
|
|
// find relations
|
|
query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
|
|
values := toQueryValues(primaryKeys)
|
|
if relation.PolymorphicType != "" {
|
|
query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
|
|
values = append(values, relation.PolymorphicValue)
|
|
}
|
|
|
|
results := makeSlice(field.Struct.Type)
|
|
scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
|
|
|
|
// assign find results
|
|
var (
|
|
resultsValue = indirect(reflect.ValueOf(results))
|
|
indirectScopeValue = scope.IndirectValue()
|
|
)
|
|
|
|
if indirectScopeValue.Kind() == reflect.Slice {
|
|
preloadMap := make(map[string][]reflect.Value)
|
|
for i := 0; i < resultsValue.Len(); i++ {
|
|
result := resultsValue.Index(i)
|
|
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
|
|
preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result)
|
|
}
|
|
|
|
for j := 0; j < indirectScopeValue.Len(); j++ {
|
|
object := indirect(indirectScopeValue.Index(j))
|
|
objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames)
|
|
f := object.FieldByName(field.Name)
|
|
if results, ok := preloadMap[toString(objectRealValue)]; ok {
|
|
f.Set(reflect.Append(f, results...))
|
|
} else {
|
|
f.Set(reflect.MakeSlice(f.Type(), 0, 0))
|
|
}
|
|
}
|
|
} else {
|
|
scope.Err(field.Set(resultsValue))
|
|
}
|
|
}
|
|
|
|
// handleBelongsToPreload used to preload belongs to associations
|
|
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
|
|
relation := field.Relationship
|
|
|
|
// preload conditions
|
|
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
|
|
|
// get relations's primary keys
|
|
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value)
|
|
if len(primaryKeys) == 0 {
|
|
return
|
|
}
|
|
|
|
// find relations
|
|
results := makeSlice(field.Struct.Type)
|
|
scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
|
|
|
|
// assign find results
|
|
var (
|
|
resultsValue = indirect(reflect.ValueOf(results))
|
|
indirectScopeValue = scope.IndirectValue()
|
|
)
|
|
|
|
for i := 0; i < resultsValue.Len(); i++ {
|
|
result := resultsValue.Index(i)
|
|
if indirectScopeValue.Kind() == reflect.Slice {
|
|
value := getValueFromFields(result, relation.AssociationForeignFieldNames)
|
|
for j := 0; j < indirectScopeValue.Len(); j++ {
|
|
object := indirect(indirectScopeValue.Index(j))
|
|
if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
|
|
object.FieldByName(field.Name).Set(result)
|
|
}
|
|
}
|
|
} else {
|
|
scope.Err(field.Set(result))
|
|
}
|
|
}
|
|
}
|
|
|
|
// handleManyToManyPreload used to preload many to many associations
|
|
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
|
|
var (
|
|
relation = field.Relationship
|
|
joinTableHandler = relation.JoinTableHandler
|
|
fieldType = field.Struct.Type.Elem()
|
|
foreignKeyValue interface{}
|
|
foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type()
|
|
linkHash = map[string][]reflect.Value{}
|
|
isPtr bool
|
|
)
|
|
|
|
if fieldType.Kind() == reflect.Ptr {
|
|
isPtr = true
|
|
fieldType = fieldType.Elem()
|
|
}
|
|
|
|
var sourceKeys = []string{}
|
|
for _, key := range joinTableHandler.SourceForeignKeys() {
|
|
sourceKeys = append(sourceKeys, key.DBName)
|
|
}
|
|
|
|
// preload conditions
|
|
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
|
|
|
// generate query with join table
|
|
newScope := scope.New(reflect.New(fieldType).Interface())
|
|
preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value)
|
|
|
|
if len(preloadDB.search.selects) == 0 {
|
|
preloadDB = preloadDB.Select("*")
|
|
}
|
|
|
|
preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value)
|
|
|
|
// preload inline conditions
|
|
if len(preloadConditions) > 0 {
|
|
preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...)
|
|
}
|
|
|
|
rows, err := preloadDB.Rows()
|
|
|
|
if scope.Err(err) != nil {
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
columns, _ := rows.Columns()
|
|
for rows.Next() {
|
|
var (
|
|
elem = reflect.New(fieldType).Elem()
|
|
fields = scope.New(elem.Addr().Interface()).Fields()
|
|
)
|
|
|
|
// register foreign keys in join tables
|
|
var joinTableFields []*Field
|
|
for _, sourceKey := range sourceKeys {
|
|
joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()})
|
|
}
|
|
|
|
scope.scan(rows, columns, append(fields, joinTableFields...))
|
|
|
|
scope.New(elem.Addr().Interface()).
|
|
InstanceSet("gorm:skip_query_callback", true).
|
|
callCallbacks(scope.db.parent.callbacks.queries)
|
|
|
|
var foreignKeys = make([]interface{}, len(sourceKeys))
|
|
// generate hashed forkey keys in join table
|
|
for idx, joinTableField := range joinTableFields {
|
|
if !joinTableField.Field.IsNil() {
|
|
foreignKeys[idx] = joinTableField.Field.Elem().Interface()
|
|
}
|
|
}
|
|
hashedSourceKeys := toString(foreignKeys)
|
|
|
|
if isPtr {
|
|
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr())
|
|
} else {
|
|
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem)
|
|
}
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
scope.Err(err)
|
|
}
|
|
|
|
// assign find results
|
|
var (
|
|
indirectScopeValue = scope.IndirectValue()
|
|
fieldsSourceMap = map[string][]reflect.Value{}
|
|
foreignFieldNames = []string{}
|
|
)
|
|
|
|
for _, dbName := range relation.ForeignFieldNames {
|
|
if field, ok := scope.FieldByName(dbName); ok {
|
|
foreignFieldNames = append(foreignFieldNames, field.Name)
|
|
}
|
|
}
|
|
|
|
if indirectScopeValue.Kind() == reflect.Slice {
|
|
for j := 0; j < indirectScopeValue.Len(); j++ {
|
|
object := indirect(indirectScopeValue.Index(j))
|
|
key := toString(getValueFromFields(object, foreignFieldNames))
|
|
fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name))
|
|
}
|
|
} else if indirectScopeValue.IsValid() {
|
|
key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames))
|
|
fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name))
|
|
}
|
|
for source, link := range linkHash {
|
|
for i, field := range fieldsSourceMap[source] {
|
|
//If not 0 this means Value is a pointer and we already added preloaded models to it
|
|
if fieldsSourceMap[source][i].Len() != 0 {
|
|
continue
|
|
}
|
|
field.Set(reflect.Append(fieldsSourceMap[source][i], link...))
|
|
}
|
|
|
|
}
|
|
}
|