mirror of https://github.com/go-redis/redis.git
Merge pull request #1631 from knadh/feat-decoder
Add redis.Scan() to scan results from redis maps into structs.
This commit is contained in:
commit
35f6ccd0bb
42
command.go
42
command.go
|
@ -8,6 +8,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8/internal"
|
||||
"github.com/go-redis/redis/v8/internal/hscan"
|
||||
"github.com/go-redis/redis/v8/internal/proto"
|
||||
"github.com/go-redis/redis/v8/internal/util"
|
||||
)
|
||||
|
@ -371,6 +372,26 @@ func (cmd *SliceCmd) String() string {
|
|||
return cmdString(cmd, cmd.val)
|
||||
}
|
||||
|
||||
// Scan scans the results from the map into a destination struct. The map keys
|
||||
// are matched in the Redis struct fields by the `redis:"field"` tag.
|
||||
func (cmd *SliceCmd) Scan(dst interface{}) error {
|
||||
if cmd.err != nil {
|
||||
return cmd.err
|
||||
}
|
||||
|
||||
// Pass the list of keys and values.
|
||||
// Skip the first two args for: HMGET key
|
||||
var args []interface{}
|
||||
if cmd.args[0] == "hmget" {
|
||||
args = cmd.args[2:]
|
||||
} else {
|
||||
// Otherwise, it's: MGET field field ...
|
||||
args = cmd.args[1:]
|
||||
}
|
||||
|
||||
return hscan.Scan(dst, args, cmd.val)
|
||||
}
|
||||
|
||||
func (cmd *SliceCmd) readReply(rd *proto.Reader) error {
|
||||
v, err := rd.ReadArrayReply(sliceParser)
|
||||
if err != nil {
|
||||
|
@ -917,6 +938,27 @@ func (cmd *StringStringMapCmd) String() string {
|
|||
return cmdString(cmd, cmd.val)
|
||||
}
|
||||
|
||||
// Scan scans the results from the map into a destination struct. The map keys
|
||||
// are matched in the Redis struct fields by the `redis:"field"` tag.
|
||||
func (cmd *StringStringMapCmd) Scan(dst interface{}) error {
|
||||
if cmd.err != nil {
|
||||
return cmd.err
|
||||
}
|
||||
|
||||
strct, err := hscan.Struct(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for k, v := range cmd.val {
|
||||
if err := strct.Scan(k, v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cmd *StringStringMapCmd) readReply(rd *proto.Reader) error {
|
||||
_, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
|
||||
cmd.val = make(map[string]string, n/2)
|
||||
|
|
|
@ -1134,6 +1134,22 @@ var _ = Describe("Commands", func() {
|
|||
Expect(mGet.Val()).To(Equal([]interface{}{"hello1", "hello2", nil}))
|
||||
})
|
||||
|
||||
It("should scan Mget", func() {
|
||||
err := client.MSet(ctx, "key1", "hello1", "key2", 123).Err()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
res := client.MGet(ctx, "key1", "key2", "_")
|
||||
Expect(res.Err()).NotTo(HaveOccurred())
|
||||
|
||||
type data struct {
|
||||
Key1 string `redis:"key1"`
|
||||
Key2 int `redis:"key2"`
|
||||
}
|
||||
var d data
|
||||
Expect(res.Scan(&d)).NotTo(HaveOccurred())
|
||||
Expect(d).To(Equal(data{Key1: "hello1", Key2: 123}))
|
||||
})
|
||||
|
||||
It("should MSetNX", func() {
|
||||
mSetNX := client.MSetNX(ctx, "key1", "hello1", "key2", "hello2")
|
||||
Expect(mSetNX.Err()).NotTo(HaveOccurred())
|
||||
|
@ -1375,6 +1391,22 @@ var _ = Describe("Commands", func() {
|
|||
Expect(m).To(Equal(map[string]string{"key1": "hello1", "key2": "hello2"}))
|
||||
})
|
||||
|
||||
It("should scan", func() {
|
||||
err := client.HMSet(ctx, "hash", "key1", "hello1", "key2", 123).Err()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
res := client.HGetAll(ctx, "hash")
|
||||
Expect(res.Err()).NotTo(HaveOccurred())
|
||||
|
||||
type data struct {
|
||||
Key1 string `redis:"key1"`
|
||||
Key2 int `redis:"key2"`
|
||||
}
|
||||
var d data
|
||||
Expect(res.Scan(&d)).NotTo(HaveOccurred())
|
||||
Expect(d).To(Equal(data{Key1: "hello1", Key2: 123}))
|
||||
})
|
||||
|
||||
It("should HIncrBy", func() {
|
||||
hSet := client.HSet(ctx, "hash", "key", "5")
|
||||
Expect(hSet.Err()).NotTo(HaveOccurred())
|
||||
|
|
|
@ -276,6 +276,73 @@ func ExampleClient_ScanType() {
|
|||
// Output: found 33 keys
|
||||
}
|
||||
|
||||
// ExampleStringStringMapCmd_Scan shows how to scan the results of a map fetch
|
||||
// into a struct.
|
||||
func ExampleStringStringMapCmd_Scan() {
|
||||
rdb.FlushDB(ctx)
|
||||
err := rdb.HMSet(ctx, "map",
|
||||
"name", "hello",
|
||||
"count", 123,
|
||||
"correct", true).Err()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Get the map. The same approach works for HmGet().
|
||||
res := rdb.HGetAll(ctx, "map")
|
||||
if res.Err() != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
type data struct {
|
||||
Name string `redis:"name"`
|
||||
Count int `redis:"count"`
|
||||
Correct bool `redis:"correct"`
|
||||
}
|
||||
|
||||
// Scan the results into the struct.
|
||||
var d data
|
||||
if err := res.Scan(&d); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Println(d)
|
||||
// Output: {hello 123 true}
|
||||
}
|
||||
|
||||
// ExampleSliceCmd_Scan shows how to scan the results of a multi key fetch
|
||||
// into a struct.
|
||||
func ExampleSliceCmd_Scan() {
|
||||
rdb.FlushDB(ctx)
|
||||
err := rdb.MSet(ctx,
|
||||
"name", "hello",
|
||||
"count", 123,
|
||||
"correct", true).Err()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
res := rdb.MGet(ctx, "name", "count", "empty", "correct")
|
||||
if res.Err() != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
type data struct {
|
||||
Name string `redis:"name"`
|
||||
Count int `redis:"count"`
|
||||
Correct bool `redis:"correct"`
|
||||
}
|
||||
|
||||
// Scan the results into the struct.
|
||||
var d data
|
||||
if err := res.Scan(&d); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Println(d)
|
||||
// Output: {hello 123 true}
|
||||
}
|
||||
|
||||
func ExampleClient_Pipelined() {
|
||||
var incr *redis.IntCmd
|
||||
_, err := rdb.Pipelined(ctx, func(pipe redis.Pipeliner) error {
|
||||
|
|
|
@ -0,0 +1,151 @@
|
|||
package hscan
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// decoderFunc represents decoding functions for default built-in types.
|
||||
type decoderFunc func(reflect.Value, string) error
|
||||
|
||||
var (
|
||||
// List of built-in decoders indexed by their numeric constant values (eg: reflect.Bool = 1).
|
||||
decoders = []decoderFunc{
|
||||
reflect.Bool: decodeBool,
|
||||
reflect.Int: decodeInt,
|
||||
reflect.Int8: decodeInt,
|
||||
reflect.Int16: decodeInt,
|
||||
reflect.Int32: decodeInt,
|
||||
reflect.Int64: decodeInt,
|
||||
reflect.Uint: decodeUint,
|
||||
reflect.Uint8: decodeUint,
|
||||
reflect.Uint16: decodeUint,
|
||||
reflect.Uint32: decodeUint,
|
||||
reflect.Uint64: decodeUint,
|
||||
reflect.Float32: decodeFloat,
|
||||
reflect.Float64: decodeFloat,
|
||||
reflect.Complex64: decodeUnsupported,
|
||||
reflect.Complex128: decodeUnsupported,
|
||||
reflect.Array: decodeUnsupported,
|
||||
reflect.Chan: decodeUnsupported,
|
||||
reflect.Func: decodeUnsupported,
|
||||
reflect.Interface: decodeUnsupported,
|
||||
reflect.Map: decodeUnsupported,
|
||||
reflect.Ptr: decodeUnsupported,
|
||||
reflect.Slice: decodeSlice,
|
||||
reflect.String: decodeString,
|
||||
reflect.Struct: decodeUnsupported,
|
||||
reflect.UnsafePointer: decodeUnsupported,
|
||||
}
|
||||
|
||||
// Global map of struct field specs that is populated once for every new
|
||||
// struct type that is scanned. This caches the field types and the corresponding
|
||||
// decoder functions to avoid iterating through struct fields on subsequent scans.
|
||||
globalStructMap = newStructMap()
|
||||
)
|
||||
|
||||
func Struct(dst interface{}) (StructValue, error) {
|
||||
v := reflect.ValueOf(dst)
|
||||
|
||||
// The dstination to scan into should be a struct pointer.
|
||||
if v.Kind() != reflect.Ptr || v.IsNil() {
|
||||
return StructValue{}, fmt.Errorf("redis.Scan(non-pointer %T)", dst)
|
||||
}
|
||||
|
||||
v = v.Elem()
|
||||
if v.Kind() != reflect.Struct {
|
||||
return StructValue{}, fmt.Errorf("redis.Scan(non-struct %T)", dst)
|
||||
}
|
||||
|
||||
return StructValue{
|
||||
spec: globalStructMap.get(v.Type()),
|
||||
value: v,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Scan scans the results from a key-value Redis map result set to a destination struct.
|
||||
// The Redis keys are matched to the struct's field with the `redis` tag.
|
||||
func Scan(dst interface{}, keys []interface{}, vals []interface{}) error {
|
||||
if len(keys) != len(vals) {
|
||||
return errors.New("args should have the same number of keys and vals")
|
||||
}
|
||||
|
||||
strct, err := Struct(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Iterate through the (key, value) sequence.
|
||||
for i := 0; i < len(vals); i++ {
|
||||
key, ok := keys[i].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
val, ok := vals[i].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := strct.Scan(key, val); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeBool(f reflect.Value, s string) error {
|
||||
b, err := strconv.ParseBool(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.SetBool(b)
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeInt(f reflect.Value, s string) error {
|
||||
v, err := strconv.ParseInt(s, 10, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.SetInt(v)
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeUint(f reflect.Value, s string) error {
|
||||
v, err := strconv.ParseUint(s, 10, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.SetUint(v)
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeFloat(f reflect.Value, s string) error {
|
||||
v, err := strconv.ParseFloat(s, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.SetFloat(v)
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeString(f reflect.Value, s string) error {
|
||||
f.SetString(s)
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeSlice(f reflect.Value, s string) error {
|
||||
// []byte slice ([]uint8).
|
||||
if f.Type().Elem().Kind() == reflect.Uint8 {
|
||||
f.SetBytes([]byte(s))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeUnsupported(v reflect.Value, s string) error {
|
||||
return fmt.Errorf("redis.Scan(unsupported %s)", v.Type())
|
||||
}
|
|
@ -0,0 +1,117 @@
|
|||
package hscan
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type data struct {
|
||||
Omit string `redis:"-"`
|
||||
Empty string
|
||||
|
||||
String string `redis:"string"`
|
||||
Bytes []byte `redis:"byte"`
|
||||
Int int `redis:"int"`
|
||||
Uint uint `redis:"uint"`
|
||||
Float float32 `redis:"float"`
|
||||
Bool bool `redis:"bool"`
|
||||
}
|
||||
|
||||
type i []interface{}
|
||||
|
||||
func TestGinkgoSuite(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "hscan")
|
||||
}
|
||||
|
||||
var _ = Describe("Scan", func() {
|
||||
It("catches bad args", func() {
|
||||
var d data
|
||||
|
||||
Expect(Scan(&d, i{}, i{})).NotTo(HaveOccurred())
|
||||
Expect(d).To(Equal(data{}))
|
||||
|
||||
Expect(Scan(&d, i{"key"}, i{})).To(HaveOccurred())
|
||||
Expect(Scan(&d, i{"key"}, i{"1", "2"})).To(HaveOccurred())
|
||||
Expect(Scan(nil, i{"key", "1"}, i{})).To(HaveOccurred())
|
||||
|
||||
var m map[string]interface{}
|
||||
Expect(Scan(&m, i{"key"}, i{"1"})).To(HaveOccurred())
|
||||
Expect(Scan(data{}, i{"key"}, i{"1"})).To(HaveOccurred())
|
||||
Expect(Scan(data{}, i{"key", "string"}, i{nil, nil})).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("scans good values", func() {
|
||||
var d data
|
||||
|
||||
// non-tagged fields.
|
||||
Expect(Scan(&d, i{"key"}, i{"value"})).NotTo(HaveOccurred())
|
||||
Expect(d).To(Equal(data{}))
|
||||
|
||||
keys := i{"string", "byte", "int", "uint", "float", "bool"}
|
||||
vals := i{"str!", "bytes!", "123", "456", "123.456", "1"}
|
||||
Expect(Scan(&d, keys, vals)).NotTo(HaveOccurred())
|
||||
Expect(d).To(Equal(data{
|
||||
String: "str!",
|
||||
Bytes: []byte("bytes!"),
|
||||
Int: 123,
|
||||
Uint: 456,
|
||||
Float: 123.456,
|
||||
Bool: true,
|
||||
}))
|
||||
|
||||
// Scan a different type with the same values to test that
|
||||
// the struct spec maps don't conflict.
|
||||
type data2 struct {
|
||||
String string `redis:"string"`
|
||||
Bytes []byte `redis:"byte"`
|
||||
Int int `redis:"int"`
|
||||
Uint uint `redis:"uint"`
|
||||
Float float32 `redis:"float"`
|
||||
Bool bool `redis:"bool"`
|
||||
}
|
||||
var d2 data2
|
||||
Expect(Scan(&d2, keys, vals)).NotTo(HaveOccurred())
|
||||
Expect(d2).To(Equal(data2{
|
||||
String: "str!",
|
||||
Bytes: []byte("bytes!"),
|
||||
Int: 123,
|
||||
Uint: 456,
|
||||
Float: 123.456,
|
||||
Bool: true,
|
||||
}))
|
||||
|
||||
Expect(Scan(&d, i{"string", "float", "bool"}, i{"", "1", "t"})).NotTo(HaveOccurred())
|
||||
Expect(d).To(Equal(data{
|
||||
String: "",
|
||||
Bytes: []byte("bytes!"),
|
||||
Int: 123,
|
||||
Uint: 456,
|
||||
Float: 1.0,
|
||||
Bool: true,
|
||||
}))
|
||||
})
|
||||
|
||||
It("omits untagged fields", func() {
|
||||
var d data
|
||||
|
||||
Expect(Scan(&d, i{"empty", "omit", "string"}, i{"value", "value", "str!"})).NotTo(HaveOccurred())
|
||||
Expect(d).To(Equal(data{
|
||||
String: "str!",
|
||||
}))
|
||||
})
|
||||
|
||||
It("catches bad values", func() {
|
||||
var d data
|
||||
|
||||
Expect(Scan(&d, i{"int"}, i{"a"})).To(HaveOccurred())
|
||||
Expect(Scan(&d, i{"uint"}, i{"a"})).To(HaveOccurred())
|
||||
Expect(Scan(&d, i{"uint"}, i{""})).To(HaveOccurred())
|
||||
Expect(Scan(&d, i{"float"}, i{"b"})).To(HaveOccurred())
|
||||
Expect(Scan(&d, i{"bool"}, i{"-1"})).To(HaveOccurred())
|
||||
Expect(Scan(&d, i{"bool"}, i{""})).To(HaveOccurred())
|
||||
Expect(Scan(&d, i{"bool"}, i{"123"})).To(HaveOccurred())
|
||||
})
|
||||
})
|
|
@ -0,0 +1,87 @@
|
|||
package hscan
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// structMap contains the map of struct fields for target structs
|
||||
// indexed by the struct type.
|
||||
type structMap struct {
|
||||
m sync.Map
|
||||
}
|
||||
|
||||
func newStructMap() *structMap {
|
||||
return new(structMap)
|
||||
}
|
||||
|
||||
func (s *structMap) get(t reflect.Type) *structSpec {
|
||||
if v, ok := s.m.Load(t); ok {
|
||||
return v.(*structSpec)
|
||||
}
|
||||
|
||||
spec := newStructSpec(t, "redis")
|
||||
s.m.Store(t, spec)
|
||||
return spec
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
// structSpec contains the list of all fields in a target struct.
|
||||
type structSpec struct {
|
||||
m map[string]*structField
|
||||
}
|
||||
|
||||
func (s *structSpec) set(tag string, sf *structField) {
|
||||
s.m[tag] = sf
|
||||
}
|
||||
|
||||
func newStructSpec(t reflect.Type, fieldTag string) *structSpec {
|
||||
out := &structSpec{
|
||||
m: make(map[string]*structField),
|
||||
}
|
||||
|
||||
num := t.NumField()
|
||||
for i := 0; i < num; i++ {
|
||||
f := t.Field(i)
|
||||
|
||||
tag := f.Tag.Get(fieldTag)
|
||||
if tag == "" || tag == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
tag = strings.Split(tag, ",")[0]
|
||||
if tag == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Use the built-in decoder.
|
||||
out.set(tag, &structField{index: i, fn: decoders[f.Type.Kind()]})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
// structField represents a single field in a target struct.
|
||||
type structField struct {
|
||||
index int
|
||||
fn decoderFunc
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
type StructValue struct {
|
||||
spec *structSpec
|
||||
value reflect.Value
|
||||
}
|
||||
|
||||
func (s StructValue) Scan(key string, value string) error {
|
||||
field, ok := s.spec.m[key]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return field.fn(s.value.Field(field.index), value)
|
||||
}
|
Loading…
Reference in New Issue