diff --git a/cluster.go b/cluster.go index ca8ca5b..6b05f1a 100644 --- a/cluster.go +++ b/cluster.go @@ -1333,8 +1333,8 @@ func (c *ClusterClient) remapCmds(cmds []Cmder, failedCmds map[*clusterNode][]Cm func (c *ClusterClient) pipelineProcessCmds( node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds map[*clusterNode][]Cmder, ) error { - err := cn.WithWriter(c.opt.WriteTimeout, func(wb *proto.WriteBuffer) error { - return writeCmd(wb, cmds...) + err := cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error { + return writeCmd(wr, cmds...) }) if err != nil { setCmdsErr(cmds, err) @@ -1342,14 +1342,14 @@ func (c *ClusterClient) pipelineProcessCmds( return err } - err = cn.WithReader(c.opt.ReadTimeout, func(rd proto.Reader) error { + err = cn.WithReader(c.opt.ReadTimeout, func(rd *proto.Reader) error { return c.pipelineReadCmds(rd, cmds, failedCmds) }) return err } func (c *ClusterClient) pipelineReadCmds( - rd proto.Reader, cmds []Cmder, failedCmds map[*clusterNode][]Cmder, + rd *proto.Reader, cmds []Cmder, failedCmds map[*clusterNode][]Cmder, ) error { for _, cmd := range cmds { err := cmd.readReply(rd) @@ -1476,8 +1476,8 @@ func (c *ClusterClient) mapCmdsBySlot(cmds []Cmder) map[int][]Cmder { func (c *ClusterClient) txPipelineProcessCmds( node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds map[*clusterNode][]Cmder, ) error { - err := cn.WithWriter(c.opt.WriteTimeout, func(wb *proto.WriteBuffer) error { - return txPipelineWriteMulti(wb, cmds) + err := cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error { + return txPipelineWriteMulti(wr, cmds) }) if err != nil { setCmdsErr(cmds, err) @@ -1485,7 +1485,7 @@ func (c *ClusterClient) txPipelineProcessCmds( return err } - err = cn.WithReader(c.opt.ReadTimeout, func(rd proto.Reader) error { + err = cn.WithReader(c.opt.ReadTimeout, func(rd *proto.Reader) error { err := c.txPipelineReadQueued(rd, cmds, failedCmds) if err != nil { setCmdsErr(cmds, err) @@ -1497,7 +1497,7 @@ func (c *ClusterClient) txPipelineProcessCmds( } func (c *ClusterClient) txPipelineReadQueued( - rd proto.Reader, cmds []Cmder, failedCmds map[*clusterNode][]Cmder, + rd *proto.Reader, cmds []Cmder, failedCmds map[*clusterNode][]Cmder, ) error { // Parse queued replies. var statusCmd StatusCmd diff --git a/cluster_test.go b/cluster_test.go index a37f056..f0b3d74 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -985,7 +985,7 @@ func newClusterScenario() *clusterScenario { } } -func BenchmarkRedisClusterPing(b *testing.B) { +func BenchmarkClusterPing(b *testing.B) { if testing.Short() { b.Skip("skipping in short mode") } @@ -1011,7 +1011,7 @@ func BenchmarkRedisClusterPing(b *testing.B) { }) } -func BenchmarkRedisClusterSetString(b *testing.B) { +func BenchmarkClusterSetString(b *testing.B) { if testing.Short() { b.Skip("skipping in short mode") } @@ -1039,7 +1039,7 @@ func BenchmarkRedisClusterSetString(b *testing.B) { }) } -func BenchmarkRedisClusterReloadState(b *testing.B) { +func BenchmarkClusterReloadState(b *testing.B) { if testing.Short() { b.Skip("skipping in short mode") } diff --git a/command.go b/command.go index f20d5ea..44a4987 100644 --- a/command.go +++ b/command.go @@ -16,7 +16,7 @@ type Cmder interface { Args() []interface{} stringArg(int) string - readReply(rd proto.Reader) error + readReply(rd *proto.Reader) error setErr(error) readTimeout() *time.Duration @@ -41,9 +41,9 @@ func cmdsFirstErr(cmds []Cmder) error { return nil } -func writeCmd(wb *proto.WriteBuffer, cmds ...Cmder) error { +func writeCmd(wr *proto.Writer, cmds ...Cmder) error { for _, cmd := range cmds { - err := wb.Append(cmd.Args()) + err := wr.WriteArgs(cmd.Args()) if err != nil { return err } @@ -233,13 +233,13 @@ func (cmd *Cmd) Bool() (bool, error) { } } -func (cmd *Cmd) readReply(rd proto.Reader) error { +func (cmd *Cmd) readReply(rd *proto.Reader) error { cmd.val, cmd.err = rd.ReadReply(sliceParser) return cmd.err } // Implements proto.MultiBulkParse -func sliceParser(rd proto.Reader, n int64) (interface{}, error) { +func sliceParser(rd *proto.Reader, n int64) (interface{}, error) { vals := make([]interface{}, 0, n) for i := int64(0); i < n; i++ { v, err := rd.ReadReply(sliceParser) @@ -293,7 +293,7 @@ func (cmd *SliceCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *SliceCmd) readReply(rd proto.Reader) error { +func (cmd *SliceCmd) readReply(rd *proto.Reader) error { var v interface{} v, cmd.err = rd.ReadArrayReply(sliceParser) if cmd.err != nil { @@ -331,7 +331,7 @@ func (cmd *StatusCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *StatusCmd) readReply(rd proto.Reader) error { +func (cmd *StatusCmd) readReply(rd *proto.Reader) error { cmd.val, cmd.err = rd.ReadString() return cmd.err } @@ -364,7 +364,7 @@ func (cmd *IntCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *IntCmd) readReply(rd proto.Reader) error { +func (cmd *IntCmd) readReply(rd *proto.Reader) error { cmd.val, cmd.err = rd.ReadIntReply() return cmd.err } @@ -399,7 +399,7 @@ func (cmd *DurationCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *DurationCmd) readReply(rd proto.Reader) error { +func (cmd *DurationCmd) readReply(rd *proto.Reader) error { var n int64 n, cmd.err = rd.ReadIntReply() if cmd.err != nil { @@ -437,7 +437,7 @@ func (cmd *TimeCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *TimeCmd) readReply(rd proto.Reader) error { +func (cmd *TimeCmd) readReply(rd *proto.Reader) error { var v interface{} v, cmd.err = rd.ReadArrayReply(timeParser) if cmd.err != nil { @@ -448,7 +448,7 @@ func (cmd *TimeCmd) readReply(rd proto.Reader) error { } // Implements proto.MultiBulkParse -func timeParser(rd proto.Reader, n int64) (interface{}, error) { +func timeParser(rd *proto.Reader, n int64) (interface{}, error) { if n != 2 { return nil, fmt.Errorf("got %d elements, expected 2", n) } @@ -494,7 +494,7 @@ func (cmd *BoolCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *BoolCmd) readReply(rd proto.Reader) error { +func (cmd *BoolCmd) readReply(rd *proto.Reader) error { var v interface{} v, cmd.err = rd.ReadReply(nil) // `SET key value NX` returns nil when key already exists. But @@ -581,7 +581,7 @@ func (cmd *StringCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *StringCmd) readReply(rd proto.Reader) error { +func (cmd *StringCmd) readReply(rd *proto.Reader) error { cmd.val, cmd.err = rd.ReadString() return cmd.err } @@ -614,7 +614,7 @@ func (cmd *FloatCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *FloatCmd) readReply(rd proto.Reader) error { +func (cmd *FloatCmd) readReply(rd *proto.Reader) error { cmd.val, cmd.err = rd.ReadFloatReply() return cmd.err } @@ -651,7 +651,7 @@ func (cmd *StringSliceCmd) ScanSlice(container interface{}) error { return proto.ScanSlice(cmd.Val(), container) } -func (cmd *StringSliceCmd) readReply(rd proto.Reader) error { +func (cmd *StringSliceCmd) readReply(rd *proto.Reader) error { var v interface{} v, cmd.err = rd.ReadArrayReply(stringSliceParser) if cmd.err != nil { @@ -662,7 +662,7 @@ func (cmd *StringSliceCmd) readReply(rd proto.Reader) error { } // Implements proto.MultiBulkParse -func stringSliceParser(rd proto.Reader, n int64) (interface{}, error) { +func stringSliceParser(rd *proto.Reader, n int64) (interface{}, error) { ss := make([]string, 0, n) for i := int64(0); i < n; i++ { s, err := rd.ReadString() @@ -705,7 +705,7 @@ func (cmd *BoolSliceCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *BoolSliceCmd) readReply(rd proto.Reader) error { +func (cmd *BoolSliceCmd) readReply(rd *proto.Reader) error { var v interface{} v, cmd.err = rd.ReadArrayReply(boolSliceParser) if cmd.err != nil { @@ -716,7 +716,7 @@ func (cmd *BoolSliceCmd) readReply(rd proto.Reader) error { } // Implements proto.MultiBulkParse -func boolSliceParser(rd proto.Reader, n int64) (interface{}, error) { +func boolSliceParser(rd *proto.Reader, n int64) (interface{}, error) { bools := make([]bool, 0, n) for i := int64(0); i < n; i++ { n, err := rd.ReadIntReply() @@ -756,7 +756,7 @@ func (cmd *StringStringMapCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *StringStringMapCmd) readReply(rd proto.Reader) error { +func (cmd *StringStringMapCmd) readReply(rd *proto.Reader) error { var v interface{} v, cmd.err = rd.ReadArrayReply(stringStringMapParser) if cmd.err != nil { @@ -767,7 +767,7 @@ func (cmd *StringStringMapCmd) readReply(rd proto.Reader) error { } // Implements proto.MultiBulkParse -func stringStringMapParser(rd proto.Reader, n int64) (interface{}, error) { +func stringStringMapParser(rd *proto.Reader, n int64) (interface{}, error) { m := make(map[string]string, n/2) for i := int64(0); i < n; i += 2 { key, err := rd.ReadString() @@ -813,7 +813,7 @@ func (cmd *StringIntMapCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *StringIntMapCmd) readReply(rd proto.Reader) error { +func (cmd *StringIntMapCmd) readReply(rd *proto.Reader) error { var v interface{} v, cmd.err = rd.ReadArrayReply(stringIntMapParser) if cmd.err != nil { @@ -824,7 +824,7 @@ func (cmd *StringIntMapCmd) readReply(rd proto.Reader) error { } // Implements proto.MultiBulkParse -func stringIntMapParser(rd proto.Reader, n int64) (interface{}, error) { +func stringIntMapParser(rd *proto.Reader, n int64) (interface{}, error) { m := make(map[string]int64, n/2) for i := int64(0); i < n; i += 2 { key, err := rd.ReadString() @@ -870,7 +870,7 @@ func (cmd *StringStructMapCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *StringStructMapCmd) readReply(rd proto.Reader) error { +func (cmd *StringStructMapCmd) readReply(rd *proto.Reader) error { var v interface{} v, cmd.err = rd.ReadArrayReply(stringStructMapParser) if cmd.err != nil { @@ -881,7 +881,7 @@ func (cmd *StringStructMapCmd) readReply(rd proto.Reader) error { } // Implements proto.MultiBulkParse -func stringStructMapParser(rd proto.Reader, n int64) (interface{}, error) { +func stringStructMapParser(rd *proto.Reader, n int64) (interface{}, error) { m := make(map[string]struct{}, n) for i := int64(0); i < n; i++ { key, err := rd.ReadString() @@ -927,7 +927,7 @@ func (cmd *XMessageSliceCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *XMessageSliceCmd) readReply(rd proto.Reader) error { +func (cmd *XMessageSliceCmd) readReply(rd *proto.Reader) error { var v interface{} v, cmd.err = rd.ReadArrayReply(xMessageSliceParser) if cmd.err != nil { @@ -938,10 +938,10 @@ func (cmd *XMessageSliceCmd) readReply(rd proto.Reader) error { } // Implements proto.MultiBulkParse -func xMessageSliceParser(rd proto.Reader, n int64) (interface{}, error) { +func xMessageSliceParser(rd *proto.Reader, n int64) (interface{}, error) { msgs := make([]XMessage, 0, n) for i := int64(0); i < n; i++ { - _, err := rd.ReadArrayReply(func(rd proto.Reader, n int64) (interface{}, error) { + _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { id, err := rd.ReadString() if err != nil { return nil, err @@ -966,7 +966,7 @@ func xMessageSliceParser(rd proto.Reader, n int64) (interface{}, error) { } // Implements proto.MultiBulkParse -func stringInterfaceMapParser(rd proto.Reader, n int64) (interface{}, error) { +func stringInterfaceMapParser(rd *proto.Reader, n int64) (interface{}, error) { m := make(map[string]interface{}, n/2) for i := int64(0); i < n; i += 2 { key, err := rd.ReadString() @@ -1017,7 +1017,7 @@ func (cmd *XStreamSliceCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *XStreamSliceCmd) readReply(rd proto.Reader) error { +func (cmd *XStreamSliceCmd) readReply(rd *proto.Reader) error { var v interface{} v, cmd.err = rd.ReadArrayReply(xStreamSliceParser) if cmd.err != nil { @@ -1028,10 +1028,10 @@ func (cmd *XStreamSliceCmd) readReply(rd proto.Reader) error { } // Implements proto.MultiBulkParse -func xStreamSliceParser(rd proto.Reader, n int64) (interface{}, error) { +func xStreamSliceParser(rd *proto.Reader, n int64) (interface{}, error) { ret := make([]XStream, 0, n) for i := int64(0); i < n; i++ { - _, err := rd.ReadArrayReply(func(rd proto.Reader, n int64) (interface{}, error) { + _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { if n != 2 { return nil, fmt.Errorf("got %d, wanted 2", n) } @@ -1093,7 +1093,7 @@ func (cmd *XPendingCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *XPendingCmd) readReply(rd proto.Reader) error { +func (cmd *XPendingCmd) readReply(rd *proto.Reader) error { var info interface{} info, cmd.err = rd.ReadArrayReply(xPendingParser) if cmd.err != nil { @@ -1103,7 +1103,7 @@ func (cmd *XPendingCmd) readReply(rd proto.Reader) error { return nil } -func xPendingParser(rd proto.Reader, n int64) (interface{}, error) { +func xPendingParser(rd *proto.Reader, n int64) (interface{}, error) { if n != 4 { return nil, fmt.Errorf("got %d, wanted 4", n) } @@ -1128,9 +1128,9 @@ func xPendingParser(rd proto.Reader, n int64) (interface{}, error) { Lower: lower, Higher: higher, } - _, err = rd.ReadArrayReply(func(rd proto.Reader, n int64) (interface{}, error) { + _, err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { for i := int64(0); i < n; i++ { - _, err = rd.ReadArrayReply(func(rd proto.Reader, n int64) (interface{}, error) { + _, err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { if n != 2 { return nil, fmt.Errorf("got %d, wanted 2", n) } @@ -1199,7 +1199,7 @@ func (cmd *XPendingExtCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *XPendingExtCmd) readReply(rd proto.Reader) error { +func (cmd *XPendingExtCmd) readReply(rd *proto.Reader) error { var info interface{} info, cmd.err = rd.ReadArrayReply(xPendingExtSliceParser) if cmd.err != nil { @@ -1209,10 +1209,10 @@ func (cmd *XPendingExtCmd) readReply(rd proto.Reader) error { return nil } -func xPendingExtSliceParser(rd proto.Reader, n int64) (interface{}, error) { +func xPendingExtSliceParser(rd *proto.Reader, n int64) (interface{}, error) { ret := make([]XPendingExt, 0, n) for i := int64(0); i < n; i++ { - _, err := rd.ReadArrayReply(func(rd proto.Reader, n int64) (interface{}, error) { + _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { if n != 4 { return nil, fmt.Errorf("got %d, wanted 4", n) } @@ -1282,7 +1282,7 @@ func (cmd *ZSliceCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *ZSliceCmd) readReply(rd proto.Reader) error { +func (cmd *ZSliceCmd) readReply(rd *proto.Reader) error { var v interface{} v, cmd.err = rd.ReadArrayReply(zSliceParser) if cmd.err != nil { @@ -1293,7 +1293,7 @@ func (cmd *ZSliceCmd) readReply(rd proto.Reader) error { } // Implements proto.MultiBulkParse -func zSliceParser(rd proto.Reader, n int64) (interface{}, error) { +func zSliceParser(rd *proto.Reader, n int64) (interface{}, error) { zz := make([]Z, n/2) for i := int64(0); i < n; i += 2 { var err error @@ -1345,7 +1345,7 @@ func (cmd *ScanCmd) String() string { return cmdString(cmd, cmd.page) } -func (cmd *ScanCmd) readReply(rd proto.Reader) error { +func (cmd *ScanCmd) readReply(rd *proto.Reader) error { cmd.page, cmd.cursor, cmd.err = rd.ReadScanReply() return cmd.err } @@ -1396,7 +1396,7 @@ func (cmd *ClusterSlotsCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *ClusterSlotsCmd) readReply(rd proto.Reader) error { +func (cmd *ClusterSlotsCmd) readReply(rd *proto.Reader) error { var v interface{} v, cmd.err = rd.ReadArrayReply(clusterSlotsParser) if cmd.err != nil { @@ -1407,7 +1407,7 @@ func (cmd *ClusterSlotsCmd) readReply(rd proto.Reader) error { } // Implements proto.MultiBulkParse -func clusterSlotsParser(rd proto.Reader, n int64) (interface{}, error) { +func clusterSlotsParser(rd *proto.Reader, n int64) (interface{}, error) { slots := make([]ClusterSlot, n) for i := 0; i < len(slots); i++ { n, err := rd.ReadArrayLen() @@ -1551,7 +1551,7 @@ func (cmd *GeoLocationCmd) String() string { return cmdString(cmd, cmd.locations) } -func (cmd *GeoLocationCmd) readReply(rd proto.Reader) error { +func (cmd *GeoLocationCmd) readReply(rd *proto.Reader) error { var v interface{} v, cmd.err = rd.ReadArrayReply(newGeoLocationSliceParser(cmd.q)) if cmd.err != nil { @@ -1562,7 +1562,7 @@ func (cmd *GeoLocationCmd) readReply(rd proto.Reader) error { } func newGeoLocationParser(q *GeoRadiusQuery) proto.MultiBulkParse { - return func(rd proto.Reader, n int64) (interface{}, error) { + return func(rd *proto.Reader, n int64) (interface{}, error) { var loc GeoLocation var err error @@ -1606,7 +1606,7 @@ func newGeoLocationParser(q *GeoRadiusQuery) proto.MultiBulkParse { } func newGeoLocationSliceParser(q *GeoRadiusQuery) proto.MultiBulkParse { - return func(rd proto.Reader, n int64) (interface{}, error) { + return func(rd *proto.Reader, n int64) (interface{}, error) { locs := make([]GeoLocation, 0, n) for i := int64(0); i < n; i++ { v, err := rd.ReadReply(newGeoLocationParser(q)) @@ -1660,7 +1660,7 @@ func (cmd *GeoPosCmd) String() string { return cmdString(cmd, cmd.positions) } -func (cmd *GeoPosCmd) readReply(rd proto.Reader) error { +func (cmd *GeoPosCmd) readReply(rd *proto.Reader) error { var v interface{} v, cmd.err = rd.ReadArrayReply(geoPosSliceParser) if cmd.err != nil { @@ -1670,7 +1670,7 @@ func (cmd *GeoPosCmd) readReply(rd proto.Reader) error { return nil } -func geoPosSliceParser(rd proto.Reader, n int64) (interface{}, error) { +func geoPosSliceParser(rd *proto.Reader, n int64) (interface{}, error) { positions := make([]*GeoPos, 0, n) for i := int64(0); i < n; i++ { v, err := rd.ReadReply(geoPosParser) @@ -1691,7 +1691,7 @@ func geoPosSliceParser(rd proto.Reader, n int64) (interface{}, error) { return positions, nil } -func geoPosParser(rd proto.Reader, n int64) (interface{}, error) { +func geoPosParser(rd *proto.Reader, n int64) (interface{}, error) { var pos GeoPos var err error @@ -1746,7 +1746,7 @@ func (cmd *CommandsInfoCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *CommandsInfoCmd) readReply(rd proto.Reader) error { +func (cmd *CommandsInfoCmd) readReply(rd *proto.Reader) error { var v interface{} v, cmd.err = rd.ReadArrayReply(commandInfoSliceParser) if cmd.err != nil { @@ -1757,7 +1757,7 @@ func (cmd *CommandsInfoCmd) readReply(rd proto.Reader) error { } // Implements proto.MultiBulkParse -func commandInfoSliceParser(rd proto.Reader, n int64) (interface{}, error) { +func commandInfoSliceParser(rd *proto.Reader, n int64) (interface{}, error) { m := make(map[string]*CommandInfo, n) for i := int64(0); i < n; i++ { v, err := rd.ReadReply(commandInfoParser) @@ -1771,7 +1771,7 @@ func commandInfoSliceParser(rd proto.Reader, n int64) (interface{}, error) { return m, nil } -func commandInfoParser(rd proto.Reader, n int64) (interface{}, error) { +func commandInfoParser(rd *proto.Reader, n int64) (interface{}, error) { var cmd CommandInfo var err error diff --git a/example_test.go b/example_test.go index 2a8a22b..6708e0c 100644 --- a/example_test.go +++ b/example_test.go @@ -20,7 +20,7 @@ func init() { PoolSize: 10, PoolTimeout: 30 * time.Second, }) - redisdb.FlushDB() + // redisdb.FlushDB() } func ExampleNewClient() { @@ -206,7 +206,7 @@ func ExampleClient_Scan() { for { var keys []string var err error - keys, cursor, err = redisdb.Scan(cursor, "", 10).Result() + keys, cursor, err = redisdb.Scan(cursor, "key*", 10).Result() if err != nil { panic(err) } diff --git a/internal/pool/conn.go b/internal/pool/conn.go index 5d361d1..1095bfe 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -8,20 +8,14 @@ import ( "github.com/go-redis/redis/internal/proto" ) -func makeBuffer() []byte { - const defaulBufSize = 4096 - return make([]byte, defaulBufSize) -} - var noDeadline = time.Time{} type Conn struct { netConn net.Conn - buf []byte - rd proto.Reader + rd *proto.Reader rdLocked bool - wb *proto.WriteBuffer + wr *proto.Writer InitedAt time.Time pooled bool @@ -31,10 +25,9 @@ type Conn struct { func NewConn(netConn net.Conn) *Conn { cn := &Conn{ netConn: netConn, - buf: makeBuffer(), } - cn.rd = proto.NewReader(proto.NewElasticBufReader(netConn)) - cn.wb = proto.NewWriteBuffer() + cn.rd = proto.NewReader(netConn) + cn.wr = proto.NewWriter(netConn) cn.SetUsedAt(time.Now()) return cn } @@ -50,6 +43,7 @@ func (cn *Conn) SetUsedAt(tm time.Time) { func (cn *Conn) SetNetConn(netConn net.Conn) { cn.netConn = netConn cn.rd.Reset(netConn) + cn.wr.Reset(netConn) } func (cn *Conn) setReadTimeout(timeout time.Duration) error { @@ -78,40 +72,19 @@ func (cn *Conn) RemoteAddr() net.Addr { return cn.netConn.RemoteAddr() } -func (cn *Conn) LockReaderBuffer() { - cn.rdLocked = true - cn.rd.ResetBuffer(makeBuffer()) -} - -func (cn *Conn) WithReader(timeout time.Duration, fn func(rd proto.Reader) error) error { +func (cn *Conn) WithReader(timeout time.Duration, fn func(rd *proto.Reader) error) error { _ = cn.setReadTimeout(timeout) - - if !cn.rdLocked { - cn.rd.ResetBuffer(cn.buf) - } - - err := fn(cn.rd) - - if !cn.rdLocked { - cn.buf = cn.rd.Buffer() - } - - return err + return fn(cn.rd) } -func (cn *Conn) WithWriter(timeout time.Duration, fn func(wb *proto.WriteBuffer) error) error { +func (cn *Conn) WithWriter(timeout time.Duration, fn func(wr *proto.Writer) error) error { _ = cn.setWriteTimeout(timeout) - cn.wb.ResetBuffer(cn.buf) - - firstErr := fn(cn.wb) - - _, err := cn.netConn.Write(cn.wb.Bytes()) - cn.buf = cn.wb.Buffer() + firstErr := fn(cn.wr) + err := cn.wr.Flush() if err != nil && firstErr == nil { firstErr = err } - return firstErr } diff --git a/internal/proto/elastic_reader.go b/internal/proto/elastic_reader.go deleted file mode 100644 index f075e86..0000000 --- a/internal/proto/elastic_reader.go +++ /dev/null @@ -1,205 +0,0 @@ -package proto - -import ( - "bytes" - "errors" - "io" -) - -const defaultBufSize = 4096 - -// ElasticBufReader is like bufio.Reader but instead of returning ErrBufferFull -// it automatically grows the buffer. -type ElasticBufReader struct { - buf []byte - rd io.Reader // reader provided by the client - r, w int // buf read and write positions - err error -} - -func NewElasticBufReader(rd io.Reader) *ElasticBufReader { - return &ElasticBufReader{ - rd: rd, - } -} - -func (b *ElasticBufReader) Reset(rd io.Reader) { - b.rd = rd - b.r, b.w = 0, 0 - b.err = nil -} - -func (b *ElasticBufReader) Buffer() []byte { - return b.buf -} - -func (b *ElasticBufReader) ResetBuffer(buf []byte) { - b.buf = buf - b.r, b.w = 0, 0 - b.err = nil -} - -// Buffered returns the number of bytes that can be read from the current buffer. -func (b *ElasticBufReader) Buffered() int { - return b.w - b.r -} - -func (b *ElasticBufReader) Bytes() []byte { - return b.buf[b.r:b.w] -} - -var errNegativeRead = errors.New("bufio: reader returned negative count from Read") - -// fill reads a new chunk into the buffer. -func (b *ElasticBufReader) fill() { - // Slide existing data to beginning. - if b.r > 0 { - copy(b.buf, b.buf[b.r:b.w]) - b.w -= b.r - b.r = 0 - } - - if b.w >= len(b.buf) { - panic("bufio: tried to fill full buffer") - } - - // Read new data: try a limited number of times. - const maxConsecutiveEmptyReads = 100 - for i := maxConsecutiveEmptyReads; i > 0; i-- { - n, err := b.rd.Read(b.buf[b.w:]) - if n < 0 { - panic(errNegativeRead) - } - b.w += n - if err != nil { - b.err = err - return - } - if n > 0 { - return - } - } - b.err = io.ErrNoProgress -} - -func (b *ElasticBufReader) readErr() error { - err := b.err - b.err = nil - return err -} - -func (b *ElasticBufReader) ReadSlice(delim byte) (line []byte, err error) { - for { - // Search buffer. - if i := bytes.IndexByte(b.buf[b.r:b.w], delim); i >= 0 { - line = b.buf[b.r : b.r+i+1] - b.r += i + 1 - break - } - - // Pending error? - if b.err != nil { - line = b.buf[b.r:b.w] - b.r = b.w - err = b.readErr() - break - } - - // Buffer full? - if b.Buffered() >= len(b.buf) { - b.grow(len(b.buf) + defaultBufSize) - } - - b.fill() // buffer is not full - } - - return -} - -func (b *ElasticBufReader) ReadLine() (line []byte, err error) { - line, err = b.ReadSlice('\n') - if len(line) == 0 { - if err != nil { - line = nil - } - return - } - err = nil - - if line[len(line)-1] == '\n' { - drop := 1 - if len(line) > 1 && line[len(line)-2] == '\r' { - drop = 2 - } - line = line[:len(line)-drop] - } - return -} - -func (b *ElasticBufReader) ReadByte() (byte, error) { - for b.r == b.w { - if b.err != nil { - return 0, b.readErr() - } - b.fill() // buffer is empty - } - c := b.buf[b.r] - b.r++ - return c, nil -} - -func (b *ElasticBufReader) ReadN(n int) ([]byte, error) { - b.grow(n) - for b.Buffered() < n { - // Pending error? - if b.err != nil { - buf := b.buf[b.r:b.w] - b.r = b.w - return buf, b.readErr() - } - - b.fill() - } - - buf := b.buf[b.r : b.r+n] - b.r += n - return buf, nil -} - -func (b *ElasticBufReader) grow(n int) { - if b.w-b.r >= n { - return - } - - // Slide existing data to beginning. - if b.r > 0 { - copy(b.buf, b.buf[b.r:b.w]) - b.w -= b.r - b.r = 0 - } - - // Extend buffer if needed. - if d := n - len(b.buf); d > 0 { - b.buf = append(b.buf, make([]byte, d)...) - } -} - -func (b *ElasticBufReader) Read(p []byte) (n int, err error) { - if len(p) == 0 { - return 0, b.readErr() - } - - if b.r != b.w { - // copy as much as we can - n = copy(p, b.buf[b.r:b.w]) - b.r += n - return n, nil - } - - if b.err != nil { - return 0, b.readErr() - } - - n, b.err = b.rd.Read(p) - return n, b.readErr() -} diff --git a/internal/proto/reader.go b/internal/proto/reader.go index 43abed5..896b6f6 100644 --- a/internal/proto/reader.go +++ b/internal/proto/reader.go @@ -1,6 +1,7 @@ package proto import ( + "bufio" "fmt" "io" "strconv" @@ -26,39 +27,32 @@ func (e RedisError) Error() string { return string(e) } //------------------------------------------------------------------------------ -type MultiBulkParse func(Reader, int64) (interface{}, error) +type MultiBulkParse func(*Reader, int64) (interface{}, error) type Reader struct { - src *ElasticBufReader + rd *bufio.Reader + _buf []byte } -func NewReader(src *ElasticBufReader) Reader { - return Reader{ - src: src, +func NewReader(rd io.Reader) *Reader { + return &Reader{ + rd: bufio.NewReader(rd), + _buf: make([]byte, 64), } } -func (r Reader) Reset(rd io.Reader) { - r.src.Reset(rd) +func (r *Reader) Reset(rd io.Reader) { + r.rd.Reset(rd) } -func (r Reader) Buffer() []byte { - return r.src.Buffer() -} - -func (r Reader) ResetBuffer(buf []byte) { - r.src.ResetBuffer(buf) -} - -func (r Reader) Bytes() []byte { - return r.src.Bytes() -} - -func (r Reader) ReadLine() ([]byte, error) { - line, err := r.src.ReadLine() +func (r *Reader) ReadLine() ([]byte, error) { + line, isPrefix, err := r.rd.ReadLine() if err != nil { return nil, err } + if isPrefix { + return nil, bufio.ErrBufferFull + } if len(line) == 0 { return nil, fmt.Errorf("redis: reply is empty") } @@ -68,7 +62,7 @@ func (r Reader) ReadLine() ([]byte, error) { return line, nil } -func (r Reader) ReadReply(m MultiBulkParse) (interface{}, error) { +func (r *Reader) ReadReply(m MultiBulkParse) (interface{}, error) { line, err := r.ReadLine() if err != nil { return nil, err @@ -93,7 +87,7 @@ func (r Reader) ReadReply(m MultiBulkParse) (interface{}, error) { return nil, fmt.Errorf("redis: can't parse %.100q", line) } -func (r Reader) ReadIntReply() (int64, error) { +func (r *Reader) ReadIntReply() (int64, error) { line, err := r.ReadLine() if err != nil { return 0, err @@ -108,7 +102,7 @@ func (r Reader) ReadIntReply() (int64, error) { } } -func (r Reader) ReadString() (string, error) { +func (r *Reader) ReadString() (string, error) { line, err := r.ReadLine() if err != nil { return "", err @@ -127,7 +121,7 @@ func (r Reader) ReadString() (string, error) { } } -func (r Reader) readStringReply(line []byte) (string, error) { +func (r *Reader) readStringReply(line []byte) (string, error) { if isNilReply(line) { return "", Nil } @@ -138,7 +132,7 @@ func (r Reader) readStringReply(line []byte) (string, error) { } b := make([]byte, replyLen+2) - _, err = io.ReadFull(r.src, b) + _, err = io.ReadFull(r.rd, b) if err != nil { return "", err } @@ -146,7 +140,7 @@ func (r Reader) readStringReply(line []byte) (string, error) { return util.BytesToString(b[:replyLen]), nil } -func (r Reader) ReadArrayReply(m MultiBulkParse) (interface{}, error) { +func (r *Reader) ReadArrayReply(m MultiBulkParse) (interface{}, error) { line, err := r.ReadLine() if err != nil { return nil, err @@ -165,7 +159,7 @@ func (r Reader) ReadArrayReply(m MultiBulkParse) (interface{}, error) { } } -func (r Reader) ReadArrayLen() (int64, error) { +func (r *Reader) ReadArrayLen() (int64, error) { line, err := r.ReadLine() if err != nil { return 0, err @@ -180,7 +174,7 @@ func (r Reader) ReadArrayLen() (int64, error) { } } -func (r Reader) ReadScanReply() ([]string, uint64, error) { +func (r *Reader) ReadScanReply() ([]string, uint64, error) { n, err := r.ReadArrayLen() if err != nil { return nil, 0, err @@ -211,7 +205,7 @@ func (r Reader) ReadScanReply() ([]string, uint64, error) { return keys, cursor, err } -func (r Reader) ReadInt() (int64, error) { +func (r *Reader) ReadInt() (int64, error) { b, err := r.readTmpBytesReply() if err != nil { return 0, err @@ -219,7 +213,7 @@ func (r Reader) ReadInt() (int64, error) { return util.ParseInt(b, 10, 64) } -func (r Reader) ReadUint() (uint64, error) { +func (r *Reader) ReadUint() (uint64, error) { b, err := r.readTmpBytesReply() if err != nil { return 0, err @@ -227,7 +221,7 @@ func (r Reader) ReadUint() (uint64, error) { return util.ParseUint(b, 10, 64) } -func (r Reader) ReadFloatReply() (float64, error) { +func (r *Reader) ReadFloatReply() (float64, error) { b, err := r.readTmpBytesReply() if err != nil { return 0, err @@ -235,7 +229,7 @@ func (r Reader) ReadFloatReply() (float64, error) { return util.ParseFloat(b, 64) } -func (r Reader) readTmpBytesReply() ([]byte, error) { +func (r *Reader) readTmpBytesReply() ([]byte, error) { line, err := r.ReadLine() if err != nil { return nil, err @@ -252,7 +246,7 @@ func (r Reader) readTmpBytesReply() ([]byte, error) { } } -func (r Reader) _readTmpBytesReply(line []byte) ([]byte, error) { +func (r *Reader) _readTmpBytesReply(line []byte) ([]byte, error) { if isNilReply(line) { return nil, Nil } @@ -262,11 +256,20 @@ func (r Reader) _readTmpBytesReply(line []byte) ([]byte, error) { return nil, err } - b, err := r.src.ReadN(replyLen + 2) + buf := r.buf(replyLen + 2) + _, err = io.ReadFull(r.rd, buf) if err != nil { return nil, err } - return b[:replyLen], nil + + return buf[:replyLen], nil +} + +func (r *Reader) buf(n int) []byte { + if d := n - cap(r._buf); d > 0 { + r._buf = append(r._buf, make([]byte, d)...) + } + return r._buf[:n] } func isNilReply(b []byte) bool { diff --git a/internal/proto/reader_test.go b/internal/proto/reader_test.go index 687b20f..54e85af 100644 --- a/internal/proto/reader_test.go +++ b/internal/proto/reader_test.go @@ -8,8 +8,8 @@ import ( "github.com/go-redis/redis/internal/proto" ) -func newReader(s string) proto.Reader { - return proto.NewReader(proto.NewElasticBufReader(strings.NewReader(s))) +func newReader(s string) *proto.Reader { + return proto.NewReader(strings.NewReader(s)) } func BenchmarkReader_ParseReply_Status(b *testing.B) { @@ -37,7 +37,7 @@ func benchmarkParseReply(b *testing.B, reply string, m proto.MultiBulkParse, wan for i := 0; i < b.N; i++ { buf.WriteString(reply) } - p := proto.NewReader(proto.NewElasticBufReader(buf)) + p := proto.NewReader(buf) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -48,7 +48,7 @@ func benchmarkParseReply(b *testing.B, reply string, m proto.MultiBulkParse, wan } } -func multiBulkParse(p proto.Reader, n int64) (interface{}, error) { +func multiBulkParse(p *proto.Reader, n int64) (interface{}, error) { vv := make([]interface{}, 0, n) for i := int64(0); i < n; i++ { v, err := p.ReadReply(multiBulkParse) diff --git a/internal/proto/write_buffer.go b/internal/proto/write_buffer.go deleted file mode 100644 index 3e0e3ba..0000000 --- a/internal/proto/write_buffer.go +++ /dev/null @@ -1,127 +0,0 @@ -package proto - -import ( - "encoding" - "fmt" - "strconv" -) - -type WriteBuffer struct { - buf []byte -} - -func NewWriteBuffer() *WriteBuffer { - return &WriteBuffer{} -} - -func (w *WriteBuffer) Len() int { - return len(w.buf) -} - -func (w *WriteBuffer) Bytes() []byte { - return w.buf -} - -func (w *WriteBuffer) Reset() { - w.buf = w.buf[:0] -} - -func (w *WriteBuffer) Buffer() []byte { - return w.buf[:cap(w.buf)] -} - -func (w *WriteBuffer) ResetBuffer(buf []byte) { - w.buf = buf[:0] -} - -func (w *WriteBuffer) Append(args []interface{}) error { - w.buf = append(w.buf, ArrayReply) - w.buf = strconv.AppendUint(w.buf, uint64(len(args)), 10) - w.buf = append(w.buf, '\r', '\n') - - for _, arg := range args { - if err := w.append(arg); err != nil { - return err - } - } - return nil -} - -func (w *WriteBuffer) append(val interface{}) error { - switch v := val.(type) { - case nil: - w.AppendString("") - case string: - w.AppendString(v) - case []byte: - w.AppendBytes(v) - case int: - w.AppendString(formatInt(int64(v))) - case int8: - w.AppendString(formatInt(int64(v))) - case int16: - w.AppendString(formatInt(int64(v))) - case int32: - w.AppendString(formatInt(int64(v))) - case int64: - w.AppendString(formatInt(v)) - case uint: - w.AppendString(formatUint(uint64(v))) - case uint8: - w.AppendString(formatUint(uint64(v))) - case uint16: - w.AppendString(formatUint(uint64(v))) - case uint32: - w.AppendString(formatUint(uint64(v))) - case uint64: - w.AppendString(formatUint(v)) - case float32: - w.AppendString(formatFloat(float64(v))) - case float64: - w.AppendString(formatFloat(v)) - case bool: - if v { - w.AppendString("1") - } else { - w.AppendString("0") - } - case encoding.BinaryMarshaler: - b, err := v.MarshalBinary() - if err != nil { - return err - } - w.AppendBytes(b) - default: - return fmt.Errorf( - "redis: can't marshal %T (consider implementing encoding.BinaryMarshaler)", val) - } - return nil -} - -func (w *WriteBuffer) AppendString(s string) { - w.buf = append(w.buf, StringReply) - w.buf = strconv.AppendUint(w.buf, uint64(len(s)), 10) - w.buf = append(w.buf, '\r', '\n') - w.buf = append(w.buf, s...) - w.buf = append(w.buf, '\r', '\n') -} - -func (w *WriteBuffer) AppendBytes(p []byte) { - w.buf = append(w.buf, StringReply) - w.buf = strconv.AppendUint(w.buf, uint64(len(p)), 10) - w.buf = append(w.buf, '\r', '\n') - w.buf = append(w.buf, p...) - w.buf = append(w.buf, '\r', '\n') -} - -func formatInt(n int64) string { - return strconv.FormatInt(n, 10) -} - -func formatUint(u uint64) string { - return strconv.FormatUint(u, 10) -} - -func formatFloat(f float64) string { - return strconv.FormatFloat(f, 'f', -1, 64) -} diff --git a/internal/proto/write_buffer_test.go b/internal/proto/write_buffer_test.go index 84799ff..29aa469 100644 --- a/internal/proto/write_buffer_test.go +++ b/internal/proto/write_buffer_test.go @@ -1,6 +1,8 @@ package proto_test import ( + "bytes" + "io/ioutil" "testing" "time" @@ -11,21 +13,16 @@ import ( ) var _ = Describe("WriteBuffer", func() { - var buf *proto.WriteBuffer + var buf *bytes.Buffer + var wr *proto.Writer BeforeEach(func() { - buf = proto.NewWriteBuffer() + buf = new(bytes.Buffer) + wr = proto.NewWriter(buf) }) - It("should reset", func() { - buf.AppendString("string") - Expect(buf.Len()).To(Equal(12)) - buf.Reset() - Expect(buf.Len()).To(Equal(0)) - }) - - It("should append args", func() { - err := buf.Append([]interface{}{ + It("should write args", func() { + err := wr.WriteArgs([]interface{}{ "string", 12, 34.56, @@ -34,6 +31,10 @@ var _ = Describe("WriteBuffer", func() { nil, }) Expect(err).NotTo(HaveOccurred()) + + err = wr.Flush() + Expect(err).NotTo(HaveOccurred()) + Expect(buf.Bytes()).To(Equal([]byte("*6\r\n" + "$6\r\nstring\r\n" + "$2\r\n12\r\n" + @@ -45,19 +46,30 @@ var _ = Describe("WriteBuffer", func() { }) It("should append marshalable args", func() { - err := buf.Append([]interface{}{time.Unix(1414141414, 0)}) + err := wr.WriteArgs([]interface{}{time.Unix(1414141414, 0)}) Expect(err).NotTo(HaveOccurred()) + + err = wr.Flush() + Expect(err).NotTo(HaveOccurred()) + Expect(buf.Len()).To(Equal(26)) }) }) func BenchmarkWriteBuffer_Append(b *testing.B) { - buf := proto.NewWriteBuffer() + buf := proto.NewWriter(ioutil.Discard) args := []interface{}{"hello", "world", "foo", "bar"} for i := 0; i < b.N; i++ { - buf.Append(args) - buf.Reset() + err := buf.WriteArgs(args) + if err != nil { + panic(err) + } + + err = buf.Flush() + if err != nil { + panic(err) + } } } diff --git a/internal/proto/writer.go b/internal/proto/writer.go new file mode 100644 index 0000000..d106ce0 --- /dev/null +++ b/internal/proto/writer.go @@ -0,0 +1,159 @@ +package proto + +import ( + "bufio" + "encoding" + "fmt" + "io" + "strconv" + + "github.com/go-redis/redis/internal/util" +) + +type Writer struct { + wr *bufio.Writer + + lenBuf []byte + numBuf []byte +} + +func NewWriter(wr io.Writer) *Writer { + return &Writer{ + wr: bufio.NewWriter(wr), + + lenBuf: make([]byte, 64), + numBuf: make([]byte, 64), + } +} + +func (w *Writer) WriteArgs(args []interface{}) error { + err := w.wr.WriteByte(ArrayReply) + if err != nil { + return err + } + + err = w.writeLen(len(args)) + if err != nil { + return err + } + + for _, arg := range args { + err := w.writeArg(arg) + if err != nil { + return err + } + } + + return nil +} + +func (w *Writer) writeLen(n int) error { + w.lenBuf = strconv.AppendUint(w.lenBuf[:0], uint64(n), 10) + w.lenBuf = append(w.lenBuf, '\r', '\n') + _, err := w.wr.Write(w.lenBuf) + return err +} + +func (w *Writer) writeArg(v interface{}) error { + switch v := v.(type) { + case nil: + return w.string("") + case string: + return w.string(v) + case []byte: + return w.bytes(v) + case int: + return w.int(int64(v)) + case int8: + return w.int(int64(v)) + case int16: + return w.int(int64(v)) + case int32: + return w.int(int64(v)) + case int64: + return w.int(v) + case uint: + return w.uint(uint64(v)) + case uint8: + return w.uint(uint64(v)) + case uint16: + return w.uint(uint64(v)) + case uint32: + return w.uint(uint64(v)) + case uint64: + return w.uint(v) + case float32: + return w.float(float64(v)) + case float64: + return w.float(v) + case bool: + if v { + return w.int(1) + } else { + return w.int(0) + } + case encoding.BinaryMarshaler: + b, err := v.MarshalBinary() + if err != nil { + return err + } + return w.bytes(b) + default: + return fmt.Errorf( + "redis: can't marshal %T (implement encoding.BinaryMarshaler)", v) + } +} + +func (w *Writer) bytes(b []byte) error { + err := w.wr.WriteByte(StringReply) + if err != nil { + return err + } + + err = w.writeLen(len(b)) + if err != nil { + return err + } + + _, err = w.wr.Write(b) + if err != nil { + return err + } + + return w.crlf() +} + +func (w *Writer) string(s string) error { + return w.bytes(util.StringToBytes(s)) +} + +func (w *Writer) uint(n uint64) error { + w.numBuf = strconv.AppendUint(w.numBuf[:0], n, 10) + return w.bytes(w.numBuf) +} + +func (w *Writer) int(n int64) error { + w.numBuf = strconv.AppendInt(w.numBuf[:0], n, 10) + return w.bytes(w.numBuf) +} + +func (w *Writer) float(f float64) error { + w.numBuf = strconv.AppendFloat(w.numBuf[:0], f, 'f', -1, 64) + return w.bytes(w.numBuf) +} + +func (w *Writer) crlf() error { + err := w.wr.WriteByte('\r') + if err != nil { + return err + } + return w.wr.WriteByte('\n') +} + +func (w *Writer) Reset(wr io.Writer) { + w.wr.Reset(wr) +} + +func (w *Writer) Flush() error { + return w.wr.Flush() +} diff --git a/internal/util/safe.go b/internal/util/safe.go index cd89183..1b3060e 100644 --- a/internal/util/safe.go +++ b/internal/util/safe.go @@ -5,3 +5,7 @@ package util func BytesToString(b []byte) string { return string(b) } + +func StringToBytes(s string) []byte { + return []byte(s) +} diff --git a/internal/util/unsafe.go b/internal/util/unsafe.go index 93a89c5..c9868aa 100644 --- a/internal/util/unsafe.go +++ b/internal/util/unsafe.go @@ -10,3 +10,13 @@ import ( func BytesToString(b []byte) string { return *(*string)(unsafe.Pointer(&b)) } + +// StringToBytes converts string to byte slice. +func StringToBytes(s string) []byte { + return *(*[]byte)(unsafe.Pointer( + &struct { + string + Cap int + }{s, len(s)}, + )) +} diff --git a/pubsub.go b/pubsub.go index d8ad82c..b08f34a 100644 --- a/pubsub.go +++ b/pubsub.go @@ -63,7 +63,6 @@ func (c *PubSub) _conn(newChannels []string) (*pool.Conn, error) { if err != nil { return nil, err } - cn.LockReaderBuffer() if err := c.resubscribe(cn); err != nil { _ = c.closeConn(cn) @@ -75,8 +74,8 @@ func (c *PubSub) _conn(newChannels []string) (*pool.Conn, error) { } func (c *PubSub) writeCmd(cn *pool.Conn, cmd Cmder) error { - return cn.WithWriter(c.opt.WriteTimeout, func(wb *proto.WriteBuffer) error { - return writeCmd(wb, cmd) + return cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error { + return writeCmd(wr, cmd) }) } @@ -341,7 +340,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { return nil, err } - err = cn.WithReader(timeout, func(rd proto.Reader) error { + err = cn.WithReader(timeout, func(rd *proto.Reader) error { return c.cmd.readReply(rd) }) diff --git a/redis.go b/redis.go index d4ed075..3e72bf0 100644 --- a/redis.go +++ b/redis.go @@ -156,8 +156,8 @@ func (c *baseClient) defaultProcess(cmd Cmder) error { return err } - err = cn.WithWriter(c.opt.WriteTimeout, func(wb *proto.WriteBuffer) error { - return writeCmd(wb, cmd) + err = cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error { + return writeCmd(wr, cmd) }) if err != nil { c.releaseConn(cn, err) @@ -168,7 +168,7 @@ func (c *baseClient) defaultProcess(cmd Cmder) error { return err } - err = cn.WithReader(c.cmdTimeout(cmd), func(rd proto.Reader) error { + err = cn.WithReader(c.cmdTimeout(cmd), func(rd *proto.Reader) error { return cmd.readReply(rd) }) c.releaseConn(cn, err) @@ -259,21 +259,21 @@ func (c *baseClient) generalProcessPipeline(cmds []Cmder, p pipelineProcessor) e } func (c *baseClient) pipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, error) { - err := cn.WithWriter(c.opt.WriteTimeout, func(wb *proto.WriteBuffer) error { - return writeCmd(wb, cmds...) + err := cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error { + return writeCmd(wr, cmds...) }) if err != nil { setCmdsErr(cmds, err) return true, err } - err = cn.WithReader(c.opt.ReadTimeout, func(rd proto.Reader) error { + err = cn.WithReader(c.opt.ReadTimeout, func(rd *proto.Reader) error { return pipelineReadCmds(rd, cmds) }) return true, err } -func pipelineReadCmds(rd proto.Reader, cmds []Cmder) error { +func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error { for _, cmd := range cmds { err := cmd.readReply(rd) if err != nil && !internal.IsRedisError(err) { @@ -284,15 +284,15 @@ func pipelineReadCmds(rd proto.Reader, cmds []Cmder) error { } func (c *baseClient) txPipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, error) { - err := cn.WithWriter(c.opt.WriteTimeout, func(wb *proto.WriteBuffer) error { - return txPipelineWriteMulti(wb, cmds) + err := cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error { + return txPipelineWriteMulti(wr, cmds) }) if err != nil { setCmdsErr(cmds, err) return true, err } - err = cn.WithReader(c.opt.ReadTimeout, func(rd proto.Reader) error { + err = cn.WithReader(c.opt.ReadTimeout, func(rd *proto.Reader) error { err := txPipelineReadQueued(rd, cmds) if err != nil { setCmdsErr(cmds, err) @@ -303,15 +303,15 @@ func (c *baseClient) txPipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, e return false, err } -func txPipelineWriteMulti(wb *proto.WriteBuffer, cmds []Cmder) error { +func txPipelineWriteMulti(wr *proto.Writer, cmds []Cmder) error { multiExec := make([]Cmder, 0, len(cmds)+2) multiExec = append(multiExec, NewStatusCmd("MULTI")) multiExec = append(multiExec, cmds...) multiExec = append(multiExec, NewSliceCmd("EXEC")) - return writeCmd(wb, multiExec...) + return writeCmd(wr, multiExec...) } -func txPipelineReadQueued(rd proto.Reader, cmds []Cmder) error { +func txPipelineReadQueued(rd *proto.Reader, cmds []Cmder) error { // Parse queued replies. var statusCmd StatusCmd err := statusCmd.readReply(rd)