Support slice of pointers

This commit is contained in:
Jinzhu 2014-07-08 10:45:31 +08:00
parent a9ac3e10a7
commit 3bd5131132
3 changed files with 33 additions and 3 deletions

View File

@ -11,6 +11,7 @@ func Query(scope *Scope) {
var ( var (
isSlice bool isSlice bool
isPtr bool
anyRecordFound bool anyRecordFound bool
destType reflect.Type destType reflect.Type
) )
@ -23,6 +24,10 @@ func Query(scope *Scope) {
if dest.Kind() == reflect.Slice { if dest.Kind() == reflect.Slice {
isSlice = true isSlice = true
destType = dest.Type().Elem() destType = dest.Type().Elem()
if destType.Kind() == reflect.Ptr {
isPtr = true
destType = destType.Elem()
}
} else { } else {
scope.Search = scope.Search.clone().limit(1) scope.Search = scope.Search.clone().limit(1)
} }
@ -58,7 +63,11 @@ func Query(scope *Scope) {
scope.Err(rows.Scan(values...)) scope.Err(rows.Scan(values...))
if isSlice { if isSlice {
dest.Set(reflect.Append(dest, elem)) if isPtr {
dest.Set(reflect.Append(dest, elem.Addr()))
} else {
dest.Set(reflect.Append(dest, elem))
}
} }
} }

View File

@ -269,6 +269,18 @@ func TestFirstAndLast(t *testing.T) {
} }
} }
func TestFindSliceOfPointers(t *testing.T) {
var users []User
db.Find(&users)
var userPointers []*User
db.Find(&userPointers)
if len(users) == 0 || len(users) != len(userPointers) {
t.Errorf("Find slice of pointers")
}
}
func TestFirstAndLastWithJoins(t *testing.T) { func TestFirstAndLastWithJoins(t *testing.T) {
var user1, user2, user3, user4 User var user1, user2, user3, user4 User
db.Joins("left join emails on emails.user_id = users.id").First(&user1) db.Joins("left join emails on emails.user_id = users.id").First(&user1)

View File

@ -118,7 +118,12 @@ func (scope *Scope) FieldByName(name string) (interface{}, bool) {
return field.Interface(), true return field.Interface(), true
} }
} else if data.Kind() == reflect.Slice { } else if data.Kind() == reflect.Slice {
return nil, reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid() elem := data.Type().Elem()
if elem.Kind() == reflect.Ptr {
return nil, reflect.New(data.Type().Elem().Elem()).Elem().FieldByName(name).IsValid()
} else {
return nil, reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid()
}
} }
return nil, false return nil, false
} }
@ -190,7 +195,11 @@ func (scope *Scope) TableName() string {
data := reflect.Indirect(reflect.ValueOf(scope.Value)) data := reflect.Indirect(reflect.ValueOf(scope.Value))
if data.Kind() == reflect.Slice { if data.Kind() == reflect.Slice {
data = reflect.New(data.Type().Elem()).Elem() elem := data.Type().Elem()
if elem.Kind() == reflect.Ptr {
elem = elem.Elem()
}
data = reflect.New(elem).Elem()
} }
if fm := data.MethodByName("TableName"); fm.IsValid() { if fm := data.MethodByName("TableName"); fm.IsValid() {