From 8ad01240a494fe590a414df0c50fdb612af127e4 Mon Sep 17 00:00:00 2001 From: monkey92t Date: Tue, 27 Apr 2021 15:04:46 +0800 Subject: [PATCH] Add support for resp3 protocol (#1739) * support resp3 protocol Signed-off-by: monkey * Upgrade mod version limit go1.14 https://github.com/go-redis/redis/issues/1715#issuecomment-820685614 Signed-off-by: monkey * Remove the redundant check of ReadReply Signed-off-by: monkey * fix the problem Signed-off-by: monkey * workflows add v9 Signed-off-by: monkey * update StringStringMapCmd to MapStringStringCmd Signed-off-by: monkey --- .github/workflows/build.yml | 2 +- bench_decode_test.go | 18 +- cluster.go | 7 +- cluster_test.go | 11 +- command.go | 1616 +++++++++++++++++---------------- commands.go | 79 +- commands_test.go | 44 +- example_test.go | 6 +- go.mod | 2 +- internal/proto/reader.go | 549 +++++++---- internal/proto/reader_test.go | 66 +- internal/proto/writer.go | 4 +- pubsub_test.go | 30 - redis.go | 26 +- result.go | 4 +- ring_test.go | 3 +- sentinel.go | 67 +- sentinel_test.go | 4 +- 18 files changed, 1402 insertions(+), 1136 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index f58a11c..e6858a3 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -4,7 +4,7 @@ on: push: branches: [master] pull_request: - branches: [master] + branches: [master, v9] jobs: build: diff --git a/bench_decode_test.go b/bench_decode_test.go index 8382806..b07ad4e 100644 --- a/bench_decode_test.go +++ b/bench_decode_test.go @@ -18,14 +18,17 @@ type ClientStub struct { resp []byte } +var initHello = []byte("%1\r\n+proto\r\n:3\r\n") + func NewClientStub(resp []byte) *ClientStub { stub := &ClientStub{ resp: resp, } + stub.Cmdable = NewClient(&Options{ PoolSize: 128, Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) { - return stub.stubConn(), nil + return stub.stubConn(initHello), nil }, }) return stub @@ -40,7 +43,7 @@ func NewClusterClientStub(resp []byte) *ClientStub { PoolSize: 128, Addrs: []string{"127.0.0.1:6379"}, Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) { - return stub.stubConn(), nil + return stub.stubConn(initHello), nil }, ClusterSlots: func(_ context.Context) ([]ClusterSlot, error) { return []ClusterSlot{ @@ -65,18 +68,27 @@ func NewClusterClientStub(resp []byte) *ClientStub { return stub } -func (c *ClientStub) stubConn() *ConnStub { +func (c *ClientStub) stubConn(init []byte) *ConnStub { return &ConnStub{ + init: init, resp: c.resp, } } type ConnStub struct { + init []byte resp []byte pos int } func (c *ConnStub) Read(b []byte) (n int, err error) { + // Return conn.init() + if len(c.init) > 0 { + n = copy(b, c.init) + c.init = c.init[n:] + return n, nil + } + if len(c.resp) == 0 { return 0, io.EOF } diff --git a/cluster.go b/cluster.go index e5d49dd..738d50a 100644 --- a/cluster.go +++ b/cluster.go @@ -1392,12 +1392,7 @@ func (c *ClusterClient) txPipelineReadQueued( return err } - switch line[0] { - case proto.ErrorReply: - return proto.ParseErrorReply(line) - case proto.ArrayReply: - // ok - default: + if line[0] != proto.RespArray { return fmt.Errorf("redis: expected '*', but got line %q", line) } diff --git a/cluster_test.go b/cluster_test.go index 3880d43..4c4e4d3 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -1182,16 +1182,17 @@ var _ = Describe("ClusterClient with unavailable Cluster", func() { var client *redis.ClusterClient BeforeEach(func() { - for _, node := range cluster.clients { - err := node.ClientPause(ctx, 5*time.Second).Err() - Expect(err).NotTo(HaveOccurred()) - } - opt := redisClusterOptions() opt.ReadTimeout = 250 * time.Millisecond opt.WriteTimeout = 250 * time.Millisecond opt.MaxRedirects = 1 client = cluster.newClusterClientUnstable(opt) + Expect(client.Ping(ctx).Err()).NotTo(HaveOccurred()) + + for _, node := range cluster.clients { + err := node.ClientPause(ctx, 5*time.Second).Err() + Expect(err).NotTo(HaveOccurred()) + } }) AfterEach(func() { diff --git a/command.go b/command.go index f10c478..28fd8c1 100644 --- a/command.go +++ b/command.go @@ -316,31 +316,10 @@ func (cmd *Cmd) Bool() (bool, error) { } func (cmd *Cmd) readReply(rd *proto.Reader) (err error) { - cmd.val, err = rd.ReadReply(sliceParser) + cmd.val, err = rd.ReadReply() return err } -// sliceParser implements proto.MultiBulkParse. -func sliceParser(rd *proto.Reader, n int64) (interface{}, error) { - vals := make([]interface{}, n) - for i := 0; i < len(vals); i++ { - v, err := rd.ReadReply(sliceParser) - if err != nil { - if err == Nil { - vals[i] = nil - continue - } - if err, ok := err.(proto.RedisError); ok { - vals[i] = err - continue - } - return nil, err - } - vals[i] = v - } - return vals, nil -} - //------------------------------------------------------------------------------ type SliceCmd struct { @@ -392,13 +371,9 @@ func (cmd *SliceCmd) Scan(dst interface{}) error { return hscan.Scan(dst, args, cmd.val) } -func (cmd *SliceCmd) readReply(rd *proto.Reader) error { - v, err := rd.ReadArrayReply(sliceParser) - if err != nil { - return err - } - cmd.val = v.([]interface{}) - return nil +func (cmd *SliceCmd) readReply(rd *proto.Reader) (err error) { + cmd.val, err = rd.ReadSlice() + return err } //------------------------------------------------------------------------------ @@ -473,7 +448,7 @@ func (cmd *IntCmd) String() string { } func (cmd *IntCmd) readReply(rd *proto.Reader) (err error) { - cmd.val, err = rd.ReadIntReply() + cmd.val, err = rd.ReadInt() return err } @@ -509,18 +484,17 @@ func (cmd *IntSliceCmd) String() string { } func (cmd *IntSliceCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.val = make([]int64, n) - for i := 0; i < len(cmd.val); i++ { - num, err := rd.ReadIntReply() - if err != nil { - return nil, err - } - cmd.val[i] = num + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + cmd.val = make([]int64, n) + for i := 0; i < len(cmd.val); i++ { + if cmd.val[i], err = rd.ReadInt(); err != nil { + return err } - return nil, nil - }) - return err + } + return nil } //------------------------------------------------------------------------------ @@ -557,7 +531,7 @@ func (cmd *DurationCmd) String() string { } func (cmd *DurationCmd) readReply(rd *proto.Reader) error { - n, err := rd.ReadIntReply() + n, err := rd.ReadInt() if err != nil { return err } @@ -604,25 +578,19 @@ func (cmd *TimeCmd) String() string { } func (cmd *TimeCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - if n != 2 { - return nil, fmt.Errorf("got %d elements, expected 2", n) - } - - sec, err := rd.ReadInt() - if err != nil { - return nil, err - } - - microsec, err := rd.ReadInt() - if err != nil { - return nil, err - } - - cmd.val = time.Unix(sec, microsec*1000) - return nil, nil - }) - return err + if err := rd.ReadFixedArrayLen(2); err != nil { + return err + } + second, err := rd.ReadInt() + if err != nil { + return err + } + microsecond, err := rd.ReadInt() + if err != nil { + return err + } + cmd.val = time.Unix(second, microsecond*1000) + return nil } //------------------------------------------------------------------------------ @@ -656,27 +624,16 @@ func (cmd *BoolCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *BoolCmd) readReply(rd *proto.Reader) error { - v, err := rd.ReadReply(nil) +func (cmd *BoolCmd) readReply(rd *proto.Reader) (err error) { + cmd.val, err = rd.ReadBool() + // `SET key value NX` returns nil when key already exists. But // `SETNX key value` returns bool (0/1). So convert nil to bool. if err == Nil { cmd.val = false - return nil - } - if err != nil { - return err - } - switch v := v.(type) { - case int64: - cmd.val = v == 1 - return nil - case string: - cmd.val = v == "OK" - return nil - default: - return fmt.Errorf("got %T, wanted int64 or string", v) + err = nil } + return err } //------------------------------------------------------------------------------ @@ -811,7 +768,7 @@ func (cmd *FloatCmd) String() string { } func (cmd *FloatCmd) readReply(rd *proto.Reader) (err error) { - cmd.val, err = rd.ReadFloatReply() + cmd.val, err = rd.ReadFloat() return err } @@ -847,21 +804,23 @@ func (cmd *FloatSliceCmd) String() string { } func (cmd *FloatSliceCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.val = make([]float64, n) - for i := 0; i < len(cmd.val); i++ { - switch num, err := rd.ReadFloatReply(); { - case err == Nil: - cmd.val[i] = 0 - case err != nil: - return nil, err - default: - cmd.val[i] = num - } + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + + cmd.val = make([]float64, n) + for i := 0; i < len(cmd.val); i++ { + switch num, err := rd.ReadFloat(); { + case err == Nil: + cmd.val[i] = 0 + case err != nil: + return err + default: + cmd.val[i] = num } - return nil, nil - }) - return err + } + return nil } //------------------------------------------------------------------------------ @@ -900,21 +859,111 @@ func (cmd *StringSliceCmd) ScanSlice(container interface{}) error { } func (cmd *StringSliceCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.val = make([]string, n) - for i := 0; i < len(cmd.val); i++ { - switch s, err := rd.ReadString(); { - case err == Nil: - cmd.val[i] = "" - case err != nil: - return nil, err - default: - cmd.val[i] = s + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + cmd.val = make([]string, n) + for i := 0; i < len(cmd.val); i++ { + switch s, err := rd.ReadString(); { + case err == Nil: + cmd.val[i] = "" + case err != nil: + return err + default: + cmd.val[i] = s + } + } + return nil +} + +//------------------------------------------------------------------------------ + +type KeyValue struct { + Key string + Value string +} + +type KeyValueSliceCmd struct { + baseCmd + + val []KeyValue +} + +var _ Cmder = (*KeyValueSliceCmd)(nil) + +func NewKeyValueSliceCmd(ctx context.Context, args ...interface{}) *KeyValueSliceCmd { + return &KeyValueSliceCmd{ + baseCmd: baseCmd{ + ctx: ctx, + args: args, + }, + } +} + +func (cmd *KeyValueSliceCmd) Val() []KeyValue { + return cmd.val +} + +func (cmd *KeyValueSliceCmd) Result() ([]KeyValue, error) { + return cmd.val, cmd.err +} + +func (cmd *KeyValueSliceCmd) String() string { + return cmdString(cmd, cmd.val) +} + +// Many commands will respond to two formats: +// 1) 1) "one" +// 2) (double) 1 +// 2) 1) "two" +// 2) (double) 2 +// OR: +// 1) "two" +// 2) (double) 2 +// 3) "one" +// 4) (double) 1 +func (cmd *KeyValueSliceCmd) readReply(rd *proto.Reader) error { // nolint:dupl + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + + // If the n is 0, can't continue reading. + if n == 0 { + cmd.val = make([]KeyValue, 0) + return nil + } + + typ, err := rd.PeekReplyType() + if err != nil { + return err + } + array := typ == proto.RespArray + + if array { + cmd.val = make([]KeyValue, n) + } else { + cmd.val = make([]KeyValue, n/2) + } + + for i := 0; i < len(cmd.val); i++ { + if array { + if err = rd.ReadFixedArrayLen(2); err != nil { + return err } } - return nil, nil - }) - return err + + if cmd.val[i].Key, err = rd.ReadString(); err != nil { + return err + } + + if cmd.val[i].Value, err = rd.ReadString(); err != nil { + return err + } + } + + return nil } //------------------------------------------------------------------------------ @@ -949,32 +998,31 @@ func (cmd *BoolSliceCmd) String() string { } func (cmd *BoolSliceCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.val = make([]bool, n) - for i := 0; i < len(cmd.val); i++ { - n, err := rd.ReadIntReply() - if err != nil { - return nil, err - } - cmd.val[i] = n == 1 + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + cmd.val = make([]bool, n) + for i := 0; i < len(cmd.val); i++ { + if cmd.val[i], err = rd.ReadBool(); err != nil { + return err } - return nil, nil - }) - return err + } + return nil } //------------------------------------------------------------------------------ -type StringStringMapCmd struct { +type MapStringStringCmd struct { baseCmd val map[string]string } -var _ Cmder = (*StringStringMapCmd)(nil) +var _ Cmder = (*MapStringStringCmd)(nil) -func NewStringStringMapCmd(ctx context.Context, args ...interface{}) *StringStringMapCmd { - return &StringStringMapCmd{ +func NewMapStringStringCmd(ctx context.Context, args ...interface{}) *MapStringStringCmd { + return &MapStringStringCmd{ baseCmd: baseCmd{ ctx: ctx, args: args, @@ -982,21 +1030,21 @@ func NewStringStringMapCmd(ctx context.Context, args ...interface{}) *StringStri } } -func (cmd *StringStringMapCmd) Val() map[string]string { +func (cmd *MapStringStringCmd) Val() map[string]string { return cmd.val } -func (cmd *StringStringMapCmd) Result() (map[string]string, error) { +func (cmd *MapStringStringCmd) Result() (map[string]string, error) { return cmd.val, cmd.err } -func (cmd *StringStringMapCmd) String() string { +func (cmd *MapStringStringCmd) String() string { return cmdString(cmd, cmd.val) } // Scan scans the results from the map into a destination struct. The map keys // are matched in the Redis struct fields by the `redis:"field"` tag. -func (cmd *StringStringMapCmd) Scan(dst interface{}) error { +func (cmd *MapStringStringCmd) Scan(dst interface{}) error { if cmd.err != nil { return cmd.err } @@ -1015,25 +1063,27 @@ func (cmd *StringStringMapCmd) Scan(dst interface{}) error { return nil } -func (cmd *StringStringMapCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.val = make(map[string]string, n/2) - for i := int64(0); i < n; i += 2 { - key, err := rd.ReadString() - if err != nil { - return nil, err - } +func (cmd *MapStringStringCmd) readReply(rd *proto.Reader) error { + n, err := rd.ReadMapLen() + if err != nil { + return err + } - value, err := rd.ReadString() - if err != nil { - return nil, err - } - - cmd.val[key] = value + cmd.val = make(map[string]string, n) + for i := 0; i < n; i++ { + key, err := rd.ReadString() + if err != nil { + return err } - return nil, nil - }) - return err + + value, err := rd.ReadString() + if err != nil { + return err + } + + cmd.val[key] = value + } + return nil } //------------------------------------------------------------------------------ @@ -1068,24 +1118,25 @@ func (cmd *StringIntMapCmd) String() string { } func (cmd *StringIntMapCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.val = make(map[string]int64, n/2) - for i := int64(0); i < n; i += 2 { - key, err := rd.ReadString() - if err != nil { - return nil, err - } + n, err := rd.ReadMapLen() + if err != nil { + return err + } - n, err := rd.ReadIntReply() - if err != nil { - return nil, err - } - - cmd.val[key] = n + cmd.val = make(map[string]int64, n) + for i := 0; i < n; i++ { + key, err := rd.ReadString() + if err != nil { + return err } - return nil, nil - }) - return err + + nn, err := rd.ReadInt() + if err != nil { + return err + } + cmd.val[key] = nn + } + return nil } //------------------------------------------------------------------------------ @@ -1120,18 +1171,20 @@ func (cmd *StringStructMapCmd) String() string { } func (cmd *StringStructMapCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.val = make(map[string]struct{}, n) - for i := int64(0); i < n; i++ { - key, err := rd.ReadString() - if err != nil { - return nil, err - } - cmd.val[key] = struct{}{} + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + + cmd.val = make(map[string]struct{}, n) + for i := 0; i < n; i++ { + key, err := rd.ReadString() + if err != nil { + return err } - return nil, nil - }) - return err + cmd.val[key] = struct{}{} + } + return nil } //------------------------------------------------------------------------------ @@ -1170,8 +1223,7 @@ func (cmd *XMessageSliceCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *XMessageSliceCmd) readReply(rd *proto.Reader) error { - var err error +func (cmd *XMessageSliceCmd) readReply(rd *proto.Reader) (err error) { cmd.val, err = readXMessageSlice(rd) return err } @@ -1183,10 +1235,8 @@ func readXMessageSlice(rd *proto.Reader) ([]XMessage, error) { } msgs := make([]XMessage, n) - for i := 0; i < n; i++ { - var err error - msgs[i], err = readXMessage(rd) - if err != nil { + for i := 0; i < len(msgs); i++ { + if msgs[i], err = readXMessage(rd); err != nil { return nil, err } } @@ -1194,40 +1244,36 @@ func readXMessageSlice(rd *proto.Reader) ([]XMessage, error) { } func readXMessage(rd *proto.Reader) (XMessage, error) { - n, err := rd.ReadArrayLen() - if err != nil { + if err := rd.ReadFixedArrayLen(2); err != nil { return XMessage{}, err } - if n != 2 { - return XMessage{}, fmt.Errorf("got %d, wanted 2", n) - } id, err := rd.ReadString() if err != nil { return XMessage{}, err } - var values map[string]interface{} - - v, err := rd.ReadArrayReply(stringInterfaceMapParser) + v, err := stringInterfaceMapParser(rd) if err != nil { if err != proto.Nil { return XMessage{}, err } - } else { - values = v.(map[string]interface{}) } return XMessage{ ID: id, - Values: values, + Values: v, }, nil } -// stringInterfaceMapParser implements proto.MultiBulkParse. -func stringInterfaceMapParser(rd *proto.Reader, n int64) (interface{}, error) { - m := make(map[string]interface{}, n/2) - for i := int64(0); i < n; i += 2 { +func stringInterfaceMapParser(rd *proto.Reader) (map[string]interface{}, error) { + n, err := rd.ReadMapLen() + if err != nil { + return nil, err + } + + m := make(map[string]interface{}, n) + for i := 0; i < n; i++ { key, err := rd.ReadString() if err != nil { return nil, err @@ -1280,38 +1326,35 @@ func (cmd *XStreamSliceCmd) String() string { } func (cmd *XStreamSliceCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.val = make([]XStream, n) - for i := 0; i < len(cmd.val); i++ { - i := i - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - if n != 2 { - return nil, fmt.Errorf("got %d, wanted 2", n) - } + typ, err := rd.PeekReplyType() + if err != nil { + return err + } - stream, err := rd.ReadString() - if err != nil { - return nil, err - } - - msgs, err := readXMessageSlice(rd) - if err != nil { - return nil, err - } - - cmd.val[i] = XStream{ - Stream: stream, - Messages: msgs, - } - return nil, nil - }) - if err != nil { - return nil, err + var n int + if typ == proto.RespMap { + n, err = rd.ReadMapLen() + } else { + n, err = rd.ReadArrayLen() + } + if err != nil { + return err + } + cmd.val = make([]XStream, n) + for i := 0; i < len(cmd.val); i++ { + if typ != proto.RespMap { + if err = rd.ReadFixedArrayLen(2); err != nil { + return err } } - return nil, nil - }) - return err + if cmd.val[i].Stream, err = rd.ReadString(); err != nil { + return err + } + if cmd.val[i].Messages, err = readXMessageSlice(rd); err != nil { + return err + } + } + return nil } //------------------------------------------------------------------------------ @@ -1352,68 +1395,45 @@ func (cmd *XPendingCmd) String() string { } func (cmd *XPendingCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - if n != 4 { - return nil, fmt.Errorf("got %d, wanted 4", n) + var err error + if err = rd.ReadFixedArrayLen(4); err != nil { + return err + } + cmd.val = &XPending{} + + if cmd.val.Count, err = rd.ReadInt(); err != nil { + return err + } + + if cmd.val.Lower, err = rd.ReadString(); err != nil && err != Nil { + return err + } + + if cmd.val.Higher, err = rd.ReadString(); err != nil && err != Nil { + return err + } + + n, err := rd.ReadArrayLen() + if err != nil && err != Nil { + return err + } + cmd.val.Consumers = make(map[string]int64, n) + for i := 0; i < n; i++ { + if err = rd.ReadFixedArrayLen(2); err != nil { + return err } - count, err := rd.ReadIntReply() + consumerName, err := rd.ReadString() if err != nil { - return nil, err + return err } - - lower, err := rd.ReadString() - if err != nil && err != Nil { - return nil, err + consumerPending, err := rd.ReadInt() + if err != nil { + return err } - - higher, err := rd.ReadString() - if err != nil && err != Nil { - return nil, err - } - - cmd.val = &XPending{ - Count: count, - Lower: lower, - Higher: higher, - } - _, err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - for i := int64(0); i < n; i++ { - _, err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - if n != 2 { - return nil, fmt.Errorf("got %d, wanted 2", n) - } - - consumerName, err := rd.ReadString() - if err != nil { - return nil, err - } - - consumerPending, err := rd.ReadInt() - if err != nil { - return nil, err - } - - if cmd.val.Consumers == nil { - cmd.val.Consumers = make(map[string]int64) - } - cmd.val.Consumers[consumerName] = consumerPending - - return nil, nil - }) - if err != nil { - return nil, err - } - } - return nil, nil - }) - if err != nil && err != Nil { - return nil, err - } - - return nil, nil - }) - return err + cmd.val.Consumers[consumerName] = consumerPending + } + return nil } //------------------------------------------------------------------------------ @@ -1454,49 +1474,37 @@ func (cmd *XPendingExtCmd) String() string { } func (cmd *XPendingExtCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.val = make([]XPendingExt, 0, n) - for i := int64(0); i < n; i++ { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - if n != 4 { - return nil, fmt.Errorf("got %d, wanted 4", n) - } + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + cmd.val = make([]XPendingExt, n) - id, err := rd.ReadString() - if err != nil { - return nil, err - } - - consumer, err := rd.ReadString() - if err != nil && err != Nil { - return nil, err - } - - idle, err := rd.ReadIntReply() - if err != nil && err != Nil { - return nil, err - } - - retryCount, err := rd.ReadIntReply() - if err != nil && err != Nil { - return nil, err - } - - cmd.val = append(cmd.val, XPendingExt{ - ID: id, - Consumer: consumer, - Idle: time.Duration(idle) * time.Millisecond, - RetryCount: retryCount, - }) - return nil, nil - }) - if err != nil { - return nil, err - } + for i := 0; i < len(cmd.val); i++ { + if err = rd.ReadFixedArrayLen(4); err != nil { + return err } - return nil, nil - }) - return err + + if cmd.val[i].ID, err = rd.ReadString(); err != nil { + return err + } + + if cmd.val[i].Consumer, err = rd.ReadString(); err != nil && err != Nil { + return err + } + + idle, err := rd.ReadInt() + if err != nil && err != Nil { + return err + } + cmd.val[i].Idle = time.Duration(idle) * time.Millisecond + + if cmd.val[i].RetryCount, err = rd.ReadInt(); err != nil && err != Nil { + return err + } + } + + return nil } //------------------------------------------------------------------------------ @@ -1540,62 +1548,39 @@ func (cmd *XInfoConsumersCmd) readReply(rd *proto.Reader) error { if err != nil { return err } - cmd.val = make([]XInfoConsumer, n) - for i := 0; i < n; i++ { - cmd.val[i], err = readXConsumerInfo(rd) - if err != nil { + for i := 0; i < len(cmd.val); i++ { + if err = rd.ReadFixedMapLen(3); err != nil { return err } + + var key string + for f := 0; f < 3; f++ { + key, err = rd.ReadString() + if err != nil { + return err + } + + switch key { + case "name": + cmd.val[i].Name, err = rd.ReadString() + case "pending": + cmd.val[i].Pending, err = rd.ReadInt() + case "idle": + cmd.val[i].Idle, err = rd.ReadInt() + default: + return fmt.Errorf("redis: unexpected content %s in XINFO CONSUMERS reply", key) + } + if err != nil { + return err + } + } } return nil } -func readXConsumerInfo(rd *proto.Reader) (XInfoConsumer, error) { - var consumer XInfoConsumer - - n, err := rd.ReadArrayLen() - if err != nil { - return consumer, err - } - if n != 6 { - return consumer, fmt.Errorf("redis: got %d elements in XINFO CONSUMERS reply, wanted 6", n) - } - - for i := 0; i < 3; i++ { - key, err := rd.ReadString() - if err != nil { - return consumer, err - } - - val, err := rd.ReadString() - if err != nil { - return consumer, err - } - - switch key { - case "name": - consumer.Name = val - case "pending": - consumer.Pending, err = strconv.ParseInt(val, 0, 64) - if err != nil { - return consumer, err - } - case "idle": - consumer.Idle, err = strconv.ParseInt(val, 0, 64) - if err != nil { - return consumer, err - } - default: - return consumer, fmt.Errorf("redis: unexpected content %s in XINFO CONSUMERS reply", key) - } - } - - return consumer, nil -} - //------------------------------------------------------------------------------ type XInfoGroupsCmd struct { @@ -1638,64 +1623,41 @@ func (cmd *XInfoGroupsCmd) readReply(rd *proto.Reader) error { if err != nil { return err } - cmd.val = make([]XInfoGroup, n) - for i := 0; i < n; i++ { - cmd.val[i], err = readXGroupInfo(rd) - if err != nil { + for i := 0; i < len(cmd.val); i++ { + if err = rd.ReadFixedMapLen(4); err != nil { return err } + + var key string + for f := 0; f < 4; f++ { + key, err = rd.ReadString() + if err != nil { + return err + } + + switch key { + case "name": + cmd.val[i].Name, err = rd.ReadString() + case "consumers": + cmd.val[i].Consumers, err = rd.ReadInt() + case "pending": + cmd.val[i].Pending, err = rd.ReadInt() + case "last-delivered-id": + cmd.val[i].LastDeliveredID, err = rd.ReadString() + default: + return fmt.Errorf("redis: unexpected content %s in XINFO GROUPS reply", key) + } + if err != nil { + return err + } + } } return nil } -func readXGroupInfo(rd *proto.Reader) (XInfoGroup, error) { - var group XInfoGroup - - n, err := rd.ReadArrayLen() - if err != nil { - return group, err - } - if n != 8 { - return group, fmt.Errorf("redis: got %d elements in XINFO GROUPS reply, wanted 8", n) - } - - for i := 0; i < 4; i++ { - key, err := rd.ReadString() - if err != nil { - return group, err - } - - val, err := rd.ReadString() - if err != nil { - return group, err - } - - switch key { - case "name": - group.Name = val - case "consumers": - group.Consumers, err = strconv.ParseInt(val, 0, 64) - if err != nil { - return group, err - } - case "pending": - group.Pending, err = strconv.ParseInt(val, 0, 64) - if err != nil { - return group, err - } - case "last-delivered-id": - group.LastDeliveredID = val - default: - return group, fmt.Errorf("redis: unexpected content %s in XINFO GROUPS reply", key) - } - } - - return group, nil -} - //------------------------------------------------------------------------------ type XInfoStreamCmd struct { @@ -1737,49 +1699,40 @@ func (cmd *XInfoStreamCmd) String() string { } func (cmd *XInfoStreamCmd) readReply(rd *proto.Reader) error { - v, err := rd.ReadReply(xStreamInfoParser) - if err != nil { + if err := rd.ReadFixedMapLen(7); err != nil { return err } - cmd.val = v.(*XInfoStream) - return nil -} + cmd.val = &XInfoStream{} -func xStreamInfoParser(rd *proto.Reader, n int64) (interface{}, error) { - if n != 14 { - return nil, fmt.Errorf("redis: got %d elements in XINFO STREAM reply,"+ - "wanted 14", n) - } - var info XInfoStream for i := 0; i < 7; i++ { key, err := rd.ReadString() if err != nil { - return nil, err + return err } switch key { case "length": - info.Length, err = rd.ReadIntReply() + cmd.val.Length, err = rd.ReadInt() case "radix-tree-keys": - info.RadixTreeKeys, err = rd.ReadIntReply() + cmd.val.RadixTreeKeys, err = rd.ReadInt() case "radix-tree-nodes": - info.RadixTreeNodes, err = rd.ReadIntReply() + cmd.val.RadixTreeNodes, err = rd.ReadInt() case "groups": - info.Groups, err = rd.ReadIntReply() + cmd.val.Groups, err = rd.ReadInt() case "last-generated-id": - info.LastGeneratedID, err = rd.ReadString() + cmd.val.LastGeneratedID, err = rd.ReadString() case "first-entry": - info.FirstEntry, err = readXMessage(rd) + cmd.val.FirstEntry, err = readXMessage(rd) case "last-entry": - info.LastEntry, err = readXMessage(rd) + cmd.val.LastEntry, err = readXMessage(rd) default: - return nil, fmt.Errorf("redis: unexpected content %s "+ + return fmt.Errorf("redis: unexpected content %s "+ "in XINFO STREAM reply", key) } if err != nil { - return nil, err + return err } } - return &info, nil + return nil } //------------------------------------------------------------------------------ @@ -1813,28 +1766,47 @@ func (cmd *ZSliceCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *ZSliceCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { +func (cmd *ZSliceCmd) readReply(rd *proto.Reader) error { // nolint:dupl + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + + // If the n is 0, can't continue reading. + if n == 0 { + cmd.val = make([]Z, 0) + return nil + } + + typ, err := rd.PeekReplyType() + if err != nil { + return err + } + array := typ == proto.RespArray + + if array { + cmd.val = make([]Z, n) + } else { cmd.val = make([]Z, n/2) - for i := 0; i < len(cmd.val); i++ { - member, err := rd.ReadString() - if err != nil { - return nil, err - } + } - score, err := rd.ReadFloatReply() - if err != nil { - return nil, err - } - - cmd.val[i] = Z{ - Member: member, - Score: score, + for i := 0; i < len(cmd.val); i++ { + if array { + if err = rd.ReadFixedArrayLen(2); err != nil { + return err } } - return nil, nil - }) - return err + + if cmd.val[i].Member, err = rd.ReadString(); err != nil { + return err + } + + if cmd.val[i].Score, err = rd.ReadFloat(); err != nil { + return err + } + } + + return nil } //------------------------------------------------------------------------------ @@ -1868,33 +1840,23 @@ func (cmd *ZWithKeyCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *ZWithKeyCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - if n != 3 { - return nil, fmt.Errorf("got %d elements, expected 3", n) - } +func (cmd *ZWithKeyCmd) readReply(rd *proto.Reader) (err error) { + if err = rd.ReadFixedArrayLen(3); err != nil { + return err + } + cmd.val = &ZWithKey{} - cmd.val = &ZWithKey{} - var err error + if cmd.val.Key, err = rd.ReadString(); err != nil { + return err + } + if cmd.val.Member, err = rd.ReadString(); err != nil { + return err + } + if cmd.val.Score, err = rd.ReadFloat(); err != nil { + return err + } - cmd.val.Key, err = rd.ReadString() - if err != nil { - return nil, err - } - - cmd.val.Member, err = rd.ReadString() - if err != nil { - return nil, err - } - - cmd.val.Score, err = rd.ReadFloatReply() - if err != nil { - return nil, err - } - - return nil, nil - }) - return err + return nil } //------------------------------------------------------------------------------ @@ -1932,9 +1894,29 @@ func (cmd *ScanCmd) String() string { return cmdString(cmd, cmd.page) } -func (cmd *ScanCmd) readReply(rd *proto.Reader) (err error) { - cmd.page, cmd.cursor, err = rd.ReadScanReply() - return err +func (cmd *ScanCmd) readReply(rd *proto.Reader) error { + if err := rd.ReadFixedArrayLen(2); err != nil { + return err + } + + cursor, err := rd.ReadInt() + if err != nil { + return err + } + cmd.cursor = uint64(cursor) + + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + cmd.page = make([]string, n) + + for i := 0; i < len(cmd.page); i++ { + if cmd.page[i], err = rd.ReadString(); err != nil { + return err + } + } + return nil } // Iterator creates a new ScanIterator. @@ -1987,69 +1969,70 @@ func (cmd *ClusterSlotsCmd) String() string { } func (cmd *ClusterSlotsCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.val = make([]ClusterSlot, n) - for i := 0; i < len(cmd.val); i++ { - n, err := rd.ReadArrayLen() + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + cmd.val = make([]ClusterSlot, n) + + for i := 0; i < len(cmd.val); i++ { + n, err = rd.ReadArrayLen() + if err != nil { + return err + } + if n < 2 { + return fmt.Errorf("redis: got %d elements in cluster info, expected at least 2", n) + } + + start, err := rd.ReadInt() + if err != nil { + return err + } + + end, err := rd.ReadInt() + if err != nil { + return err + } + + // subtract start and end. + nodes := make([]ClusterNode, n-2) + for j := 0; j < len(nodes); j++ { + nn, err := rd.ReadArrayLen() if err != nil { - return nil, err + return err } - if n < 2 { - err := fmt.Errorf("redis: got %d elements in cluster info, expected at least 2", n) - return nil, err + if nn != 2 && nn != 3 { + return fmt.Errorf("got %d elements in cluster info address, expected 2 or 3", nn) } - start, err := rd.ReadIntReply() + ip, err := rd.ReadString() if err != nil { - return nil, err + return err } - end, err := rd.ReadIntReply() + port, err := rd.ReadString() if err != nil { - return nil, err + return err } - nodes := make([]ClusterNode, n-2) - for j := 0; j < len(nodes); j++ { - n, err := rd.ReadArrayLen() + nodes[j].Addr = net.JoinHostPort(ip, port) + + if nn == 3 { + id, err := rd.ReadString() if err != nil { - return nil, err + return err } - if n != 2 && n != 3 { - err := fmt.Errorf("got %d elements in cluster info address, expected 2 or 3", n) - return nil, err - } - - ip, err := rd.ReadString() - if err != nil { - return nil, err - } - - port, err := rd.ReadString() - if err != nil { - return nil, err - } - - nodes[j].Addr = net.JoinHostPort(ip, port) - - if n == 3 { - id, err := rd.ReadString() - if err != nil { - return nil, err - } - nodes[j].ID = id - } - } - - cmd.val[i] = ClusterSlot{ - Start: int(start), - End: int(end), - Nodes: nodes, + nodes[j].ID = id } } - return nil, nil - }) - return err + cmd.val[i] = ClusterSlot{ + Start: int(start), + End: int(end), + Nodes: nodes, + } + } + + return nil } //------------------------------------------------------------------------------ @@ -2074,6 +2057,9 @@ type GeoRadiusQuery struct { Sort string Store string StoreDist string + + // WithCoord+WithDist+WithGeoHash + withLen int } type GeoLocationCmd struct { @@ -2104,12 +2090,15 @@ func geoLocationArgs(q *GeoRadiusQuery, args ...interface{}) []interface{} { } if q.WithCoord { args = append(args, "withcoord") + q.withLen++ } if q.WithDist { args = append(args, "withdist") + q.withLen++ } if q.WithGeoHash { args = append(args, "withhash") + q.withLen++ } if q.Count > 0 { args = append(args, "count", q.Count) @@ -2141,82 +2130,55 @@ func (cmd *GeoLocationCmd) String() string { } func (cmd *GeoLocationCmd) readReply(rd *proto.Reader) error { - v, err := rd.ReadArrayReply(newGeoLocationSliceParser(cmd.q)) + n, err := rd.ReadArrayLen() if err != nil { return err } - cmd.locations = v.([]GeoLocation) + cmd.locations = make([]GeoLocation, n) + + for i := 0; i < len(cmd.locations); i++ { + // only name + if cmd.q.withLen == 0 { + if cmd.locations[i].Name, err = rd.ReadString(); err != nil { + return err + } + continue + } + + // +name + if err = rd.ReadFixedArrayLen(cmd.q.withLen + 1); err != nil { + return err + } + + if cmd.locations[i].Name, err = rd.ReadString(); err != nil { + return err + } + if cmd.q.WithDist { + if cmd.locations[i].Dist, err = rd.ReadFloat(); err != nil { + return err + } + } + if cmd.q.WithGeoHash { + if cmd.locations[i].GeoHash, err = rd.ReadInt(); err != nil { + return err + } + } + if cmd.q.WithCoord { + if err = rd.ReadFixedArrayLen(2); err != nil { + return err + } + if cmd.locations[i].Longitude, err = rd.ReadFloat(); err != nil { + return err + } + if cmd.locations[i].Latitude, err = rd.ReadFloat(); err != nil { + return err + } + } + } + return nil } -func newGeoLocationSliceParser(q *GeoRadiusQuery) proto.MultiBulkParse { - return func(rd *proto.Reader, n int64) (interface{}, error) { - locs := make([]GeoLocation, 0, n) - for i := int64(0); i < n; i++ { - v, err := rd.ReadReply(newGeoLocationParser(q)) - if err != nil { - return nil, err - } - switch vv := v.(type) { - case string: - locs = append(locs, GeoLocation{ - Name: vv, - }) - case *GeoLocation: - // TODO: avoid copying - locs = append(locs, *vv) - default: - return nil, fmt.Errorf("got %T, expected string or *GeoLocation", v) - } - } - return locs, nil - } -} - -func newGeoLocationParser(q *GeoRadiusQuery) proto.MultiBulkParse { - return func(rd *proto.Reader, n int64) (interface{}, error) { - var loc GeoLocation - var err error - - loc.Name, err = rd.ReadString() - if err != nil { - return nil, err - } - if q.WithDist { - loc.Dist, err = rd.ReadFloatReply() - if err != nil { - return nil, err - } - } - if q.WithGeoHash { - loc.GeoHash, err = rd.ReadIntReply() - if err != nil { - return nil, err - } - } - if q.WithCoord { - n, err := rd.ReadArrayLen() - if err != nil { - return nil, err - } - if n != 2 { - return nil, fmt.Errorf("got %d coordinates, expected 2", n) - } - - loc.Longitude, err = rd.ReadFloatReply() - if err != nil { - return nil, err - } - loc.Latitude, err = rd.ReadFloatReply() - if err != nil { - return nil, err - } - } - - return &loc, nil - } -} - //------------------------------------------------------------------------------ type GeoPos struct { @@ -2253,38 +2215,38 @@ func (cmd *GeoPosCmd) String() string { } func (cmd *GeoPosCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.val = make([]*GeoPos, n) - for i := 0; i < len(cmd.val); i++ { - i := i - _, err := rd.ReadReply(func(rd *proto.Reader, n int64) (interface{}, error) { - longitude, err := rd.ReadFloatReply() - if err != nil { - return nil, err - } + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + cmd.val = make([]*GeoPos, n) - latitude, err := rd.ReadFloatReply() - if err != nil { - return nil, err - } - - cmd.val[i] = &GeoPos{ - Longitude: longitude, - Latitude: latitude, - } - return nil, nil - }) - if err != nil { - if err == Nil { - cmd.val[i] = nil - continue - } - return nil, err + for i := 0; i < len(cmd.val); i++ { + err = rd.ReadFixedArrayLen(2) + if err != nil { + if err == Nil { + cmd.val[i] = nil + continue } + return err } - return nil, nil - }) - return err + + longitude, err := rd.ReadFloat() + if err != nil { + return err + } + latitude, err := rd.ReadFloat() + if err != nil { + return err + } + + cmd.val[i] = &GeoPos{ + Longitude: longitude, + Latitude: latitude, + } + } + + return nil } //------------------------------------------------------------------------------ @@ -2330,112 +2292,94 @@ func (cmd *CommandsInfoCmd) String() string { } func (cmd *CommandsInfoCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.val = make(map[string]*CommandInfo, n) - for i := int64(0); i < n; i++ { - v, err := rd.ReadReply(commandInfoParser) - if err != nil { - return nil, err - } - vv := v.(*CommandInfo) - cmd.val[vv.Name] = vv - } - return nil, nil - }) - return err -} - -func commandInfoParser(rd *proto.Reader, n int64) (interface{}, error) { const numArgRedis5 = 6 const numArgRedis6 = 7 - switch n { - case numArgRedis5, numArgRedis6: - // continue - default: - return nil, fmt.Errorf("redis: got %d elements in COMMAND reply, wanted 7", n) - } - - var cmd CommandInfo - var err error - - cmd.Name, err = rd.ReadString() + n, err := rd.ReadArrayLen() if err != nil { - return nil, err + return err } + cmd.val = make(map[string]*CommandInfo, n) - arity, err := rd.ReadIntReply() - if err != nil { - return nil, err - } - cmd.Arity = int8(arity) + for i := 0; i < n; i++ { + nn, err := rd.ReadArrayLen() + if err != nil { + return err + } + if nn != numArgRedis5 && nn != numArgRedis6 { + return fmt.Errorf("redis: got %d elements in COMMAND reply, wanted 6/7", nn) + } - _, err = rd.ReadReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.Flags = make([]string, n) - for i := 0; i < len(cmd.Flags); i++ { + cmdInfo := &CommandInfo{} + if cmdInfo.Name, err = rd.ReadString(); err != nil { + return err + } + + arity, err := rd.ReadInt() + if err != nil { + return err + } + cmdInfo.Arity = int8(arity) + + flagLen, err := rd.ReadArrayLen() + if err != nil { + return err + } + cmdInfo.Flags = make([]string, flagLen) + for f := 0; f < len(cmdInfo.Flags); f++ { switch s, err := rd.ReadString(); { case err == Nil: - cmd.Flags[i] = "" + cmdInfo.Flags[f] = "" case err != nil: - return nil, err + return err default: - cmd.Flags[i] = s + if !cmdInfo.ReadOnly && s == "readonly" { + cmdInfo.ReadOnly = true + } + cmdInfo.Flags[f] = s } } - return nil, nil - }) - if err != nil { - return nil, err - } - firstKeyPos, err := rd.ReadIntReply() - if err != nil { - return nil, err - } - cmd.FirstKeyPos = int8(firstKeyPos) - - lastKeyPos, err := rd.ReadIntReply() - if err != nil { - return nil, err - } - cmd.LastKeyPos = int8(lastKeyPos) - - stepCount, err := rd.ReadIntReply() - if err != nil { - return nil, err - } - cmd.StepCount = int8(stepCount) - - for _, flag := range cmd.Flags { - if flag == "readonly" { - cmd.ReadOnly = true - break + firstKeyPos, err := rd.ReadInt() + if err != nil { + return err } - } + cmdInfo.FirstKeyPos = int8(firstKeyPos) - if n == numArgRedis5 { - return &cmd, nil - } + lastKeyPos, err := rd.ReadInt() + if err != nil { + return err + } + cmdInfo.LastKeyPos = int8(lastKeyPos) - _, err = rd.ReadReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.ACLFlags = make([]string, n) - for i := 0; i < len(cmd.ACLFlags); i++ { - switch s, err := rd.ReadString(); { - case err == Nil: - cmd.ACLFlags[i] = "" - case err != nil: - return nil, err - default: - cmd.ACLFlags[i] = s + stepCount, err := rd.ReadInt() + if err != nil { + return err + } + cmdInfo.StepCount = int8(stepCount) + + if nn == numArgRedis6 { + aclFlagLen, err := rd.ReadArrayLen() + if err != nil { + return err + } + cmdInfo.ACLFlags = make([]string, aclFlagLen) + for f := 0; f < len(cmdInfo.ACLFlags); f++ { + switch s, err := rd.ReadString(); { + case err == Nil: + cmdInfo.ACLFlags[f] = "" + case err != nil: + return err + default: + cmdInfo.ACLFlags[f] = s + } } } - return nil, nil - }) - if err != nil { - return nil, err + + cmd.val[cmdInfo.Name] = cmdInfo } - return &cmd, nil + return nil } //------------------------------------------------------------------------------ @@ -2517,75 +2461,185 @@ func (cmd *SlowLogCmd) String() string { } func (cmd *SlowLogCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.val = make([]SlowLog, n) - for i := 0; i < len(cmd.val); i++ { - n, err := rd.ReadArrayLen() + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + cmd.val = make([]SlowLog, n) + + for i := 0; i < len(cmd.val); i++ { + nn, err := rd.ReadArrayLen() + if err != nil { + return err + } + if nn < 4 { + return fmt.Errorf("redis: got %d elements in slowlog get, expected at least 4", nn) + } + + if cmd.val[i].ID, err = rd.ReadInt(); err != nil { + return err + } + + createdAt, err := rd.ReadInt() + if err != nil { + return err + } + cmd.val[i].Time = time.Unix(createdAt, 0) + + costs, err := rd.ReadInt() + if err != nil { + return err + } + cmd.val[i].Duration = time.Duration(costs) * time.Microsecond + + cmdLen, err := rd.ReadArrayLen() + if err != nil { + return err + } + if cmdLen < 1 { + return fmt.Errorf("redis: got %d elements commands reply in slowlog get, expected at least 1", cmdLen) + } + + cmd.val[i].Args = make([]string, cmdLen) + for f := 0; f < len(cmd.val[i].Args); f++ { + cmd.val[i].Args[f], err = rd.ReadString() if err != nil { - return nil, err - } - if n < 4 { - err := fmt.Errorf("redis: got %d elements in slowlog get, expected at least 4", n) - return nil, err - } - - id, err := rd.ReadIntReply() - if err != nil { - return nil, err - } - - createdAt, err := rd.ReadIntReply() - if err != nil { - return nil, err - } - createdAtTime := time.Unix(createdAt, 0) - - costs, err := rd.ReadIntReply() - if err != nil { - return nil, err - } - costsDuration := time.Duration(costs) * time.Microsecond - - cmdLen, err := rd.ReadArrayLen() - if err != nil { - return nil, err - } - if cmdLen < 1 { - err := fmt.Errorf("redis: got %d elements commands reply in slowlog get, expected at least 1", cmdLen) - return nil, err - } - - cmdString := make([]string, cmdLen) - for i := 0; i < cmdLen; i++ { - cmdString[i], err = rd.ReadString() - if err != nil { - return nil, err - } - } - - var address, name string - for i := 4; i < n; i++ { - str, err := rd.ReadString() - if err != nil { - return nil, err - } - if i == 4 { - address = str - } else if i == 5 { - name = str - } - } - - cmd.val[i] = SlowLog{ - ID: id, - Time: createdAtTime, - Duration: costsDuration, - Args: cmdString, - ClientAddr: address, - ClientName: name, + return err } } - return nil, nil - }) - return err + + if nn >= 5 { + if cmd.val[i].ClientAddr, err = rd.ReadString(); err != nil { + return err + } + } + + if nn >= 6 { + if cmd.val[i].ClientName, err = rd.ReadString(); err != nil { + return err + } + } + } + + return nil +} + +//----------------------------------------------------------------------- + +type MapStringInterfaceCmd struct { + baseCmd + + val map[string]interface{} +} + +var _ Cmder = (*MapStringInterfaceCmd)(nil) + +func NewMapStringInterfaceCmd(ctx context.Context, args ...interface{}) *MapStringInterfaceCmd { + return &MapStringInterfaceCmd{ + baseCmd: baseCmd{ + ctx: ctx, + args: args, + }, + } +} + +func (cmd *MapStringInterfaceCmd) Val() map[string]interface{} { + return cmd.val +} + +func (cmd *MapStringInterfaceCmd) Result() (map[string]interface{}, error) { + return cmd.Val(), cmd.Err() +} + +func (cmd *MapStringInterfaceCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *MapStringInterfaceCmd) readReply(rd *proto.Reader) error { + n, err := rd.ReadMapLen() + if err != nil { + return err + } + + cmd.val = make(map[string]interface{}, n) + for i := 0; i < n; i++ { + k, err := rd.ReadString() + if err != nil { + return err + } + v, err := rd.ReadReply() + if err != nil { + if err == Nil { + cmd.val[k] = Nil + continue + } + if err, ok := err.(proto.RedisError); ok { + cmd.val[k] = err + continue + } + return err + } + cmd.val[k] = v + } + return nil +} + +//----------------------------------------------------------------------- + +type MapStringStringSliceCmd struct { + baseCmd + + val []map[string]string +} + +var _ Cmder = (*MapStringStringSliceCmd)(nil) + +func NewMapStringStringSliceCmd(ctx context.Context, args ...interface{}) *MapStringStringSliceCmd { + return &MapStringStringSliceCmd{ + baseCmd: baseCmd{ + ctx: ctx, + args: args, + }, + } +} + +func (cmd *MapStringStringSliceCmd) Val() []map[string]string { + return cmd.val +} + +func (cmd *MapStringStringSliceCmd) Result() ([]map[string]string, error) { + return cmd.Val(), cmd.Err() +} + +func (cmd *MapStringStringSliceCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *MapStringStringSliceCmd) readReply(rd *proto.Reader) error { + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + + cmd.val = make([]map[string]string, n) + for i := 0; i < n; i++ { + nn, err := rd.ReadMapLen() + if err != nil { + return err + } + cmd.val[i] = make(map[string]string, nn) + for f := 0; f < nn; f++ { + k, err := rd.ReadString() + if err != nil { + return err + } + + v, err := rd.ReadString() + if err != nil { + return err + } + cmd.val[i][k] = v + } + } + return nil } diff --git a/commands.go b/commands.go index 3774633..65ca464 100644 --- a/commands.go +++ b/commands.go @@ -158,7 +158,7 @@ type Cmdable interface { HDel(ctx context.Context, key string, fields ...string) *IntCmd HExists(ctx context.Context, key, field string) *BoolCmd HGet(ctx context.Context, key, field string) *StringCmd - HGetAll(ctx context.Context, key string) *StringStringMapCmd + HGetAll(ctx context.Context, key string) *MapStringStringCmd HIncrBy(ctx context.Context, key, field string, incr int64) *IntCmd HIncrByFloat(ctx context.Context, key, field string, incr float64) *FloatCmd HKeys(ctx context.Context, key string) *StringSliceCmd @@ -168,7 +168,8 @@ type Cmdable interface { HMSet(ctx context.Context, key string, values ...interface{}) *BoolCmd HSetNX(ctx context.Context, key, field string, value interface{}) *BoolCmd HVals(ctx context.Context, key string) *StringSliceCmd - HRandField(ctx context.Context, key string, count int, withValues bool) *StringSliceCmd + HRandField(ctx context.Context, key string, count int) *StringSliceCmd + HRandFieldWithValues(ctx context.Context, key string, count int) *KeyValueSliceCmd BLPop(ctx context.Context, timeout time.Duration, keys ...string) *StringSliceCmd BRPop(ctx context.Context, timeout time.Duration, keys ...string) *StringSliceCmd @@ -274,7 +275,8 @@ type Cmdable interface { ZRevRank(ctx context.Context, key, member string) *IntCmd ZScore(ctx context.Context, key, member string) *FloatCmd ZUnionStore(ctx context.Context, dest string, store *ZStore) *IntCmd - ZRandMember(ctx context.Context, key string, count int, withScores bool) *StringSliceCmd + ZRandMember(ctx context.Context, key string, count int) *StringSliceCmd + ZRandMemberWithScores(ctx context.Context, key string, count int) *ZSliceCmd PFAdd(ctx context.Context, key string, els ...interface{}) *IntCmd PFCount(ctx context.Context, keys ...string) *IntCmd @@ -287,7 +289,7 @@ type Cmdable interface { ClientList(ctx context.Context) *StringCmd ClientPause(ctx context.Context, dur time.Duration) *BoolCmd ClientID(ctx context.Context) *IntCmd - ConfigGet(ctx context.Context, parameter string) *SliceCmd + ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd ConfigResetStat(ctx context.Context) *StatusCmd ConfigSet(ctx context.Context, parameter, value string) *StatusCmd ConfigRewrite(ctx context.Context) *StatusCmd @@ -358,6 +360,7 @@ type StatefulCmdable interface { Select(ctx context.Context, index int) *StatusCmd SwapDB(ctx context.Context, index1, index2 int) *StatusCmd ClientSetName(ctx context.Context, name string) *BoolCmd + Hello(ctx context.Context, ver int, username, password, clientName string) *MapStringInterfaceCmd } var ( @@ -413,6 +416,26 @@ func (c statefulCmdable) ClientSetName(ctx context.Context, name string) *BoolCm return cmd } +// Set the resp protocol used. +func (c statefulCmdable) Hello(ctx context.Context, + ver int, username, password, clientName string) *MapStringInterfaceCmd { + args := make([]interface{}, 0, 7) + args = append(args, "hello", ver) + if password != "" { + if username != "" { + args = append(args, "auth", username, password) + } else { + args = append(args, "auth", "default", password) + } + } + if clientName != "" { + args = append(args, "setname", clientName) + } + cmd := NewMapStringInterfaceCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + //------------------------------------------------------------------------------ func (c cmdable) Command(ctx context.Context) *CommandsInfoCmd { @@ -440,7 +463,7 @@ func (c cmdable) Ping(ctx context.Context) *StatusCmd { return cmd } -func (c cmdable) Quit(ctx context.Context) *StatusCmd { +func (c cmdable) Quit(_ context.Context) *StatusCmd { panic("not implemented") } @@ -1138,8 +1161,8 @@ func (c cmdable) HGet(ctx context.Context, key, field string) *StringCmd { return cmd } -func (c cmdable) HGetAll(ctx context.Context, key string) *StringStringMapCmd { - cmd := NewStringStringMapCmd(ctx, "hgetall", key) +func (c cmdable) HGetAll(ctx context.Context, key string) *MapStringStringCmd { + cmd := NewMapStringStringCmd(ctx, "hgetall", key) _ = c(ctx, cmd) return cmd } @@ -1222,16 +1245,15 @@ func (c cmdable) HVals(ctx context.Context, key string) *StringSliceCmd { } // redis-server version >= 6.2.0. -func (c cmdable) HRandField(ctx context.Context, key string, count int, withValues bool) *StringSliceCmd { - args := make([]interface{}, 0, 4) +func (c cmdable) HRandField(ctx context.Context, key string, count int) *StringSliceCmd { + cmd := NewStringSliceCmd(ctx, "hrandfield", key, count) + _ = c(ctx, cmd) + return cmd +} - // Although count=0 is meaningless, redis accepts count=0. - args = append(args, "hrandfield", key, count) - if withValues { - args = append(args, "withvalues") - } - - cmd := NewStringSliceCmd(ctx, args...) +// redis-server version >= 6.2.0. +func (c cmdable) HRandFieldWithValues(ctx context.Context, key string, count int) *KeyValueSliceCmd { + cmd := NewKeyValueSliceCmd(ctx, "hrandfield", key, count, "withvalues") _ = c(ctx, cmd) return cmd } @@ -2316,17 +2338,16 @@ func (c cmdable) ZUnionStore(ctx context.Context, dest string, store *ZStore) *I return cmd } -// redis-server version >= 6.2.0. -func (c cmdable) ZRandMember(ctx context.Context, key string, count int, withScores bool) *StringSliceCmd { - args := make([]interface{}, 0, 4) +// ZRandMember redis-server version >= 6.2.0. +func (c cmdable) ZRandMember(ctx context.Context, key string, count int) *StringSliceCmd { + cmd := NewStringSliceCmd(ctx, "zrandmember", key, count) + _ = c(ctx, cmd) + return cmd +} - // Although count=0 is meaningless, redis accepts count=0. - args = append(args, "zrandmember", key, count) - if withScores { - args = append(args, "withscores") - } - - cmd := NewStringSliceCmd(ctx, args...) +// ZRandMemberWithScores redis-server version >= 6.2.0. +func (c cmdable) ZRandMemberWithScores(ctx context.Context, key string, count int) *ZSliceCmd { + cmd := NewZSliceCmd(ctx, "zrandmember", key, count, "withscores") _ = c(ctx, cmd) return cmd } @@ -2431,8 +2452,8 @@ func (c cmdable) ClientUnblockWithError(ctx context.Context, id int64) *IntCmd { return cmd } -func (c cmdable) ConfigGet(ctx context.Context, parameter string) *SliceCmd { - cmd := NewSliceCmd(ctx, "config", "get", parameter) +func (c cmdable) ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd { + cmd := NewMapStringStringCmd(ctx, "config", "get", parameter) _ = c(ctx, cmd) return cmd } @@ -2553,7 +2574,7 @@ func (c cmdable) SlowLogGet(ctx context.Context, num int64) *SlowLogCmd { return cmd } -func (c cmdable) Sync(ctx context.Context) { +func (c cmdable) Sync(_ context.Context) { panic("not implemented") } diff --git a/commands_test.go b/commands_test.go index e927c58..1449587 100644 --- a/commands_test.go +++ b/commands_test.go @@ -47,6 +47,17 @@ var _ = Describe("Commands", func() { Expect(stats.IdleConns).To(Equal(uint32(1))) }) + It("should hello", func() { + cmds, err := client.Pipelined(ctx, func(pipe redis.Pipeliner) error { + pipe.Hello(ctx, 3, "", "", "") + return nil + }) + Expect(err).NotTo(HaveOccurred()) + m, err := cmds[0].(*redis.MapStringInterfaceCmd).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(m["proto"]).To(Equal(int64(3))) + }) + It("should Echo", func() { pipe := client.Pipeline() echo := pipe.Echo(ctx, "hello") @@ -182,10 +193,11 @@ var _ = Describe("Commands", func() { It("should ConfigSet", func() { configGet := client.ConfigGet(ctx, "maxmemory") Expect(configGet.Err()).NotTo(HaveOccurred()) - Expect(configGet.Val()).To(HaveLen(2)) - Expect(configGet.Val()[0]).To(Equal("maxmemory")) + Expect(configGet.Val()).To(HaveLen(1)) + _, ok := configGet.Val()["maxmemory"] + Expect(ok).To(BeTrue()) - configSet := client.ConfigSet(ctx, "maxmemory", configGet.Val()[1].(string)) + configSet := client.ConfigSet(ctx, "maxmemory", configGet.Val()["maxmemory"]) Expect(configSet.Err()).NotTo(HaveOccurred()) Expect(configSet.Val()).To(Equal("OK")) }) @@ -1839,18 +1851,20 @@ var _ = Describe("Commands", func() { err = client.HSet(ctx, "hash", "key2", "hello2").Err() Expect(err).NotTo(HaveOccurred()) - v := client.HRandField(ctx, "hash", 1, false) + v := client.HRandField(ctx, "hash", 1) Expect(v.Err()).NotTo(HaveOccurred()) Expect(v.Val()).To(Or(Equal([]string{"key1"}), Equal([]string{"key2"}))) - v = client.HRandField(ctx, "hash", 0, false) + v = client.HRandField(ctx, "hash", 0) Expect(v.Err()).NotTo(HaveOccurred()) Expect(v.Val()).To(HaveLen(0)) - var slice []string - err = client.HRandField(ctx, "hash", 1, true).ScanSlice(&slice) + kv, err := client.HRandFieldWithValues(ctx, "hash", 1).Result() Expect(err).NotTo(HaveOccurred()) - Expect(slice).To(Or(Equal([]string{"key1", "hello1"}), Equal([]string{"key2", "hello2"}))) + Expect(kv).To(Or( + Equal([]redis.KeyValue{{Key: "key1", Value: "hello1"}}), + Equal([]redis.KeyValue{{Key: "key2", Value: "hello2"}}), + )) }) }) @@ -3919,18 +3933,20 @@ var _ = Describe("Commands", func() { err = client.ZAdd(ctx, "zset", &redis.Z{Score: 2, Member: "two"}).Err() Expect(err).NotTo(HaveOccurred()) - v := client.ZRandMember(ctx, "zset", 1, false) + v := client.ZRandMember(ctx, "zset", 1) Expect(v.Err()).NotTo(HaveOccurred()) Expect(v.Val()).To(Or(Equal([]string{"one"}), Equal([]string{"two"}))) - v = client.ZRandMember(ctx, "zset", 0, false) + v = client.ZRandMember(ctx, "zset", 0) Expect(v.Err()).NotTo(HaveOccurred()) Expect(v.Val()).To(HaveLen(0)) - var slice []string - err = client.ZRandMember(ctx, "zset", 1, true).ScanSlice(&slice) + kv, err := client.ZRandMemberWithScores(ctx, "zset", 1).Result() Expect(err).NotTo(HaveOccurred()) - Expect(slice).To(Or(Equal([]string{"one", "1"}), Equal([]string{"two", "2"}))) + Expect(kv).To(Or( + Equal([]redis.Z{{Member: "one", Score: 1}}), + Equal([]redis.Z{{Member: "two", Score: 2}}), + )) }) }) @@ -4675,7 +4691,7 @@ var _ = Describe("Commands", func() { old := client.ConfigGet(ctx, key).Val() client.ConfigSet(ctx, key, "0") - defer client.ConfigSet(ctx, key, old[1].(string)) + defer client.ConfigSet(ctx, key, old[key]) err := rdb.Do(ctx, "slowlog", "reset").Err() Expect(err).NotTo(HaveOccurred()) diff --git a/example_test.go b/example_test.go index 7d9f740..73b2f2d 100644 --- a/example_test.go +++ b/example_test.go @@ -276,9 +276,9 @@ func ExampleClient_ScanType() { // Output: found 33 keys } -// ExampleStringStringMapCmd_Scan shows how to scan the results of a map fetch +// ExampleMapStringStringCmd_Scan shows how to scan the results of a map fetch // into a struct. -func ExampleStringStringMapCmd_Scan() { +func ExampleMapStringStringCmd_Scan() { rdb.FlushDB(ctx) err := rdb.HMSet(ctx, "map", "name", "hello", @@ -615,7 +615,7 @@ func ExampleClient_SlowLogGet() { old := rdb.ConfigGet(ctx, key).Val() rdb.ConfigSet(ctx, key, "0") - defer rdb.ConfigSet(ctx, key, old[1].(string)) + defer rdb.ConfigSet(ctx, key, old[key]) if err := rdb.Do(ctx, "slowlog", "reset").Err(); err != nil { panic(err) diff --git a/go.mod b/go.mod index 938768f..1d3dfe1 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/go-redis/redis/v8 -go 1.13 +go 1.14 require ( github.com/cespare/xxhash/v2 v2.1.1 diff --git a/internal/proto/reader.go b/internal/proto/reader.go index 0ab8c9d..410e466 100644 --- a/internal/proto/reader.go +++ b/internal/proto/reader.go @@ -2,20 +2,39 @@ package proto import ( "bufio" + "errors" "fmt" "io" + "math" + "math/big" + "strconv" "github.com/go-redis/redis/v8/internal/util" ) const ( - ErrorReply = '-' - StatusReply = '+' - IntReply = ':' - StringReply = '$' - ArrayReply = '*' + RespStatus = '+' // +\r\n + RespError = '-' // -\r\n + RespString = '$' // $\r\n\r\n + RespInt = ':' // :\r\n + RespNil = '_' // _\r\n + RespFloat = ',' // ,\r\n (golang float) + RespBool = '#' // true: #t\r\n false: #f\r\n + RespBlobError = '!' // !\r\n\r\n + RespVerbatim = '=' // =\r\nFORMAT:\r\n + RespBigInt = '(' // (\r\n + RespArray = '*' // *\r\n... (same as resp2) + RespMap = '%' // %\r\n(key)\r\n(value)\r\n... (golang map) + RespSet = '~' // ~\r\n... (same as Array) + RespAttr = '|' // |\r\n(key)\r\n(value)\r\n... + command reply + RespPush = '>' // >\r\n... (same as Array) ) +// Not used temporarily. +// Redis has not used these two data types for the time being, and will implement them later. +// Streamed = "EOF:" +// StreamedAggregated = '?' + //------------------------------------------------------------------------------ const Nil = RedisError("redis: nil") @@ -26,19 +45,19 @@ func (e RedisError) Error() string { return string(e) } func (RedisError) RedisError() {} +func ParseErrorReply(line []byte) error { + return RedisError(line[1:]) +} + //------------------------------------------------------------------------------ -type MultiBulkParse func(*Reader, int64) (interface{}, error) - type Reader struct { - rd *bufio.Reader - _buf []byte + rd *bufio.Reader } func NewReader(rd io.Reader) *Reader { return &Reader{ - rd: bufio.NewReader(rd), - _buf: make([]byte, 64), + rd: bufio.NewReader(rd), } } @@ -54,14 +73,53 @@ func (r *Reader) Reset(rd io.Reader) { r.rd.Reset(rd) } +// PeekReplyType returns the data type of the next response without advancing the Reader, +// and discard the attribute type. +func (r *Reader) PeekReplyType() (byte, error) { + b, err := r.rd.Peek(1) + if err != nil { + return 0, err + } + if b[0] == RespAttr { + if err = r.DiscardNext(); err != nil { + return 0, err + } + return r.PeekReplyType() + } + return b[0], nil +} + +// ReadLine Return a valid reply, it will check the protocol or redis error, +// and discard the attribute type. func (r *Reader) ReadLine() ([]byte, error) { line, err := r.readLine() if err != nil { return nil, err } - if isNilReply(line) { + switch line[0] { + case RespError: + return nil, ParseErrorReply(line) + case RespNil: + return nil, Nil + case RespBlobError: + var blobErr string + blobErr, err = r.readStringReply(line) + if err == nil { + err = RedisError(blobErr) + } + return nil, err + case RespAttr: + if err = r.Discard(line); err != nil { + return nil, err + } + return r.ReadLine() + } + + // Compatible with RESP2 + if IsNilReply(line) { return nil, Nil } + return line, nil } @@ -92,48 +150,192 @@ func (r *Reader) readLine() ([]byte, error) { return b[:len(b)-2], nil } -func (r *Reader) ReadReply(m MultiBulkParse) (interface{}, error) { +func (r *Reader) ReadReply() (interface{}, error) { line, err := r.ReadLine() if err != nil { return nil, err } switch line[0] { - case ErrorReply: - return nil, ParseErrorReply(line) - case StatusReply: + case RespStatus: return string(line[1:]), nil - case IntReply: + case RespInt: return util.ParseInt(line[1:], 10, 64) - case StringReply: + case RespFloat: + return r.readFloat(line) + case RespBool: + return r.readBool(line) + case RespBigInt: + return r.readBigInt(line) + + case RespString: return r.readStringReply(line) - case ArrayReply: - n, err := parseArrayLen(line) - if err != nil { - return nil, err - } - if m == nil { - err := fmt.Errorf("redis: got %.100q, but multi bulk parser is nil", line) - return nil, err - } - return m(r, n) + case RespVerbatim: + return r.readVerb(line) + + case RespArray, RespSet, RespPush: + return r.readSlice(line) + case RespMap: + return r.readMap(line) } return nil, fmt.Errorf("redis: can't parse %.100q", line) } -func (r *Reader) ReadIntReply() (int64, error) { +func (r *Reader) readFloat(line []byte) (float64, error) { + v := string(line[1:]) + switch string(line[1:]) { + case "inf": + return math.Inf(1), nil + case "-inf": + return math.Inf(-1), nil + } + return strconv.ParseFloat(v, 64) +} + +func (r *Reader) readBool(line []byte) (bool, error) { + switch string(line[1:]) { + case "t": + return true, nil + case "f": + return false, nil + } + return false, fmt.Errorf("redis: can't parse bool reply: %q", line) +} + +func (r *Reader) readBigInt(line []byte) (*big.Int, error) { + i := new(big.Int) + if i, ok := i.SetString(string(line[1:]), 10); ok { + return i, nil + } + return nil, fmt.Errorf("redis: can't parse bigInt reply: %q", line) +} + +func (r *Reader) readStringReply(line []byte) (string, error) { + n, err := replyLen(line) + if err != nil { + return "", err + } + + b := make([]byte, n+2) + _, err = io.ReadFull(r.rd, b) + if err != nil { + return "", err + } + + return util.BytesToString(b[:n]), nil +} + +func (r *Reader) readVerb(line []byte) (string, error) { + s, err := r.readStringReply(line) + if err != nil { + return "", err + } + if len(s) < 4 || s[3] != ':' { + return "", fmt.Errorf("redis: can't parse verbatim string reply: %q", line) + } + return s[4:], nil +} + +func (r *Reader) readSlice(line []byte) ([]interface{}, error) { + n, err := replyLen(line) + if err != nil { + return nil, err + } + + val := make([]interface{}, n) + for i := 0; i < len(val); i++ { + v, err := r.ReadReply() + if err != nil { + if err == Nil { + val[i] = nil + continue + } + if err, ok := err.(RedisError); ok { + val[i] = err + continue + } + return nil, err + } + val[i] = v + } + return val, nil +} + +func (r *Reader) readMap(line []byte) (map[interface{}]interface{}, error) { + n, err := replyLen(line) + if err != nil { + return nil, err + } + m := make(map[interface{}]interface{}, n) + for i := 0; i < n; i++ { + k, err := r.ReadReply() + if err != nil { + return nil, err + } + v, err := r.ReadReply() + if err != nil { + if err == Nil { + m[k] = nil + continue + } + if err, ok := err.(RedisError); ok { + m[k] = err + continue + } + return nil, err + } + m[k] = v + } + return m, nil +} + +// ------------------------------- + +func (r *Reader) ReadInt() (int64, error) { line, err := r.ReadLine() if err != nil { return 0, err } switch line[0] { - case ErrorReply: - return 0, ParseErrorReply(line) - case IntReply: + case RespInt, RespStatus: return util.ParseInt(line[1:], 10, 64) - default: - return 0, fmt.Errorf("redis: can't parse int reply: %.100q", line) + case RespString: + s, err := r.readStringReply(line) + if err != nil { + return 0, err + } + return util.ParseInt([]byte(s), 10, 64) + case RespBigInt: + b, err := r.readBigInt(line) + if err != nil { + return 0, err + } + if !b.IsInt64() { + return 0, fmt.Errorf("bigInt(%s) value out of range", b.String()) + } + return b.Int64(), nil } + return 0, fmt.Errorf("redis: can't parse int reply: %.100q", line) +} + +func (r *Reader) ReadFloat() (float64, error) { + line, err := r.ReadLine() + if err != nil { + return 0, err + } + switch line[0] { + case RespFloat: + return r.readFloat(line) + case RespStatus: + return strconv.ParseFloat(string(line[1:]), 64) + case RespString: + s, err := r.readStringReply(line) + if err != nil { + return 0, err + } + return strconv.ParseFloat(s, 64) + } + return 0, fmt.Errorf("redis: can't parse float reply: %.100q", line) } func (r *Reader) ReadString() (string, error) { @@ -141,191 +343,180 @@ func (r *Reader) ReadString() (string, error) { if err != nil { return "", err } + switch line[0] { - case ErrorReply: - return "", ParseErrorReply(line) - case StringReply: + case RespStatus, RespInt, RespFloat: + return string(line[1:]), nil + case RespString: return r.readStringReply(line) - case StatusReply: - return string(line[1:]), nil - case IntReply: - return string(line[1:]), nil - default: - return "", fmt.Errorf("redis: can't parse reply=%.100q reading string", line) + case RespBool: + b, err := r.readBool(line) + return strconv.FormatBool(b), err + case RespVerbatim: + return r.readVerb(line) + case RespBigInt: + b, err := r.readBigInt(line) + if err != nil { + return "", err + } + return b.String(), nil } + return "", fmt.Errorf("redis: can't parse reply=%.100q reading string", line) } -func (r *Reader) readStringReply(line []byte) (string, error) { - if isNilReply(line) { - return "", Nil - } - - replyLen, err := util.Atoi(line[1:]) +func (r *Reader) ReadBool() (bool, error) { + s, err := r.ReadString() if err != nil { - return "", err + return false, err } - - b := make([]byte, replyLen+2) - _, err = io.ReadFull(r.rd, b) - if err != nil { - return "", err - } - - return util.BytesToString(b[:replyLen]), nil + return s == "OK" || s == "1" || s == "true", nil } -func (r *Reader) ReadArrayReply(m MultiBulkParse) (interface{}, error) { +func (r *Reader) ReadSlice() ([]interface{}, error) { line, err := r.ReadLine() if err != nil { return nil, err } - switch line[0] { - case ErrorReply: - return nil, ParseErrorReply(line) - case ArrayReply: - n, err := parseArrayLen(line) - if err != nil { - return nil, err - } - return m(r, n) - default: - return nil, fmt.Errorf("redis: can't parse array reply: %.100q", line) - } + return r.readSlice(line) } +// ReadFixedArrayLen read fixed array length. +func (r *Reader) ReadFixedArrayLen(fixedLen int) error { + n, err := r.ReadArrayLen() + if err != nil { + return err + } + if n != fixedLen { + return fmt.Errorf("redis: got %d elements of array length, wanted %d", n, fixedLen) + } + return nil +} + +// ReadArrayLen Read and return the length of the array. func (r *Reader) ReadArrayLen() (int, error) { line, err := r.ReadLine() if err != nil { return 0, err } switch line[0] { - case ErrorReply: - return 0, ParseErrorReply(line) - case ArrayReply: - n, err := parseArrayLen(line) + case RespArray, RespSet, RespPush: + return replyLen(line) + default: + return 0, fmt.Errorf("redis: can't parse array(array/set/push) reply: %.100q", line) + } +} + +// ReadFixedMapLen read fixed map length. +func (r *Reader) ReadFixedMapLen(fixedLen int) error { + n, err := r.ReadMapLen() + if err != nil { + return err + } + if n != fixedLen { + return fmt.Errorf("redis: got %d elements of map length, wanted %d", n, fixedLen) + } + return nil +} + +// ReadMapLen read the length of the map type. +// If responding to the array type (RespArray/RespSet/RespPush), +// it must be a multiple of 2 and return n/2. +// Other types will return an error. +func (r *Reader) ReadMapLen() (int, error) { + line, err := r.ReadLine() + if err != nil { + return 0, err + } + switch line[0] { + case RespMap: + return replyLen(line) + case RespArray, RespSet, RespPush: + // Some commands and RESP2 protocol may respond to array types. + n, err := replyLen(line) if err != nil { return 0, err } - return int(n), nil - default: - return 0, fmt.Errorf("redis: can't parse array reply: %.100q", line) - } -} - -func (r *Reader) ReadScanReply() ([]string, uint64, error) { - n, err := r.ReadArrayLen() - if err != nil { - return nil, 0, err - } - if n != 2 { - return nil, 0, fmt.Errorf("redis: got %d elements in scan reply, expected 2", n) - } - - cursor, err := r.ReadUint() - if err != nil { - return nil, 0, err - } - - n, err = r.ReadArrayLen() - if err != nil { - return nil, 0, err - } - - keys := make([]string, n) - - for i := 0; i < n; i++ { - key, err := r.ReadString() - if err != nil { - return nil, 0, err + if n%2 != 0 { + return 0, fmt.Errorf("redis: the length of the array must be a multiple of 2, got: %d", n) } - keys[i] = key + return n / 2, nil + default: + return 0, fmt.Errorf("redis: can't parse map reply: %.100q", line) } - - return keys, cursor, err } -func (r *Reader) ReadInt() (int64, error) { - b, err := r.readTmpBytesReply() - if err != nil { - return 0, err - } - return util.ParseInt(b, 10, 64) -} - -func (r *Reader) ReadUint() (uint64, error) { - b, err := r.readTmpBytesReply() - if err != nil { - return 0, err - } - return util.ParseUint(b, 10, 64) -} - -func (r *Reader) ReadFloatReply() (float64, error) { - b, err := r.readTmpBytesReply() - if err != nil { - return 0, err - } - return util.ParseFloat(b, 64) -} - -func (r *Reader) readTmpBytesReply() ([]byte, error) { - line, err := r.ReadLine() - if err != nil { - return nil, err +// Discard the data represented by line. +func (r *Reader) Discard(line []byte) (err error) { + if len(line) == 0 { + return errors.New("redis: invalid line") } switch line[0] { - case ErrorReply: - return nil, ParseErrorReply(line) - case StringReply: - return r._readTmpBytesReply(line) - case StatusReply: - return line[1:], nil - default: - return nil, fmt.Errorf("redis: can't parse string reply: %.100q", line) + case RespStatus, RespError, RespInt, RespNil, RespFloat, RespBool, RespBigInt: + return nil } + + n, err := replyLen(line) + if err != nil && err != Nil { + return err + } + + switch line[0] { + case RespBlobError, RespString, RespVerbatim: + // +\r\n + _, err = r.rd.Discard(n + 2) + return err + case RespArray, RespSet, RespPush: + for i := 0; i < n; i++ { + if err = r.DiscardNext(); err != nil { + return err + } + } + return nil + case RespMap, RespAttr: + // Read key & value. + for i := 0; i < n*2; i++ { + if err = r.DiscardNext(); err != nil { + return err + } + } + return nil + } + + return fmt.Errorf("redis: can't parse %.100q", line) } -func (r *Reader) _readTmpBytesReply(line []byte) ([]byte, error) { - if isNilReply(line) { - return nil, Nil - } - - replyLen, err := util.Atoi(line[1:]) +// DiscardNext read and discard the data represented by the next line. +func (r *Reader) DiscardNext() error { + line, err := r.readLine() if err != nil { - return nil, err + return err } + return r.Discard(line) +} - buf := r.buf(replyLen + 2) - _, err = io.ReadFull(r.rd, buf) +func replyLen(line []byte) (n int, err error) { + n, err = util.Atoi(line[1:]) if err != nil { - return nil, err + return 0, err } - return buf[:replyLen], nil -} - -func (r *Reader) buf(n int) []byte { - if n <= cap(r._buf) { - return r._buf[:n] + if n < -1 { + return 0, fmt.Errorf("redis: invalid reply: %q", line) } - d := n - cap(r._buf) - r._buf = append(r._buf, make([]byte, d)...) - return r._buf -} -func isNilReply(b []byte) bool { - return len(b) == 3 && - (b[0] == StringReply || b[0] == ArrayReply) && - b[1] == '-' && b[2] == '1' -} - -func ParseErrorReply(line []byte) error { - return RedisError(string(line[1:])) -} - -func parseArrayLen(line []byte) (int64, error) { - if isNilReply(line) { - return 0, Nil + switch line[0] { + case RespString, RespVerbatim, RespBlobError, + RespArray, RespSet, RespPush, RespMap, RespAttr: + if n == -1 { + return 0, Nil + } } - return util.ParseInt(line[1:], 10, 64) + return n, nil +} + +// IsNilReply detect redis.Nil of RESP2. +func IsNilReply(line []byte) bool { + return len(line) == 3 && + (line[0] == RespString || line[0] == RespArray) && + line[1] == '-' && line[2] == '1' } diff --git a/internal/proto/reader_test.go b/internal/proto/reader_test.go index b8c99dd..9881047 100644 --- a/internal/proto/reader_test.go +++ b/internal/proto/reader_test.go @@ -9,23 +9,63 @@ import ( ) func BenchmarkReader_ParseReply_Status(b *testing.B) { - benchmarkParseReply(b, "+OK\r\n", nil, false) + benchmarkParseReply(b, "+OK\r\n", false) } func BenchmarkReader_ParseReply_Int(b *testing.B) { - benchmarkParseReply(b, ":1\r\n", nil, false) + benchmarkParseReply(b, ":1\r\n", false) +} + +func BenchmarkReader_ParseReply_Float(b *testing.B) { + benchmarkParseReply(b, ",123.456\r\n", false) +} + +func BenchmarkReader_ParseReply_Bool(b *testing.B) { + benchmarkParseReply(b, "#t\r\n", false) +} + +func BenchmarkReader_ParseReply_BigInt(b *testing.B) { + benchmarkParseReply(b, "(3492890328409238509324850943850943825024385\r\n", false) } func BenchmarkReader_ParseReply_Error(b *testing.B) { - benchmarkParseReply(b, "-Error message\r\n", nil, true) + benchmarkParseReply(b, "-Error message\r\n", true) +} + +func BenchmarkReader_ParseReply_Nil(b *testing.B) { + benchmarkParseReply(b, "_\r\n", true) +} + +func BenchmarkReader_ParseReply_BlobError(b *testing.B) { + benchmarkParseReply(b, "!21\r\nSYNTAX invalid syntax", true) } func BenchmarkReader_ParseReply_String(b *testing.B) { - benchmarkParseReply(b, "$5\r\nhello\r\n", nil, false) + benchmarkParseReply(b, "$5\r\nhello\r\n", false) +} + +func BenchmarkReader_ParseReply_Verb(b *testing.B) { + benchmarkParseReply(b, "$9\r\ntxt:hello\r\n", false) } func BenchmarkReader_ParseReply_Slice(b *testing.B) { - benchmarkParseReply(b, "*2\r\n$5\r\nhello\r\n$5\r\nworld\r\n", multiBulkParse, false) + benchmarkParseReply(b, "*2\r\n$5\r\nhello\r\n$5\r\nworld\r\n", false) +} + +func BenchmarkReader_ParseReply_Set(b *testing.B) { + benchmarkParseReply(b, "~2\r\n$5\r\nhello\r\n$5\r\nworld\r\n", false) +} + +func BenchmarkReader_ParseReply_Push(b *testing.B) { + benchmarkParseReply(b, ">2\r\n$5\r\nhello\r\n$5\r\nworld\r\n", false) +} + +func BenchmarkReader_ParseReply_Map(b *testing.B) { + benchmarkParseReply(b, "%2\r\n$5\r\nhello\r\n$5\r\nworld\r\n+key\r\n+value\r\n", false) +} + +func BenchmarkReader_ParseReply_Attr(b *testing.B) { + benchmarkParseReply(b, "%1\r\n+key\r\n+value\r\n+hello\r\n", false) } func TestReader_ReadLine(t *testing.T) { @@ -43,7 +83,7 @@ func TestReader_ReadLine(t *testing.T) { } } -func benchmarkParseReply(b *testing.B, reply string, m proto.MultiBulkParse, wanterr bool) { +func benchmarkParseReply(b *testing.B, reply string, wanterr bool) { buf := new(bytes.Buffer) for i := 0; i < b.N; i++ { buf.WriteString(reply) @@ -52,21 +92,9 @@ func benchmarkParseReply(b *testing.B, reply string, m proto.MultiBulkParse, wan b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := p.ReadReply(m) + _, err := p.ReadReply() if !wanterr && err != nil { b.Fatal(err) } } } - -func multiBulkParse(p *proto.Reader, n int64) (interface{}, error) { - vv := make([]interface{}, 0, n) - for i := int64(0); i < n; i++ { - v, err := p.ReadReply(multiBulkParse) - if err != nil { - return nil, err - } - vv = append(vv, v) - } - return vv, nil -} diff --git a/internal/proto/writer.go b/internal/proto/writer.go index 81b09b8..72b3044 100644 --- a/internal/proto/writer.go +++ b/internal/proto/writer.go @@ -34,7 +34,7 @@ func NewWriter(wr writer) *Writer { } func (w *Writer) WriteArgs(args []interface{}) error { - if err := w.WriteByte(ArrayReply); err != nil { + if err := w.WriteByte(RespArray); err != nil { return err } @@ -111,7 +111,7 @@ func (w *Writer) WriteArg(v interface{}) error { } func (w *Writer) bytes(b []byte) error { - if err := w.WriteByte(StringReply); err != nil { + if err := w.WriteByte(RespString); err != nil { return err } diff --git a/pubsub_test.go b/pubsub_test.go index d32d5e0..b9633b2 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -1,7 +1,6 @@ package redis_test import ( - "context" "io" "net" "sync" @@ -15,16 +14,11 @@ import ( var _ = Describe("PubSub", func() { var client *redis.Client - var clientID int64 BeforeEach(func() { opt := redisOptions() opt.MinIdleConns = 0 opt.MaxConnAge = 0 - opt.OnConnect = func(ctx context.Context, cn *redis.Conn) (err error) { - clientID, err = cn.ClientID(ctx).Result() - return err - } client = redis.NewClient(opt) Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred()) }) @@ -421,30 +415,6 @@ var _ = Describe("PubSub", func() { Expect(msg.Payload).To(Equal(string(bigVal))) }) - It("handles message payload slice with server-assisted client-size caching", func() { - pubsub := client.Subscribe(ctx, "__redis__:invalidate") - defer pubsub.Close() - - client2 := redis.NewClient(redisOptions()) - defer client2.Close() - - err := client2.Do(ctx, "CLIENT", "TRACKING", "on", "REDIRECT", clientID).Err() - Expect(err).NotTo(HaveOccurred()) - - err = client2.Do(ctx, "GET", "mykey").Err() - Expect(err).To(Equal(redis.Nil)) - - err = client2.Do(ctx, "SET", "mykey", "myvalue").Err() - Expect(err).NotTo(HaveOccurred()) - - ch := pubsub.Channel() - - var msg *redis.Message - Eventually(ch).Should(Receive(&msg)) - Expect(msg.Channel).To(Equal("__redis__:invalidate")) - Expect(msg.PayloadSlice).To(Equal([]string{"mykey"})) - }) - It("supports concurrent Ping and Receive", func() { const N = 100 diff --git a/redis.go b/redis.go index 7995c43..8bf4403 100644 --- a/redis.go +++ b/redis.go @@ -230,21 +230,21 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { } cn.Inited = true - if c.opt.Password == "" && - c.opt.DB == 0 && - !c.opt.readOnly && - c.opt.OnConnect == nil { - return nil - } - ctx, span := internal.StartSpan(ctx, "redis.init_conn") defer span.End() connPool := pool.NewSingleConnPool(c.connPool, cn) conn := newConn(ctx, c.opt, connPool) + var auth bool + + // The low version of redis-server does not support the hello command. + if conn.Hello(ctx, 3, c.opt.Username, c.opt.Password, "").Err() == nil { + auth = true + } + _, err := conn.Pipelined(ctx, func(pipe Pipeliner) error { - if c.opt.Password != "" { + if !auth && c.opt.Password != "" { if c.opt.Username != "" { pipe.AuthACL(ctx, c.opt.Username, c.opt.Password) } else { @@ -542,14 +542,8 @@ func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) return err } - switch line[0] { - case proto.ErrorReply: - return proto.ParseErrorReply(line) - case proto.ArrayReply: - // ok - default: - err := fmt.Errorf("redis: expected '*', but got line %q", line) - return err + if line[0] != proto.RespArray { + return fmt.Errorf("redis: expected '*', but got line %q", line) } return nil diff --git a/result.go b/result.go index 24cfd49..5043bf9 100644 --- a/result.go +++ b/result.go @@ -83,8 +83,8 @@ func NewBoolSliceResult(val []bool, err error) *BoolSliceCmd { } // NewStringStringMapResult returns a StringStringMapCmd initialised with val and err for testing. -func NewStringStringMapResult(val map[string]string, err error) *StringStringMapCmd { - var cmd StringStringMapCmd +func NewStringStringMapResult(val map[string]string, err error) *MapStringStringCmd { + var cmd MapStringStringCmd cmd.val = val cmd.SetErr(err) return &cmd diff --git a/ring_test.go b/ring_test.go index 2189cd6..4a434a5 100644 --- a/ring_test.go +++ b/ring_test.go @@ -177,6 +177,7 @@ var _ = Describe("Redis Ring", func() { It("can be initialized with a new client callback", func() { opts := redisRingOptions() opts.NewClient = func(name string, opt *redis.Options) *redis.Client { + opt.Username = "username1" opt.Password = "password1" return redis.NewClient(opt) } @@ -184,7 +185,7 @@ var _ = Describe("Redis Ring", func() { err := ring.Ping(ctx).Err() Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("ERR AUTH")) + Expect(err.Error()).To(ContainSubstring("WRONGPASS")) }) }) diff --git a/sentinel.go b/sentinel.go index ca2e088..15f3366 100644 --- a/sentinel.go +++ b/sentinel.go @@ -322,8 +322,8 @@ func (c *SentinelClient) GetMasterAddrByName(ctx context.Context, name string) * return cmd } -func (c *SentinelClient) Sentinels(ctx context.Context, name string) *SliceCmd { - cmd := NewSliceCmd(ctx, "sentinel", "sentinels", name) +func (c *SentinelClient) Sentinels(ctx context.Context, name string) *MapStringStringSliceCmd { + cmd := NewMapStringStringSliceCmd(ctx, "sentinel", "sentinels", name) _ = c.Process(ctx, cmd) return cmd } @@ -355,8 +355,8 @@ func (c *SentinelClient) FlushConfig(ctx context.Context) *StatusCmd { } // Master shows the state and info of the specified master. -func (c *SentinelClient) Master(ctx context.Context, name string) *StringStringMapCmd { - cmd := NewStringStringMapCmd(ctx, "sentinel", "master", name) +func (c *SentinelClient) Master(ctx context.Context, name string) *MapStringStringCmd { + cmd := NewMapStringStringCmd(ctx, "sentinel", "master", name) _ = c.Process(ctx, cmd) return cmd } @@ -369,8 +369,8 @@ func (c *SentinelClient) Masters(ctx context.Context) *SliceCmd { } // Slaves shows a list of slaves for the specified master and their state. -func (c *SentinelClient) Slaves(ctx context.Context, name string) *SliceCmd { - cmd := NewSliceCmd(ctx, "sentinel", "slaves", name) +func (c *SentinelClient) Slaves(ctx context.Context, name string) *MapStringStringSliceCmd { + cmd := NewMapStringStringSliceCmd(ctx, "sentinel", "slaves", name) _ = c.Process(ctx, cmd) return cmd } @@ -588,40 +588,24 @@ func (c *sentinelFailover) getSlaveAddrs(ctx context.Context, sentinel *Sentinel return parseSlaveAddrs(addrs, false) } -func parseSlaveAddrs(addrs []interface{}, keepDisconnected bool) []string { +func parseSlaveAddrs(addrs []map[string]string, keepDisconnected bool) []string { nodes := make([]string, 0, len(addrs)) for _, node := range addrs { - ip := "" - port := "" - flags := []string{} - lastkey := "" isDown := false - - for _, key := range node.([]interface{}) { - switch lastkey { - case "ip": - ip = key.(string) - case "port": - port = key.(string) - case "flags": - flags = strings.Split(key.(string), ",") - } - lastkey = key.(string) - } - - for _, flag := range flags { - switch flag { - case "s_down", "o_down": - isDown = true - case "disconnected": - if !keepDisconnected { + if flags, ok := node["flags"]; ok { + for _, flag := range strings.Split(flags, ",") { + switch flag { + case "s_down", "o_down": isDown = true + case "disconnected": + if !keepDisconnected { + isDown = true + } } } } - - if !isDown { - nodes = append(nodes, net.JoinHostPort(ip, port)) + if !isDown && node["ip"] != "" && node["port"] != "" { + nodes = append(nodes, net.JoinHostPort(node["ip"], node["port"])) } } @@ -670,16 +654,13 @@ func (c *sentinelFailover) discoverSentinels(ctx context.Context) { return } for _, sentinel := range sentinels { - vals := sentinel.([]interface{}) - var ip, port string - for i := 0; i < len(vals); i += 2 { - key := vals[i].(string) - switch key { - case "ip": - ip = vals[i+1].(string) - case "port": - port = vals[i+1].(string) - } + ip, ok := sentinel["ip"] + if !ok { + continue + } + port, ok := sentinel["port"] + if !ok { + continue } if ip != "" && port != "" { sentinelAddr := net.JoinHostPort(ip, port) diff --git a/sentinel_test.go b/sentinel_test.go index 7b4aabd..f5cfa3d 100644 --- a/sentinel_test.go +++ b/sentinel_test.go @@ -185,7 +185,8 @@ var _ = Describe("NewFailoverClusterClient", func() { } // Create subscription. - ch := client.Subscribe(ctx, "foo").Channel() + sub := client.Subscribe(ctx, "foo") + ch := sub.Channel() // Kill master. err = master.Shutdown(ctx).Err() @@ -207,6 +208,7 @@ var _ = Describe("NewFailoverClusterClient", func() { }, "15s", "100ms").Should(Receive(&msg)) Expect(msg.Channel).To(Equal("foo")) Expect(msg.Payload).To(Equal("hello")) + Expect(sub.Close()).NotTo(HaveOccurred()) _, err = startRedis(masterPort) Expect(err).NotTo(HaveOccurred())