From 3bd51311322cac785aba691a05b841696020a302 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 8 Jul 2014 10:45:31 +0800 Subject: [PATCH] Support slice of pointers --- callback_query.go | 11 ++++++++++- main_test.go | 12 ++++++++++++ scope.go | 13 +++++++++++-- 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/callback_query.go b/callback_query.go index 61db20e5..b86f41f0 100644 --- a/callback_query.go +++ b/callback_query.go @@ -11,6 +11,7 @@ func Query(scope *Scope) { var ( isSlice bool + isPtr bool anyRecordFound bool destType reflect.Type ) @@ -23,6 +24,10 @@ func Query(scope *Scope) { if dest.Kind() == reflect.Slice { isSlice = true destType = dest.Type().Elem() + if destType.Kind() == reflect.Ptr { + isPtr = true + destType = destType.Elem() + } } else { scope.Search = scope.Search.clone().limit(1) } @@ -58,7 +63,11 @@ func Query(scope *Scope) { scope.Err(rows.Scan(values...)) if isSlice { - dest.Set(reflect.Append(dest, elem)) + if isPtr { + dest.Set(reflect.Append(dest, elem.Addr())) + } else { + dest.Set(reflect.Append(dest, elem)) + } } } diff --git a/main_test.go b/main_test.go index 6616b992..24e01181 100644 --- a/main_test.go +++ b/main_test.go @@ -269,6 +269,18 @@ func TestFirstAndLast(t *testing.T) { } } +func TestFindSliceOfPointers(t *testing.T) { + var users []User + db.Find(&users) + + var userPointers []*User + db.Find(&userPointers) + + if len(users) == 0 || len(users) != len(userPointers) { + t.Errorf("Find slice of pointers") + } +} + func TestFirstAndLastWithJoins(t *testing.T) { var user1, user2, user3, user4 User db.Joins("left join emails on emails.user_id = users.id").First(&user1) diff --git a/scope.go b/scope.go index 8571fef5..dd8c3acf 100644 --- a/scope.go +++ b/scope.go @@ -118,7 +118,12 @@ func (scope *Scope) FieldByName(name string) (interface{}, bool) { return field.Interface(), true } } else if data.Kind() == reflect.Slice { - return nil, reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid() + elem := data.Type().Elem() + if elem.Kind() == reflect.Ptr { + return nil, reflect.New(data.Type().Elem().Elem()).Elem().FieldByName(name).IsValid() + } else { + return nil, reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid() + } } return nil, false } @@ -190,7 +195,11 @@ func (scope *Scope) TableName() string { data := reflect.Indirect(reflect.ValueOf(scope.Value)) if data.Kind() == reflect.Slice { - data = reflect.New(data.Type().Elem()).Elem() + elem := data.Type().Elem() + if elem.Kind() == reflect.Ptr { + elem = elem.Elem() + } + data = reflect.New(elem).Elem() } if fm := data.MethodByName("TableName"); fm.IsValid() {