diff --git a/cluster.go b/cluster.go index 12ff1d9..ca8ca5b 100644 --- a/cluster.go +++ b/cluster.go @@ -1333,19 +1333,19 @@ func (c *ClusterClient) remapCmds(cmds []Cmder, failedCmds map[*clusterNode][]Cm func (c *ClusterClient) pipelineProcessCmds( node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds map[*clusterNode][]Cmder, ) error { - cn.SetWriteTimeout(c.opt.WriteTimeout) - - err := writeCmd(cn, cmds...) + err := cn.WithWriter(c.opt.WriteTimeout, func(wb *proto.WriteBuffer) error { + return writeCmd(wb, cmds...) + }) if err != nil { setCmdsErr(cmds, err) failedCmds[node] = cmds return err } - // Set read timeout for all commands. - cn.SetReadTimeout(c.opt.ReadTimeout) - - return c.pipelineReadCmds(cn.Rd, cmds, failedCmds) + err = cn.WithReader(c.opt.ReadTimeout, func(rd proto.Reader) error { + return c.pipelineReadCmds(rd, cmds, failedCmds) + }) + return err } func (c *ClusterClient) pipelineReadCmds( @@ -1476,23 +1476,24 @@ func (c *ClusterClient) mapCmdsBySlot(cmds []Cmder) map[int][]Cmder { func (c *ClusterClient) txPipelineProcessCmds( node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds map[*clusterNode][]Cmder, ) error { - cn.SetWriteTimeout(c.opt.WriteTimeout) - if err := txPipelineWriteMulti(cn, cmds); err != nil { + err := cn.WithWriter(c.opt.WriteTimeout, func(wb *proto.WriteBuffer) error { + return txPipelineWriteMulti(wb, cmds) + }) + if err != nil { setCmdsErr(cmds, err) failedCmds[node] = cmds return err } - // Set read timeout for all commands. - cn.SetReadTimeout(c.opt.ReadTimeout) - - err := c.txPipelineReadQueued(cn.Rd, cmds, failedCmds) - if err != nil { - setCmdsErr(cmds, err) - return err - } - - return pipelineReadCmds(cn.Rd, cmds) + err = cn.WithReader(c.opt.ReadTimeout, func(rd proto.Reader) error { + err := c.txPipelineReadQueued(rd, cmds, failedCmds) + if err != nil { + setCmdsErr(cmds, err) + return err + } + return pipelineReadCmds(rd, cmds) + }) + return err } func (c *ClusterClient) txPipelineReadQueued( diff --git a/command.go b/command.go index 522d6bf..4921857 100644 --- a/command.go +++ b/command.go @@ -9,7 +9,6 @@ import ( "time" "github.com/go-redis/redis/internal" - "github.com/go-redis/redis/internal/pool" "github.com/go-redis/redis/internal/proto" "github.com/go-redis/redis/internal/util" ) @@ -44,17 +43,14 @@ func cmdsFirstErr(cmds []Cmder) error { return nil } -func writeCmd(cn *pool.Conn, cmds ...Cmder) error { - wb := cn.PrepareWriteBuffer() +func writeCmd(wb *proto.WriteBuffer, cmds ...Cmder) error { for _, cmd := range cmds { err := wb.Append(cmd.Args()) if err != nil { return err } } - - err := cn.FlushWriteBuffer(wb) - return err + return nil } func cmdString(cmd Cmder, val interface{}) string { diff --git a/internal/pool/conn.go b/internal/pool/conn.go index cfdf60d..5d361d1 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -8,15 +8,20 @@ import ( "github.com/go-redis/redis/internal/proto" ) +func makeBuffer() []byte { + const defaulBufSize = 4096 + return make([]byte, defaulBufSize) +} + var noDeadline = time.Time{} type Conn struct { netConn net.Conn - Rd proto.Reader - wb *proto.WriteBuffer - - concurrentReadWrite bool + buf []byte + rd proto.Reader + rdLocked bool + wb *proto.WriteBuffer InitedAt time.Time pooled bool @@ -26,9 +31,9 @@ type Conn struct { func NewConn(netConn net.Conn) *Conn { cn := &Conn{ netConn: netConn, + buf: makeBuffer(), } - buf := proto.NewElasticBufReader(netConn) - cn.Rd = proto.NewReader(buf) + cn.rd = proto.NewReader(proto.NewElasticBufReader(netConn)) cn.wb = proto.NewWriteBuffer() cn.SetUsedAt(time.Now()) return cn @@ -44,27 +49,25 @@ func (cn *Conn) SetUsedAt(tm time.Time) { func (cn *Conn) SetNetConn(netConn net.Conn) { cn.netConn = netConn - cn.Rd.Reset(netConn) + cn.rd.Reset(netConn) } -func (cn *Conn) SetReadTimeout(timeout time.Duration) { +func (cn *Conn) setReadTimeout(timeout time.Duration) error { now := time.Now() cn.SetUsedAt(now) if timeout > 0 { - cn.netConn.SetReadDeadline(now.Add(timeout)) - } else { - cn.netConn.SetReadDeadline(noDeadline) + return cn.netConn.SetReadDeadline(now.Add(timeout)) } + return cn.netConn.SetReadDeadline(noDeadline) } -func (cn *Conn) SetWriteTimeout(timeout time.Duration) { +func (cn *Conn) setWriteTimeout(timeout time.Duration) error { now := time.Now() cn.SetUsedAt(now) if timeout > 0 { - cn.netConn.SetWriteDeadline(now.Add(timeout)) - } else { - cn.netConn.SetWriteDeadline(noDeadline) + return cn.netConn.SetWriteDeadline(now.Add(timeout)) } + return cn.netConn.SetWriteDeadline(noDeadline) } func (cn *Conn) Write(b []byte) (int, error) { @@ -75,28 +78,43 @@ func (cn *Conn) RemoteAddr() net.Addr { return cn.netConn.RemoteAddr() } -func (cn *Conn) EnableConcurrentReadWrite() { - cn.concurrentReadWrite = true - cn.wb.ResetBuffer(make([]byte, 4096)) +func (cn *Conn) LockReaderBuffer() { + cn.rdLocked = true + cn.rd.ResetBuffer(makeBuffer()) } -func (cn *Conn) PrepareWriteBuffer() *proto.WriteBuffer { - if cn.concurrentReadWrite { - cn.wb.Reset() - } else { - cn.wb.ResetBuffer(cn.Rd.Buffer()) - } - return cn.wb -} +func (cn *Conn) WithReader(timeout time.Duration, fn func(rd proto.Reader) error) error { + _ = cn.setReadTimeout(timeout) -func (cn *Conn) FlushWriteBuffer(wb *proto.WriteBuffer) error { - _, err := cn.netConn.Write(wb.Bytes()) - if !cn.concurrentReadWrite { - cn.Rd.ResetBuffer(wb.Buffer()) + if !cn.rdLocked { + cn.rd.ResetBuffer(cn.buf) } + + err := fn(cn.rd) + + if !cn.rdLocked { + cn.buf = cn.rd.Buffer() + } + return err } +func (cn *Conn) WithWriter(timeout time.Duration, fn func(wb *proto.WriteBuffer) error) error { + _ = cn.setWriteTimeout(timeout) + + cn.wb.ResetBuffer(cn.buf) + + firstErr := fn(cn.wb) + + _, err := cn.netConn.Write(cn.wb.Bytes()) + cn.buf = cn.wb.Buffer() + if err != nil && firstErr == nil { + firstErr = err + } + + return firstErr +} + func (cn *Conn) Close() error { return cn.netConn.Close() } diff --git a/internal/pool/pool.go b/internal/pool/pool.go index c39ac9f..9cecee8 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -288,13 +288,6 @@ func (p *ConnPool) popIdle() *Conn { } func (p *ConnPool) Put(cn *Conn) { - buf := cn.Rd.Bytes() - if len(buf) > 0 { - internal.Logf("connection has unread data: %.100q", buf) - p.Remove(cn) - return - } - if !cn.pooled { p.Remove(cn) return diff --git a/internal/proto/elastic_reader.go b/internal/proto/elastic_reader.go index ce9ea3d..c89a7eb 100644 --- a/internal/proto/elastic_reader.go +++ b/internal/proto/elastic_reader.go @@ -19,8 +19,7 @@ type ElasticBufReader struct { func NewElasticBufReader(rd io.Reader) *ElasticBufReader { return &ElasticBufReader{ - buf: make([]byte, defaultBufSize), - rd: rd, + rd: rd, } } @@ -89,44 +88,6 @@ func (b *ElasticBufReader) readErr() error { return err } -func (b *ElasticBufReader) Read(p []byte) (n int, err error) { - n = len(p) - if n == 0 { - return 0, b.readErr() - } - if b.r == b.w { - if b.err != nil { - return 0, b.readErr() - } - if len(p) >= len(b.buf) { - // Large read, empty buffer. - // Read directly into p to avoid copy. - n, b.err = b.rd.Read(p) - if n < 0 { - panic(errNegativeRead) - } - return n, b.readErr() - } - // One read. - // Do not use b.fill, which will loop. - b.r = 0 - b.w = 0 - n, b.err = b.rd.Read(b.buf) - if n < 0 { - panic(errNegativeRead) - } - if n == 0 { - return 0, b.readErr() - } - b.w += n - } - - // copy as much as we can - n = copy(p, b.buf[b.r:b.w]) - b.r += n - return n, nil -} - func (b *ElasticBufReader) ReadSlice(delim byte) (line []byte, err error) { for { // Search buffer. diff --git a/pubsub.go b/pubsub.go index 0d146a7..d8ad82c 100644 --- a/pubsub.go +++ b/pubsub.go @@ -7,9 +7,10 @@ import ( "github.com/go-redis/redis/internal" "github.com/go-redis/redis/internal/pool" + "github.com/go-redis/redis/internal/proto" ) -// PubSub implements Pub/Sub commands as described in +// PubSub implements Pub/Sub commands bas described in // http://redis.io/topics/pubsub. Message receiving is NOT safe // for concurrent use by multiple goroutines. // @@ -62,7 +63,7 @@ func (c *PubSub) _conn(newChannels []string) (*pool.Conn, error) { if err != nil { return nil, err } - cn.EnableConcurrentReadWrite() + cn.LockReaderBuffer() if err := c.resubscribe(cn); err != nil { _ = c.closeConn(cn) @@ -74,8 +75,9 @@ func (c *PubSub) _conn(newChannels []string) (*pool.Conn, error) { } func (c *PubSub) writeCmd(cn *pool.Conn, cmd Cmder) error { - cn.SetWriteTimeout(c.opt.WriteTimeout) - return writeCmd(cn, cmd) + return cn.WithWriter(c.opt.WriteTimeout, func(wb *proto.WriteBuffer) error { + return writeCmd(wb, cmd) + }) } func (c *PubSub) resubscribe(cn *pool.Conn) error { @@ -339,8 +341,10 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { return nil, err } - cn.SetReadTimeout(timeout) - err = c.cmd.readReply(cn.Rd) + err = cn.WithReader(timeout, func(rd proto.Reader) error { + return c.cmd.readReply(rd) + }) + c.releaseConn(cn, err, timeout > 0) if err != nil { return nil, err diff --git a/redis.go b/redis.go index 32daab1..d4ed075 100644 --- a/redis.go +++ b/redis.go @@ -156,8 +156,10 @@ func (c *baseClient) defaultProcess(cmd Cmder) error { return err } - cn.SetWriteTimeout(c.opt.WriteTimeout) - if err := writeCmd(cn, cmd); err != nil { + err = cn.WithWriter(c.opt.WriteTimeout, func(wb *proto.WriteBuffer) error { + return writeCmd(wb, cmd) + }) + if err != nil { c.releaseConn(cn, err) cmd.setErr(err) if internal.IsRetryableError(err, true) { @@ -166,8 +168,9 @@ func (c *baseClient) defaultProcess(cmd Cmder) error { return err } - cn.SetReadTimeout(c.cmdTimeout(cmd)) - err = cmd.readReply(cn.Rd) + err = cn.WithReader(c.cmdTimeout(cmd), func(rd proto.Reader) error { + return cmd.readReply(rd) + }) c.releaseConn(cn, err) if err != nil && internal.IsRetryableError(err, cmd.readTimeout() == nil) { continue @@ -256,15 +259,18 @@ func (c *baseClient) generalProcessPipeline(cmds []Cmder, p pipelineProcessor) e } func (c *baseClient) pipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, error) { - cn.SetWriteTimeout(c.opt.WriteTimeout) - if err := writeCmd(cn, cmds...); err != nil { + err := cn.WithWriter(c.opt.WriteTimeout, func(wb *proto.WriteBuffer) error { + return writeCmd(wb, cmds...) + }) + if err != nil { setCmdsErr(cmds, err) return true, err } - // Set read timeout for all commands. - cn.SetReadTimeout(c.opt.ReadTimeout) - return true, pipelineReadCmds(cn.Rd, cmds) + err = cn.WithReader(c.opt.ReadTimeout, func(rd proto.Reader) error { + return pipelineReadCmds(rd, cmds) + }) + return true, err } func pipelineReadCmds(rd proto.Reader, cmds []Cmder) error { @@ -278,34 +284,34 @@ func pipelineReadCmds(rd proto.Reader, cmds []Cmder) error { } func (c *baseClient) txPipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, error) { - cn.SetWriteTimeout(c.opt.WriteTimeout) - err := txPipelineWriteMulti(cn, cmds) + err := cn.WithWriter(c.opt.WriteTimeout, func(wb *proto.WriteBuffer) error { + return txPipelineWriteMulti(wb, cmds) + }) if err != nil { setCmdsErr(cmds, err) return true, err } - // Set read timeout for all commands. - cn.SetReadTimeout(c.opt.ReadTimeout) - - err = c.txPipelineReadQueued(cn.Rd, cmds) - if err != nil { - setCmdsErr(cmds, err) - return false, err - } - - return false, pipelineReadCmds(cn.Rd, cmds) + err = cn.WithReader(c.opt.ReadTimeout, func(rd proto.Reader) error { + err := txPipelineReadQueued(rd, cmds) + if err != nil { + setCmdsErr(cmds, err) + return err + } + return pipelineReadCmds(rd, cmds) + }) + return false, err } -func txPipelineWriteMulti(cn *pool.Conn, cmds []Cmder) error { +func txPipelineWriteMulti(wb *proto.WriteBuffer, cmds []Cmder) error { multiExec := make([]Cmder, 0, len(cmds)+2) multiExec = append(multiExec, NewStatusCmd("MULTI")) multiExec = append(multiExec, cmds...) multiExec = append(multiExec, NewSliceCmd("EXEC")) - return writeCmd(cn, multiExec...) + return writeCmd(wb, multiExec...) } -func (c *baseClient) txPipelineReadQueued(rd proto.Reader, cmds []Cmder) error { +func txPipelineReadQueued(rd proto.Reader, cmds []Cmder) error { // Parse queued replies. var statusCmd StatusCmd err := statusCmd.readReply(rd)