mirror of https://github.com/go-gorm/gorm.git
Fix Preload with slice of pointer
This commit is contained in:
parent
f97e2c088e
commit
aac52fdcf8
16
preload.go
16
preload.go
|
@ -8,7 +8,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func getFieldValue(value reflect.Value, field string) interface{} {
|
func getFieldValue(value reflect.Value, field string) interface{} {
|
||||||
result := value.FieldByName(field).Interface()
|
result := reflect.Indirect(value).FieldByName(field).Interface()
|
||||||
if r, ok := result.(driver.Valuer); ok {
|
if r, ok := result.(driver.Valuer); ok {
|
||||||
result, _ = r.Value()
|
result, _ = r.Value()
|
||||||
}
|
}
|
||||||
|
@ -25,7 +25,11 @@ func Preload(scope *Scope) {
|
||||||
var isSlice bool
|
var isSlice bool
|
||||||
if scope.IndirectValue().Kind() == reflect.Slice {
|
if scope.IndirectValue().Kind() == reflect.Slice {
|
||||||
isSlice = true
|
isSlice = true
|
||||||
elem := reflect.New(scope.IndirectValue().Type().Elem()).Elem()
|
typ := scope.IndirectValue().Type().Elem()
|
||||||
|
if typ.Kind() == reflect.Ptr {
|
||||||
|
typ = typ.Elem()
|
||||||
|
}
|
||||||
|
elem := reflect.New(typ).Elem()
|
||||||
fields = scope.New(elem.Addr().Interface()).Fields()
|
fields = scope.New(elem.Addr().Interface()).Fields()
|
||||||
} else {
|
} else {
|
||||||
fields = scope.Fields()
|
fields = scope.Fields()
|
||||||
|
@ -53,7 +57,7 @@ func Preload(scope *Scope) {
|
||||||
objects := scope.IndirectValue()
|
objects := scope.IndirectValue()
|
||||||
for j := 0; j < objects.Len(); j++ {
|
for j := 0; j < objects.Len(); j++ {
|
||||||
if equalAsString(getFieldValue(objects.Index(j), primaryName), value) {
|
if equalAsString(getFieldValue(objects.Index(j), primaryName), value) {
|
||||||
objects.Index(j).FieldByName(field.Name).Set(result)
|
reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -71,7 +75,7 @@ func Preload(scope *Scope) {
|
||||||
value := getFieldValue(result, relation.ForeignKey)
|
value := getFieldValue(result, relation.ForeignKey)
|
||||||
objects := scope.IndirectValue()
|
objects := scope.IndirectValue()
|
||||||
for j := 0; j < objects.Len(); j++ {
|
for j := 0; j < objects.Len(); j++ {
|
||||||
object := objects.Index(j)
|
object := reflect.Indirect(objects.Index(j))
|
||||||
if equalAsString(getFieldValue(object, primaryName), value) {
|
if equalAsString(getFieldValue(object, primaryName), value) {
|
||||||
f := object.FieldByName(field.Name)
|
f := object.FieldByName(field.Name)
|
||||||
f.Set(reflect.Append(f, result))
|
f.Set(reflect.Append(f, result))
|
||||||
|
@ -91,7 +95,7 @@ func Preload(scope *Scope) {
|
||||||
value := getFieldValue(result, associationPrimaryKey)
|
value := getFieldValue(result, associationPrimaryKey)
|
||||||
objects := scope.IndirectValue()
|
objects := scope.IndirectValue()
|
||||||
for j := 0; j < objects.Len(); j++ {
|
for j := 0; j < objects.Len(); j++ {
|
||||||
object := objects.Index(j)
|
object := reflect.Indirect(objects.Index(j))
|
||||||
if equalAsString(getFieldValue(object, relation.ForeignKey), value) {
|
if equalAsString(getFieldValue(object, relation.ForeignKey), value) {
|
||||||
object.FieldByName(field.Name).Set(result)
|
object.FieldByName(field.Name).Set(result)
|
||||||
break
|
break
|
||||||
|
@ -130,7 +134,7 @@ func (scope *Scope) getColumnAsArray(column string) (primaryKeys []interface{})
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
for i := 0; i < values.Len(); i++ {
|
for i := 0; i < values.Len(); i++ {
|
||||||
value := values.Index(i)
|
value := values.Index(i)
|
||||||
primaryKeys = append(primaryKeys, value.FieldByName(column).Interface())
|
primaryKeys = append(primaryKeys, reflect.Indirect(value).FieldByName(column).Interface())
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
return []interface{}{values.FieldByName(column).Interface()}
|
return []interface{}{values.FieldByName(column).Interface()}
|
||||||
|
|
|
@ -76,4 +76,12 @@ func TestPreload(t *testing.T) {
|
||||||
for _, user := range users {
|
for _, user := range users {
|
||||||
checkUserHasPreloadData(user, t)
|
checkUserHasPreloadData(user, t)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var users2 []*User
|
||||||
|
DB.Where("role = ?", "Preload").Preload("BillingAddress").Preload("ShippingAddress").
|
||||||
|
Preload("CreditCard").Preload("Emails").Find(&users2)
|
||||||
|
|
||||||
|
for _, user := range users2 {
|
||||||
|
checkUserHasPreloadData(*user, t)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
7
scope.go
7
scope.go
|
@ -101,8 +101,13 @@ func (scope *Scope) PrimaryKeyField() *Field {
|
||||||
var indirectValue = scope.IndirectValue()
|
var indirectValue = scope.IndirectValue()
|
||||||
|
|
||||||
clone := scope
|
clone := scope
|
||||||
|
// FIXME
|
||||||
if indirectValue.Kind() == reflect.Slice {
|
if indirectValue.Kind() == reflect.Slice {
|
||||||
clone = scope.New(reflect.New(indirectValue.Type().Elem()).Elem().Interface())
|
typ := indirectValue.Type().Elem()
|
||||||
|
if typ.Kind() == reflect.Ptr {
|
||||||
|
typ = typ.Elem()
|
||||||
|
}
|
||||||
|
clone = scope.New(reflect.New(typ).Elem().Interface())
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, field := range clone.Fields() {
|
for _, field := range clone.Fields() {
|
||||||
|
|
Loading…
Reference in New Issue