forked from mirror/gorm
Support slice of pointers
This commit is contained in:
parent
a9ac3e10a7
commit
3bd5131132
|
@ -11,6 +11,7 @@ func Query(scope *Scope) {
|
||||||
|
|
||||||
var (
|
var (
|
||||||
isSlice bool
|
isSlice bool
|
||||||
|
isPtr bool
|
||||||
anyRecordFound bool
|
anyRecordFound bool
|
||||||
destType reflect.Type
|
destType reflect.Type
|
||||||
)
|
)
|
||||||
|
@ -23,6 +24,10 @@ func Query(scope *Scope) {
|
||||||
if dest.Kind() == reflect.Slice {
|
if dest.Kind() == reflect.Slice {
|
||||||
isSlice = true
|
isSlice = true
|
||||||
destType = dest.Type().Elem()
|
destType = dest.Type().Elem()
|
||||||
|
if destType.Kind() == reflect.Ptr {
|
||||||
|
isPtr = true
|
||||||
|
destType = destType.Elem()
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
scope.Search = scope.Search.clone().limit(1)
|
scope.Search = scope.Search.clone().limit(1)
|
||||||
}
|
}
|
||||||
|
@ -58,9 +63,13 @@ func Query(scope *Scope) {
|
||||||
scope.Err(rows.Scan(values...))
|
scope.Err(rows.Scan(values...))
|
||||||
|
|
||||||
if isSlice {
|
if isSlice {
|
||||||
|
if isPtr {
|
||||||
|
dest.Set(reflect.Append(dest, elem.Addr()))
|
||||||
|
} else {
|
||||||
dest.Set(reflect.Append(dest, elem))
|
dest.Set(reflect.Append(dest, elem))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if !anyRecordFound {
|
if !anyRecordFound {
|
||||||
scope.Err(RecordNotFound)
|
scope.Err(RecordNotFound)
|
||||||
|
|
12
main_test.go
12
main_test.go
|
@ -269,6 +269,18 @@ func TestFirstAndLast(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFindSliceOfPointers(t *testing.T) {
|
||||||
|
var users []User
|
||||||
|
db.Find(&users)
|
||||||
|
|
||||||
|
var userPointers []*User
|
||||||
|
db.Find(&userPointers)
|
||||||
|
|
||||||
|
if len(users) == 0 || len(users) != len(userPointers) {
|
||||||
|
t.Errorf("Find slice of pointers")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestFirstAndLastWithJoins(t *testing.T) {
|
func TestFirstAndLastWithJoins(t *testing.T) {
|
||||||
var user1, user2, user3, user4 User
|
var user1, user2, user3, user4 User
|
||||||
db.Joins("left join emails on emails.user_id = users.id").First(&user1)
|
db.Joins("left join emails on emails.user_id = users.id").First(&user1)
|
||||||
|
|
11
scope.go
11
scope.go
|
@ -118,8 +118,13 @@ func (scope *Scope) FieldByName(name string) (interface{}, bool) {
|
||||||
return field.Interface(), true
|
return field.Interface(), true
|
||||||
}
|
}
|
||||||
} else if data.Kind() == reflect.Slice {
|
} else if data.Kind() == reflect.Slice {
|
||||||
|
elem := data.Type().Elem()
|
||||||
|
if elem.Kind() == reflect.Ptr {
|
||||||
|
return nil, reflect.New(data.Type().Elem().Elem()).Elem().FieldByName(name).IsValid()
|
||||||
|
} else {
|
||||||
return nil, reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid()
|
return nil, reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid()
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -190,7 +195,11 @@ func (scope *Scope) TableName() string {
|
||||||
data := reflect.Indirect(reflect.ValueOf(scope.Value))
|
data := reflect.Indirect(reflect.ValueOf(scope.Value))
|
||||||
|
|
||||||
if data.Kind() == reflect.Slice {
|
if data.Kind() == reflect.Slice {
|
||||||
data = reflect.New(data.Type().Elem()).Elem()
|
elem := data.Type().Elem()
|
||||||
|
if elem.Kind() == reflect.Ptr {
|
||||||
|
elem = elem.Elem()
|
||||||
|
}
|
||||||
|
data = reflect.New(elem).Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
if fm := data.MethodByName("TableName"); fm.IsValid() {
|
if fm := data.MethodByName("TableName"); fm.IsValid() {
|
||||||
|
|
Loading…
Reference in New Issue