From 6f96bebac7ad7fc4f23d539f4ac1ef297a0cef19 Mon Sep 17 00:00:00 2001 From: Kailash Nadh Date: Thu, 28 Jan 2021 10:58:24 +0530 Subject: [PATCH 01/10] Add redis.Scan() to scan results from redis maps into structs. The package uses reflection to decode default types (int, string etc.) from Redis map results (key-value pair sequences) into struct fields where the fields are matched to Redis keys by tags. Similar to how `encoding/json` allows custom decoders using `UnmarshalJSON()`, the package supports decoding of arbitrary types into struct fields by defining a `Decode(string) error` function on types. The field/type spec of every struct that's passed to Scan() is cached in the package so that subsequent scans avoid iteration and reflection of the struct's fields. --- internal/hscan/hscan.go | 156 ++++++++++++++++++++++++++++++++++++ internal/hscan/structmap.go | 87 ++++++++++++++++++++ 2 files changed, 243 insertions(+) create mode 100644 internal/hscan/hscan.go create mode 100644 internal/hscan/structmap.go diff --git a/internal/hscan/hscan.go b/internal/hscan/hscan.go new file mode 100644 index 00000000..a2b3c92d --- /dev/null +++ b/internal/hscan/hscan.go @@ -0,0 +1,156 @@ +package hscan + +import ( + "errors" + "fmt" + "reflect" + "strconv" +) + +// decoderFunc represents decoding functions for default built-in types. +type decoderFunc func(reflect.Value, string) error + +var ( + // List of built-in decoders indexed by their numeric constant values (eg: reflect.Bool = 1) + decoders = []decoderFunc{ + reflect.Bool: decodeBool, + reflect.Int: decodeInt, + reflect.Int8: decodeInt, + reflect.Int16: decodeInt, + reflect.Int32: decodeInt, + reflect.Int64: decodeInt, + reflect.Uint: decodeUint, + reflect.Uint8: decodeUint, + reflect.Uint16: decodeUint, + reflect.Uint32: decodeUint, + reflect.Uint64: decodeUint, + reflect.Float32: decodeFloat, + reflect.Float64: decodeFloat, + reflect.Complex64: decodeUnsupported, + reflect.Complex128: decodeUnsupported, + reflect.Array: decodeUnsupported, + reflect.Chan: decodeUnsupported, + reflect.Func: decodeUnsupported, + reflect.Interface: decodeUnsupported, + reflect.Map: decodeUnsupported, + reflect.Ptr: decodeUnsupported, + reflect.Slice: decodeStringSlice, + reflect.String: decodeString, + reflect.Struct: decodeUnsupported, + reflect.UnsafePointer: decodeUnsupported, + } + + // Global map of struct field specs that is populated once for every new + // struct type that is scanned. This caches the field types and the corresponding + // decoder functions to avoid iterating through struct fields on subsequent scans. + structSpecs = newStructMap() +) + +// Scan scans the results from a key-value Redis map result set ([]interface{}) +// to a destination struct. The Redis keys are matched to the struct's field +// with the `redis` tag. +func Scan(vals []interface{}, dest interface{}) error { + if len(vals)%2 != 0 { + return errors.New("args should have an even number of items (key-val)") + } + + // The destination to scan into should be a struct pointer. + v := reflect.ValueOf(dest) + if v.Kind() != reflect.Ptr || v.IsNil() { + return fmt.Errorf("redis.Scan(non-pointer %T)", dest) + } + v = v.Elem() + + if v.Kind() != reflect.Struct { + return fmt.Errorf("redis.Scan(non-struct %T)", dest) + } + + // If the struct field spec is not cached, build and cache it to avoid + // iterating through the fields of a struct type every time values are + // scanned into it. + typ := v.Type() + fMap, ok := structSpecs.get(typ) + + if !ok { + fMap = makeStructSpecs(v, "redis") + structSpecs.set(typ, fMap) + } + + // Iterate through the (key, value) sequence. + for i := 0; i < len(vals); i += 2 { + key, ok := vals[i].(string) + if !ok { + continue + } + + val, ok := vals[i+1].(string) + if !ok { + continue + } + + // Check if the field name is in the field spec map. + field, ok := fMap.get(key) + if !ok { + continue + } + + if err := field.fn(v.Field(field.index), val); err != nil { + return err + } + } + + return nil +} + +func decodeBool(f reflect.Value, s string) error { + b, err := strconv.ParseBool(s) + if err != nil { + return err + } + f.SetBool(b) + return nil +} + +func decodeInt(f reflect.Value, s string) error { + v, err := strconv.ParseInt(s, 10, 0) + if err != nil { + return err + } + f.SetInt(v) + return nil +} + +func decodeUint(f reflect.Value, s string) error { + v, err := strconv.ParseUint(s, 10, 0) + if err != nil { + return err + } + f.SetUint(v) + return nil +} + +func decodeFloat(f reflect.Value, s string) error { + v, err := strconv.ParseFloat(s, 0) + if err != nil { + return err + } + f.SetFloat(v) + return nil +} + +func decodeString(f reflect.Value, s string) error { + f.SetString(s) + return nil +} + +func decodeStringSlice(f reflect.Value, s string) error { + // []byte slice ([]uint8). + if f.Type().Elem().Kind() == reflect.Uint8 { + f.SetBytes([]byte(s)) + } + return nil +} + +func decodeUnsupported(f reflect.Value, s string) error { + return fmt.Errorf("redis.Scan(unsupported type %v)", f) +} diff --git a/internal/hscan/structmap.go b/internal/hscan/structmap.go new file mode 100644 index 00000000..130b07bc --- /dev/null +++ b/internal/hscan/structmap.go @@ -0,0 +1,87 @@ +package hscan + +import ( + "reflect" + "strings" + "sync" +) + +// structField represents a single field in a target struct. +type structField struct { + index int + fn decoderFunc +} + +// structFields contains the list of all fields in a target struct. +type structFields struct { + m map[string]*structField +} + +// structMap contains the map of struct fields for target structs +// indexed by the struct type. +type structMap struct { + m sync.Map +} + +func newStructMap() *structMap { + return &structMap{ + m: sync.Map{}, + } +} + +func (s *structMap) get(t reflect.Type) (*structFields, bool) { + m, ok := s.m.Load(t) + if !ok { + return nil, ok + } + + return m.(*structFields), true +} + +func (s *structMap) set(t reflect.Type, sf *structFields) { + s.m.Store(t, sf) +} + +func newStructFields() *structFields { + return &structFields{ + m: make(map[string]*structField), + } +} + +func (s *structFields) set(tag string, sf *structField) { + s.m[tag] = sf +} + +func (s *structFields) get(tag string) (*structField, bool) { + f, ok := s.m[tag] + return f, ok +} + +func makeStructSpecs(ob reflect.Value, fieldTag string) *structFields { + var ( + num = ob.NumField() + out = newStructFields() + ) + + for i := 0; i < num; i++ { + f := ob.Field(i) + if !f.IsValid() || !f.CanSet() { + continue + } + + tag := ob.Type().Field(i).Tag.Get(fieldTag) + if tag == "" || tag == "-" { + continue + } + + tag = strings.Split(tag, ",")[0] + if tag == "" { + continue + } + + // Use the built-in decoder. + out.set(tag, &structField{index: i, fn: decoders[f.Kind()]}) + } + + return out +} From 8926f2992ac1cc9ec72ba196e18934bae0579df4 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Fri, 29 Jan 2021 12:04:39 +0200 Subject: [PATCH 02/10] Cleanup --- internal/hscan/hscan.go | 15 +++++---------- internal/hscan/structmap.go | 21 ++++++++------------- 2 files changed, 13 insertions(+), 23 deletions(-) diff --git a/internal/hscan/hscan.go b/internal/hscan/hscan.go index a2b3c92d..1c46f2cd 100644 --- a/internal/hscan/hscan.go +++ b/internal/hscan/hscan.go @@ -34,7 +34,7 @@ var ( reflect.Interface: decodeUnsupported, reflect.Map: decodeUnsupported, reflect.Ptr: decodeUnsupported, - reflect.Slice: decodeStringSlice, + reflect.Slice: decodeSlice, reflect.String: decodeString, reflect.Struct: decodeUnsupported, reflect.UnsafePointer: decodeUnsupported, @@ -69,12 +69,7 @@ func Scan(vals []interface{}, dest interface{}) error { // iterating through the fields of a struct type every time values are // scanned into it. typ := v.Type() - fMap, ok := structSpecs.get(typ) - - if !ok { - fMap = makeStructSpecs(v, "redis") - structSpecs.set(typ, fMap) - } + fMap := structSpecs.get(typ) // Iterate through the (key, value) sequence. for i := 0; i < len(vals); i += 2 { @@ -143,7 +138,7 @@ func decodeString(f reflect.Value, s string) error { return nil } -func decodeStringSlice(f reflect.Value, s string) error { +func decodeSlice(f reflect.Value, s string) error { // []byte slice ([]uint8). if f.Type().Elem().Kind() == reflect.Uint8 { f.SetBytes([]byte(s)) @@ -151,6 +146,6 @@ func decodeStringSlice(f reflect.Value, s string) error { return nil } -func decodeUnsupported(f reflect.Value, s string) error { - return fmt.Errorf("redis.Scan(unsupported type %v)", f) +func decodeUnsupported(v reflect.Value, s string) error { + return fmt.Errorf("redis.Scan(unsupported %s)", v.Type()) } diff --git a/internal/hscan/structmap.go b/internal/hscan/structmap.go index 130b07bc..d66f76f0 100644 --- a/internal/hscan/structmap.go +++ b/internal/hscan/structmap.go @@ -24,22 +24,17 @@ type structMap struct { } func newStructMap() *structMap { - return &structMap{ - m: sync.Map{}, - } + return new(structMap) } -func (s *structMap) get(t reflect.Type) (*structFields, bool) { - m, ok := s.m.Load(t) - if !ok { - return nil, ok +func (s *structMap) get(t reflect.Type) *structFields { + if v, ok := s.m.Load(t); ok { + return m.(*structFields), ok } - return m.(*structFields), true -} - -func (s *structMap) set(t reflect.Type, sf *structFields) { - s.m.Store(t, sf) + fMap := getStructFields(v, "redis") + s.m.Store(t, fMap) + return fmap, true } func newStructFields() *structFields { @@ -57,7 +52,7 @@ func (s *structFields) get(tag string) (*structField, bool) { return f, ok } -func makeStructSpecs(ob reflect.Value, fieldTag string) *structFields { +func getStructFields(ob reflect.Value, fieldTag string) *structFields { var ( num = ob.NumField() out = newStructFields() From 380ab17274cf3f82dffefa02b8c37cbe59e9aa0d Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Fri, 29 Jan 2021 12:12:47 +0200 Subject: [PATCH 03/10] Fix cleanup --- internal/hscan/structmap.go | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/internal/hscan/structmap.go b/internal/hscan/structmap.go index d66f76f0..06d5d2fe 100644 --- a/internal/hscan/structmap.go +++ b/internal/hscan/structmap.go @@ -29,12 +29,12 @@ func newStructMap() *structMap { func (s *structMap) get(t reflect.Type) *structFields { if v, ok := s.m.Load(t); ok { - return m.(*structFields), ok + return v.(*structFields) } - fMap := getStructFields(v, "redis") + fMap := getStructFields(t, "redis") s.m.Store(t, fMap) - return fmap, true + return fMap } func newStructFields() *structFields { @@ -52,19 +52,16 @@ func (s *structFields) get(tag string) (*structField, bool) { return f, ok } -func getStructFields(ob reflect.Value, fieldTag string) *structFields { +func getStructFields(t reflect.Type, fieldTag string) *structFields { var ( - num = ob.NumField() + num = t.NumField() out = newStructFields() ) for i := 0; i < num; i++ { - f := ob.Field(i) - if !f.IsValid() || !f.CanSet() { - continue - } + f := t.Field(i) - tag := ob.Type().Field(i).Tag.Get(fieldTag) + tag := t.Field(i).Tag.Get(fieldTag) if tag == "" || tag == "-" { continue } @@ -75,7 +72,7 @@ func getStructFields(ob reflect.Value, fieldTag string) *structFields { } // Use the built-in decoder. - out.set(tag, &structField{index: i, fn: decoders[f.Kind()]}) + out.set(tag, &structField{index: i, fn: decoders[f.Type.Kind()]}) } return out From a4144ea98eb00abe3eb5f78621d93c1c8c229493 Mon Sep 17 00:00:00 2001 From: Kailash Nadh Date: Tue, 2 Feb 2021 13:04:52 +0530 Subject: [PATCH 04/10] Add SliceCmd.Scan() (hscan pkg) and tests --- command.go | 8 +++ internal/hscan/hscan.go | 2 +- internal/hscan/hscan_test.go | 125 +++++++++++++++++++++++++++++++++++ internal/hscan/structmap.go | 2 +- 4 files changed, 135 insertions(+), 2 deletions(-) create mode 100644 internal/hscan/hscan_test.go diff --git a/command.go b/command.go index 5dd55332..602ff61b 100644 --- a/command.go +++ b/command.go @@ -8,6 +8,7 @@ import ( "time" "github.com/go-redis/redis/v8/internal" + "github.com/go-redis/redis/v8/internal/hscan" "github.com/go-redis/redis/v8/internal/proto" "github.com/go-redis/redis/v8/internal/util" ) @@ -371,6 +372,13 @@ func (cmd *SliceCmd) String() string { return cmdString(cmd, cmd.val) } +// Scan scans the results from a key-value Redis map result set ([]interface{}) +// like HMGET and HGETALL to a destination struct. +// The Redis keys are matched to the struct's field with the `redis` tag. +func (cmd *SliceCmd) Scan(val interface{}) error { + return hscan.Scan(cmd.val, val) +} + func (cmd *SliceCmd) readReply(rd *proto.Reader) error { v, err := rd.ReadArrayReply(sliceParser) if err != nil { diff --git a/internal/hscan/hscan.go b/internal/hscan/hscan.go index 1c46f2cd..51346b83 100644 --- a/internal/hscan/hscan.go +++ b/internal/hscan/hscan.go @@ -11,7 +11,7 @@ import ( type decoderFunc func(reflect.Value, string) error var ( - // List of built-in decoders indexed by their numeric constant values (eg: reflect.Bool = 1) + // List of built-in decoders indexed by their numeric constant values (eg: reflect.Bool = 1). decoders = []decoderFunc{ reflect.Bool: decodeBool, reflect.Int: decodeInt, diff --git a/internal/hscan/hscan_test.go b/internal/hscan/hscan_test.go new file mode 100644 index 00000000..7a4d0f5b --- /dev/null +++ b/internal/hscan/hscan_test.go @@ -0,0 +1,125 @@ +package hscan + +import ( + "testing" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +type data struct { + Omit string `redis:"-"` + Empty string + + String string `redis:"string"` + Bytes []byte `redis:"byte"` + Int int `redis:"int"` + Uint uint `redis:"uint"` + Float float32 `redis:"float"` + Bool bool `redis:"bool"` +} + +func TestGinkgoSuite(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "hscan") +} + +var _ = Describe("Scan", func() { + It("catches bad args", func() { + var d data + + Expect(Scan([]interface{}{}, &d)).NotTo(HaveOccurred()) + Expect(d).To(Equal(data{})) + + Expect(Scan([]interface{}{"key"}, &d)).To(HaveOccurred()) + Expect(Scan([]interface{}{"key", "1", "2"}, &d)).To(HaveOccurred()) + Expect(Scan([]interface{}{"key", "1"}, nil)).To(HaveOccurred()) + + var i map[string]interface{} + Expect(Scan([]interface{}{"key", "1"}, &i)).To(HaveOccurred()) + Expect(Scan([]interface{}{"key", "1"}, data{})).To(HaveOccurred()) + Expect(Scan([]interface{}{"key", nil, "string", nil}, data{})).To(HaveOccurred()) + }) + + It("scans good values", func() { + var d data + + // non-tagged fields. + Expect(Scan([]interface{}{"key", "value"}, &d)).NotTo(HaveOccurred()) + Expect(d).To(Equal(data{})) + + res := []interface{}{"string", "str!", + "byte", "bytes!", + "int", "123", + "uint", "456", + "float", "123.456", + "bool", "1"} + Expect(Scan(res, &d)).NotTo(HaveOccurred()) + Expect(d).To(Equal(data{ + String: "str!", + Bytes: []byte("bytes!"), + Int: 123, + Uint: 456, + Float: 123.456, + Bool: true, + })) + + // Scan a different type with the same values to test that + // the struct spec maps don't conflict. + type data2 struct { + String string `redis:"string"` + Bytes []byte `redis:"byte"` + Int int `redis:"int"` + Uint uint `redis:"uint"` + Float float32 `redis:"float"` + Bool bool `redis:"bool"` + } + var d2 data2 + Expect(Scan(res, &d2)).NotTo(HaveOccurred()) + Expect(d2).To(Equal(data2{ + String: "str!", + Bytes: []byte("bytes!"), + Int: 123, + Uint: 456, + Float: 123.456, + Bool: true, + })) + + Expect(Scan([]interface{}{ + "string", "", + "float", "1", + "bool", "t"}, &d)).NotTo(HaveOccurred()) + Expect(d).To(Equal(data{ + String: "", + Bytes: []byte("bytes!"), + Int: 123, + Uint: 456, + Float: 1.0, + Bool: true, + })) + }) + + It("omits untagged fields", func() { + var d data + + Expect(Scan([]interface{}{ + "empty", "value", + "omit", "value", + "string", "str!"}, &d)).NotTo(HaveOccurred()) + Expect(d).To(Equal(data{ + String: "str!", + })) + }) + + It("catches bad values", func() { + var d data + + Expect(Scan([]interface{}{"int", "a"}, &d)).To(HaveOccurred()) + Expect(Scan([]interface{}{"uint", "a"}, &d)).To(HaveOccurred()) + Expect(Scan([]interface{}{"uint", ""}, &d)).To(HaveOccurred()) + Expect(Scan([]interface{}{"float", "b"}, &d)).To(HaveOccurred()) + Expect(Scan([]interface{}{"bool", "-1"}, &d)).To(HaveOccurred()) + Expect(Scan([]interface{}{"bool", ""}, &d)).To(HaveOccurred()) + Expect(Scan([]interface{}{"bool", "123"}, &d)).To(HaveOccurred()) + }) +}) diff --git a/internal/hscan/structmap.go b/internal/hscan/structmap.go index 06d5d2fe..da97f907 100644 --- a/internal/hscan/structmap.go +++ b/internal/hscan/structmap.go @@ -61,7 +61,7 @@ func getStructFields(t reflect.Type, fieldTag string) *structFields { for i := 0; i < num; i++ { f := t.Field(i) - tag := t.Field(i).Tag.Get(fieldTag) + tag := f.Tag.Get(fieldTag) if tag == "" || tag == "-" { continue } From f9dfc7a949741988b7b59ea2a21dd01e11b9e404 Mon Sep 17 00:00:00 2001 From: Kailash Nadh Date: Tue, 2 Feb 2021 16:28:10 +0530 Subject: [PATCH 05/10] Refactor scan signature to work with Slice and StringMap cmds --- command.go | 28 ++++++++++++++--- commands_test.go | 16 ++++++++++ internal/hscan/hscan.go | 17 +++++----- internal/hscan/hscan_test.go | 60 ++++++++++++++++-------------------- 4 files changed, 74 insertions(+), 47 deletions(-) diff --git a/command.go b/command.go index 602ff61b..642352c9 100644 --- a/command.go +++ b/command.go @@ -372,11 +372,12 @@ func (cmd *SliceCmd) String() string { return cmdString(cmd, cmd.val) } -// Scan scans the results from a key-value Redis map result set ([]interface{}) -// like HMGET and HGETALL to a destination struct. -// The Redis keys are matched to the struct's field with the `redis` tag. -func (cmd *SliceCmd) Scan(val interface{}) error { - return hscan.Scan(cmd.val, val) +// Scan scans the results from the map into a destination struct. The map keys +// are matched in the Redis struct fields by the `redis:"field"` tag. +func (cmd *SliceCmd) Scan(dest interface{}) error { + // Pass the list of keys and values. Skip the first to args (command, key), + // eg: HMGET map. + return hscan.Scan(cmd.args[2:], cmd.val, dest) } func (cmd *SliceCmd) readReply(rd *proto.Reader) error { @@ -925,6 +926,23 @@ func (cmd *StringStringMapCmd) String() string { return cmdString(cmd, cmd.val) } +// Scan scans the results from the map into a destination struct. The map keys +// are matched in the Redis struct fields by the `redis:"field"` tag. +func (cmd *StringStringMapCmd) Scan(dest interface{}) error { + // Pass the list of keys and values. Skip the first to args (command, key), + // eg: HGETALL map. + var ( + keys = make([]interface{}, 0, len(cmd.val)) + vals = make([]interface{}, 0, len(cmd.val)) + ) + for k, v := range cmd.val { + keys = append(keys, k) + vals = append(vals, v) + } + + return hscan.Scan(keys, vals, dest) +} + func (cmd *StringStringMapCmd) readReply(rd *proto.Reader) error { _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { cmd.val = make(map[string]string, n/2) diff --git a/commands_test.go b/commands_test.go index 707aea1b..19af05f0 100644 --- a/commands_test.go +++ b/commands_test.go @@ -1375,6 +1375,22 @@ var _ = Describe("Commands", func() { Expect(m).To(Equal(map[string]string{"key1": "hello1", "key2": "hello2"})) }) + It("should scan", func() { + err := client.HMSet(ctx, "hash", "key1", "hello1", "key2", 123).Err() + Expect(err).NotTo(HaveOccurred()) + + res := client.HGetAll(ctx, "hash") + Expect(res.Err()).NotTo(HaveOccurred()) + + type data struct { + Key1 string `redis:"key1"` + Key2 int `redis:"key2"` + } + var d data + Expect(res.Scan(&d)).NotTo(HaveOccurred()) + Expect(d).To(Equal(data{Key1: "hello1", Key2: 123})) + }) + It("should HIncrBy", func() { hSet := client.HSet(ctx, "hash", "key", "5") Expect(hSet.Err()).NotTo(HaveOccurred()) diff --git a/internal/hscan/hscan.go b/internal/hscan/hscan.go index 51346b83..a4b9e80e 100644 --- a/internal/hscan/hscan.go +++ b/internal/hscan/hscan.go @@ -46,12 +46,11 @@ var ( structSpecs = newStructMap() ) -// Scan scans the results from a key-value Redis map result set ([]interface{}) -// to a destination struct. The Redis keys are matched to the struct's field -// with the `redis` tag. -func Scan(vals []interface{}, dest interface{}) error { - if len(vals)%2 != 0 { - return errors.New("args should have an even number of items (key-val)") +// Scan scans the results from a key-value Redis map result set to a destination struct. +// The Redis keys are matched to the struct's field with the `redis` tag. +func Scan(keys []interface{}, vals []interface{}, dest interface{}) error { + if len(keys) != len(vals) { + return errors.New("args should have the same number of keys and vals") } // The destination to scan into should be a struct pointer. @@ -72,13 +71,13 @@ func Scan(vals []interface{}, dest interface{}) error { fMap := structSpecs.get(typ) // Iterate through the (key, value) sequence. - for i := 0; i < len(vals); i += 2 { - key, ok := vals[i].(string) + for i := 0; i < len(vals); i++ { + key, ok := keys[i].(string) if !ok { continue } - val, ok := vals[i+1].(string) + val, ok := vals[i].(string) if !ok { continue } diff --git a/internal/hscan/hscan_test.go b/internal/hscan/hscan_test.go index 7a4d0f5b..e14b4e91 100644 --- a/internal/hscan/hscan_test.go +++ b/internal/hscan/hscan_test.go @@ -19,6 +19,8 @@ type data struct { Bool bool `redis:"bool"` } +type i []interface{} + func TestGinkgoSuite(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "hscan") @@ -26,35 +28,33 @@ func TestGinkgoSuite(t *testing.T) { var _ = Describe("Scan", func() { It("catches bad args", func() { - var d data + var ( + d data + ) - Expect(Scan([]interface{}{}, &d)).NotTo(HaveOccurred()) + Expect(Scan(i{}, i{}, &d)).NotTo(HaveOccurred()) Expect(d).To(Equal(data{})) - Expect(Scan([]interface{}{"key"}, &d)).To(HaveOccurred()) - Expect(Scan([]interface{}{"key", "1", "2"}, &d)).To(HaveOccurred()) - Expect(Scan([]interface{}{"key", "1"}, nil)).To(HaveOccurred()) + Expect(Scan(i{"key"}, i{}, &d)).To(HaveOccurred()) + Expect(Scan(i{"key"}, i{"1", "2"}, &d)).To(HaveOccurred()) + Expect(Scan(i{"key", "1"}, i{}, nil)).To(HaveOccurred()) - var i map[string]interface{} - Expect(Scan([]interface{}{"key", "1"}, &i)).To(HaveOccurred()) - Expect(Scan([]interface{}{"key", "1"}, data{})).To(HaveOccurred()) - Expect(Scan([]interface{}{"key", nil, "string", nil}, data{})).To(HaveOccurred()) + var m map[string]interface{} + Expect(Scan(i{"key"}, i{"1"}, &m)).To(HaveOccurred()) + Expect(Scan(i{"key"}, i{"1"}, data{})).To(HaveOccurred()) + Expect(Scan(i{"key", "string"}, i{nil, nil}, data{})).To(HaveOccurred()) }) It("scans good values", func() { var d data // non-tagged fields. - Expect(Scan([]interface{}{"key", "value"}, &d)).NotTo(HaveOccurred()) + Expect(Scan(i{"key"}, i{"value"}, &d)).NotTo(HaveOccurred()) Expect(d).To(Equal(data{})) - res := []interface{}{"string", "str!", - "byte", "bytes!", - "int", "123", - "uint", "456", - "float", "123.456", - "bool", "1"} - Expect(Scan(res, &d)).NotTo(HaveOccurred()) + keys := i{"string", "byte", "int", "uint", "float", "bool"} + vals := i{"str!", "bytes!", "123", "456", "123.456", "1"} + Expect(Scan(keys, vals, &d)).NotTo(HaveOccurred()) Expect(d).To(Equal(data{ String: "str!", Bytes: []byte("bytes!"), @@ -75,7 +75,7 @@ var _ = Describe("Scan", func() { Bool bool `redis:"bool"` } var d2 data2 - Expect(Scan(res, &d2)).NotTo(HaveOccurred()) + Expect(Scan(keys, vals, &d2)).NotTo(HaveOccurred()) Expect(d2).To(Equal(data2{ String: "str!", Bytes: []byte("bytes!"), @@ -85,10 +85,7 @@ var _ = Describe("Scan", func() { Bool: true, })) - Expect(Scan([]interface{}{ - "string", "", - "float", "1", - "bool", "t"}, &d)).NotTo(HaveOccurred()) + Expect(Scan(i{"string", "float", "bool"}, i{"", "1", "t"}, &d)).NotTo(HaveOccurred()) Expect(d).To(Equal(data{ String: "", Bytes: []byte("bytes!"), @@ -102,10 +99,7 @@ var _ = Describe("Scan", func() { It("omits untagged fields", func() { var d data - Expect(Scan([]interface{}{ - "empty", "value", - "omit", "value", - "string", "str!"}, &d)).NotTo(HaveOccurred()) + Expect(Scan(i{"empty", "omit", "string"}, i{"value", "value", "str!"}, &d)).NotTo(HaveOccurred()) Expect(d).To(Equal(data{ String: "str!", })) @@ -114,12 +108,12 @@ var _ = Describe("Scan", func() { It("catches bad values", func() { var d data - Expect(Scan([]interface{}{"int", "a"}, &d)).To(HaveOccurred()) - Expect(Scan([]interface{}{"uint", "a"}, &d)).To(HaveOccurred()) - Expect(Scan([]interface{}{"uint", ""}, &d)).To(HaveOccurred()) - Expect(Scan([]interface{}{"float", "b"}, &d)).To(HaveOccurred()) - Expect(Scan([]interface{}{"bool", "-1"}, &d)).To(HaveOccurred()) - Expect(Scan([]interface{}{"bool", ""}, &d)).To(HaveOccurred()) - Expect(Scan([]interface{}{"bool", "123"}, &d)).To(HaveOccurred()) + Expect(Scan(i{"int"}, i{"a"}, &d)).To(HaveOccurred()) + Expect(Scan(i{"uint"}, i{"a"}, &d)).To(HaveOccurred()) + Expect(Scan(i{"uint"}, i{""}, &d)).To(HaveOccurred()) + Expect(Scan(i{"float"}, i{"b"}, &d)).To(HaveOccurred()) + Expect(Scan(i{"bool"}, i{"-1"}, &d)).To(HaveOccurred()) + Expect(Scan(i{"bool"}, i{""}, &d)).To(HaveOccurred()) + Expect(Scan(i{"bool"}, i{"123"}, &d)).To(HaveOccurred()) }) }) From e113512c184cfaebc6560782706bd02d93884cf7 Mon Sep 17 00:00:00 2001 From: Kailash Nadh Date: Tue, 2 Feb 2021 16:28:35 +0530 Subject: [PATCH 06/10] Add example showing scanning to struct --- example_test.go | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/example_test.go b/example_test.go index 161eabf6..07eb5108 100644 --- a/example_test.go +++ b/example_test.go @@ -276,6 +276,40 @@ func ExampleClient_ScanType() { // Output: found 33 keys } +// ExampleStringStringMapCmd_Scan shows how to scan the results of a map fetch +// into a struct. +func ExampleStringStringMapCmd_Scan() { + rdb.FlushDB(ctx) + err := rdb.HMSet(ctx, "map", + "name", "hello", + "count", 123, + "correct", true).Err() + if err != nil { + panic(err) + } + + // Get the map. The same approach works for HmGet(). + res := rdb.HGetAll(ctx, "map") + if res.Err() != nil { + panic(err) + } + + type data struct { + Name string `redis:"name"` + Count int `redis:"count"` + Correct bool `redis:"correct"` + } + + // Scan the results into the struct. + var d data + if err := res.Scan(&d); err != nil { + panic(err) + } + + fmt.Println(d) + // Output: {hello 123 true} +} + func ExampleClient_Pipelined() { var incr *redis.IntCmd _, err := rdb.Pipelined(ctx, func(pipe redis.Pipeliner) error { From 600f1665a0b58e69a81a1b4e29e2bd76b574e946 Mon Sep 17 00:00:00 2001 From: Kailash Nadh Date: Wed, 3 Feb 2021 13:37:27 +0530 Subject: [PATCH 07/10] Add missing error checks and support for MGET in Scan() --- command.go | 22 +++++++++++++++++++--- example_test.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 3 deletions(-) diff --git a/command.go b/command.go index 642352c9..73194ecc 100644 --- a/command.go +++ b/command.go @@ -375,9 +375,21 @@ func (cmd *SliceCmd) String() string { // Scan scans the results from the map into a destination struct. The map keys // are matched in the Redis struct fields by the `redis:"field"` tag. func (cmd *SliceCmd) Scan(dest interface{}) error { - // Pass the list of keys and values. Skip the first to args (command, key), - // eg: HMGET map. - return hscan.Scan(cmd.args[2:], cmd.val, dest) + if cmd.err != nil { + return cmd.err + } + + // Pass the list of keys and values. + // Skip the first two args for: HMGET key + var args []interface{} + if cmd.args[0] == "hmget" { + args = cmd.args[2:] + } else { + // Otherwise, it's: MGET field field ... + args = cmd.args[1:] + } + + return hscan.Scan(args, cmd.val, dest) } func (cmd *SliceCmd) readReply(rd *proto.Reader) error { @@ -929,6 +941,10 @@ func (cmd *StringStringMapCmd) String() string { // Scan scans the results from the map into a destination struct. The map keys // are matched in the Redis struct fields by the `redis:"field"` tag. func (cmd *StringStringMapCmd) Scan(dest interface{}) error { + if cmd.err != nil { + return cmd.err + } + // Pass the list of keys and values. Skip the first to args (command, key), // eg: HGETALL map. var ( diff --git a/example_test.go b/example_test.go index 07eb5108..7d9f7405 100644 --- a/example_test.go +++ b/example_test.go @@ -310,6 +310,39 @@ func ExampleStringStringMapCmd_Scan() { // Output: {hello 123 true} } +// ExampleSliceCmd_Scan shows how to scan the results of a multi key fetch +// into a struct. +func ExampleSliceCmd_Scan() { + rdb.FlushDB(ctx) + err := rdb.MSet(ctx, + "name", "hello", + "count", 123, + "correct", true).Err() + if err != nil { + panic(err) + } + + res := rdb.MGet(ctx, "name", "count", "empty", "correct") + if res.Err() != nil { + panic(err) + } + + type data struct { + Name string `redis:"name"` + Count int `redis:"count"` + Correct bool `redis:"correct"` + } + + // Scan the results into the struct. + var d data + if err := res.Scan(&d); err != nil { + panic(err) + } + + fmt.Println(d) + // Output: {hello 123 true} +} + func ExampleClient_Pipelined() { var incr *redis.IntCmd _, err := rdb.Pipelined(ctx, func(pipe redis.Pipeliner) error { From bd234b91fe4eb09e1c3281f9ecc0ec3bbdb3cf30 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Wed, 3 Feb 2021 12:45:02 +0200 Subject: [PATCH 08/10] Add StructValue so we don't need temp slices to pass keys and values --- command.go | 26 +++++++------- internal/hscan/hscan.go | 49 +++++++++++++------------- internal/hscan/hscan_test.go | 42 +++++++++++----------- internal/hscan/structmap.go | 67 +++++++++++++++++++++--------------- 4 files changed, 98 insertions(+), 86 deletions(-) diff --git a/command.go b/command.go index 73194ecc..2932035e 100644 --- a/command.go +++ b/command.go @@ -374,7 +374,7 @@ func (cmd *SliceCmd) String() string { // Scan scans the results from the map into a destination struct. The map keys // are matched in the Redis struct fields by the `redis:"field"` tag. -func (cmd *SliceCmd) Scan(dest interface{}) error { +func (cmd *SliceCmd) Scan(dst interface{}) error { if cmd.err != nil { return cmd.err } @@ -389,7 +389,7 @@ func (cmd *SliceCmd) Scan(dest interface{}) error { args = cmd.args[1:] } - return hscan.Scan(args, cmd.val, dest) + return hscan.Scan(dst, args, cmd.val) } func (cmd *SliceCmd) readReply(rd *proto.Reader) error { @@ -940,23 +940,23 @@ func (cmd *StringStringMapCmd) String() string { // Scan scans the results from the map into a destination struct. The map keys // are matched in the Redis struct fields by the `redis:"field"` tag. -func (cmd *StringStringMapCmd) Scan(dest interface{}) error { +func (cmd *StringStringMapCmd) Scan(dst interface{}) error { if cmd.err != nil { return cmd.err } - // Pass the list of keys and values. Skip the first to args (command, key), - // eg: HGETALL map. - var ( - keys = make([]interface{}, 0, len(cmd.val)) - vals = make([]interface{}, 0, len(cmd.val)) - ) - for k, v := range cmd.val { - keys = append(keys, k) - vals = append(vals, v) + strct, err := hscan.Struct(dst) + if err != nil { + return err } - return hscan.Scan(keys, vals, dest) + for k, v := range cmd.val { + if err := strct.Scan(k, v); err != nil { + return err + } + } + + return nil } func (cmd *StringStringMapCmd) readReply(rd *proto.Reader) error { diff --git a/internal/hscan/hscan.go b/internal/hscan/hscan.go index a4b9e80e..181260b8 100644 --- a/internal/hscan/hscan.go +++ b/internal/hscan/hscan.go @@ -43,32 +43,39 @@ var ( // Global map of struct field specs that is populated once for every new // struct type that is scanned. This caches the field types and the corresponding // decoder functions to avoid iterating through struct fields on subsequent scans. - structSpecs = newStructMap() + globalStructMap = newStructMap() ) +func Struct(dst interface{}) (StructValue, error) { + v := reflect.ValueOf(dst) + + // The dstination to scan into should be a struct pointer. + if v.Kind() != reflect.Ptr || v.IsNil() { + return StructValue{}, fmt.Errorf("redis.Scan(non-pointer %T)", dst) + } + + v = v.Elem() + if v.Kind() != reflect.Struct { + return StructValue{}, fmt.Errorf("redis.Scan(non-struct %T)", dst) + } + + return StructValue{ + spec: globalStructMap.get(v.Type()), + value: v, + }, nil +} + // Scan scans the results from a key-value Redis map result set to a destination struct. // The Redis keys are matched to the struct's field with the `redis` tag. -func Scan(keys []interface{}, vals []interface{}, dest interface{}) error { +func Scan(dst interface{}, keys []interface{}, vals []interface{}) error { if len(keys) != len(vals) { return errors.New("args should have the same number of keys and vals") } - // The destination to scan into should be a struct pointer. - v := reflect.ValueOf(dest) - if v.Kind() != reflect.Ptr || v.IsNil() { - return fmt.Errorf("redis.Scan(non-pointer %T)", dest) + strct, err := Struct(dst) + if err != nil { + return err } - v = v.Elem() - - if v.Kind() != reflect.Struct { - return fmt.Errorf("redis.Scan(non-struct %T)", dest) - } - - // If the struct field spec is not cached, build and cache it to avoid - // iterating through the fields of a struct type every time values are - // scanned into it. - typ := v.Type() - fMap := structSpecs.get(typ) // Iterate through the (key, value) sequence. for i := 0; i < len(vals); i++ { @@ -82,13 +89,7 @@ func Scan(keys []interface{}, vals []interface{}, dest interface{}) error { continue } - // Check if the field name is in the field spec map. - field, ok := fMap.get(key) - if !ok { - continue - } - - if err := field.fn(v.Field(field.index), val); err != nil { + if err := strct.Scan(key, val); err != nil { return err } } diff --git a/internal/hscan/hscan_test.go b/internal/hscan/hscan_test.go index e14b4e91..f7a88f0f 100644 --- a/internal/hscan/hscan_test.go +++ b/internal/hscan/hscan_test.go @@ -28,33 +28,31 @@ func TestGinkgoSuite(t *testing.T) { var _ = Describe("Scan", func() { It("catches bad args", func() { - var ( - d data - ) + var d data - Expect(Scan(i{}, i{}, &d)).NotTo(HaveOccurred()) + Expect(Scan(&d, i{}, i{})).NotTo(HaveOccurred()) Expect(d).To(Equal(data{})) - Expect(Scan(i{"key"}, i{}, &d)).To(HaveOccurred()) - Expect(Scan(i{"key"}, i{"1", "2"}, &d)).To(HaveOccurred()) - Expect(Scan(i{"key", "1"}, i{}, nil)).To(HaveOccurred()) + Expect(Scan(&d, i{"key"}, i{})).To(HaveOccurred()) + Expect(Scan(&d, i{"key"}, i{"1", "2"})).To(HaveOccurred()) + Expect(Scan(nil, i{"key", "1"}, i{})).To(HaveOccurred()) var m map[string]interface{} - Expect(Scan(i{"key"}, i{"1"}, &m)).To(HaveOccurred()) - Expect(Scan(i{"key"}, i{"1"}, data{})).To(HaveOccurred()) - Expect(Scan(i{"key", "string"}, i{nil, nil}, data{})).To(HaveOccurred()) + Expect(Scan(&m, i{"key"}, i{"1"})).To(HaveOccurred()) + Expect(Scan(data{}, i{"key"}, i{"1"})).To(HaveOccurred()) + Expect(Scan(data{}, i{"key", "string"}, i{nil, nil})).To(HaveOccurred()) }) It("scans good values", func() { var d data // non-tagged fields. - Expect(Scan(i{"key"}, i{"value"}, &d)).NotTo(HaveOccurred()) + Expect(Scan(&d, i{"key"}, i{"value"})).NotTo(HaveOccurred()) Expect(d).To(Equal(data{})) keys := i{"string", "byte", "int", "uint", "float", "bool"} vals := i{"str!", "bytes!", "123", "456", "123.456", "1"} - Expect(Scan(keys, vals, &d)).NotTo(HaveOccurred()) + Expect(Scan(&d, keys, vals)).NotTo(HaveOccurred()) Expect(d).To(Equal(data{ String: "str!", Bytes: []byte("bytes!"), @@ -75,7 +73,7 @@ var _ = Describe("Scan", func() { Bool bool `redis:"bool"` } var d2 data2 - Expect(Scan(keys, vals, &d2)).NotTo(HaveOccurred()) + Expect(Scan(&d2, keys, vals)).NotTo(HaveOccurred()) Expect(d2).To(Equal(data2{ String: "str!", Bytes: []byte("bytes!"), @@ -85,7 +83,7 @@ var _ = Describe("Scan", func() { Bool: true, })) - Expect(Scan(i{"string", "float", "bool"}, i{"", "1", "t"}, &d)).NotTo(HaveOccurred()) + Expect(Scan(&d, i{"string", "float", "bool"}, i{"", "1", "t"})).NotTo(HaveOccurred()) Expect(d).To(Equal(data{ String: "", Bytes: []byte("bytes!"), @@ -99,7 +97,7 @@ var _ = Describe("Scan", func() { It("omits untagged fields", func() { var d data - Expect(Scan(i{"empty", "omit", "string"}, i{"value", "value", "str!"}, &d)).NotTo(HaveOccurred()) + Expect(Scan(&d, i{"empty", "omit", "string"}, i{"value", "value", "str!"})).NotTo(HaveOccurred()) Expect(d).To(Equal(data{ String: "str!", })) @@ -108,12 +106,12 @@ var _ = Describe("Scan", func() { It("catches bad values", func() { var d data - Expect(Scan(i{"int"}, i{"a"}, &d)).To(HaveOccurred()) - Expect(Scan(i{"uint"}, i{"a"}, &d)).To(HaveOccurred()) - Expect(Scan(i{"uint"}, i{""}, &d)).To(HaveOccurred()) - Expect(Scan(i{"float"}, i{"b"}, &d)).To(HaveOccurred()) - Expect(Scan(i{"bool"}, i{"-1"}, &d)).To(HaveOccurred()) - Expect(Scan(i{"bool"}, i{""}, &d)).To(HaveOccurred()) - Expect(Scan(i{"bool"}, i{"123"}, &d)).To(HaveOccurred()) + Expect(Scan(&d, i{"int"}, i{"a"})).To(HaveOccurred()) + Expect(Scan(&d, i{"uint"}, i{"a"})).To(HaveOccurred()) + Expect(Scan(&d, i{"uint"}, i{""})).To(HaveOccurred()) + Expect(Scan(&d, i{"float"}, i{"b"})).To(HaveOccurred()) + Expect(Scan(&d, i{"bool"}, i{"-1"})).To(HaveOccurred()) + Expect(Scan(&d, i{"bool"}, i{""})).To(HaveOccurred()) + Expect(Scan(&d, i{"bool"}, i{"123"})).To(HaveOccurred()) }) }) diff --git a/internal/hscan/structmap.go b/internal/hscan/structmap.go index da97f907..c913c61b 100644 --- a/internal/hscan/structmap.go +++ b/internal/hscan/structmap.go @@ -6,17 +6,6 @@ import ( "sync" ) -// structField represents a single field in a target struct. -type structField struct { - index int - fn decoderFunc -} - -// structFields contains the list of all fields in a target struct. -type structFields struct { - m map[string]*structField -} - // structMap contains the map of struct fields for target structs // indexed by the struct type. type structMap struct { @@ -27,37 +16,38 @@ func newStructMap() *structMap { return new(structMap) } -func (s *structMap) get(t reflect.Type) *structFields { +func (s *structMap) get(t reflect.Type) *structSpec { if v, ok := s.m.Load(t); ok { - return v.(*structFields) + return v.(*structSpec) } - fMap := getStructFields(t, "redis") - s.m.Store(t, fMap) - return fMap + spec := newStructSpec(t, "redis") + s.m.Store(t, spec) + return spec } -func newStructFields() *structFields { - return &structFields{ - m: make(map[string]*structField), - } +//------------------------------------------------------------------------------ + +// structSpec contains the list of all fields in a target struct. +type structSpec struct { + m map[string]*structField } -func (s *structFields) set(tag string, sf *structField) { +func (s *structSpec) set(tag string, sf *structField) { s.m[tag] = sf } -func (s *structFields) get(tag string) (*structField, bool) { +func (s *structSpec) get(tag string) (*structField, bool) { f, ok := s.m[tag] return f, ok } -func getStructFields(t reflect.Type, fieldTag string) *structFields { - var ( - num = t.NumField() - out = newStructFields() - ) +func newStructSpec(t reflect.Type, fieldTag string) *structSpec { + out := &structSpec{ + m: make(map[string]*structField), + } + num := t.NumField() for i := 0; i < num; i++ { f := t.Field(i) @@ -77,3 +67,26 @@ func getStructFields(t reflect.Type, fieldTag string) *structFields { return out } + +//------------------------------------------------------------------------------ + +// structField represents a single field in a target struct. +type structField struct { + index int + fn decoderFunc +} + +//------------------------------------------------------------------------------ + +type StructValue struct { + spec *structSpec + value reflect.Value +} + +func (s StructValue) Scan(key string, value string) error { + field, ok := s.spec.m[key] + if !ok { + return nil + } + return field.fn(s.value.Field(field.index), value) +} From f8a546b4820de9e52fe0d244e24dca155688abef Mon Sep 17 00:00:00 2001 From: Kailash Nadh Date: Wed, 3 Feb 2021 17:10:01 +0530 Subject: [PATCH 09/10] Add test for MGet/struct scan --- commands_test.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/commands_test.go b/commands_test.go index 19af05f0..a73f4f8b 100644 --- a/commands_test.go +++ b/commands_test.go @@ -1134,6 +1134,22 @@ var _ = Describe("Commands", func() { Expect(mGet.Val()).To(Equal([]interface{}{"hello1", "hello2", nil})) }) + It("should scan Mget", func() { + err := client.MSet(ctx, "key1", "hello1", "key2", 123).Err() + Expect(err).NotTo(HaveOccurred()) + + res := client.MGet(ctx, "key1", "key2", "_") + Expect(res.Err()).NotTo(HaveOccurred()) + + type data struct { + Key1 string `redis:"key1"` + Key2 int `redis:"key2"` + } + var d data + Expect(res.Scan(&d)).NotTo(HaveOccurred()) + Expect(d).To(Equal(data{Key1: "hello1", Key2: 123})) + }) + It("should MSetNX", func() { mSetNX := client.MSetNX(ctx, "key1", "hello1", "key2", "hello2") Expect(mSetNX.Err()).NotTo(HaveOccurred()) From b358584bd3adbe34d03c2087ead5c01a31ea97cd Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Thu, 4 Feb 2021 09:34:10 +0200 Subject: [PATCH 10/10] Fix build --- internal/hscan/structmap.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/internal/hscan/structmap.go b/internal/hscan/structmap.go index c913c61b..37d86ba5 100644 --- a/internal/hscan/structmap.go +++ b/internal/hscan/structmap.go @@ -37,11 +37,6 @@ func (s *structSpec) set(tag string, sf *structField) { s.m[tag] = sf } -func (s *structSpec) get(tag string) (*structField, bool) { - f, ok := s.m[tag] - return f, ok -} - func newStructSpec(t reflect.Type, fieldTag string) *structSpec { out := &structSpec{ m: make(map[string]*structField),