diff --git a/command.go b/command.go index 602ff61b..642352c9 100644 --- a/command.go +++ b/command.go @@ -372,11 +372,12 @@ func (cmd *SliceCmd) String() string { return cmdString(cmd, cmd.val) } -// Scan scans the results from a key-value Redis map result set ([]interface{}) -// like HMGET and HGETALL to a destination struct. -// The Redis keys are matched to the struct's field with the `redis` tag. -func (cmd *SliceCmd) Scan(val interface{}) error { - return hscan.Scan(cmd.val, val) +// Scan scans the results from the map into a destination struct. The map keys +// are matched in the Redis struct fields by the `redis:"field"` tag. +func (cmd *SliceCmd) Scan(dest interface{}) error { + // Pass the list of keys and values. Skip the first to args (command, key), + // eg: HMGET map. + return hscan.Scan(cmd.args[2:], cmd.val, dest) } func (cmd *SliceCmd) readReply(rd *proto.Reader) error { @@ -925,6 +926,23 @@ func (cmd *StringStringMapCmd) String() string { return cmdString(cmd, cmd.val) } +// Scan scans the results from the map into a destination struct. The map keys +// are matched in the Redis struct fields by the `redis:"field"` tag. +func (cmd *StringStringMapCmd) Scan(dest interface{}) error { + // Pass the list of keys and values. Skip the first to args (command, key), + // eg: HGETALL map. + var ( + keys = make([]interface{}, 0, len(cmd.val)) + vals = make([]interface{}, 0, len(cmd.val)) + ) + for k, v := range cmd.val { + keys = append(keys, k) + vals = append(vals, v) + } + + return hscan.Scan(keys, vals, dest) +} + func (cmd *StringStringMapCmd) readReply(rd *proto.Reader) error { _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { cmd.val = make(map[string]string, n/2) diff --git a/commands_test.go b/commands_test.go index 707aea1b..19af05f0 100644 --- a/commands_test.go +++ b/commands_test.go @@ -1375,6 +1375,22 @@ var _ = Describe("Commands", func() { Expect(m).To(Equal(map[string]string{"key1": "hello1", "key2": "hello2"})) }) + It("should scan", func() { + err := client.HMSet(ctx, "hash", "key1", "hello1", "key2", 123).Err() + Expect(err).NotTo(HaveOccurred()) + + res := client.HGetAll(ctx, "hash") + Expect(res.Err()).NotTo(HaveOccurred()) + + type data struct { + Key1 string `redis:"key1"` + Key2 int `redis:"key2"` + } + var d data + Expect(res.Scan(&d)).NotTo(HaveOccurred()) + Expect(d).To(Equal(data{Key1: "hello1", Key2: 123})) + }) + It("should HIncrBy", func() { hSet := client.HSet(ctx, "hash", "key", "5") Expect(hSet.Err()).NotTo(HaveOccurred()) diff --git a/internal/hscan/hscan.go b/internal/hscan/hscan.go index 51346b83..a4b9e80e 100644 --- a/internal/hscan/hscan.go +++ b/internal/hscan/hscan.go @@ -46,12 +46,11 @@ var ( structSpecs = newStructMap() ) -// Scan scans the results from a key-value Redis map result set ([]interface{}) -// to a destination struct. The Redis keys are matched to the struct's field -// with the `redis` tag. -func Scan(vals []interface{}, dest interface{}) error { - if len(vals)%2 != 0 { - return errors.New("args should have an even number of items (key-val)") +// Scan scans the results from a key-value Redis map result set to a destination struct. +// The Redis keys are matched to the struct's field with the `redis` tag. +func Scan(keys []interface{}, vals []interface{}, dest interface{}) error { + if len(keys) != len(vals) { + return errors.New("args should have the same number of keys and vals") } // The destination to scan into should be a struct pointer. @@ -72,13 +71,13 @@ func Scan(vals []interface{}, dest interface{}) error { fMap := structSpecs.get(typ) // Iterate through the (key, value) sequence. - for i := 0; i < len(vals); i += 2 { - key, ok := vals[i].(string) + for i := 0; i < len(vals); i++ { + key, ok := keys[i].(string) if !ok { continue } - val, ok := vals[i+1].(string) + val, ok := vals[i].(string) if !ok { continue } diff --git a/internal/hscan/hscan_test.go b/internal/hscan/hscan_test.go index 7a4d0f5b..e14b4e91 100644 --- a/internal/hscan/hscan_test.go +++ b/internal/hscan/hscan_test.go @@ -19,6 +19,8 @@ type data struct { Bool bool `redis:"bool"` } +type i []interface{} + func TestGinkgoSuite(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "hscan") @@ -26,35 +28,33 @@ func TestGinkgoSuite(t *testing.T) { var _ = Describe("Scan", func() { It("catches bad args", func() { - var d data + var ( + d data + ) - Expect(Scan([]interface{}{}, &d)).NotTo(HaveOccurred()) + Expect(Scan(i{}, i{}, &d)).NotTo(HaveOccurred()) Expect(d).To(Equal(data{})) - Expect(Scan([]interface{}{"key"}, &d)).To(HaveOccurred()) - Expect(Scan([]interface{}{"key", "1", "2"}, &d)).To(HaveOccurred()) - Expect(Scan([]interface{}{"key", "1"}, nil)).To(HaveOccurred()) + Expect(Scan(i{"key"}, i{}, &d)).To(HaveOccurred()) + Expect(Scan(i{"key"}, i{"1", "2"}, &d)).To(HaveOccurred()) + Expect(Scan(i{"key", "1"}, i{}, nil)).To(HaveOccurred()) - var i map[string]interface{} - Expect(Scan([]interface{}{"key", "1"}, &i)).To(HaveOccurred()) - Expect(Scan([]interface{}{"key", "1"}, data{})).To(HaveOccurred()) - Expect(Scan([]interface{}{"key", nil, "string", nil}, data{})).To(HaveOccurred()) + var m map[string]interface{} + Expect(Scan(i{"key"}, i{"1"}, &m)).To(HaveOccurred()) + Expect(Scan(i{"key"}, i{"1"}, data{})).To(HaveOccurred()) + Expect(Scan(i{"key", "string"}, i{nil, nil}, data{})).To(HaveOccurred()) }) It("scans good values", func() { var d data // non-tagged fields. - Expect(Scan([]interface{}{"key", "value"}, &d)).NotTo(HaveOccurred()) + Expect(Scan(i{"key"}, i{"value"}, &d)).NotTo(HaveOccurred()) Expect(d).To(Equal(data{})) - res := []interface{}{"string", "str!", - "byte", "bytes!", - "int", "123", - "uint", "456", - "float", "123.456", - "bool", "1"} - Expect(Scan(res, &d)).NotTo(HaveOccurred()) + keys := i{"string", "byte", "int", "uint", "float", "bool"} + vals := i{"str!", "bytes!", "123", "456", "123.456", "1"} + Expect(Scan(keys, vals, &d)).NotTo(HaveOccurred()) Expect(d).To(Equal(data{ String: "str!", Bytes: []byte("bytes!"), @@ -75,7 +75,7 @@ var _ = Describe("Scan", func() { Bool bool `redis:"bool"` } var d2 data2 - Expect(Scan(res, &d2)).NotTo(HaveOccurred()) + Expect(Scan(keys, vals, &d2)).NotTo(HaveOccurred()) Expect(d2).To(Equal(data2{ String: "str!", Bytes: []byte("bytes!"), @@ -85,10 +85,7 @@ var _ = Describe("Scan", func() { Bool: true, })) - Expect(Scan([]interface{}{ - "string", "", - "float", "1", - "bool", "t"}, &d)).NotTo(HaveOccurred()) + Expect(Scan(i{"string", "float", "bool"}, i{"", "1", "t"}, &d)).NotTo(HaveOccurred()) Expect(d).To(Equal(data{ String: "", Bytes: []byte("bytes!"), @@ -102,10 +99,7 @@ var _ = Describe("Scan", func() { It("omits untagged fields", func() { var d data - Expect(Scan([]interface{}{ - "empty", "value", - "omit", "value", - "string", "str!"}, &d)).NotTo(HaveOccurred()) + Expect(Scan(i{"empty", "omit", "string"}, i{"value", "value", "str!"}, &d)).NotTo(HaveOccurred()) Expect(d).To(Equal(data{ String: "str!", })) @@ -114,12 +108,12 @@ var _ = Describe("Scan", func() { It("catches bad values", func() { var d data - Expect(Scan([]interface{}{"int", "a"}, &d)).To(HaveOccurred()) - Expect(Scan([]interface{}{"uint", "a"}, &d)).To(HaveOccurred()) - Expect(Scan([]interface{}{"uint", ""}, &d)).To(HaveOccurred()) - Expect(Scan([]interface{}{"float", "b"}, &d)).To(HaveOccurred()) - Expect(Scan([]interface{}{"bool", "-1"}, &d)).To(HaveOccurred()) - Expect(Scan([]interface{}{"bool", ""}, &d)).To(HaveOccurred()) - Expect(Scan([]interface{}{"bool", "123"}, &d)).To(HaveOccurred()) + Expect(Scan(i{"int"}, i{"a"}, &d)).To(HaveOccurred()) + Expect(Scan(i{"uint"}, i{"a"}, &d)).To(HaveOccurred()) + Expect(Scan(i{"uint"}, i{""}, &d)).To(HaveOccurred()) + Expect(Scan(i{"float"}, i{"b"}, &d)).To(HaveOccurred()) + Expect(Scan(i{"bool"}, i{"-1"}, &d)).To(HaveOccurred()) + Expect(Scan(i{"bool"}, i{""}, &d)).To(HaveOccurred()) + Expect(Scan(i{"bool"}, i{"123"}, &d)).To(HaveOccurred()) }) })