diff --git a/command.go b/command.go index d7c76cf..68c438c 100644 --- a/command.go +++ b/command.go @@ -21,6 +21,7 @@ var ( _ Cmder = (*StringSliceCmd)(nil) _ Cmder = (*BoolSliceCmd)(nil) _ Cmder = (*StringStringMapCmd)(nil) + _ Cmder = (*StringIntMapCmd)(nil) _ Cmder = (*ZSliceCmd)(nil) _ Cmder = (*ScanCmd)(nil) ) @@ -514,6 +515,42 @@ func (cmd *StringStringMapCmd) parseReply(rd *bufio.Reader) error { //------------------------------------------------------------------------------ +type StringIntMapCmd struct { + *baseCmd + + val map[string]int64 +} + +func NewStringIntMapCmd(args ...string) *StringIntMapCmd { + return &StringIntMapCmd{ + baseCmd: newBaseCmd(args...), + } +} + +func (cmd *StringIntMapCmd) Val() map[string]int64 { + return cmd.val +} + +func (cmd *StringIntMapCmd) Result() (map[string]int64, error) { + return cmd.val, cmd.err +} + +func (cmd *StringIntMapCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *StringIntMapCmd) parseReply(rd *bufio.Reader) error { + v, err := parseReply(rd, parseStringIntMap) + if err != nil { + cmd.err = err + return err + } + cmd.val = v.(map[string]int64) + return nil +} + +//------------------------------------------------------------------------------ + type ZSliceCmd struct { *baseCmd diff --git a/commands.go b/commands.go index 6068bab..796b04a 100644 --- a/commands.go +++ b/commands.go @@ -1231,10 +1231,10 @@ func (c *Client) PubSubChannels(pattern string) *StringSliceCmd { return cmd } -func (c *Client) PubSubNumSub(channels ...string) *SliceCmd { +func (c *Client) PubSubNumSub(channels ...string) *StringIntMapCmd { args := []string{"PUBSUB", "NUMSUB"} args = append(args, channels...) - cmd := NewSliceCmd(args...) + cmd := NewStringIntMapCmd(args...) c.Process(cmd) return cmd } diff --git a/parser.go b/parser.go index b4c380c..b5e30a4 100644 --- a/parser.go +++ b/parser.go @@ -229,6 +229,38 @@ func parseStringStringMap(rd *bufio.Reader, n int64) (interface{}, error) { return m, nil } +func parseStringIntMap(rd *bufio.Reader, n int64) (interface{}, error) { + m := make(map[string]int64, n/2) + for i := int64(0); i < n; i += 2 { + keyiface, err := parseReply(rd, nil) + if err != nil { + return nil, err + } + key, ok := keyiface.(string) + if !ok { + return nil, fmt.Errorf("got %T, expected string", keyiface) + } + + valueiface, err := parseReply(rd, nil) + if err != nil { + return nil, err + } + switch value := valueiface.(type) { + case int64: + m[key] = value + case string: + m[key], err = strconv.ParseInt(value, 10, 64) + if err != nil { + return nil, fmt.Errorf("got %v, expected number", value) + } + default: + return nil, fmt.Errorf("got %T, expected number or string", valueiface) + + } + } + return m, nil +} + func parseZSlice(rd *bufio.Reader, n int64) (interface{}, error) { zz := make([]Z, n/2) for i := int64(0); i < n; i += 2 { diff --git a/pubsub_test.go b/pubsub_test.go index 76e3fa9..d67f745 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -101,10 +101,10 @@ var _ = Describe("PubSub", func() { channels, err := client.PubSubNumSub("mychannel", "mychannel2", "mychannel3").Result() Expect(err).NotTo(HaveOccurred()) - Expect(channels).To(Equal([]interface{}{ - "mychannel", int64(1), - "mychannel2", int64(1), - "mychannel3", int64(0), + Expect(channels).To(Equal(map[string]int64{ + "mychannel": 1, + "mychannel2": 1, + "mychannel3": 0, })) })