From aac52fdcf8c0cef3b4bc4139a7a7ef72319ca6d1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 11 Feb 2015 17:58:19 +0800 Subject: [PATCH] Fix Preload with slice of pointer --- preload.go | 16 ++++++++++------ preload_test.go | 8 ++++++++ scope.go | 7 ++++++- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/preload.go b/preload.go index 1e6c26b3..4e48db53 100644 --- a/preload.go +++ b/preload.go @@ -8,7 +8,7 @@ import ( ) func getFieldValue(value reflect.Value, field string) interface{} { - result := value.FieldByName(field).Interface() + result := reflect.Indirect(value).FieldByName(field).Interface() if r, ok := result.(driver.Valuer); ok { result, _ = r.Value() } @@ -25,7 +25,11 @@ func Preload(scope *Scope) { var isSlice bool if scope.IndirectValue().Kind() == reflect.Slice { isSlice = true - elem := reflect.New(scope.IndirectValue().Type().Elem()).Elem() + typ := scope.IndirectValue().Type().Elem() + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + elem := reflect.New(typ).Elem() fields = scope.New(elem.Addr().Interface()).Fields() } else { fields = scope.Fields() @@ -53,7 +57,7 @@ func Preload(scope *Scope) { objects := scope.IndirectValue() for j := 0; j < objects.Len(); j++ { if equalAsString(getFieldValue(objects.Index(j), primaryName), value) { - objects.Index(j).FieldByName(field.Name).Set(result) + reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result) break } } @@ -71,7 +75,7 @@ func Preload(scope *Scope) { value := getFieldValue(result, relation.ForeignKey) objects := scope.IndirectValue() for j := 0; j < objects.Len(); j++ { - object := objects.Index(j) + object := reflect.Indirect(objects.Index(j)) if equalAsString(getFieldValue(object, primaryName), value) { f := object.FieldByName(field.Name) f.Set(reflect.Append(f, result)) @@ -91,7 +95,7 @@ func Preload(scope *Scope) { value := getFieldValue(result, associationPrimaryKey) objects := scope.IndirectValue() for j := 0; j < objects.Len(); j++ { - object := objects.Index(j) + object := reflect.Indirect(objects.Index(j)) if equalAsString(getFieldValue(object, relation.ForeignKey), value) { object.FieldByName(field.Name).Set(result) break @@ -130,7 +134,7 @@ func (scope *Scope) getColumnAsArray(column string) (primaryKeys []interface{}) case reflect.Slice: for i := 0; i < values.Len(); i++ { value := values.Index(i) - primaryKeys = append(primaryKeys, value.FieldByName(column).Interface()) + primaryKeys = append(primaryKeys, reflect.Indirect(value).FieldByName(column).Interface()) } case reflect.Struct: return []interface{}{values.FieldByName(column).Interface()} diff --git a/preload_test.go b/preload_test.go index 1107f234..2a0e967b 100644 --- a/preload_test.go +++ b/preload_test.go @@ -76,4 +76,12 @@ func TestPreload(t *testing.T) { for _, user := range users { checkUserHasPreloadData(user, t) } + + var users2 []*User + DB.Where("role = ?", "Preload").Preload("BillingAddress").Preload("ShippingAddress"). + Preload("CreditCard").Preload("Emails").Find(&users2) + + for _, user := range users2 { + checkUserHasPreloadData(*user, t) + } } diff --git a/scope.go b/scope.go index 076fc336..d1b11615 100644 --- a/scope.go +++ b/scope.go @@ -101,8 +101,13 @@ func (scope *Scope) PrimaryKeyField() *Field { var indirectValue = scope.IndirectValue() clone := scope + // FIXME if indirectValue.Kind() == reflect.Slice { - clone = scope.New(reflect.New(indirectValue.Type().Elem()).Elem().Interface()) + typ := indirectValue.Type().Elem() + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + clone = scope.New(reflect.New(typ).Elem().Interface()) } for _, field := range clone.Fields() {