forked from mirror/gorm
support nested preloading
This commit is contained in:
parent
055bf79f8b
commit
6d58dc9f4e
245
preload.go
245
preload.go
|
@ -5,6 +5,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func getRealValue(value reflect.Value, field string) interface{} {
|
||||
|
@ -20,90 +21,139 @@ func equalAsString(a interface{}, b interface{}) bool {
|
|||
}
|
||||
|
||||
func Preload(scope *Scope) {
|
||||
preloadMap := map[string]bool{}
|
||||
if scope.Search.preload != nil {
|
||||
fields := scope.Fields()
|
||||
isSlice := scope.IndirectValue().Kind() == reflect.Slice
|
||||
|
||||
for key, conditions := range scope.Search.preload {
|
||||
for _, field := range fields {
|
||||
if field.Name == key && field.Relationship != nil {
|
||||
results := makeSlice(field.Struct.Type)
|
||||
relation := field.Relationship
|
||||
primaryName := scope.PrimaryField().Name
|
||||
associationPrimaryKey := scope.New(results).PrimaryField().Name
|
||||
|
||||
switch relation.Kind {
|
||||
case "has_one":
|
||||
if primaryKeys := scope.getColumnAsArray(primaryName); len(primaryKeys) > 0 {
|
||||
condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName))
|
||||
scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...)
|
||||
|
||||
resultValues := reflect.Indirect(reflect.ValueOf(results))
|
||||
for i := 0; i < resultValues.Len(); i++ {
|
||||
result := resultValues.Index(i)
|
||||
if isSlice {
|
||||
value := getRealValue(result, relation.ForeignFieldName)
|
||||
objects := scope.IndirectValue()
|
||||
for j := 0; j < objects.Len(); j++ {
|
||||
if equalAsString(getRealValue(objects.Index(j), primaryName), value) {
|
||||
reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
scope.SetColumn(field, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
case "has_many":
|
||||
if primaryKeys := scope.getColumnAsArray(primaryName); len(primaryKeys) > 0 {
|
||||
condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName))
|
||||
scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...)
|
||||
resultValues := reflect.Indirect(reflect.ValueOf(results))
|
||||
if isSlice {
|
||||
for i := 0; i < resultValues.Len(); i++ {
|
||||
result := resultValues.Index(i)
|
||||
value := getRealValue(result, relation.ForeignFieldName)
|
||||
objects := scope.IndirectValue()
|
||||
for j := 0; j < objects.Len(); j++ {
|
||||
object := reflect.Indirect(objects.Index(j))
|
||||
if equalAsString(getRealValue(object, primaryName), value) {
|
||||
f := object.FieldByName(field.Name)
|
||||
f.Set(reflect.Append(f, result))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
scope.SetColumn(field, resultValues)
|
||||
}
|
||||
}
|
||||
case "belongs_to":
|
||||
if primaryKeys := scope.getColumnAsArray(relation.ForeignFieldName); len(primaryKeys) > 0 {
|
||||
scope.NewDB().Where(primaryKeys).Find(results, conditions...)
|
||||
resultValues := reflect.Indirect(reflect.ValueOf(results))
|
||||
for i := 0; i < resultValues.Len(); i++ {
|
||||
result := resultValues.Index(i)
|
||||
if isSlice {
|
||||
value := getRealValue(result, associationPrimaryKey)
|
||||
objects := scope.IndirectValue()
|
||||
for j := 0; j < objects.Len(); j++ {
|
||||
object := reflect.Indirect(objects.Index(j))
|
||||
if equalAsString(getRealValue(object, relation.ForeignFieldName), value) {
|
||||
object.FieldByName(field.Name).Set(result)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
scope.SetColumn(field, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
case "many_to_many":
|
||||
scope.Err(errors.New("not supported relation"))
|
||||
default:
|
||||
scope.Err(errors.New("not supported relation"))
|
||||
for _, preload := range scope.Search.preload {
|
||||
schema, conditions := preload.schema, preload.conditions
|
||||
keys := strings.Split(schema, ".")
|
||||
currentScope := scope
|
||||
currentFields := fields
|
||||
currentIsSlice := isSlice
|
||||
originalConditions := conditions
|
||||
conditions = []interface{}{}
|
||||
for i, key := range keys {
|
||||
// log.Printf("--> %+v\n", key)
|
||||
if !preloadMap[strings.Join(keys[:i+1], ".")] {
|
||||
if i == len(keys)-1 {
|
||||
// log.Printf("--> %+v\n", originalConditions)
|
||||
conditions = originalConditions
|
||||
}
|
||||
break
|
||||
|
||||
var found bool
|
||||
for _, field := range currentFields {
|
||||
if field.Name == key && field.Relationship != nil {
|
||||
found = true
|
||||
// log.Printf("--> %+v\n", field.Name)
|
||||
results := makeSlice(field.Struct.Type)
|
||||
relation := field.Relationship
|
||||
primaryName := currentScope.PrimaryField().Name
|
||||
associationPrimaryKey := currentScope.New(results).PrimaryField().Name
|
||||
|
||||
switch relation.Kind {
|
||||
case "has_one":
|
||||
if primaryKeys := currentScope.getColumnAsArray(primaryName); len(primaryKeys) > 0 {
|
||||
condition := fmt.Sprintf("%v IN (?)", currentScope.Quote(relation.ForeignDBName))
|
||||
currentScope.NewDB().Where(condition, primaryKeys).Find(results, conditions...)
|
||||
|
||||
resultValues := reflect.Indirect(reflect.ValueOf(results))
|
||||
for i := 0; i < resultValues.Len(); i++ {
|
||||
result := resultValues.Index(i)
|
||||
if currentIsSlice {
|
||||
value := getRealValue(result, relation.ForeignFieldName)
|
||||
objects := currentScope.IndirectValue()
|
||||
for j := 0; j < objects.Len(); j++ {
|
||||
if equalAsString(getRealValue(objects.Index(j), primaryName), value) {
|
||||
reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// log.Printf("--> %+v\n", result.Interface())
|
||||
err := currentScope.SetColumn(field, result)
|
||||
if err != nil {
|
||||
scope.Err(err)
|
||||
return
|
||||
}
|
||||
// printutils.PrettyPrint(currentScope.Value)
|
||||
}
|
||||
}
|
||||
// printutils.PrettyPrint(currentScope.Value)
|
||||
}
|
||||
case "has_many":
|
||||
// log.Printf("--> %+v\n", key)
|
||||
if primaryKeys := currentScope.getColumnAsArray(primaryName); len(primaryKeys) > 0 {
|
||||
condition := fmt.Sprintf("%v IN (?)", currentScope.Quote(relation.ForeignDBName))
|
||||
currentScope.NewDB().Where(condition, primaryKeys).Find(results, conditions...)
|
||||
resultValues := reflect.Indirect(reflect.ValueOf(results))
|
||||
if currentIsSlice {
|
||||
for i := 0; i < resultValues.Len(); i++ {
|
||||
result := resultValues.Index(i)
|
||||
value := getRealValue(result, relation.ForeignFieldName)
|
||||
objects := currentScope.IndirectValue()
|
||||
for j := 0; j < objects.Len(); j++ {
|
||||
object := reflect.Indirect(objects.Index(j))
|
||||
if equalAsString(getRealValue(object, primaryName), value) {
|
||||
f := object.FieldByName(field.Name)
|
||||
f.Set(reflect.Append(f, result))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
// printutils.PrettyPrint(currentScope.IndirectValue().Interface())
|
||||
} else {
|
||||
currentScope.SetColumn(field, resultValues)
|
||||
}
|
||||
}
|
||||
case "belongs_to":
|
||||
if primaryKeys := currentScope.getColumnAsArray(relation.ForeignFieldName); len(primaryKeys) > 0 {
|
||||
currentScope.NewDB().Where(primaryKeys).Find(results, conditions...)
|
||||
resultValues := reflect.Indirect(reflect.ValueOf(results))
|
||||
for i := 0; i < resultValues.Len(); i++ {
|
||||
result := resultValues.Index(i)
|
||||
if currentIsSlice {
|
||||
value := getRealValue(result, associationPrimaryKey)
|
||||
objects := currentScope.IndirectValue()
|
||||
for j := 0; j < objects.Len(); j++ {
|
||||
object := reflect.Indirect(objects.Index(j))
|
||||
if equalAsString(getRealValue(object, relation.ForeignFieldName), value) {
|
||||
object.FieldByName(field.Name).Set(result)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
currentScope.SetColumn(field, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
case "many_to_many":
|
||||
// currentScope.Err(errors.New("not supported relation"))
|
||||
fallthrough
|
||||
default:
|
||||
currentScope.Err(errors.New("not supported relation"))
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
value := reflect.ValueOf(currentScope.Value)
|
||||
if value.Kind() == reflect.Slice && value.Type().Elem().Kind() == reflect.Interface {
|
||||
value = value.Index(0).Elem()
|
||||
}
|
||||
scope.Err(fmt.Errorf("can't found field %s in %s", key, value.Type()))
|
||||
return
|
||||
}
|
||||
|
||||
preloadMap[strings.Join(keys[:i+1], ".")] = true
|
||||
}
|
||||
|
||||
if i < len(keys)-1 {
|
||||
// TODO: update current scope
|
||||
currentScope = currentScope.getColumnsAsScope(key)
|
||||
currentFields = currentScope.Fields()
|
||||
currentIsSlice = currentScope.IndirectValue().Kind() == reflect.Slice
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -120,19 +170,44 @@ func makeSlice(typ reflect.Type) interface{} {
|
|||
return slice.Interface()
|
||||
}
|
||||
|
||||
func (scope *Scope) getColumnAsArray(column string) (primaryKeys []interface{}) {
|
||||
func (scope *Scope) getColumnAsArray(column string) (columns []interface{}) {
|
||||
values := scope.IndirectValue()
|
||||
switch values.Kind() {
|
||||
case reflect.Slice:
|
||||
primaryKeyMap := map[interface{}]bool{}
|
||||
for i := 0; i < values.Len(); i++ {
|
||||
primaryKeyMap[reflect.Indirect(values.Index(i)).FieldByName(column).Interface()] = true
|
||||
}
|
||||
for key := range primaryKeyMap {
|
||||
primaryKeys = append(primaryKeys, key)
|
||||
columns = append(columns, reflect.Indirect(values.Index(i)).FieldByName(column).Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
return []interface{}{values.FieldByName(column).Interface()}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (scope *Scope) getColumnsAsScope(column string) *Scope {
|
||||
values := scope.IndirectValue()
|
||||
// log.Println(values.Type(), column)
|
||||
switch values.Kind() {
|
||||
case reflect.Slice:
|
||||
fieldType, _ := values.Type().Elem().FieldByName(column)
|
||||
var columns reflect.Value
|
||||
if fieldType.Type.Kind() == reflect.Slice {
|
||||
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType.Type.Elem()))).Elem()
|
||||
} else {
|
||||
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType.Type))).Elem()
|
||||
}
|
||||
for i := 0; i < values.Len(); i++ {
|
||||
column := reflect.Indirect(values.Index(i)).FieldByName(column)
|
||||
if column.Kind() == reflect.Slice {
|
||||
for i := 0; i < column.Len(); i++ {
|
||||
columns = reflect.Append(columns, column.Index(i).Addr())
|
||||
}
|
||||
} else {
|
||||
columns = reflect.Append(columns, column.Addr())
|
||||
}
|
||||
}
|
||||
return scope.New(columns.Interface())
|
||||
case reflect.Struct:
|
||||
return scope.New(values.FieldByName(column).Addr().Interface())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
421
preload_test.go
421
preload_test.go
|
@ -1,6 +1,11 @@
|
|||
package gorm_test
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func getPreloadUser(name string) *User {
|
||||
return getPreparedUser(name, "Preload")
|
||||
|
@ -85,3 +90,417 @@ func TestPreload(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNestedPreload(t *testing.T) {
|
||||
log.SetFlags(log.Lshortfile)
|
||||
// Struct: Level3
|
||||
{
|
||||
type (
|
||||
Level1 struct {
|
||||
ID uint
|
||||
Value string
|
||||
Level2ID uint
|
||||
}
|
||||
Level2 struct {
|
||||
ID uint
|
||||
Level1 Level1
|
||||
Level3ID uint
|
||||
}
|
||||
Level3 struct {
|
||||
ID uint
|
||||
Level2 Level2
|
||||
}
|
||||
)
|
||||
DB.DropTableIfExists(&Level3{})
|
||||
DB.DropTableIfExists(&Level2{})
|
||||
DB.DropTableIfExists(&Level1{})
|
||||
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
want := Level3{Level2: Level2{Level1: Level1{Value: "value"}}}
|
||||
if err := DB.Create(&want).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var got Level3
|
||||
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||
}
|
||||
}
|
||||
{
|
||||
type (
|
||||
Level1 struct {
|
||||
ID uint
|
||||
Value string
|
||||
Level2ID uint
|
||||
}
|
||||
Level2 struct {
|
||||
ID uint
|
||||
Level1s []Level1
|
||||
Level3ID uint
|
||||
}
|
||||
Level3 struct {
|
||||
ID uint
|
||||
Level2s []Level2
|
||||
}
|
||||
)
|
||||
DB.DropTableIfExists(&Level3{})
|
||||
DB.DropTableIfExists(&Level2{})
|
||||
DB.DropTableIfExists(&Level1{})
|
||||
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
want := Level3{
|
||||
Level2s: []Level2{
|
||||
{
|
||||
Level1s: []Level1{
|
||||
{Value: "value1"},
|
||||
{Value: "value2"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Level1s: []Level1{
|
||||
{Value: "value3"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := DB.Create(&want).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var got Level3
|
||||
if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||
}
|
||||
}
|
||||
{
|
||||
type (
|
||||
Level1 struct {
|
||||
ID uint
|
||||
Value string
|
||||
Level2ID uint
|
||||
}
|
||||
Level2 struct {
|
||||
ID uint
|
||||
Level1 Level1
|
||||
Level3ID uint
|
||||
}
|
||||
Level3 struct {
|
||||
ID uint
|
||||
Level2s []Level2
|
||||
}
|
||||
)
|
||||
DB.DropTableIfExists(&Level3{})
|
||||
DB.DropTableIfExists(&Level2{})
|
||||
DB.DropTableIfExists(&Level1{})
|
||||
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
want := Level3{
|
||||
Level2s: []Level2{
|
||||
{Level1: Level1{Value: "value1"}},
|
||||
{Level1: Level1{Value: "value2"}},
|
||||
},
|
||||
}
|
||||
if err := DB.Create(&want).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var got Level3
|
||||
if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||
}
|
||||
}
|
||||
{
|
||||
type (
|
||||
Level1 struct {
|
||||
ID uint
|
||||
Value string
|
||||
Level2ID uint
|
||||
}
|
||||
Level2 struct {
|
||||
ID uint
|
||||
Level1s []Level1
|
||||
Level3ID uint
|
||||
}
|
||||
Level3 struct {
|
||||
ID uint
|
||||
Level2 Level2
|
||||
}
|
||||
)
|
||||
DB.DropTableIfExists(&Level3{})
|
||||
DB.DropTableIfExists(&Level2{})
|
||||
DB.DropTableIfExists(&Level1{})
|
||||
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
want := Level3{
|
||||
Level2: Level2{
|
||||
Level1s: []Level1{
|
||||
Level1{Value: "value1"},
|
||||
Level1{Value: "value2"},
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := DB.Create(&want).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var got Level3
|
||||
if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||
}
|
||||
}
|
||||
|
||||
// Slice: []Level3
|
||||
{
|
||||
type (
|
||||
Level1 struct {
|
||||
ID uint
|
||||
Value string
|
||||
Level2ID uint
|
||||
}
|
||||
Level2 struct {
|
||||
ID uint
|
||||
Level1 Level1
|
||||
Level3ID uint
|
||||
}
|
||||
Level3 struct {
|
||||
ID uint
|
||||
Level2 Level2
|
||||
}
|
||||
)
|
||||
DB.DropTableIfExists(&Level3{})
|
||||
DB.DropTableIfExists(&Level2{})
|
||||
DB.DropTableIfExists(&Level1{})
|
||||
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
want := make([]Level3, 2)
|
||||
want[0] = Level3{Level2: Level2{Level1: Level1{Value: "value"}}}
|
||||
if err := DB.Create(&want[0]).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
want[1] = Level3{Level2: Level2{Level1: Level1{Value: "value2"}}}
|
||||
if err := DB.Create(&want[1]).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var got []Level3
|
||||
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||
}
|
||||
}
|
||||
{
|
||||
type (
|
||||
Level1 struct {
|
||||
ID uint
|
||||
Value string
|
||||
Level2ID uint
|
||||
}
|
||||
Level2 struct {
|
||||
ID uint
|
||||
Level1s []Level1
|
||||
Level3ID uint
|
||||
}
|
||||
Level3 struct {
|
||||
ID uint
|
||||
Level2s []Level2
|
||||
}
|
||||
)
|
||||
DB.DropTableIfExists(&Level3{})
|
||||
DB.DropTableIfExists(&Level2{})
|
||||
DB.DropTableIfExists(&Level1{})
|
||||
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
want := make([]Level3, 2)
|
||||
want[0] = Level3{
|
||||
Level2s: []Level2{
|
||||
{
|
||||
Level1s: []Level1{
|
||||
{Value: "value1"},
|
||||
{Value: "value2"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Level1s: []Level1{
|
||||
{Value: "value3"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := DB.Create(&want[0]).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
want[1] = Level3{
|
||||
Level2s: []Level2{
|
||||
{
|
||||
Level1s: []Level1{
|
||||
{Value: "value3"},
|
||||
{Value: "value4"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Level1s: []Level1{
|
||||
{Value: "value5"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := DB.Create(&want[1]).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var got []Level3
|
||||
if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||
}
|
||||
}
|
||||
{
|
||||
type (
|
||||
Level1 struct {
|
||||
ID uint
|
||||
Value string
|
||||
Level2ID uint
|
||||
}
|
||||
Level2 struct {
|
||||
ID uint
|
||||
Level1 Level1
|
||||
Level3ID uint
|
||||
}
|
||||
Level3 struct {
|
||||
ID uint
|
||||
Level2s []Level2
|
||||
}
|
||||
)
|
||||
DB.DropTableIfExists(&Level3{})
|
||||
DB.DropTableIfExists(&Level2{})
|
||||
DB.DropTableIfExists(&Level1{})
|
||||
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
want := make([]Level3, 2)
|
||||
want[0] = Level3{
|
||||
Level2s: []Level2{
|
||||
{Level1: Level1{Value: "value1"}},
|
||||
{Level1: Level1{Value: "value2"}},
|
||||
},
|
||||
}
|
||||
if err := DB.Create(&want[0]).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
want[1] = Level3{
|
||||
Level2s: []Level2{
|
||||
{Level1: Level1{Value: "value3"}},
|
||||
{Level1: Level1{Value: "value4"}},
|
||||
},
|
||||
}
|
||||
if err := DB.Create(&want[1]).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var got []Level3
|
||||
if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||
}
|
||||
}
|
||||
{
|
||||
type (
|
||||
Level1 struct {
|
||||
ID uint
|
||||
Value string
|
||||
Level2ID uint
|
||||
}
|
||||
Level2 struct {
|
||||
ID uint
|
||||
Level1s []Level1
|
||||
Level3ID uint
|
||||
}
|
||||
Level3 struct {
|
||||
ID uint
|
||||
Level2 Level2
|
||||
}
|
||||
)
|
||||
DB.DropTableIfExists(&Level3{})
|
||||
DB.DropTableIfExists(&Level2{})
|
||||
DB.DropTableIfExists(&Level1{})
|
||||
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
want := make([]Level3, 2)
|
||||
want[0] = Level3{
|
||||
Level2: Level2{
|
||||
Level1s: []Level1{
|
||||
Level1{Value: "value1"},
|
||||
Level1{Value: "value2"},
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := DB.Create(&want[0]).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
want[1] = Level3{
|
||||
Level2: Level2{
|
||||
Level1s: []Level1{
|
||||
Level1{Value: "value3"},
|
||||
Level1{Value: "value4"},
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := DB.Create(&want[1]).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var got []Level3
|
||||
if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func toJSONString(v interface{}) []byte {
|
||||
r, _ := json.MarshalIndent(v, "", " ")
|
||||
return r
|
||||
}
|
||||
|
|
19
search.go
19
search.go
|
@ -14,7 +14,7 @@ type search struct {
|
|||
omits []string
|
||||
orders []string
|
||||
joins string
|
||||
preload map[string][]interface{}
|
||||
preload []searchPreload
|
||||
offset string
|
||||
limit string
|
||||
group string
|
||||
|
@ -23,6 +23,11 @@ type search struct {
|
|||
Unscoped bool
|
||||
}
|
||||
|
||||
type searchPreload struct {
|
||||
schema string
|
||||
conditions []interface{}
|
||||
}
|
||||
|
||||
func (s *search) clone() *search {
|
||||
clone := *s
|
||||
return &clone
|
||||
|
@ -97,11 +102,15 @@ func (s *search) Joins(query string) *search {
|
|||
return s
|
||||
}
|
||||
|
||||
func (s *search) Preload(column string, values ...interface{}) *search {
|
||||
if s.preload == nil {
|
||||
s.preload = map[string][]interface{}{}
|
||||
func (s *search) Preload(schema string, values ...interface{}) *search {
|
||||
var preloads []searchPreload
|
||||
for _, preload := range s.preload {
|
||||
if preload.schema != schema {
|
||||
preloads = append(preloads, preload)
|
||||
}
|
||||
}
|
||||
s.preload[column] = values
|
||||
preloads = append(preloads, searchPreload{schema, values})
|
||||
s.preload = preloads
|
||||
return s
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue