gorm/preload.go

379 lines
11 KiB
Go
Raw Normal View History

package gorm
import (
"database/sql/driver"
"errors"
"fmt"
"reflect"
2015-04-21 10:00:36 +03:00
"strings"
)
2015-07-30 13:19:49 +03:00
func getRealValue(value reflect.Value, columns []string) (results []interface{}) {
// If value is a nil pointer, Indirect returns a zero Value!
// Therefor we need to check for a zero value,
// as FieldByName could panic
if pointedValue := reflect.Indirect(value); pointedValue.IsValid() {
for _, column := range columns {
if pointedValue.FieldByName(column).IsValid() {
result := pointedValue.FieldByName(column).Interface()
if r, ok := result.(driver.Valuer); ok {
result, _ = r.Value()
}
results = append(results, result)
2015-08-16 10:10:11 +03:00
}
2015-07-30 13:19:49 +03:00
}
}
2015-07-30 13:19:49 +03:00
return
}
func equalAsString(a interface{}, b interface{}) bool {
2016-01-16 06:37:16 +03:00
return toString(a) == toString(b)
}
func Preload(scope *Scope) {
2016-01-03 09:21:21 +03:00
if scope.Search.preload == nil || scope.HasError() {
2015-04-21 11:51:52 +03:00
return
}
2015-04-21 10:00:36 +03:00
preloadMap := map[string]bool{}
2015-04-21 11:51:52 +03:00
fields := scope.Fields()
for _, preload := range scope.Search.preload {
schema, conditions := preload.schema, preload.conditions
keys := strings.Split(schema, ".")
currentScope := scope
currentFields := fields
originalConditions := conditions
conditions = []interface{}{}
for i, key := range keys {
var found bool
if preloadMap[strings.Join(keys[:i+1], ".")] {
goto nextLoop
}
if i == len(keys)-1 {
conditions = originalConditions
}
for _, field := range currentFields {
if field.Name != key || field.Relationship == nil {
continue
2015-04-21 10:00:36 +03:00
}
2015-04-21 11:51:52 +03:00
found = true
switch field.Relationship.Kind {
case "has_one":
currentScope.handleHasOnePreload(field, conditions)
case "has_many":
currentScope.handleHasManyPreload(field, conditions)
case "belongs_to":
currentScope.handleBelongsToPreload(field, conditions)
case "many_to_many":
currentScope.handleManyToManyPreload(field, conditions)
2015-04-21 11:51:52 +03:00
default:
currentScope.Err(errors.New("not supported relation"))
}
2015-04-21 11:51:52 +03:00
break
}
if !found {
value := reflect.ValueOf(currentScope.Value)
if value.Kind() == reflect.Slice && value.Type().Elem().Kind() == reflect.Interface {
value = value.Index(0).Elem()
}
2015-06-11 17:14:36 +03:00
scope.Err(fmt.Errorf("can't find field %s in %s", key, value.Type()))
2015-04-21 11:51:52 +03:00
return
}
preloadMap[strings.Join(keys[:i+1], ".")] = true
nextLoop:
if i < len(keys)-1 {
currentScope = currentScope.getColumnsAsScope(key)
currentFields = currentScope.Fields()
}
}
}
2015-04-21 11:51:52 +03:00
}
2015-02-17 17:55:14 +03:00
func makeSlice(typ reflect.Type) interface{} {
if typ.Kind() == reflect.Slice {
typ = typ.Elem()
}
sliceType := reflect.SliceOf(typ)
slice := reflect.New(sliceType)
slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0))
return slice.Interface()
}
2015-04-21 11:51:52 +03:00
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
2015-07-30 13:19:49 +03:00
relation := field.Relationship
2015-07-30 17:36:04 +03:00
2015-07-30 13:19:49 +03:00
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames)
2015-04-21 11:51:52 +03:00
if len(primaryKeys) == 0 {
return
}
results := makeSlice(field.Struct.Type)
2015-07-30 13:19:49 +03:00
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
2015-04-22 10:36:10 +03:00
resultValues := reflect.Indirect(reflect.ValueOf(results))
2015-04-21 11:51:52 +03:00
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
if scope.IndirectValue().Kind() == reflect.Slice {
2015-07-30 13:19:49 +03:00
value := getRealValue(result, relation.ForeignFieldNames)
2015-04-21 11:51:52 +03:00
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
2015-07-30 13:19:49 +03:00
if equalAsString(getRealValue(objects.Index(j), relation.AssociationForeignFieldNames), value) {
2015-04-21 11:51:52 +03:00
reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result)
break
}
}
} else {
2015-04-22 10:36:10 +03:00
if err := scope.SetColumn(field, result); err != nil {
2015-04-21 11:51:52 +03:00
scope.Err(err)
return
}
}
}
}
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
2015-07-30 13:19:49 +03:00
relation := field.Relationship
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames)
2015-04-21 11:51:52 +03:00
if len(primaryKeys) == 0 {
return
}
results := makeSlice(field.Struct.Type)
2015-07-30 13:19:49 +03:00
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
2015-04-22 10:36:10 +03:00
resultValues := reflect.Indirect(reflect.ValueOf(results))
2015-04-21 11:51:52 +03:00
if scope.IndirectValue().Kind() == reflect.Slice {
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
2015-07-30 13:19:49 +03:00
value := getRealValue(result, relation.ForeignFieldNames)
2015-04-21 11:51:52 +03:00
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j))
2015-07-30 13:19:49 +03:00
if equalAsString(getRealValue(object, relation.AssociationForeignFieldNames), value) {
2015-04-21 11:51:52 +03:00
f := object.FieldByName(field.Name)
f.Set(reflect.Append(f, result))
break
}
}
}
} else {
scope.SetColumn(field, resultValues)
}
}
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
relation := field.Relationship
2015-07-30 13:19:49 +03:00
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames)
2015-04-21 11:51:52 +03:00
if len(primaryKeys) == 0 {
return
}
results := makeSlice(field.Struct.Type)
2015-07-30 17:36:04 +03:00
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
2015-04-22 10:36:10 +03:00
resultValues := reflect.Indirect(reflect.ValueOf(results))
2015-04-21 11:51:52 +03:00
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
if scope.IndirectValue().Kind() == reflect.Slice {
2015-07-30 13:19:49 +03:00
value := getRealValue(result, relation.AssociationForeignFieldNames)
2015-04-21 11:51:52 +03:00
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j))
if object.Kind() == reflect.Ptr {
object = reflect.Indirect(objects.Index(j).Elem())
}
2015-07-30 13:19:49 +03:00
if equalAsString(getRealValue(object, relation.ForeignFieldNames), value) {
2015-04-21 11:51:52 +03:00
object.FieldByName(field.Name).Set(result)
}
}
} else {
scope.SetColumn(field, result)
}
}
}
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
2015-08-16 10:10:11 +03:00
relation := field.Relationship
joinTableHandler := relation.JoinTableHandler
destType := field.StructField.Struct.Type.Elem()
var isPtr bool
if destType.Kind() == reflect.Ptr {
isPtr = true
destType = destType.Elem()
}
2015-08-16 10:10:11 +03:00
var sourceKeys []string
var linkHash = make(map[string][]reflect.Value)
2015-08-16 10:10:11 +03:00
for _, key := range joinTableHandler.SourceForeignKeys() {
sourceKeys = append(sourceKeys, key.DBName)
}
2015-10-01 02:43:38 +03:00
db := scope.NewDB().Table(scope.New(reflect.New(destType).Interface()).TableName()).Select("*")
2015-08-18 03:05:44 +03:00
preloadJoinDB := joinTableHandler.JoinWith(joinTableHandler, db, scope.Value)
if len(conditions) > 0 {
preloadJoinDB = preloadJoinDB.Where(conditions[0], conditions[1:]...)
}
rows, err := preloadJoinDB.Rows()
2015-08-16 10:10:11 +03:00
if scope.Err(err) != nil {
return
}
defer rows.Close()
columns, _ := rows.Columns()
for rows.Next() {
elem := reflect.New(destType).Elem()
var values = make([]interface{}, len(columns))
fields := scope.New(elem.Addr().Interface()).Fields()
var foundFields = map[string]bool{}
2015-08-16 10:10:11 +03:00
for index, column := range columns {
if field, ok := fields[column]; ok && !foundFields[column] {
2015-08-16 10:10:11 +03:00
if field.Field.Kind() == reflect.Ptr {
values[index] = field.Field.Addr().Interface()
} else {
values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface()
}
foundFields[column] = true
2015-08-16 10:10:11 +03:00
} else {
var i interface{}
values[index] = &i
}
}
scope.Err(rows.Scan(values...))
var sourceKey []interface{}
var scannedFields = map[string]bool{}
2015-08-16 10:10:11 +03:00
for index, column := range columns {
value := values[index]
if field, ok := fields[column]; ok && !scannedFields[column] {
2015-08-16 10:10:11 +03:00
if field.Field.Kind() == reflect.Ptr {
field.Field.Set(reflect.ValueOf(value).Elem())
} else if v := reflect.ValueOf(value).Elem().Elem(); v.IsValid() {
field.Field.Set(v)
}
scannedFields[column] = true
2015-08-16 10:10:11 +03:00
} else if strInSlice(column, sourceKeys) {
sourceKey = append(sourceKey, *(value.(*interface{})))
}
}
2015-08-17 22:28:40 +03:00
if len(sourceKey) != 0 {
if isPtr {
linkHash[toString(sourceKey)] = append(linkHash[toString(sourceKey)], elem.Addr())
} else {
linkHash[toString(sourceKey)] = append(linkHash[toString(sourceKey)], elem)
}
}
2015-08-16 10:10:11 +03:00
}
var foreignFieldNames []string
for _, dbName := range relation.ForeignFieldNames {
if field, ok := scope.FieldByName(dbName); ok {
foreignFieldNames = append(foreignFieldNames, field.Name)
}
}
2015-08-16 10:10:11 +03:00
if scope.IndirectValue().Kind() == reflect.Slice {
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j))
source := getRealValue(object, foreignFieldNames)
field := object.FieldByName(field.Name)
for _, link := range linkHash[toString(source)] {
field.Set(reflect.Append(field, link))
2015-08-16 10:10:11 +03:00
}
}
2015-08-16 12:25:25 +03:00
} else {
if object := scope.IndirectValue(); object.IsValid() {
source := getRealValue(object, foreignFieldNames)
field := object.FieldByName(field.Name)
for _, link := range linkHash[toString(source)] {
field.Set(reflect.Append(field, link))
}
2015-08-16 12:25:25 +03:00
}
2015-08-16 10:10:11 +03:00
}
}
2015-07-30 13:19:49 +03:00
func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) {
values := scope.IndirectValue()
switch values.Kind() {
case reflect.Slice:
for i := 0; i < values.Len(); i++ {
2015-07-30 13:19:49 +03:00
var result []interface{}
for _, column := range columns {
value := reflect.Indirect(values.Index(i))
if value.Kind() == reflect.Ptr {
value = reflect.Indirect(values.Index(i).Elem())
}
result = append(result, value.FieldByName(column).Interface())
2015-07-30 13:19:49 +03:00
}
results = append(results, result)
}
case reflect.Struct:
2015-07-30 13:19:49 +03:00
var result []interface{}
for _, column := range columns {
result = append(result, values.FieldByName(column).Interface())
}
return [][]interface{}{result}
}
return
}
2015-04-21 10:00:36 +03:00
func (scope *Scope) getColumnsAsScope(column string) *Scope {
values := scope.IndirectValue()
switch values.Kind() {
case reflect.Slice:
2015-04-22 10:36:10 +03:00
modelType := values.Type().Elem()
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
2015-04-21 11:51:52 +03:00
}
2015-04-22 10:36:10 +03:00
fieldStruct, _ := modelType.FieldByName(column)
2015-04-21 10:00:36 +03:00
var columns reflect.Value
if fieldStruct.Type.Kind() == reflect.Slice || fieldStruct.Type.Kind() == reflect.Ptr {
2015-04-22 10:36:10 +03:00
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type.Elem()))).Elem()
2015-04-21 10:00:36 +03:00
} else {
2015-04-22 10:36:10 +03:00
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type))).Elem()
2015-04-21 10:00:36 +03:00
}
for i := 0; i < values.Len(); i++ {
column := reflect.Indirect(values.Index(i)).FieldByName(column)
if column.Kind() == reflect.Ptr {
column = column.Elem()
}
2015-04-21 10:00:36 +03:00
if column.Kind() == reflect.Slice {
for i := 0; i < column.Len(); i++ {
elem := column.Index(i)
if elem.CanAddr() {
columns = reflect.Append(columns, elem.Addr())
}
2015-04-21 10:00:36 +03:00
}
} else {
if column.CanAddr() {
columns = reflect.Append(columns, column.Addr())
}
2015-04-21 10:00:36 +03:00
}
}
return scope.New(columns.Interface())
case reflect.Struct:
field := values.FieldByName(column)
if !field.CanAddr() {
return nil
}
return scope.New(field.Addr().Interface())
2015-04-21 10:00:36 +03:00
}
return nil
}