refactor preload and its tests

This commit is contained in:
bom_d_van 2015-04-21 16:51:52 +08:00
parent 6d58dc9f4e
commit 9e9367e815
2 changed files with 635 additions and 515 deletions

View File

@ -21,121 +21,49 @@ func equalAsString(a interface{}, b interface{}) bool {
} }
func Preload(scope *Scope) { func Preload(scope *Scope) {
preloadMap := map[string]bool{} if scope.Search.preload == nil {
if scope.Search.preload != nil { return
fields := scope.Fields() }
isSlice := scope.IndirectValue().Kind() == reflect.Slice
preloadMap := map[string]bool{}
fields := scope.Fields()
for _, preload := range scope.Search.preload { for _, preload := range scope.Search.preload {
schema, conditions := preload.schema, preload.conditions schema, conditions := preload.schema, preload.conditions
keys := strings.Split(schema, ".") keys := strings.Split(schema, ".")
currentScope := scope currentScope := scope
currentFields := fields currentFields := fields
currentIsSlice := isSlice
originalConditions := conditions originalConditions := conditions
conditions = []interface{}{} conditions = []interface{}{}
for i, key := range keys { for i, key := range keys {
// log.Printf("--> %+v\n", key) var found bool
if !preloadMap[strings.Join(keys[:i+1], ".")] { if preloadMap[strings.Join(keys[:i+1], ".")] {
goto nextLoop
}
if i == len(keys)-1 { if i == len(keys)-1 {
// log.Printf("--> %+v\n", originalConditions)
conditions = originalConditions conditions = originalConditions
} }
var found bool
for _, field := range currentFields { for _, field := range currentFields {
if field.Name == key && field.Relationship != nil { if field.Name != key || field.Relationship == nil {
continue
}
found = true found = true
// log.Printf("--> %+v\n", field.Name) switch field.Relationship.Kind {
results := makeSlice(field.Struct.Type)
relation := field.Relationship
primaryName := currentScope.PrimaryField().Name
associationPrimaryKey := currentScope.New(results).PrimaryField().Name
switch relation.Kind {
case "has_one": case "has_one":
if primaryKeys := currentScope.getColumnAsArray(primaryName); len(primaryKeys) > 0 { currentScope.handleHasOnePreload(field, conditions)
condition := fmt.Sprintf("%v IN (?)", currentScope.Quote(relation.ForeignDBName))
currentScope.NewDB().Where(condition, primaryKeys).Find(results, conditions...)
resultValues := reflect.Indirect(reflect.ValueOf(results))
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
if currentIsSlice {
value := getRealValue(result, relation.ForeignFieldName)
objects := currentScope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
if equalAsString(getRealValue(objects.Index(j), primaryName), value) {
reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result)
break
}
}
} else {
// log.Printf("--> %+v\n", result.Interface())
err := currentScope.SetColumn(field, result)
if err != nil {
scope.Err(err)
return
}
// printutils.PrettyPrint(currentScope.Value)
}
}
// printutils.PrettyPrint(currentScope.Value)
}
case "has_many": case "has_many":
// log.Printf("--> %+v\n", key) currentScope.handleHasManyPreload(field, conditions)
if primaryKeys := currentScope.getColumnAsArray(primaryName); len(primaryKeys) > 0 {
condition := fmt.Sprintf("%v IN (?)", currentScope.Quote(relation.ForeignDBName))
currentScope.NewDB().Where(condition, primaryKeys).Find(results, conditions...)
resultValues := reflect.Indirect(reflect.ValueOf(results))
if currentIsSlice {
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
value := getRealValue(result, relation.ForeignFieldName)
objects := currentScope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j))
if equalAsString(getRealValue(object, primaryName), value) {
f := object.FieldByName(field.Name)
f.Set(reflect.Append(f, result))
break
}
}
}
// printutils.PrettyPrint(currentScope.IndirectValue().Interface())
} else {
currentScope.SetColumn(field, resultValues)
}
}
case "belongs_to": case "belongs_to":
if primaryKeys := currentScope.getColumnAsArray(relation.ForeignFieldName); len(primaryKeys) > 0 { currentScope.handleBelongsToPreload(field, conditions)
currentScope.NewDB().Where(primaryKeys).Find(results, conditions...)
resultValues := reflect.Indirect(reflect.ValueOf(results))
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
if currentIsSlice {
value := getRealValue(result, associationPrimaryKey)
objects := currentScope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j))
if equalAsString(getRealValue(object, relation.ForeignFieldName), value) {
object.FieldByName(field.Name).Set(result)
}
}
} else {
currentScope.SetColumn(field, result)
}
}
}
case "many_to_many": case "many_to_many":
// currentScope.Err(errors.New("not supported relation"))
fallthrough fallthrough
default: default:
currentScope.Err(errors.New("not supported relation")) currentScope.Err(errors.New("not supported relation"))
} }
break break
} }
}
if !found { if !found {
value := reflect.ValueOf(currentScope.Value) value := reflect.ValueOf(currentScope.Value)
@ -147,17 +75,15 @@ func Preload(scope *Scope) {
} }
preloadMap[strings.Join(keys[:i+1], ".")] = true preloadMap[strings.Join(keys[:i+1], ".")] = true
}
nextLoop:
if i < len(keys)-1 { if i < len(keys)-1 {
// TODO: update current scope
currentScope = currentScope.getColumnsAsScope(key) currentScope = currentScope.getColumnsAsScope(key)
currentFields = currentScope.Fields() currentFields = currentScope.Fields()
currentIsSlice = currentScope.IndirectValue().Kind() == reflect.Slice
}
} }
} }
} }
} }
func makeSlice(typ reflect.Type) interface{} { func makeSlice(typ reflect.Type) interface{} {
@ -170,6 +96,105 @@ func makeSlice(typ reflect.Type) interface{} {
return slice.Interface() return slice.Interface()
} }
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
primaryName := scope.PrimaryField().Name
primaryKeys := scope.getColumnAsArray(primaryName)
if len(primaryKeys) == 0 {
return
}
results := makeSlice(field.Struct.Type)
relation := field.Relationship
condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName))
resultValues := reflect.Indirect(reflect.ValueOf(results))
// TODO: handle error?
scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...)
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
if scope.IndirectValue().Kind() == reflect.Slice {
value := getRealValue(result, relation.ForeignFieldName)
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
if equalAsString(getRealValue(objects.Index(j), primaryName), value) {
reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result)
break
}
}
} else {
err := scope.SetColumn(field, result)
if err != nil {
scope.Err(err)
return
}
}
}
}
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
primaryName := scope.PrimaryField().Name
primaryKeys := scope.getColumnAsArray(primaryName)
if len(primaryKeys) == 0 {
return
}
results := makeSlice(field.Struct.Type)
relation := field.Relationship
condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName))
resultValues := reflect.Indirect(reflect.ValueOf(results))
scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...)
if scope.IndirectValue().Kind() == reflect.Slice {
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
value := getRealValue(result, relation.ForeignFieldName)
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j))
if equalAsString(getRealValue(object, primaryName), value) {
f := object.FieldByName(field.Name)
f.Set(reflect.Append(f, result))
break
}
}
}
} else {
scope.SetColumn(field, resultValues)
}
}
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
relation := field.Relationship
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldName)
if len(primaryKeys) == 0 {
return
}
results := makeSlice(field.Struct.Type)
resultValues := reflect.Indirect(reflect.ValueOf(results))
associationPrimaryKey := scope.New(results).PrimaryField().Name
scope.NewDB().Where(primaryKeys).Find(results, conditions...)
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
if scope.IndirectValue().Kind() == reflect.Slice {
value := getRealValue(result, associationPrimaryKey)
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j))
if equalAsString(getRealValue(object, relation.ForeignFieldName), value) {
object.FieldByName(field.Name).Set(result)
}
}
} else {
scope.SetColumn(field, result)
}
}
}
func (scope *Scope) getColumnAsArray(column string) (columns []interface{}) { func (scope *Scope) getColumnAsArray(column string) (columns []interface{}) {
values := scope.IndirectValue() values := scope.IndirectValue()
switch values.Kind() { switch values.Kind() {
@ -185,10 +210,13 @@ func (scope *Scope) getColumnAsArray(column string) (columns []interface{}) {
func (scope *Scope) getColumnsAsScope(column string) *Scope { func (scope *Scope) getColumnsAsScope(column string) *Scope {
values := scope.IndirectValue() values := scope.IndirectValue()
// log.Println(values.Type(), column)
switch values.Kind() { switch values.Kind() {
case reflect.Slice: case reflect.Slice:
fieldType, _ := values.Type().Elem().FieldByName(column) model := values.Type().Elem()
if model.Kind() == reflect.Ptr {
model = model.Elem()
}
fieldType, _ := model.FieldByName(column)
var columns reflect.Value var columns reflect.Value
if fieldType.Type.Kind() == reflect.Slice { if fieldType.Type.Kind() == reflect.Slice {
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType.Type.Elem()))).Elem() columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType.Type.Elem()))).Elem()

View File

@ -2,7 +2,6 @@ package gorm_test
import ( import (
"encoding/json" "encoding/json"
"log"
"reflect" "reflect"
"testing" "testing"
) )
@ -91,10 +90,7 @@ func TestPreload(t *testing.T) {
} }
} }
func TestNestedPreload(t *testing.T) { func TestNestedPreload1(t *testing.T) {
log.SetFlags(log.Lshortfile)
// Struct: Level3
{
type ( type (
Level1 struct { Level1 struct {
ID uint ID uint
@ -131,8 +127,9 @@ func TestNestedPreload(t *testing.T) {
if !reflect.DeepEqual(got, want) { if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
} }
} }
{
func TestNestedPreload2(t *testing.T) {
type ( type (
Level1 struct { Level1 struct {
ID uint ID uint
@ -141,7 +138,7 @@ func TestNestedPreload(t *testing.T) {
} }
Level2 struct { Level2 struct {
ID uint ID uint
Level1s []Level1 Level1s []*Level1
Level3ID uint Level3ID uint
} }
Level3 struct { Level3 struct {
@ -159,14 +156,14 @@ func TestNestedPreload(t *testing.T) {
want := Level3{ want := Level3{
Level2s: []Level2{ Level2s: []Level2{
{ {
Level1s: []Level1{ Level1s: []*Level1{
{Value: "value1"}, &Level1{Value: "value1"},
{Value: "value2"}, &Level1{Value: "value2"},
}, },
}, },
{ {
Level1s: []Level1{ Level1s: []*Level1{
{Value: "value3"}, &Level1{Value: "value3"},
}, },
}, },
}, },
@ -183,8 +180,9 @@ func TestNestedPreload(t *testing.T) {
if !reflect.DeepEqual(got, want) { if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
} }
} }
{
func TestNestedPreload3(t *testing.T) {
type ( type (
Level1 struct { Level1 struct {
ID uint ID uint
@ -226,8 +224,9 @@ func TestNestedPreload(t *testing.T) {
if !reflect.DeepEqual(got, want) { if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
} }
} }
{
func TestNestedPreload4(t *testing.T) {
type ( type (
Level1 struct { Level1 struct {
ID uint ID uint
@ -271,10 +270,10 @@ func TestNestedPreload(t *testing.T) {
if !reflect.DeepEqual(got, want) { if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
} }
} }
// Slice: []Level3 // Slice: []Level3
{ func TestNestedPreload5(t *testing.T) {
type ( type (
Level1 struct { Level1 struct {
ID uint ID uint
@ -316,8 +315,9 @@ func TestNestedPreload(t *testing.T) {
if !reflect.DeepEqual(got, want) { if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
} }
} }
{
func TestNestedPreload6(t *testing.T) {
type ( type (
Level1 struct { Level1 struct {
ID uint ID uint
@ -387,8 +387,9 @@ func TestNestedPreload(t *testing.T) {
if !reflect.DeepEqual(got, want) { if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
} }
} }
{
func TestNestedPreload7(t *testing.T) {
type ( type (
Level1 struct { Level1 struct {
ID uint ID uint
@ -440,8 +441,9 @@ func TestNestedPreload(t *testing.T) {
if !reflect.DeepEqual(got, want) { if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
} }
} }
{
func TestNestedPreload8(t *testing.T) {
type ( type (
Level1 struct { Level1 struct {
ID uint ID uint
@ -497,6 +499,96 @@ func TestNestedPreload(t *testing.T) {
if !reflect.DeepEqual(got, want) { if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
} }
}
func TestNestedPreload9(t *testing.T) {
type (
Level0 struct {
ID uint
Value string
Level1ID uint
}
Level1 struct {
ID uint
Value string
Level2ID uint
Level2_1ID uint
Level0s []Level0
}
Level2 struct {
ID uint
Level1s []Level1
Level3ID uint
}
Level2_1 struct {
ID uint
Level1s []Level1
Level3ID uint
}
Level3 struct {
ID uint
Level2 Level2
Level2_1 Level2_1
}
)
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level2_1{})
DB.DropTableIfExists(&Level1{})
DB.DropTableIfExists(&Level0{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}).Error; err != nil {
panic(err)
}
want := make([]Level3, 2)
want[0] = Level3{
Level2: Level2{
Level1s: []Level1{
Level1{Value: "value1"},
Level1{Value: "value2"},
},
},
Level2_1: Level2_1{
Level1s: []Level1{
Level1{
Value: "value1-1",
Level0s: []Level0{{Value: "Level0-1"}},
},
Level1{
Value: "value2-2",
Level0s: []Level0{{Value: "Level0-2"}},
},
},
},
}
if err := DB.Create(&want[0]).Error; err != nil {
panic(err)
}
want[1] = Level3{
Level2: Level2{
Level1s: []Level1{
Level1{Value: "value3"},
Level1{Value: "value4"},
},
},
Level2_1: Level2_1{
Level1s: []Level1{
Level1{Value: "value3-3"},
Level1{Value: "value4-4"},
},
},
}
if err := DB.Create(&want[1]).Error; err != nil {
panic(err)
}
var got []Level3
if err := DB.Preload("Level2").Preload("Level2.Level1s").Preload("Level2_1").Preload("Level2_1.Level1s").Preload("Level2_1.Level1s.Level0s").Find(&got).Error; err != nil {
panic(err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
} }
} }