Add indirect method

This commit is contained in:
Jinzhu 2016-01-18 12:20:27 +08:00
commit 896ee534e2
5 changed files with 82 additions and 18 deletions

View File

@ -89,7 +89,7 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{})
// assign find results // assign find results
var ( var (
resultsValue = reflect.Indirect(reflect.ValueOf(results)) resultsValue = indirect(reflect.ValueOf(results))
indirectScopeValue = scope.IndirectValue() indirectScopeValue = scope.IndirectValue()
) )
@ -98,7 +98,7 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{})
if indirectScopeValue.Kind() == reflect.Slice { if indirectScopeValue.Kind() == reflect.Slice {
foreignValues := getValueFromFields(result, relation.ForeignFieldNames) foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
for j := 0; j < indirectScopeValue.Len(); j++ { for j := 0; j < indirectScopeValue.Len(); j++ {
if indirectValue := reflect.Indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) { if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) {
indirectValue.FieldByName(field.Name).Set(result) indirectValue.FieldByName(field.Name).Set(result)
break break
} }
@ -125,7 +125,7 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{})
// assign find results // assign find results
var ( var (
resultsValue = reflect.Indirect(reflect.ValueOf(results)) resultsValue = indirect(reflect.ValueOf(results))
indirectScopeValue = scope.IndirectValue() indirectScopeValue = scope.IndirectValue()
) )
@ -134,7 +134,7 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{})
result := resultsValue.Index(i) result := resultsValue.Index(i)
foreignValues := getValueFromFields(result, relation.ForeignFieldNames) foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
for j := 0; j < indirectScopeValue.Len(); j++ { for j := 0; j < indirectScopeValue.Len(); j++ {
object := reflect.Indirect(indirectScopeValue.Index(j)) object := indirect(indirectScopeValue.Index(j))
if equalAsString(getValueFromFields(object, relation.AssociationForeignFieldNames), foreignValues) { if equalAsString(getValueFromFields(object, relation.AssociationForeignFieldNames), foreignValues) {
objectField := object.FieldByName(field.Name) objectField := object.FieldByName(field.Name)
objectField.Set(reflect.Append(objectField, result)) objectField.Set(reflect.Append(objectField, result))
@ -163,7 +163,7 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
// assign find results // assign find results
var ( var (
resultsValue = reflect.Indirect(reflect.ValueOf(results)) resultsValue = indirect(reflect.ValueOf(results))
indirectScopeValue = scope.IndirectValue() indirectScopeValue = scope.IndirectValue()
) )
@ -172,7 +172,7 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
if indirectScopeValue.Kind() == reflect.Slice { if indirectScopeValue.Kind() == reflect.Slice {
value := getValueFromFields(result, relation.AssociationForeignFieldNames) value := getValueFromFields(result, relation.AssociationForeignFieldNames)
for j := 0; j < indirectScopeValue.Len(); j++ { for j := 0; j < indirectScopeValue.Len(); j++ {
object := reflect.Indirect(indirectScopeValue.Index(j)) object := indirect(indirectScopeValue.Index(j))
if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) { if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
object.FieldByName(field.Name).Set(result) object.FieldByName(field.Name).Set(result)
} }
@ -265,7 +265,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
if indirectScopeValue.Kind() == reflect.Slice { if indirectScopeValue.Kind() == reflect.Slice {
for j := 0; j < indirectScopeValue.Len(); j++ { for j := 0; j < indirectScopeValue.Len(); j++ {
object := reflect.Indirect(indirectScopeValue.Index(j)) object := indirect(indirectScopeValue.Index(j))
fieldsSourceMap[toString(getValueFromFields(object, foreignFieldNames))] = object.FieldByName(field.Name) fieldsSourceMap[toString(getValueFromFields(object, foreignFieldNames))] = object.FieldByName(field.Name)
} }
} else if indirectScopeValue.IsValid() { } else if indirectScopeValue.IsValid() {

View File

@ -611,6 +611,70 @@ func TestNestedPreload9(t *testing.T) {
} }
} }
type Level1A struct {
ID uint
Value string
}
type Level1B struct {
ID uint
Value string
Level2s []*Level2
}
type Level2 struct {
ID uint
Value string
Level1AID sql.NullInt64
Level1A *Level1A
Level1BID sql.NullInt64
Level1B *Level1B
}
func TestNestedPreload10(t *testing.T) {
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1B{})
DB.DropTableIfExists(&Level1A{})
if err := DB.AutoMigrate(&Level1A{}, &Level1B{}, &Level2{}).Error; err != nil {
t.Error(err)
}
level1A := &Level1A{Value: "foo"}
if err := DB.Save(&level1A).Error; err != nil {
t.Error(err)
}
want := []*Level1B{
&Level1B{
Value: "bar",
Level2s: []*Level2{
&Level2{
Value: "qux",
Level1A: level1A,
},
},
},
&Level1B{
Value: "bar 2",
},
}
for _, level1B := range want {
if err := DB.Save(level1B).Error; err != nil {
t.Error(err)
}
}
var got []*Level1B
if err := DB.Preload("Level2s.Level1A").Find(&got).Error; err != nil {
t.Error(err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
}
func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" { if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" {
return return

View File

@ -16,7 +16,6 @@ type Scope struct {
Sql string Sql string
SqlVars []interface{} SqlVars []interface{}
db *DB db *DB
indirectValue *reflect.Value
instanceID string instanceID string
primaryKeyField *Field primaryKeyField *Field
skipLeft bool skipLeft bool
@ -25,14 +24,7 @@ type Scope struct {
} }
func (scope *Scope) IndirectValue() reflect.Value { func (scope *Scope) IndirectValue() reflect.Value {
if scope.indirectValue == nil { return indirect(reflect.ValueOf(scope.Value))
value := reflect.Indirect(reflect.ValueOf(scope.Value))
if value.Kind() == reflect.Ptr {
value = value.Elem()
}
scope.indirectValue = &value
}
return *scope.indirectValue
} }
// New create a new Scope without search information // New create a new Scope without search information

View File

@ -13,7 +13,7 @@ func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (r
case reflect.Slice: case reflect.Slice:
for i := 0; i < indirectValue.Len(); i++ { for i := 0; i < indirectValue.Len(); i++ {
var result []interface{} var result []interface{}
var object = reflect.Indirect(indirectValue.Index(i)) var object = indirect(indirectValue.Index(i))
for _, column := range columns { for _, column := range columns {
result = append(result, object.FieldByName(column).Interface()) result = append(result, object.FieldByName(column).Interface())
} }
@ -44,7 +44,7 @@ func (scope *Scope) getColumnAsScope(column string) *Scope {
results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem() results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem()
for i := 0; i < indirectScopeValue.Len(); i++ { for i := 0; i < indirectScopeValue.Len(); i++ {
result := reflect.Indirect(reflect.Indirect(indirectScopeValue.Index(i)).FieldByName(column)) result := indirect(indirect(indirectScopeValue.Index(i)).FieldByName(column))
if result.Kind() == reflect.Slice { if result.Kind() == reflect.Slice {
for j := 0; j < result.Len(); j++ { for j := 0; j < result.Len(); j++ {

View File

@ -3,6 +3,7 @@ package gorm
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"reflect"
"strings" "strings"
"sync" "sync"
) )
@ -102,6 +103,13 @@ func Expr(expression string, args ...interface{}) *expr {
return &expr{expr: expression, args: args} return &expr{expr: expression, args: args}
} }
func indirect(reflectValue reflect.Value) reflect.Value {
for reflectValue.Kind() == reflect.Ptr {
reflectValue = reflectValue.Elem()
}
return reflectValue
}
func toQueryMarks(primaryValues [][]interface{}) string { func toQueryMarks(primaryValues [][]interface{}) string {
var results []string var results []string