Support slice of pointers

This commit is contained in:
Jinzhu 2014-07-08 10:45:31 +08:00
parent a9ac3e10a7
commit 3bd5131132
3 changed files with 33 additions and 3 deletions

View File

@ -11,6 +11,7 @@ func Query(scope *Scope) {
var (
isSlice bool
isPtr bool
anyRecordFound bool
destType reflect.Type
)
@ -23,6 +24,10 @@ func Query(scope *Scope) {
if dest.Kind() == reflect.Slice {
isSlice = true
destType = dest.Type().Elem()
if destType.Kind() == reflect.Ptr {
isPtr = true
destType = destType.Elem()
}
} else {
scope.Search = scope.Search.clone().limit(1)
}
@ -58,7 +63,11 @@ func Query(scope *Scope) {
scope.Err(rows.Scan(values...))
if isSlice {
dest.Set(reflect.Append(dest, elem))
if isPtr {
dest.Set(reflect.Append(dest, elem.Addr()))
} else {
dest.Set(reflect.Append(dest, elem))
}
}
}

View File

@ -269,6 +269,18 @@ func TestFirstAndLast(t *testing.T) {
}
}
func TestFindSliceOfPointers(t *testing.T) {
var users []User
db.Find(&users)
var userPointers []*User
db.Find(&userPointers)
if len(users) == 0 || len(users) != len(userPointers) {
t.Errorf("Find slice of pointers")
}
}
func TestFirstAndLastWithJoins(t *testing.T) {
var user1, user2, user3, user4 User
db.Joins("left join emails on emails.user_id = users.id").First(&user1)

View File

@ -118,7 +118,12 @@ func (scope *Scope) FieldByName(name string) (interface{}, bool) {
return field.Interface(), true
}
} else if data.Kind() == reflect.Slice {
return nil, reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid()
elem := data.Type().Elem()
if elem.Kind() == reflect.Ptr {
return nil, reflect.New(data.Type().Elem().Elem()).Elem().FieldByName(name).IsValid()
} else {
return nil, reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid()
}
}
return nil, false
}
@ -190,7 +195,11 @@ func (scope *Scope) TableName() string {
data := reflect.Indirect(reflect.ValueOf(scope.Value))
if data.Kind() == reflect.Slice {
data = reflect.New(data.Type().Elem()).Elem()
elem := data.Type().Elem()
if elem.Kind() == reflect.Ptr {
elem = elem.Elem()
}
data = reflect.New(elem).Elem()
}
if fm := data.MethodByName("TableName"); fm.IsValid() {