diff --git a/callbacks/preload.go b/callbacks/preload.go index 7e3810b5..f48777c2 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -22,7 +22,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { ) if len(rels) > 1 { - reflectValue = schema.GetRelationsValues(reflectValue, rels[:len(rels)]) + reflectValue = schema.GetRelationsValues(reflectValue, rels[:len(rels)-1]) } if rel.JoinTable != nil { @@ -107,9 +107,9 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { rel.Field.Set(data, reflectResults.Index(i).Interface()) case reflect.Slice, reflect.Array: if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { - rel.Field.Set(data, reflect.Append(reflectFieldValue, reflectResults.Index(i).Addr()).Interface()) - } else { rel.Field.Set(data, reflect.Append(reflectFieldValue, reflectResults.Index(i)).Interface()) + } else { + rel.Field.Set(data, reflect.Append(reflectFieldValue, reflectResults.Index(i).Elem()).Interface()) } } } diff --git a/schema/schema.go b/schema/schema.go index 79faae12..caae55ac 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -49,7 +49,7 @@ func (schema Schema) String() string { } func (schema Schema) MakeSlice() reflect.Value { - slice := reflect.MakeSlice(reflect.SliceOf(schema.ModelType), 0, 0) + slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 0) results := reflect.New(slice.Type()) results.Elem().Set(slice) return results diff --git a/schema/utils.go b/schema/utils.go index c47f1984..f7808f0e 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -55,17 +55,21 @@ func removeSettingFromTag(tag reflect.StructTag, name string) reflect.StructTag // GetRelationsValues get relations's values from a reflect value func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) { for _, rel := range rels { - reflectResults = reflect.MakeSlice(reflect.SliceOf(rel.FieldSchema.ModelType), 0, 0) + reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 0) appendToResults := func(value reflect.Value) { if _, isZero := rel.Field.ValueOf(value); !isZero { result := reflect.Indirect(rel.Field.ReflectValueOf(value)) switch result.Kind() { case reflect.Struct: - reflectResults = reflect.Append(reflectResults, result) + reflectResults = reflect.Append(reflectResults, result.Addr()) case reflect.Slice, reflect.Array: - for i := 0; i < value.Len(); i++ { - reflectResults = reflect.Append(reflectResults, reflect.Indirect(value.Index(i))) + for i := 0; i < result.Len(); i++ { + if result.Index(i).Kind() == reflect.Ptr { + reflectResults = reflect.Append(reflectResults, result.Index(i)) + } else { + reflectResults = reflect.Append(reflectResults, result.Index(i).Addr()) + } } } } diff --git a/tests/preload_test.go b/tests/preload_test.go new file mode 100644 index 00000000..74f21f55 --- /dev/null +++ b/tests/preload_test.go @@ -0,0 +1,58 @@ +package tests_test + +import ( + "strconv" + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +func TestNestedPreload(t *testing.T) { + var user = *GetUser("nested_preload", Config{Pets: 2}) + + for idx, pet := range user.Pets { + pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(idx+1)} + } + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var user2 User + DB.Preload("Pets.Toy").Find(&user2, "id = ?", user.ID) + + CheckUser(t, user2, user) +} + +func TestNestedPreloadForSlice(t *testing.T) { + var users = []User{ + *GetUser("slice_nested_preload_1", Config{Pets: 2}), + *GetUser("slice_nested_preload_2", Config{Pets: 0}), + *GetUser("slice_nested_preload_3", Config{Pets: 3}), + } + + for _, user := range users { + for idx, pet := range user.Pets { + pet.Toy = Toy{Name: user.Name + "_toy_nested_preload_" + strconv.Itoa(idx+1)} + } + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var userIDs []uint + for _, user := range users { + userIDs = append(userIDs, user.ID) + } + + var users2 []User + DB.Preload("Pets.Toy").Find(&users2, "id IN ?", userIDs) + + for idx, user := range users2 { + CheckUser(t, user, users[idx]) + } +} + +func TestPreloadWithConds(t *testing.T) { +}