diff --git a/finisher_api.go b/finisher_api.go index 49b08fa4..334aea58 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -273,6 +273,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { // db.Find(&users).Pluck("age", &ages) func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { tx = db.getInstance() + tx.Statement.AddClause(clause.Select{Columns: []clause.Column{{Name: column}}}) tx.Statement.Dest = dest tx.callbacks.Query().Execute(tx) return diff --git a/scan.go b/scan.go index 66cb0b94..4d328fde 100644 --- a/scan.go +++ b/scan.go @@ -58,7 +58,12 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { default: switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - isPtr := db.Statement.ReflectValue.Type().Elem().Kind() == reflect.Ptr + reflectValueType := db.Statement.ReflectValue.Type().Elem() + isPtr := reflectValueType.Kind() == reflect.Ptr + if isPtr { + reflectValueType = reflectValueType.Elem() + } + db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0)) fields := make([]*schema.Field, len(columns)) joinFields := make([][2]*schema.Field, len(columns)) @@ -81,17 +86,22 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { for initialized || rows.Next() { initialized = false - elem := reflect.New(db.Statement.Schema.ModelType).Elem() - for idx, field := range fields { - if field != nil { - values[idx] = field.ReflectValueOf(elem).Addr().Interface() - } else if joinFields[idx][0] != nil { - relValue := joinFields[idx][0].ReflectValueOf(elem) - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - relValue.Set(reflect.New(relValue.Type().Elem())) - } + elem := reflect.New(reflectValueType).Elem() - values[idx] = joinFields[idx][1].ReflectValueOf(relValue).Addr().Interface() + if reflectValueType.Kind() != reflect.Struct && len(fields) == 1 { + values[0] = elem.Addr().Interface() + } else { + for idx, field := range fields { + if field != nil { + values[idx] = field.ReflectValueOf(elem).Addr().Interface() + } else if joinFields[idx][0] != nil { + relValue := joinFields[idx][0].ReflectValueOf(elem) + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + relValue.Set(reflect.New(relValue.Type().Elem())) + } + + values[idx] = joinFields[idx][1].ReflectValueOf(relValue).Addr().Interface() + } } } diff --git a/tests/query_test.go b/tests/query_test.go index 4388066f..b7c619d7 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -80,3 +80,35 @@ func TestFind(t *testing.T) { } } } + +func TestPluck(t *testing.T) { + users := []*User{ + GetUser("pluck-user1", Config{}), + GetUser("pluck-user2", Config{}), + GetUser("pluck-user3", Config{}), + } + + DB.Create(&users) + + var names []string + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Order("name").Pluck("name", &names).Error; err != nil { + t.Errorf("Raise error when pluck name, got %v", err) + } + + var ids []int + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Order("name").Pluck("id", &ids).Error; err != nil { + t.Errorf("Raise error when pluck id, got %v", err) + } + + for idx, name := range names { + if name != users[idx].Name { + t.Errorf("Unexpected result on pluck name, got %+v", names) + } + } + + for idx, id := range ids { + if int(id) != int(users[idx].ID) { + t.Errorf("Unexpected result on pluck id, got %+v", ids) + } + } +}