Fix Preload with slice of pointer

This commit is contained in:
Jinzhu 2015-02-11 17:58:19 +08:00
parent f97e2c088e
commit aac52fdcf8
3 changed files with 24 additions and 7 deletions

View File

@ -8,7 +8,7 @@ import (
) )
func getFieldValue(value reflect.Value, field string) interface{} { func getFieldValue(value reflect.Value, field string) interface{} {
result := value.FieldByName(field).Interface() result := reflect.Indirect(value).FieldByName(field).Interface()
if r, ok := result.(driver.Valuer); ok { if r, ok := result.(driver.Valuer); ok {
result, _ = r.Value() result, _ = r.Value()
} }
@ -25,7 +25,11 @@ func Preload(scope *Scope) {
var isSlice bool var isSlice bool
if scope.IndirectValue().Kind() == reflect.Slice { if scope.IndirectValue().Kind() == reflect.Slice {
isSlice = true isSlice = true
elem := reflect.New(scope.IndirectValue().Type().Elem()).Elem() typ := scope.IndirectValue().Type().Elem()
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
elem := reflect.New(typ).Elem()
fields = scope.New(elem.Addr().Interface()).Fields() fields = scope.New(elem.Addr().Interface()).Fields()
} else { } else {
fields = scope.Fields() fields = scope.Fields()
@ -53,7 +57,7 @@ func Preload(scope *Scope) {
objects := scope.IndirectValue() objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ { for j := 0; j < objects.Len(); j++ {
if equalAsString(getFieldValue(objects.Index(j), primaryName), value) { if equalAsString(getFieldValue(objects.Index(j), primaryName), value) {
objects.Index(j).FieldByName(field.Name).Set(result) reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result)
break break
} }
} }
@ -71,7 +75,7 @@ func Preload(scope *Scope) {
value := getFieldValue(result, relation.ForeignKey) value := getFieldValue(result, relation.ForeignKey)
objects := scope.IndirectValue() objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ { for j := 0; j < objects.Len(); j++ {
object := objects.Index(j) object := reflect.Indirect(objects.Index(j))
if equalAsString(getFieldValue(object, primaryName), value) { if equalAsString(getFieldValue(object, primaryName), value) {
f := object.FieldByName(field.Name) f := object.FieldByName(field.Name)
f.Set(reflect.Append(f, result)) f.Set(reflect.Append(f, result))
@ -91,7 +95,7 @@ func Preload(scope *Scope) {
value := getFieldValue(result, associationPrimaryKey) value := getFieldValue(result, associationPrimaryKey)
objects := scope.IndirectValue() objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ { for j := 0; j < objects.Len(); j++ {
object := objects.Index(j) object := reflect.Indirect(objects.Index(j))
if equalAsString(getFieldValue(object, relation.ForeignKey), value) { if equalAsString(getFieldValue(object, relation.ForeignKey), value) {
object.FieldByName(field.Name).Set(result) object.FieldByName(field.Name).Set(result)
break break
@ -130,7 +134,7 @@ func (scope *Scope) getColumnAsArray(column string) (primaryKeys []interface{})
case reflect.Slice: case reflect.Slice:
for i := 0; i < values.Len(); i++ { for i := 0; i < values.Len(); i++ {
value := values.Index(i) value := values.Index(i)
primaryKeys = append(primaryKeys, value.FieldByName(column).Interface()) primaryKeys = append(primaryKeys, reflect.Indirect(value).FieldByName(column).Interface())
} }
case reflect.Struct: case reflect.Struct:
return []interface{}{values.FieldByName(column).Interface()} return []interface{}{values.FieldByName(column).Interface()}

View File

@ -76,4 +76,12 @@ func TestPreload(t *testing.T) {
for _, user := range users { for _, user := range users {
checkUserHasPreloadData(user, t) checkUserHasPreloadData(user, t)
} }
var users2 []*User
DB.Where("role = ?", "Preload").Preload("BillingAddress").Preload("ShippingAddress").
Preload("CreditCard").Preload("Emails").Find(&users2)
for _, user := range users2 {
checkUserHasPreloadData(*user, t)
}
} }

View File

@ -101,8 +101,13 @@ func (scope *Scope) PrimaryKeyField() *Field {
var indirectValue = scope.IndirectValue() var indirectValue = scope.IndirectValue()
clone := scope clone := scope
// FIXME
if indirectValue.Kind() == reflect.Slice { if indirectValue.Kind() == reflect.Slice {
clone = scope.New(reflect.New(indirectValue.Type().Elem()).Elem().Interface()) typ := indirectValue.Type().Elem()
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
clone = scope.New(reflect.New(typ).Elem().Interface())
} }
for _, field := range clone.Fields() { for _, field := range clone.Fields() {