mirror of https://github.com/go-gorm/gorm.git
Add indirect method
This commit is contained in:
commit
896ee534e2
|
@ -89,7 +89,7 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{})
|
||||||
|
|
||||||
// assign find results
|
// assign find results
|
||||||
var (
|
var (
|
||||||
resultsValue = reflect.Indirect(reflect.ValueOf(results))
|
resultsValue = indirect(reflect.ValueOf(results))
|
||||||
indirectScopeValue = scope.IndirectValue()
|
indirectScopeValue = scope.IndirectValue()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -98,7 +98,7 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{})
|
||||||
if indirectScopeValue.Kind() == reflect.Slice {
|
if indirectScopeValue.Kind() == reflect.Slice {
|
||||||
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
|
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
|
||||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||||
if indirectValue := reflect.Indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) {
|
if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) {
|
||||||
indirectValue.FieldByName(field.Name).Set(result)
|
indirectValue.FieldByName(field.Name).Set(result)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@ -125,7 +125,7 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{})
|
||||||
|
|
||||||
// assign find results
|
// assign find results
|
||||||
var (
|
var (
|
||||||
resultsValue = reflect.Indirect(reflect.ValueOf(results))
|
resultsValue = indirect(reflect.ValueOf(results))
|
||||||
indirectScopeValue = scope.IndirectValue()
|
indirectScopeValue = scope.IndirectValue()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -134,7 +134,7 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{})
|
||||||
result := resultsValue.Index(i)
|
result := resultsValue.Index(i)
|
||||||
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
|
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
|
||||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||||
object := reflect.Indirect(indirectScopeValue.Index(j))
|
object := indirect(indirectScopeValue.Index(j))
|
||||||
if equalAsString(getValueFromFields(object, relation.AssociationForeignFieldNames), foreignValues) {
|
if equalAsString(getValueFromFields(object, relation.AssociationForeignFieldNames), foreignValues) {
|
||||||
objectField := object.FieldByName(field.Name)
|
objectField := object.FieldByName(field.Name)
|
||||||
objectField.Set(reflect.Append(objectField, result))
|
objectField.Set(reflect.Append(objectField, result))
|
||||||
|
@ -163,7 +163,7 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
|
||||||
|
|
||||||
// assign find results
|
// assign find results
|
||||||
var (
|
var (
|
||||||
resultsValue = reflect.Indirect(reflect.ValueOf(results))
|
resultsValue = indirect(reflect.ValueOf(results))
|
||||||
indirectScopeValue = scope.IndirectValue()
|
indirectScopeValue = scope.IndirectValue()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -172,7 +172,7 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
|
||||||
if indirectScopeValue.Kind() == reflect.Slice {
|
if indirectScopeValue.Kind() == reflect.Slice {
|
||||||
value := getValueFromFields(result, relation.AssociationForeignFieldNames)
|
value := getValueFromFields(result, relation.AssociationForeignFieldNames)
|
||||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||||
object := reflect.Indirect(indirectScopeValue.Index(j))
|
object := indirect(indirectScopeValue.Index(j))
|
||||||
if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
|
if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
|
||||||
object.FieldByName(field.Name).Set(result)
|
object.FieldByName(field.Name).Set(result)
|
||||||
}
|
}
|
||||||
|
@ -265,7 +265,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
|
||||||
|
|
||||||
if indirectScopeValue.Kind() == reflect.Slice {
|
if indirectScopeValue.Kind() == reflect.Slice {
|
||||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||||
object := reflect.Indirect(indirectScopeValue.Index(j))
|
object := indirect(indirectScopeValue.Index(j))
|
||||||
fieldsSourceMap[toString(getValueFromFields(object, foreignFieldNames))] = object.FieldByName(field.Name)
|
fieldsSourceMap[toString(getValueFromFields(object, foreignFieldNames))] = object.FieldByName(field.Name)
|
||||||
}
|
}
|
||||||
} else if indirectScopeValue.IsValid() {
|
} else if indirectScopeValue.IsValid() {
|
||||||
|
|
|
@ -611,6 +611,70 @@ func TestNestedPreload9(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Level1A struct {
|
||||||
|
ID uint
|
||||||
|
Value string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Level1B struct {
|
||||||
|
ID uint
|
||||||
|
Value string
|
||||||
|
Level2s []*Level2
|
||||||
|
}
|
||||||
|
|
||||||
|
type Level2 struct {
|
||||||
|
ID uint
|
||||||
|
Value string
|
||||||
|
Level1AID sql.NullInt64
|
||||||
|
Level1A *Level1A
|
||||||
|
Level1BID sql.NullInt64
|
||||||
|
Level1B *Level1B
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNestedPreload10(t *testing.T) {
|
||||||
|
DB.DropTableIfExists(&Level2{})
|
||||||
|
DB.DropTableIfExists(&Level1B{})
|
||||||
|
DB.DropTableIfExists(&Level1A{})
|
||||||
|
|
||||||
|
if err := DB.AutoMigrate(&Level1A{}, &Level1B{}, &Level2{}).Error; err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
level1A := &Level1A{Value: "foo"}
|
||||||
|
if err := DB.Save(&level1A).Error; err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []*Level1B{
|
||||||
|
&Level1B{
|
||||||
|
Value: "bar",
|
||||||
|
Level2s: []*Level2{
|
||||||
|
&Level2{
|
||||||
|
Value: "qux",
|
||||||
|
Level1A: level1A,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
&Level1B{
|
||||||
|
Value: "bar 2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, level1B := range want {
|
||||||
|
if err := DB.Save(level1B).Error; err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var got []*Level1B
|
||||||
|
if err := DB.Preload("Level2s.Level1A").Find(&got).Error; err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
|
func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
|
||||||
if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" {
|
if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" {
|
||||||
return
|
return
|
||||||
|
|
10
scope.go
10
scope.go
|
@ -16,7 +16,6 @@ type Scope struct {
|
||||||
Sql string
|
Sql string
|
||||||
SqlVars []interface{}
|
SqlVars []interface{}
|
||||||
db *DB
|
db *DB
|
||||||
indirectValue *reflect.Value
|
|
||||||
instanceID string
|
instanceID string
|
||||||
primaryKeyField *Field
|
primaryKeyField *Field
|
||||||
skipLeft bool
|
skipLeft bool
|
||||||
|
@ -25,14 +24,7 @@ type Scope struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) IndirectValue() reflect.Value {
|
func (scope *Scope) IndirectValue() reflect.Value {
|
||||||
if scope.indirectValue == nil {
|
return indirect(reflect.ValueOf(scope.Value))
|
||||||
value := reflect.Indirect(reflect.ValueOf(scope.Value))
|
|
||||||
if value.Kind() == reflect.Ptr {
|
|
||||||
value = value.Elem()
|
|
||||||
}
|
|
||||||
scope.indirectValue = &value
|
|
||||||
}
|
|
||||||
return *scope.indirectValue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// New create a new Scope without search information
|
// New create a new Scope without search information
|
||||||
|
|
|
@ -13,7 +13,7 @@ func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (r
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
for i := 0; i < indirectValue.Len(); i++ {
|
for i := 0; i < indirectValue.Len(); i++ {
|
||||||
var result []interface{}
|
var result []interface{}
|
||||||
var object = reflect.Indirect(indirectValue.Index(i))
|
var object = indirect(indirectValue.Index(i))
|
||||||
for _, column := range columns {
|
for _, column := range columns {
|
||||||
result = append(result, object.FieldByName(column).Interface())
|
result = append(result, object.FieldByName(column).Interface())
|
||||||
}
|
}
|
||||||
|
@ -44,7 +44,7 @@ func (scope *Scope) getColumnAsScope(column string) *Scope {
|
||||||
results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem()
|
results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem()
|
||||||
|
|
||||||
for i := 0; i < indirectScopeValue.Len(); i++ {
|
for i := 0; i < indirectScopeValue.Len(); i++ {
|
||||||
result := reflect.Indirect(reflect.Indirect(indirectScopeValue.Index(i)).FieldByName(column))
|
result := indirect(indirect(indirectScopeValue.Index(i)).FieldByName(column))
|
||||||
|
|
||||||
if result.Kind() == reflect.Slice {
|
if result.Kind() == reflect.Slice {
|
||||||
for j := 0; j < result.Len(); j++ {
|
for j := 0; j < result.Len(); j++ {
|
||||||
|
|
8
utils.go
8
utils.go
|
@ -3,6 +3,7 @@ package gorm
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
@ -102,6 +103,13 @@ func Expr(expression string, args ...interface{}) *expr {
|
||||||
return &expr{expr: expression, args: args}
|
return &expr{expr: expression, args: args}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func indirect(reflectValue reflect.Value) reflect.Value {
|
||||||
|
for reflectValue.Kind() == reflect.Ptr {
|
||||||
|
reflectValue = reflectValue.Elem()
|
||||||
|
}
|
||||||
|
return reflectValue
|
||||||
|
}
|
||||||
|
|
||||||
func toQueryMarks(primaryValues [][]interface{}) string {
|
func toQueryMarks(primaryValues [][]interface{}) string {
|
||||||
var results []string
|
var results []string
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue