diff --git a/command.go b/command.go index 5dd55332..2932035e 100644 --- a/command.go +++ b/command.go @@ -8,6 +8,7 @@ import ( "time" "github.com/go-redis/redis/v8/internal" + "github.com/go-redis/redis/v8/internal/hscan" "github.com/go-redis/redis/v8/internal/proto" "github.com/go-redis/redis/v8/internal/util" ) @@ -371,6 +372,26 @@ func (cmd *SliceCmd) String() string { return cmdString(cmd, cmd.val) } +// Scan scans the results from the map into a destination struct. The map keys +// are matched in the Redis struct fields by the `redis:"field"` tag. +func (cmd *SliceCmd) Scan(dst interface{}) error { + if cmd.err != nil { + return cmd.err + } + + // Pass the list of keys and values. + // Skip the first two args for: HMGET key + var args []interface{} + if cmd.args[0] == "hmget" { + args = cmd.args[2:] + } else { + // Otherwise, it's: MGET field field ... + args = cmd.args[1:] + } + + return hscan.Scan(dst, args, cmd.val) +} + func (cmd *SliceCmd) readReply(rd *proto.Reader) error { v, err := rd.ReadArrayReply(sliceParser) if err != nil { @@ -917,6 +938,27 @@ func (cmd *StringStringMapCmd) String() string { return cmdString(cmd, cmd.val) } +// Scan scans the results from the map into a destination struct. The map keys +// are matched in the Redis struct fields by the `redis:"field"` tag. +func (cmd *StringStringMapCmd) Scan(dst interface{}) error { + if cmd.err != nil { + return cmd.err + } + + strct, err := hscan.Struct(dst) + if err != nil { + return err + } + + for k, v := range cmd.val { + if err := strct.Scan(k, v); err != nil { + return err + } + } + + return nil +} + func (cmd *StringStringMapCmd) readReply(rd *proto.Reader) error { _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { cmd.val = make(map[string]string, n/2) diff --git a/commands_test.go b/commands_test.go index 707aea1b..a73f4f8b 100644 --- a/commands_test.go +++ b/commands_test.go @@ -1134,6 +1134,22 @@ var _ = Describe("Commands", func() { Expect(mGet.Val()).To(Equal([]interface{}{"hello1", "hello2", nil})) }) + It("should scan Mget", func() { + err := client.MSet(ctx, "key1", "hello1", "key2", 123).Err() + Expect(err).NotTo(HaveOccurred()) + + res := client.MGet(ctx, "key1", "key2", "_") + Expect(res.Err()).NotTo(HaveOccurred()) + + type data struct { + Key1 string `redis:"key1"` + Key2 int `redis:"key2"` + } + var d data + Expect(res.Scan(&d)).NotTo(HaveOccurred()) + Expect(d).To(Equal(data{Key1: "hello1", Key2: 123})) + }) + It("should MSetNX", func() { mSetNX := client.MSetNX(ctx, "key1", "hello1", "key2", "hello2") Expect(mSetNX.Err()).NotTo(HaveOccurred()) @@ -1375,6 +1391,22 @@ var _ = Describe("Commands", func() { Expect(m).To(Equal(map[string]string{"key1": "hello1", "key2": "hello2"})) }) + It("should scan", func() { + err := client.HMSet(ctx, "hash", "key1", "hello1", "key2", 123).Err() + Expect(err).NotTo(HaveOccurred()) + + res := client.HGetAll(ctx, "hash") + Expect(res.Err()).NotTo(HaveOccurred()) + + type data struct { + Key1 string `redis:"key1"` + Key2 int `redis:"key2"` + } + var d data + Expect(res.Scan(&d)).NotTo(HaveOccurred()) + Expect(d).To(Equal(data{Key1: "hello1", Key2: 123})) + }) + It("should HIncrBy", func() { hSet := client.HSet(ctx, "hash", "key", "5") Expect(hSet.Err()).NotTo(HaveOccurred()) diff --git a/example_test.go b/example_test.go index 161eabf6..7d9f7405 100644 --- a/example_test.go +++ b/example_test.go @@ -276,6 +276,73 @@ func ExampleClient_ScanType() { // Output: found 33 keys } +// ExampleStringStringMapCmd_Scan shows how to scan the results of a map fetch +// into a struct. +func ExampleStringStringMapCmd_Scan() { + rdb.FlushDB(ctx) + err := rdb.HMSet(ctx, "map", + "name", "hello", + "count", 123, + "correct", true).Err() + if err != nil { + panic(err) + } + + // Get the map. The same approach works for HmGet(). + res := rdb.HGetAll(ctx, "map") + if res.Err() != nil { + panic(err) + } + + type data struct { + Name string `redis:"name"` + Count int `redis:"count"` + Correct bool `redis:"correct"` + } + + // Scan the results into the struct. + var d data + if err := res.Scan(&d); err != nil { + panic(err) + } + + fmt.Println(d) + // Output: {hello 123 true} +} + +// ExampleSliceCmd_Scan shows how to scan the results of a multi key fetch +// into a struct. +func ExampleSliceCmd_Scan() { + rdb.FlushDB(ctx) + err := rdb.MSet(ctx, + "name", "hello", + "count", 123, + "correct", true).Err() + if err != nil { + panic(err) + } + + res := rdb.MGet(ctx, "name", "count", "empty", "correct") + if res.Err() != nil { + panic(err) + } + + type data struct { + Name string `redis:"name"` + Count int `redis:"count"` + Correct bool `redis:"correct"` + } + + // Scan the results into the struct. + var d data + if err := res.Scan(&d); err != nil { + panic(err) + } + + fmt.Println(d) + // Output: {hello 123 true} +} + func ExampleClient_Pipelined() { var incr *redis.IntCmd _, err := rdb.Pipelined(ctx, func(pipe redis.Pipeliner) error { diff --git a/internal/hscan/hscan.go b/internal/hscan/hscan.go new file mode 100644 index 00000000..181260b8 --- /dev/null +++ b/internal/hscan/hscan.go @@ -0,0 +1,151 @@ +package hscan + +import ( + "errors" + "fmt" + "reflect" + "strconv" +) + +// decoderFunc represents decoding functions for default built-in types. +type decoderFunc func(reflect.Value, string) error + +var ( + // List of built-in decoders indexed by their numeric constant values (eg: reflect.Bool = 1). + decoders = []decoderFunc{ + reflect.Bool: decodeBool, + reflect.Int: decodeInt, + reflect.Int8: decodeInt, + reflect.Int16: decodeInt, + reflect.Int32: decodeInt, + reflect.Int64: decodeInt, + reflect.Uint: decodeUint, + reflect.Uint8: decodeUint, + reflect.Uint16: decodeUint, + reflect.Uint32: decodeUint, + reflect.Uint64: decodeUint, + reflect.Float32: decodeFloat, + reflect.Float64: decodeFloat, + reflect.Complex64: decodeUnsupported, + reflect.Complex128: decodeUnsupported, + reflect.Array: decodeUnsupported, + reflect.Chan: decodeUnsupported, + reflect.Func: decodeUnsupported, + reflect.Interface: decodeUnsupported, + reflect.Map: decodeUnsupported, + reflect.Ptr: decodeUnsupported, + reflect.Slice: decodeSlice, + reflect.String: decodeString, + reflect.Struct: decodeUnsupported, + reflect.UnsafePointer: decodeUnsupported, + } + + // Global map of struct field specs that is populated once for every new + // struct type that is scanned. This caches the field types and the corresponding + // decoder functions to avoid iterating through struct fields on subsequent scans. + globalStructMap = newStructMap() +) + +func Struct(dst interface{}) (StructValue, error) { + v := reflect.ValueOf(dst) + + // The dstination to scan into should be a struct pointer. + if v.Kind() != reflect.Ptr || v.IsNil() { + return StructValue{}, fmt.Errorf("redis.Scan(non-pointer %T)", dst) + } + + v = v.Elem() + if v.Kind() != reflect.Struct { + return StructValue{}, fmt.Errorf("redis.Scan(non-struct %T)", dst) + } + + return StructValue{ + spec: globalStructMap.get(v.Type()), + value: v, + }, nil +} + +// Scan scans the results from a key-value Redis map result set to a destination struct. +// The Redis keys are matched to the struct's field with the `redis` tag. +func Scan(dst interface{}, keys []interface{}, vals []interface{}) error { + if len(keys) != len(vals) { + return errors.New("args should have the same number of keys and vals") + } + + strct, err := Struct(dst) + if err != nil { + return err + } + + // Iterate through the (key, value) sequence. + for i := 0; i < len(vals); i++ { + key, ok := keys[i].(string) + if !ok { + continue + } + + val, ok := vals[i].(string) + if !ok { + continue + } + + if err := strct.Scan(key, val); err != nil { + return err + } + } + + return nil +} + +func decodeBool(f reflect.Value, s string) error { + b, err := strconv.ParseBool(s) + if err != nil { + return err + } + f.SetBool(b) + return nil +} + +func decodeInt(f reflect.Value, s string) error { + v, err := strconv.ParseInt(s, 10, 0) + if err != nil { + return err + } + f.SetInt(v) + return nil +} + +func decodeUint(f reflect.Value, s string) error { + v, err := strconv.ParseUint(s, 10, 0) + if err != nil { + return err + } + f.SetUint(v) + return nil +} + +func decodeFloat(f reflect.Value, s string) error { + v, err := strconv.ParseFloat(s, 0) + if err != nil { + return err + } + f.SetFloat(v) + return nil +} + +func decodeString(f reflect.Value, s string) error { + f.SetString(s) + return nil +} + +func decodeSlice(f reflect.Value, s string) error { + // []byte slice ([]uint8). + if f.Type().Elem().Kind() == reflect.Uint8 { + f.SetBytes([]byte(s)) + } + return nil +} + +func decodeUnsupported(v reflect.Value, s string) error { + return fmt.Errorf("redis.Scan(unsupported %s)", v.Type()) +} diff --git a/internal/hscan/hscan_test.go b/internal/hscan/hscan_test.go new file mode 100644 index 00000000..f7a88f0f --- /dev/null +++ b/internal/hscan/hscan_test.go @@ -0,0 +1,117 @@ +package hscan + +import ( + "testing" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +type data struct { + Omit string `redis:"-"` + Empty string + + String string `redis:"string"` + Bytes []byte `redis:"byte"` + Int int `redis:"int"` + Uint uint `redis:"uint"` + Float float32 `redis:"float"` + Bool bool `redis:"bool"` +} + +type i []interface{} + +func TestGinkgoSuite(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "hscan") +} + +var _ = Describe("Scan", func() { + It("catches bad args", func() { + var d data + + Expect(Scan(&d, i{}, i{})).NotTo(HaveOccurred()) + Expect(d).To(Equal(data{})) + + Expect(Scan(&d, i{"key"}, i{})).To(HaveOccurred()) + Expect(Scan(&d, i{"key"}, i{"1", "2"})).To(HaveOccurred()) + Expect(Scan(nil, i{"key", "1"}, i{})).To(HaveOccurred()) + + var m map[string]interface{} + Expect(Scan(&m, i{"key"}, i{"1"})).To(HaveOccurred()) + Expect(Scan(data{}, i{"key"}, i{"1"})).To(HaveOccurred()) + Expect(Scan(data{}, i{"key", "string"}, i{nil, nil})).To(HaveOccurred()) + }) + + It("scans good values", func() { + var d data + + // non-tagged fields. + Expect(Scan(&d, i{"key"}, i{"value"})).NotTo(HaveOccurred()) + Expect(d).To(Equal(data{})) + + keys := i{"string", "byte", "int", "uint", "float", "bool"} + vals := i{"str!", "bytes!", "123", "456", "123.456", "1"} + Expect(Scan(&d, keys, vals)).NotTo(HaveOccurred()) + Expect(d).To(Equal(data{ + String: "str!", + Bytes: []byte("bytes!"), + Int: 123, + Uint: 456, + Float: 123.456, + Bool: true, + })) + + // Scan a different type with the same values to test that + // the struct spec maps don't conflict. + type data2 struct { + String string `redis:"string"` + Bytes []byte `redis:"byte"` + Int int `redis:"int"` + Uint uint `redis:"uint"` + Float float32 `redis:"float"` + Bool bool `redis:"bool"` + } + var d2 data2 + Expect(Scan(&d2, keys, vals)).NotTo(HaveOccurred()) + Expect(d2).To(Equal(data2{ + String: "str!", + Bytes: []byte("bytes!"), + Int: 123, + Uint: 456, + Float: 123.456, + Bool: true, + })) + + Expect(Scan(&d, i{"string", "float", "bool"}, i{"", "1", "t"})).NotTo(HaveOccurred()) + Expect(d).To(Equal(data{ + String: "", + Bytes: []byte("bytes!"), + Int: 123, + Uint: 456, + Float: 1.0, + Bool: true, + })) + }) + + It("omits untagged fields", func() { + var d data + + Expect(Scan(&d, i{"empty", "omit", "string"}, i{"value", "value", "str!"})).NotTo(HaveOccurred()) + Expect(d).To(Equal(data{ + String: "str!", + })) + }) + + It("catches bad values", func() { + var d data + + Expect(Scan(&d, i{"int"}, i{"a"})).To(HaveOccurred()) + Expect(Scan(&d, i{"uint"}, i{"a"})).To(HaveOccurred()) + Expect(Scan(&d, i{"uint"}, i{""})).To(HaveOccurred()) + Expect(Scan(&d, i{"float"}, i{"b"})).To(HaveOccurred()) + Expect(Scan(&d, i{"bool"}, i{"-1"})).To(HaveOccurred()) + Expect(Scan(&d, i{"bool"}, i{""})).To(HaveOccurred()) + Expect(Scan(&d, i{"bool"}, i{"123"})).To(HaveOccurred()) + }) +}) diff --git a/internal/hscan/structmap.go b/internal/hscan/structmap.go new file mode 100644 index 00000000..37d86ba5 --- /dev/null +++ b/internal/hscan/structmap.go @@ -0,0 +1,87 @@ +package hscan + +import ( + "reflect" + "strings" + "sync" +) + +// structMap contains the map of struct fields for target structs +// indexed by the struct type. +type structMap struct { + m sync.Map +} + +func newStructMap() *structMap { + return new(structMap) +} + +func (s *structMap) get(t reflect.Type) *structSpec { + if v, ok := s.m.Load(t); ok { + return v.(*structSpec) + } + + spec := newStructSpec(t, "redis") + s.m.Store(t, spec) + return spec +} + +//------------------------------------------------------------------------------ + +// structSpec contains the list of all fields in a target struct. +type structSpec struct { + m map[string]*structField +} + +func (s *structSpec) set(tag string, sf *structField) { + s.m[tag] = sf +} + +func newStructSpec(t reflect.Type, fieldTag string) *structSpec { + out := &structSpec{ + m: make(map[string]*structField), + } + + num := t.NumField() + for i := 0; i < num; i++ { + f := t.Field(i) + + tag := f.Tag.Get(fieldTag) + if tag == "" || tag == "-" { + continue + } + + tag = strings.Split(tag, ",")[0] + if tag == "" { + continue + } + + // Use the built-in decoder. + out.set(tag, &structField{index: i, fn: decoders[f.Type.Kind()]}) + } + + return out +} + +//------------------------------------------------------------------------------ + +// structField represents a single field in a target struct. +type structField struct { + index int + fn decoderFunc +} + +//------------------------------------------------------------------------------ + +type StructValue struct { + spec *structSpec + value reflect.Value +} + +func (s StructValue) Scan(key string, value string) error { + field, ok := s.spec.m[key] + if !ok { + return nil + } + return field.fn(s.value.Field(field.index), value) +}