diff --git a/preload.go b/preload.go index 0c8d70ad..42836067 100644 --- a/preload.go +++ b/preload.go @@ -21,143 +21,69 @@ func equalAsString(a interface{}, b interface{}) bool { } func Preload(scope *Scope) { + if scope.Search.preload == nil { + return + } + preloadMap := map[string]bool{} - if scope.Search.preload != nil { - fields := scope.Fields() - isSlice := scope.IndirectValue().Kind() == reflect.Slice + fields := scope.Fields() + for _, preload := range scope.Search.preload { + schema, conditions := preload.schema, preload.conditions + keys := strings.Split(schema, ".") + currentScope := scope + currentFields := fields + originalConditions := conditions + conditions = []interface{}{} + for i, key := range keys { + var found bool + if preloadMap[strings.Join(keys[:i+1], ".")] { + goto nextLoop + } - for _, preload := range scope.Search.preload { - schema, conditions := preload.schema, preload.conditions - keys := strings.Split(schema, ".") - currentScope := scope - currentFields := fields - currentIsSlice := isSlice - originalConditions := conditions - conditions = []interface{}{} - for i, key := range keys { - // log.Printf("--> %+v\n", key) - if !preloadMap[strings.Join(keys[:i+1], ".")] { - if i == len(keys)-1 { - // log.Printf("--> %+v\n", originalConditions) - conditions = originalConditions - } + if i == len(keys)-1 { + conditions = originalConditions + } - var found bool - for _, field := range currentFields { - if field.Name == key && field.Relationship != nil { - found = true - // log.Printf("--> %+v\n", field.Name) - results := makeSlice(field.Struct.Type) - relation := field.Relationship - primaryName := currentScope.PrimaryField().Name - associationPrimaryKey := currentScope.New(results).PrimaryField().Name - - switch relation.Kind { - case "has_one": - if primaryKeys := currentScope.getColumnAsArray(primaryName); len(primaryKeys) > 0 { - condition := fmt.Sprintf("%v IN (?)", currentScope.Quote(relation.ForeignDBName)) - currentScope.NewDB().Where(condition, primaryKeys).Find(results, conditions...) - - resultValues := reflect.Indirect(reflect.ValueOf(results)) - for i := 0; i < resultValues.Len(); i++ { - result := resultValues.Index(i) - if currentIsSlice { - value := getRealValue(result, relation.ForeignFieldName) - objects := currentScope.IndirectValue() - for j := 0; j < objects.Len(); j++ { - if equalAsString(getRealValue(objects.Index(j), primaryName), value) { - reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result) - break - } - } - } else { - // log.Printf("--> %+v\n", result.Interface()) - err := currentScope.SetColumn(field, result) - if err != nil { - scope.Err(err) - return - } - // printutils.PrettyPrint(currentScope.Value) - } - } - // printutils.PrettyPrint(currentScope.Value) - } - case "has_many": - // log.Printf("--> %+v\n", key) - if primaryKeys := currentScope.getColumnAsArray(primaryName); len(primaryKeys) > 0 { - condition := fmt.Sprintf("%v IN (?)", currentScope.Quote(relation.ForeignDBName)) - currentScope.NewDB().Where(condition, primaryKeys).Find(results, conditions...) - resultValues := reflect.Indirect(reflect.ValueOf(results)) - if currentIsSlice { - for i := 0; i < resultValues.Len(); i++ { - result := resultValues.Index(i) - value := getRealValue(result, relation.ForeignFieldName) - objects := currentScope.IndirectValue() - for j := 0; j < objects.Len(); j++ { - object := reflect.Indirect(objects.Index(j)) - if equalAsString(getRealValue(object, primaryName), value) { - f := object.FieldByName(field.Name) - f.Set(reflect.Append(f, result)) - break - } - } - } - // printutils.PrettyPrint(currentScope.IndirectValue().Interface()) - } else { - currentScope.SetColumn(field, resultValues) - } - } - case "belongs_to": - if primaryKeys := currentScope.getColumnAsArray(relation.ForeignFieldName); len(primaryKeys) > 0 { - currentScope.NewDB().Where(primaryKeys).Find(results, conditions...) - resultValues := reflect.Indirect(reflect.ValueOf(results)) - for i := 0; i < resultValues.Len(); i++ { - result := resultValues.Index(i) - if currentIsSlice { - value := getRealValue(result, associationPrimaryKey) - objects := currentScope.IndirectValue() - for j := 0; j < objects.Len(); j++ { - object := reflect.Indirect(objects.Index(j)) - if equalAsString(getRealValue(object, relation.ForeignFieldName), value) { - object.FieldByName(field.Name).Set(result) - } - } - } else { - currentScope.SetColumn(field, result) - } - } - } - case "many_to_many": - // currentScope.Err(errors.New("not supported relation")) - fallthrough - default: - currentScope.Err(errors.New("not supported relation")) - } - break - } - } - - if !found { - value := reflect.ValueOf(currentScope.Value) - if value.Kind() == reflect.Slice && value.Type().Elem().Kind() == reflect.Interface { - value = value.Index(0).Elem() - } - scope.Err(fmt.Errorf("can't found field %s in %s", key, value.Type())) - return - } - - preloadMap[strings.Join(keys[:i+1], ".")] = true + for _, field := range currentFields { + if field.Name != key || field.Relationship == nil { + continue } - if i < len(keys)-1 { - // TODO: update current scope - currentScope = currentScope.getColumnsAsScope(key) - currentFields = currentScope.Fields() - currentIsSlice = currentScope.IndirectValue().Kind() == reflect.Slice + found = true + switch field.Relationship.Kind { + case "has_one": + currentScope.handleHasOnePreload(field, conditions) + case "has_many": + currentScope.handleHasManyPreload(field, conditions) + case "belongs_to": + currentScope.handleBelongsToPreload(field, conditions) + case "many_to_many": + fallthrough + default: + currentScope.Err(errors.New("not supported relation")) } + break + } + + if !found { + value := reflect.ValueOf(currentScope.Value) + if value.Kind() == reflect.Slice && value.Type().Elem().Kind() == reflect.Interface { + value = value.Index(0).Elem() + } + scope.Err(fmt.Errorf("can't found field %s in %s", key, value.Type())) + return + } + + preloadMap[strings.Join(keys[:i+1], ".")] = true + + nextLoop: + if i < len(keys)-1 { + currentScope = currentScope.getColumnsAsScope(key) + currentFields = currentScope.Fields() } } } + } func makeSlice(typ reflect.Type) interface{} { @@ -170,6 +96,105 @@ func makeSlice(typ reflect.Type) interface{} { return slice.Interface() } +func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { + primaryName := scope.PrimaryField().Name + primaryKeys := scope.getColumnAsArray(primaryName) + if len(primaryKeys) == 0 { + return + } + + results := makeSlice(field.Struct.Type) + relation := field.Relationship + condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) + resultValues := reflect.Indirect(reflect.ValueOf(results)) + + // TODO: handle error? + scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...) + + for i := 0; i < resultValues.Len(); i++ { + result := resultValues.Index(i) + if scope.IndirectValue().Kind() == reflect.Slice { + value := getRealValue(result, relation.ForeignFieldName) + objects := scope.IndirectValue() + for j := 0; j < objects.Len(); j++ { + if equalAsString(getRealValue(objects.Index(j), primaryName), value) { + reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result) + break + } + } + } else { + err := scope.SetColumn(field, result) + if err != nil { + scope.Err(err) + return + } + } + } +} + +func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { + primaryName := scope.PrimaryField().Name + primaryKeys := scope.getColumnAsArray(primaryName) + if len(primaryKeys) == 0 { + return + } + + results := makeSlice(field.Struct.Type) + relation := field.Relationship + condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) + resultValues := reflect.Indirect(reflect.ValueOf(results)) + + scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...) + + if scope.IndirectValue().Kind() == reflect.Slice { + for i := 0; i < resultValues.Len(); i++ { + result := resultValues.Index(i) + value := getRealValue(result, relation.ForeignFieldName) + objects := scope.IndirectValue() + for j := 0; j < objects.Len(); j++ { + object := reflect.Indirect(objects.Index(j)) + if equalAsString(getRealValue(object, primaryName), value) { + f := object.FieldByName(field.Name) + f.Set(reflect.Append(f, result)) + break + } + } + } + } else { + scope.SetColumn(field, resultValues) + } +} + +func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { + relation := field.Relationship + primaryKeys := scope.getColumnAsArray(relation.ForeignFieldName) + if len(primaryKeys) == 0 { + return + } + + results := makeSlice(field.Struct.Type) + resultValues := reflect.Indirect(reflect.ValueOf(results)) + associationPrimaryKey := scope.New(results).PrimaryField().Name + + scope.NewDB().Where(primaryKeys).Find(results, conditions...) + + for i := 0; i < resultValues.Len(); i++ { + result := resultValues.Index(i) + if scope.IndirectValue().Kind() == reflect.Slice { + value := getRealValue(result, associationPrimaryKey) + objects := scope.IndirectValue() + for j := 0; j < objects.Len(); j++ { + object := reflect.Indirect(objects.Index(j)) + if equalAsString(getRealValue(object, relation.ForeignFieldName), value) { + object.FieldByName(field.Name).Set(result) + } + } + } else { + scope.SetColumn(field, result) + } + } +} + func (scope *Scope) getColumnAsArray(column string) (columns []interface{}) { values := scope.IndirectValue() switch values.Kind() { @@ -185,10 +210,13 @@ func (scope *Scope) getColumnAsArray(column string) (columns []interface{}) { func (scope *Scope) getColumnsAsScope(column string) *Scope { values := scope.IndirectValue() - // log.Println(values.Type(), column) switch values.Kind() { case reflect.Slice: - fieldType, _ := values.Type().Elem().FieldByName(column) + model := values.Type().Elem() + if model.Kind() == reflect.Ptr { + model = model.Elem() + } + fieldType, _ := model.FieldByName(column) var columns reflect.Value if fieldType.Type.Kind() == reflect.Slice { columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType.Type.Elem()))).Elem() diff --git a/preload_test.go b/preload_test.go index c5d395d4..2929392c 100644 --- a/preload_test.go +++ b/preload_test.go @@ -2,7 +2,6 @@ package gorm_test import ( "encoding/json" - "log" "reflect" "testing" ) @@ -91,412 +90,505 @@ func TestPreload(t *testing.T) { } } -func TestNestedPreload(t *testing.T) { - log.SetFlags(log.Lshortfile) - // Struct: Level3 - { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1 Level1 - Level3ID uint - } - Level3 struct { - ID uint - Level2 Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - panic(err) +func TestNestedPreload1(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint } - - want := Level3{Level2: Level2{Level1: Level1{Value: "value"}}} - if err := DB.Create(&want).Error; err != nil { - panic(err) + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint } - - var got Level3 - if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { - panic(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + Level3 struct { + ID uint + Level2 Level2 } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) } - { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1s []Level1 - Level3ID uint - } - Level3 struct { - ID uint - Level2s []Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - panic(err) - } - want := Level3{ - Level2s: []Level2{ - { - Level1s: []Level1{ - {Value: "value1"}, - {Value: "value2"}, - }, - }, - { - Level1s: []Level1{ - {Value: "value3"}, - }, + want := Level3{Level2: Level2{Level1: Level1{Value: "value"}}} + if err := DB.Create(&want).Error; err != nil { + panic(err) + } + + var got Level3 + if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload2(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []*Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2s []Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := Level3{ + Level2s: []Level2{ + { + Level1s: []*Level1{ + &Level1{Value: "value1"}, + &Level1{Value: "value2"}, }, }, - } - if err := DB.Create(&want).Error; err != nil { - panic(err) - } - - var got Level3 - if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { - panic(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - } - { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1 Level1 - Level3ID uint - } - Level3 struct { - ID uint - Level2s []Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - panic(err) - } - - want := Level3{ - Level2s: []Level2{ - {Level1: Level1{Value: "value1"}}, - {Level1: Level1{Value: "value2"}}, + { + Level1s: []*Level1{ + &Level1{Value: "value3"}, + }, }, - } - if err := DB.Create(&want).Error; err != nil { - panic(err) - } - - var got Level3 - if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { - panic(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } + }, + } + if err := DB.Create(&want).Error; err != nil { + panic(err) } - { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1s []Level1 - Level3ID uint - } - Level3 struct { - ID uint - Level2 Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - panic(err) - } - want := Level3{ - Level2: Level2{ + var got Level3 + if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload3(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2s []Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := Level3{ + Level2s: []Level2{ + {Level1: Level1{Value: "value1"}}, + {Level1: Level1{Value: "value2"}}, + }, + } + if err := DB.Create(&want).Error; err != nil { + panic(err) + } + + var got Level3 + if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload4(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2 Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := Level3{ + Level2: Level2{ + Level1s: []Level1{ + Level1{Value: "value1"}, + Level1{Value: "value2"}, + }, + }, + } + if err := DB.Create(&want).Error; err != nil { + panic(err) + } + + var got Level3 + if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +// Slice: []Level3 +func TestNestedPreload5(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2 Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := make([]Level3, 2) + want[0] = Level3{Level2: Level2{Level1: Level1{Value: "value"}}} + if err := DB.Create(&want[0]).Error; err != nil { + panic(err) + } + want[1] = Level3{Level2: Level2{Level1: Level1{Value: "value2"}}} + if err := DB.Create(&want[1]).Error; err != nil { + panic(err) + } + + var got []Level3 + if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload6(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2s []Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := make([]Level3, 2) + want[0] = Level3{ + Level2s: []Level2{ + { Level1s: []Level1{ - Level1{Value: "value1"}, - Level1{Value: "value2"}, + {Value: "value1"}, + {Value: "value2"}, }, }, - } - if err := DB.Create(&want).Error; err != nil { - panic(err) - } - - var got Level3 - if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { - panic(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - } - - // Slice: []Level3 - { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1 Level1 - Level3ID uint - } - Level3 struct { - ID uint - Level2 Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - panic(err) - } - - want := make([]Level3, 2) - want[0] = Level3{Level2: Level2{Level1: Level1{Value: "value"}}} - if err := DB.Create(&want[0]).Error; err != nil { - panic(err) - } - want[1] = Level3{Level2: Level2{Level1: Level1{Value: "value2"}}} - if err := DB.Create(&want[1]).Error; err != nil { - panic(err) - } - - var got []Level3 - if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { - panic(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - } - { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1s []Level1 - Level3ID uint - } - Level3 struct { - ID uint - Level2s []Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - panic(err) - } - - want := make([]Level3, 2) - want[0] = Level3{ - Level2s: []Level2{ - { - Level1s: []Level1{ - {Value: "value1"}, - {Value: "value2"}, - }, - }, - { - Level1s: []Level1{ - {Value: "value3"}, - }, - }, - }, - } - if err := DB.Create(&want[0]).Error; err != nil { - panic(err) - } - want[1] = Level3{ - Level2s: []Level2{ - { - Level1s: []Level1{ - {Value: "value3"}, - {Value: "value4"}, - }, - }, - { - Level1s: []Level1{ - {Value: "value5"}, - }, - }, - }, - } - if err := DB.Create(&want[1]).Error; err != nil { - panic(err) - } - - var got []Level3 - if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { - panic(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - } - { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1 Level1 - Level3ID uint - } - Level3 struct { - ID uint - Level2s []Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - panic(err) - } - - want := make([]Level3, 2) - want[0] = Level3{ - Level2s: []Level2{ - {Level1: Level1{Value: "value1"}}, - {Level1: Level1{Value: "value2"}}, - }, - } - if err := DB.Create(&want[0]).Error; err != nil { - panic(err) - } - want[1] = Level3{ - Level2s: []Level2{ - {Level1: Level1{Value: "value3"}}, - {Level1: Level1{Value: "value4"}}, - }, - } - if err := DB.Create(&want[1]).Error; err != nil { - panic(err) - } - - var got []Level3 - if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { - panic(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - } - { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1s []Level1 - Level3ID uint - } - Level3 struct { - ID uint - Level2 Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - panic(err) - } - - want := make([]Level3, 2) - want[0] = Level3{ - Level2: Level2{ + { Level1s: []Level1{ - Level1{Value: "value1"}, - Level1{Value: "value2"}, + {Value: "value3"}, }, }, - } - if err := DB.Create(&want[0]).Error; err != nil { - panic(err) - } - want[1] = Level3{ - Level2: Level2{ + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + panic(err) + } + want[1] = Level3{ + Level2s: []Level2{ + { Level1s: []Level1{ - Level1{Value: "value3"}, - Level1{Value: "value4"}, + {Value: "value3"}, + {Value: "value4"}, }, }, - } - if err := DB.Create(&want[1]).Error; err != nil { - panic(err) - } + { + Level1s: []Level1{ + {Value: "value5"}, + }, + }, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + panic(err) + } - var got []Level3 - if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { - panic(err) - } + var got []Level3 + if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { + panic(err) + } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload7(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2s []Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := make([]Level3, 2) + want[0] = Level3{ + Level2s: []Level2{ + {Level1: Level1{Value: "value1"}}, + {Level1: Level1{Value: "value2"}}, + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + panic(err) + } + want[1] = Level3{ + Level2s: []Level2{ + {Level1: Level1{Value: "value3"}}, + {Level1: Level1{Value: "value4"}}, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + panic(err) + } + + var got []Level3 + if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload8(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2 Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := make([]Level3, 2) + want[0] = Level3{ + Level2: Level2{ + Level1s: []Level1{ + Level1{Value: "value1"}, + Level1{Value: "value2"}, + }, + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + panic(err) + } + want[1] = Level3{ + Level2: Level2{ + Level1s: []Level1{ + Level1{Value: "value3"}, + Level1{Value: "value4"}, + }, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + panic(err) + } + + var got []Level3 + if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload9(t *testing.T) { + type ( + Level0 struct { + ID uint + Value string + Level1ID uint + } + Level1 struct { + ID uint + Value string + Level2ID uint + Level2_1ID uint + Level0s []Level0 + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level2_1 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2 Level2 + Level2_1 Level2_1 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level2_1{}) + DB.DropTableIfExists(&Level1{}) + DB.DropTableIfExists(&Level0{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}).Error; err != nil { + panic(err) + } + + want := make([]Level3, 2) + want[0] = Level3{ + Level2: Level2{ + Level1s: []Level1{ + Level1{Value: "value1"}, + Level1{Value: "value2"}, + }, + }, + Level2_1: Level2_1{ + Level1s: []Level1{ + Level1{ + Value: "value1-1", + Level0s: []Level0{{Value: "Level0-1"}}, + }, + Level1{ + Value: "value2-2", + Level0s: []Level0{{Value: "Level0-2"}}, + }, + }, + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + panic(err) + } + want[1] = Level3{ + Level2: Level2{ + Level1s: []Level1{ + Level1{Value: "value3"}, + Level1{Value: "value4"}, + }, + }, + Level2_1: Level2_1{ + Level1s: []Level1{ + Level1{Value: "value3-3"}, + Level1{Value: "value4-4"}, + }, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + panic(err) + } + + var got []Level3 + if err := DB.Preload("Level2").Preload("Level2.Level1s").Preload("Level2_1").Preload("Level2_1.Level1s").Preload("Level2_1.Level1s.Level0s").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } }