diff --git a/command.go b/command.go index 73194ecc..2932035e 100644 --- a/command.go +++ b/command.go @@ -374,7 +374,7 @@ func (cmd *SliceCmd) String() string { // Scan scans the results from the map into a destination struct. The map keys // are matched in the Redis struct fields by the `redis:"field"` tag. -func (cmd *SliceCmd) Scan(dest interface{}) error { +func (cmd *SliceCmd) Scan(dst interface{}) error { if cmd.err != nil { return cmd.err } @@ -389,7 +389,7 @@ func (cmd *SliceCmd) Scan(dest interface{}) error { args = cmd.args[1:] } - return hscan.Scan(args, cmd.val, dest) + return hscan.Scan(dst, args, cmd.val) } func (cmd *SliceCmd) readReply(rd *proto.Reader) error { @@ -940,23 +940,23 @@ func (cmd *StringStringMapCmd) String() string { // Scan scans the results from the map into a destination struct. The map keys // are matched in the Redis struct fields by the `redis:"field"` tag. -func (cmd *StringStringMapCmd) Scan(dest interface{}) error { +func (cmd *StringStringMapCmd) Scan(dst interface{}) error { if cmd.err != nil { return cmd.err } - // Pass the list of keys and values. Skip the first to args (command, key), - // eg: HGETALL map. - var ( - keys = make([]interface{}, 0, len(cmd.val)) - vals = make([]interface{}, 0, len(cmd.val)) - ) - for k, v := range cmd.val { - keys = append(keys, k) - vals = append(vals, v) + strct, err := hscan.Struct(dst) + if err != nil { + return err } - return hscan.Scan(keys, vals, dest) + for k, v := range cmd.val { + if err := strct.Scan(k, v); err != nil { + return err + } + } + + return nil } func (cmd *StringStringMapCmd) readReply(rd *proto.Reader) error { diff --git a/internal/hscan/hscan.go b/internal/hscan/hscan.go index a4b9e80e..181260b8 100644 --- a/internal/hscan/hscan.go +++ b/internal/hscan/hscan.go @@ -43,32 +43,39 @@ var ( // Global map of struct field specs that is populated once for every new // struct type that is scanned. This caches the field types and the corresponding // decoder functions to avoid iterating through struct fields on subsequent scans. - structSpecs = newStructMap() + globalStructMap = newStructMap() ) +func Struct(dst interface{}) (StructValue, error) { + v := reflect.ValueOf(dst) + + // The dstination to scan into should be a struct pointer. + if v.Kind() != reflect.Ptr || v.IsNil() { + return StructValue{}, fmt.Errorf("redis.Scan(non-pointer %T)", dst) + } + + v = v.Elem() + if v.Kind() != reflect.Struct { + return StructValue{}, fmt.Errorf("redis.Scan(non-struct %T)", dst) + } + + return StructValue{ + spec: globalStructMap.get(v.Type()), + value: v, + }, nil +} + // Scan scans the results from a key-value Redis map result set to a destination struct. // The Redis keys are matched to the struct's field with the `redis` tag. -func Scan(keys []interface{}, vals []interface{}, dest interface{}) error { +func Scan(dst interface{}, keys []interface{}, vals []interface{}) error { if len(keys) != len(vals) { return errors.New("args should have the same number of keys and vals") } - // The destination to scan into should be a struct pointer. - v := reflect.ValueOf(dest) - if v.Kind() != reflect.Ptr || v.IsNil() { - return fmt.Errorf("redis.Scan(non-pointer %T)", dest) + strct, err := Struct(dst) + if err != nil { + return err } - v = v.Elem() - - if v.Kind() != reflect.Struct { - return fmt.Errorf("redis.Scan(non-struct %T)", dest) - } - - // If the struct field spec is not cached, build and cache it to avoid - // iterating through the fields of a struct type every time values are - // scanned into it. - typ := v.Type() - fMap := structSpecs.get(typ) // Iterate through the (key, value) sequence. for i := 0; i < len(vals); i++ { @@ -82,13 +89,7 @@ func Scan(keys []interface{}, vals []interface{}, dest interface{}) error { continue } - // Check if the field name is in the field spec map. - field, ok := fMap.get(key) - if !ok { - continue - } - - if err := field.fn(v.Field(field.index), val); err != nil { + if err := strct.Scan(key, val); err != nil { return err } } diff --git a/internal/hscan/hscan_test.go b/internal/hscan/hscan_test.go index e14b4e91..f7a88f0f 100644 --- a/internal/hscan/hscan_test.go +++ b/internal/hscan/hscan_test.go @@ -28,33 +28,31 @@ func TestGinkgoSuite(t *testing.T) { var _ = Describe("Scan", func() { It("catches bad args", func() { - var ( - d data - ) + var d data - Expect(Scan(i{}, i{}, &d)).NotTo(HaveOccurred()) + Expect(Scan(&d, i{}, i{})).NotTo(HaveOccurred()) Expect(d).To(Equal(data{})) - Expect(Scan(i{"key"}, i{}, &d)).To(HaveOccurred()) - Expect(Scan(i{"key"}, i{"1", "2"}, &d)).To(HaveOccurred()) - Expect(Scan(i{"key", "1"}, i{}, nil)).To(HaveOccurred()) + Expect(Scan(&d, i{"key"}, i{})).To(HaveOccurred()) + Expect(Scan(&d, i{"key"}, i{"1", "2"})).To(HaveOccurred()) + Expect(Scan(nil, i{"key", "1"}, i{})).To(HaveOccurred()) var m map[string]interface{} - Expect(Scan(i{"key"}, i{"1"}, &m)).To(HaveOccurred()) - Expect(Scan(i{"key"}, i{"1"}, data{})).To(HaveOccurred()) - Expect(Scan(i{"key", "string"}, i{nil, nil}, data{})).To(HaveOccurred()) + Expect(Scan(&m, i{"key"}, i{"1"})).To(HaveOccurred()) + Expect(Scan(data{}, i{"key"}, i{"1"})).To(HaveOccurred()) + Expect(Scan(data{}, i{"key", "string"}, i{nil, nil})).To(HaveOccurred()) }) It("scans good values", func() { var d data // non-tagged fields. - Expect(Scan(i{"key"}, i{"value"}, &d)).NotTo(HaveOccurred()) + Expect(Scan(&d, i{"key"}, i{"value"})).NotTo(HaveOccurred()) Expect(d).To(Equal(data{})) keys := i{"string", "byte", "int", "uint", "float", "bool"} vals := i{"str!", "bytes!", "123", "456", "123.456", "1"} - Expect(Scan(keys, vals, &d)).NotTo(HaveOccurred()) + Expect(Scan(&d, keys, vals)).NotTo(HaveOccurred()) Expect(d).To(Equal(data{ String: "str!", Bytes: []byte("bytes!"), @@ -75,7 +73,7 @@ var _ = Describe("Scan", func() { Bool bool `redis:"bool"` } var d2 data2 - Expect(Scan(keys, vals, &d2)).NotTo(HaveOccurred()) + Expect(Scan(&d2, keys, vals)).NotTo(HaveOccurred()) Expect(d2).To(Equal(data2{ String: "str!", Bytes: []byte("bytes!"), @@ -85,7 +83,7 @@ var _ = Describe("Scan", func() { Bool: true, })) - Expect(Scan(i{"string", "float", "bool"}, i{"", "1", "t"}, &d)).NotTo(HaveOccurred()) + Expect(Scan(&d, i{"string", "float", "bool"}, i{"", "1", "t"})).NotTo(HaveOccurred()) Expect(d).To(Equal(data{ String: "", Bytes: []byte("bytes!"), @@ -99,7 +97,7 @@ var _ = Describe("Scan", func() { It("omits untagged fields", func() { var d data - Expect(Scan(i{"empty", "omit", "string"}, i{"value", "value", "str!"}, &d)).NotTo(HaveOccurred()) + Expect(Scan(&d, i{"empty", "omit", "string"}, i{"value", "value", "str!"})).NotTo(HaveOccurred()) Expect(d).To(Equal(data{ String: "str!", })) @@ -108,12 +106,12 @@ var _ = Describe("Scan", func() { It("catches bad values", func() { var d data - Expect(Scan(i{"int"}, i{"a"}, &d)).To(HaveOccurred()) - Expect(Scan(i{"uint"}, i{"a"}, &d)).To(HaveOccurred()) - Expect(Scan(i{"uint"}, i{""}, &d)).To(HaveOccurred()) - Expect(Scan(i{"float"}, i{"b"}, &d)).To(HaveOccurred()) - Expect(Scan(i{"bool"}, i{"-1"}, &d)).To(HaveOccurred()) - Expect(Scan(i{"bool"}, i{""}, &d)).To(HaveOccurred()) - Expect(Scan(i{"bool"}, i{"123"}, &d)).To(HaveOccurred()) + Expect(Scan(&d, i{"int"}, i{"a"})).To(HaveOccurred()) + Expect(Scan(&d, i{"uint"}, i{"a"})).To(HaveOccurred()) + Expect(Scan(&d, i{"uint"}, i{""})).To(HaveOccurred()) + Expect(Scan(&d, i{"float"}, i{"b"})).To(HaveOccurred()) + Expect(Scan(&d, i{"bool"}, i{"-1"})).To(HaveOccurred()) + Expect(Scan(&d, i{"bool"}, i{""})).To(HaveOccurred()) + Expect(Scan(&d, i{"bool"}, i{"123"})).To(HaveOccurred()) }) }) diff --git a/internal/hscan/structmap.go b/internal/hscan/structmap.go index da97f907..c913c61b 100644 --- a/internal/hscan/structmap.go +++ b/internal/hscan/structmap.go @@ -6,17 +6,6 @@ import ( "sync" ) -// structField represents a single field in a target struct. -type structField struct { - index int - fn decoderFunc -} - -// structFields contains the list of all fields in a target struct. -type structFields struct { - m map[string]*structField -} - // structMap contains the map of struct fields for target structs // indexed by the struct type. type structMap struct { @@ -27,37 +16,38 @@ func newStructMap() *structMap { return new(structMap) } -func (s *structMap) get(t reflect.Type) *structFields { +func (s *structMap) get(t reflect.Type) *structSpec { if v, ok := s.m.Load(t); ok { - return v.(*structFields) + return v.(*structSpec) } - fMap := getStructFields(t, "redis") - s.m.Store(t, fMap) - return fMap + spec := newStructSpec(t, "redis") + s.m.Store(t, spec) + return spec } -func newStructFields() *structFields { - return &structFields{ - m: make(map[string]*structField), - } +//------------------------------------------------------------------------------ + +// structSpec contains the list of all fields in a target struct. +type structSpec struct { + m map[string]*structField } -func (s *structFields) set(tag string, sf *structField) { +func (s *structSpec) set(tag string, sf *structField) { s.m[tag] = sf } -func (s *structFields) get(tag string) (*structField, bool) { +func (s *structSpec) get(tag string) (*structField, bool) { f, ok := s.m[tag] return f, ok } -func getStructFields(t reflect.Type, fieldTag string) *structFields { - var ( - num = t.NumField() - out = newStructFields() - ) +func newStructSpec(t reflect.Type, fieldTag string) *structSpec { + out := &structSpec{ + m: make(map[string]*structField), + } + num := t.NumField() for i := 0; i < num; i++ { f := t.Field(i) @@ -77,3 +67,26 @@ func getStructFields(t reflect.Type, fieldTag string) *structFields { return out } + +//------------------------------------------------------------------------------ + +// structField represents a single field in a target struct. +type structField struct { + index int + fn decoderFunc +} + +//------------------------------------------------------------------------------ + +type StructValue struct { + spec *structSpec + value reflect.Value +} + +func (s StructValue) Scan(key string, value string) error { + field, ok := s.spec.m[key] + if !ok { + return nil + } + return field.fn(s.value.Field(field.index), value) +}