diff --git a/cluster.go b/cluster.go index 7a1af143..8c404c9c 100644 --- a/cluster.go +++ b/cluster.go @@ -8,6 +8,7 @@ import ( "math" "math/rand" "net" + "runtime" "sort" "sync" "sync/atomic" @@ -82,6 +83,10 @@ func (opt *ClusterOptions) init() { opt.ReadOnly = true } + if opt.PoolSize == 0 { + opt.PoolSize = 5 * runtime.NumCPU() + } + switch opt.ReadTimeout { case -1: opt.ReadTimeout = 0 diff --git a/command.go b/command.go index 992e6143..7fcbe5a9 100644 --- a/command.go +++ b/command.go @@ -46,14 +46,15 @@ func firstCmdsErr(cmds []Cmder) error { } func writeCmd(cn *pool.Conn, cmds ...Cmder) error { - cn.Wb.Reset() + cn.WB.Reset() for _, cmd := range cmds { - if err := cn.Wb.Append(cmd.Args()); err != nil { + err := cn.WB.Append(cmd.Args()) + if err != nil { return err } } - _, err := cn.Write(cn.Wb.Bytes()) + _, err := cn.Write(cn.WB.Flush()) return err } diff --git a/internal/pool/conn.go b/internal/pool/conn.go index acaf3665..f39f2658 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -14,7 +14,7 @@ type Conn struct { netConn net.Conn Rd *proto.Reader - Wb *proto.WriteBuffer + WB *proto.WriteBuffer Inited bool usedAt atomic.Value @@ -23,9 +23,10 @@ type Conn struct { func NewConn(netConn net.Conn) *Conn { cn := &Conn{ netConn: netConn, - Wb: proto.NewWriteBuffer(), } - cn.Rd = proto.NewReader(cn.netConn) + buf := proto.NewBufioReader(netConn) + cn.Rd = proto.NewReader(buf) + cn.WB = proto.NewWriteBuffer(buf) cn.SetUsedAt(time.Now()) return cn } diff --git a/internal/pool/pool.go b/internal/pool/pool.go index cab66904..b7b383be 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -246,8 +246,8 @@ func (p *ConnPool) popIdle() *Conn { } func (p *ConnPool) Put(cn *Conn) { - buf := cn.Rd.PeekBuffered() - if buf != nil { + buf := cn.Rd.Bytes() + if len(buf) > 0 { internal.Logf("connection has unread data: %.100q", buf) p.Remove(cn) return diff --git a/internal/proto/bufio_reader.go b/internal/proto/bufio_reader.go new file mode 100644 index 00000000..6e392e6c --- /dev/null +++ b/internal/proto/bufio_reader.go @@ -0,0 +1,231 @@ +package proto + +import ( + "bufio" + "bytes" + "errors" + "io" +) + +const defaultBufSize = 4096 + +type BufioReader struct { + buf []byte + rd io.Reader // reader provided by the client + r, w int // buf read and write positions + err error +} + +func NewBufioReader(rd io.Reader) *BufioReader { + r := new(BufioReader) + r.reset(make([]byte, defaultBufSize), rd) + return r +} + +func (b *BufioReader) Reset(rd io.Reader) { + b.reset(b.buf, rd) +} + +func (b *BufioReader) Buffer() []byte { + return b.buf +} + +func (b *BufioReader) ResetBuffer(buf []byte) { + b.reset(buf, b.rd) +} + +func (b *BufioReader) reset(buf []byte, rd io.Reader) { + *b = BufioReader{ + buf: buf, + rd: rd, + } +} + +// Buffered returns the number of bytes that can be read from the current buffer. +func (b *BufioReader) Buffered() int { return b.w - b.r } + +func (b *BufioReader) Bytes() []byte { + return b.buf[b.r:b.w] +} + +var errNegativeRead = errors.New("bufio: reader returned negative count from Read") + +// fill reads a new chunk into the buffer. +func (b *BufioReader) fill() { + // Slide existing data to beginning. + if b.r > 0 { + copy(b.buf, b.buf[b.r:b.w]) + b.w -= b.r + b.r = 0 + } + + if b.w >= len(b.buf) { + panic("bufio: tried to fill full buffer") + } + + // Read new data: try a limited number of times. + const maxConsecutiveEmptyReads = 100 + for i := maxConsecutiveEmptyReads; i > 0; i-- { + n, err := b.rd.Read(b.buf[b.w:]) + if n < 0 { + panic(errNegativeRead) + } + b.w += n + if err != nil { + b.err = err + return + } + if n > 0 { + return + } + } + b.err = io.ErrNoProgress +} + +func (b *BufioReader) readErr() error { + err := b.err + b.err = nil + return err +} + +func (b *BufioReader) Read(p []byte) (n int, err error) { + n = len(p) + if n == 0 { + return 0, b.readErr() + } + if b.r == b.w { + if b.err != nil { + return 0, b.readErr() + } + if len(p) >= len(b.buf) { + // Large read, empty buffer. + // Read directly into p to avoid copy. + n, b.err = b.rd.Read(p) + if n < 0 { + panic(errNegativeRead) + } + return n, b.readErr() + } + // One read. + // Do not use b.fill, which will loop. + b.r = 0 + b.w = 0 + n, b.err = b.rd.Read(b.buf) + if n < 0 { + panic(errNegativeRead) + } + if n == 0 { + return 0, b.readErr() + } + b.w += n + } + + // copy as much as we can + n = copy(p, b.buf[b.r:b.w]) + b.r += n + return n, nil +} + +func (b *BufioReader) ReadSlice(delim byte) (line []byte, err error) { + for { + // Search buffer. + if i := bytes.IndexByte(b.buf[b.r:b.w], delim); i >= 0 { + line = b.buf[b.r : b.r+i+1] + b.r += i + 1 + break + } + + // Pending error? + if b.err != nil { + line = b.buf[b.r:b.w] + b.r = b.w + err = b.readErr() + break + } + + // Buffer full? + if b.Buffered() >= len(b.buf) { + b.r = b.w + line = b.buf + err = bufio.ErrBufferFull + break + } + + b.fill() // buffer is not full + } + + return +} + +func (b *BufioReader) ReadLine() (line []byte, isPrefix bool, err error) { + line, err = b.ReadSlice('\n') + if err == bufio.ErrBufferFull { + // Handle the case where "\r\n" straddles the buffer. + if len(line) > 0 && line[len(line)-1] == '\r' { + // Put the '\r' back on buf and drop it from line. + // Let the next call to ReadLine check for "\r\n". + if b.r == 0 { + // should be unreachable + panic("bufio: tried to rewind past start of buffer") + } + b.r-- + line = line[:len(line)-1] + } + return line, true, nil + } + + if len(line) == 0 { + if err != nil { + line = nil + } + return + } + err = nil + + if line[len(line)-1] == '\n' { + drop := 1 + if len(line) > 1 && line[len(line)-2] == '\r' { + drop = 2 + } + line = line[:len(line)-drop] + } + return +} + +func (b *BufioReader) ReadN(n int) ([]byte, error) { + b.grow(n) + for b.Buffered() < n { + // Pending error? + if b.err != nil { + buf := b.buf[b.r:b.w] + b.r = b.w + return buf, b.readErr() + } + + // Buffer is full? + if b.Buffered() >= len(b.buf) { + b.r = b.w + return b.buf, bufio.ErrBufferFull + } + + b.fill() + } + + buf := b.buf[b.r : b.r+n] + b.r += n + return buf, nil +} + +func (b *BufioReader) grow(n int) { + // Slide existing data to beginning. + if b.r > 0 { + copy(b.buf, b.buf[b.r:b.w]) + b.w -= b.r + b.r = 0 + } + + // Extend buffer if needed. + if d := n - len(b.buf); d > 0 { + b.buf = append(b.buf, make([]byte, d)...) + } +} diff --git a/internal/proto/reader.go b/internal/proto/reader.go index 8c28c7b7..70fef5e0 100644 --- a/internal/proto/reader.go +++ b/internal/proto/reader.go @@ -32,14 +32,12 @@ func (e RedisError) Error() string { return string(e) } type MultiBulkParse func(*Reader, int64) (interface{}, error) type Reader struct { - src *bufio.Reader - buf []byte + src *BufioReader } -func NewReader(rd io.Reader) *Reader { +func NewReader(src *BufioReader) *Reader { return &Reader{ - src: bufio.NewReader(rd), - buf: make([]byte, 4096), + src: src, } } @@ -47,21 +45,12 @@ func (r *Reader) Reset(rd io.Reader) { r.src.Reset(rd) } -func (r *Reader) PeekBuffered() []byte { - if n := r.src.Buffered(); n != 0 { - b, _ := r.src.Peek(n) - return b - } - return nil +func (r *Reader) Bytes() []byte { + return r.src.Bytes() } func (r *Reader) ReadN(n int) ([]byte, error) { - b, err := readN(r.src, r.buf, n) - if err != nil { - return nil, err - } - r.buf = b - return b, nil + return r.src.ReadN(n) } func (r *Reader) ReadLine() ([]byte, error) { @@ -262,38 +251,6 @@ func (r *Reader) ReadUint() (uint64, error) { return util.ParseUint(b, 10, 64) } -// -------------------------------------------------------------------- - -func readN(r io.Reader, b []byte, n int) ([]byte, error) { - if n == 0 && b == nil { - return make([]byte, 0), nil - } - - if cap(b) >= n { - b = b[:n] - _, err := io.ReadFull(r, b) - return b, err - } - b = b[:cap(b)] - - pos := 0 - for pos < n { - diff := n - len(b) - if diff > bytesAllocLimit { - diff = bytesAllocLimit - } - b = append(b, make([]byte, diff)...) - - nn, err := io.ReadFull(r, b[pos:]) - if err != nil { - return nil, err - } - pos += nn - } - - return b, nil -} - func isNilReply(b []byte) bool { return len(b) == 3 && (b[0] == StringReply || b[0] == ArrayReply) && diff --git a/internal/proto/reader_test.go b/internal/proto/reader_test.go index 8d2d71be..c36e42fe 100644 --- a/internal/proto/reader_test.go +++ b/internal/proto/reader_test.go @@ -11,27 +11,31 @@ import ( . "github.com/onsi/gomega" ) +func newReader(s string) *proto.Reader { + return proto.NewReader(proto.NewBufioReader(strings.NewReader(s))) +} + var _ = Describe("Reader", func() { It("should read n bytes", func() { - data, err := proto.NewReader(strings.NewReader("ABCDEFGHIJKLMNO")).ReadN(10) + data, err := newReader("ABCDEFGHIJKLMNO").ReadN(10) Expect(err).NotTo(HaveOccurred()) Expect(len(data)).To(Equal(10)) Expect(string(data)).To(Equal("ABCDEFGHIJ")) - data, err = proto.NewReader(strings.NewReader(strings.Repeat("x", 8192))).ReadN(6000) + data, err = newReader(strings.Repeat("x", 8192)).ReadN(6000) Expect(err).NotTo(HaveOccurred()) Expect(len(data)).To(Equal(6000)) }) It("should read lines", func() { - p := proto.NewReader(strings.NewReader("$5\r\nhello\r\n")) + r := newReader("$5\r\nhello\r\n") - data, err := p.ReadLine() + data, err := r.ReadLine() Expect(err).NotTo(HaveOccurred()) Expect(string(data)).To(Equal("$5")) - data, err = p.ReadLine() + data, err = r.ReadLine() Expect(err).NotTo(HaveOccurred()) Expect(string(data)).To(Equal("hello")) }) @@ -63,7 +67,7 @@ func benchmarkParseReply(b *testing.B, reply string, m proto.MultiBulkParse, wan for i := 0; i < b.N; i++ { buf.WriteString(reply) } - p := proto.NewReader(buf) + p := proto.NewReader(proto.NewBufioReader(buf)) b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/internal/proto/write_buffer.go b/internal/proto/write_buffer.go index 664f4c33..51c5480b 100644 --- a/internal/proto/write_buffer.go +++ b/internal/proto/write_buffer.go @@ -7,23 +7,50 @@ import ( ) type WriteBuffer struct { - b []byte + rb *BufioReader + buf []byte } -func NewWriteBuffer() *WriteBuffer { +func NewWriteBuffer(rb *BufioReader) *WriteBuffer { return &WriteBuffer{ - b: make([]byte, 0, 4096), + rb: rb, } } -func (w *WriteBuffer) Len() int { return len(w.b) } -func (w *WriteBuffer) Bytes() []byte { return w.b } -func (w *WriteBuffer) Reset() { w.b = w.b[:0] } +func (w *WriteBuffer) Len() int { + return len(w.buf) +} + +func (w *WriteBuffer) Bytes() []byte { + return w.buf +} + +func (w *WriteBuffer) AllocBuffer() { + w.rb = nil + w.buf = make([]byte, defaultBufSize) +} + +func (w *WriteBuffer) Reset() { + if w.rb != nil { + w.buf = w.rb.Buffer()[:0] + } else { + w.buf = w.buf[:0] + } +} + +func (w *WriteBuffer) Flush() []byte { + b := w.buf + if w.rb != nil { + w.rb.ResetBuffer(w.buf[:cap(w.buf)]) + w.buf = nil + } + return b +} func (w *WriteBuffer) Append(args []interface{}) error { - w.b = append(w.b, ArrayReply) - w.b = strconv.AppendUint(w.b, uint64(len(args)), 10) - w.b = append(w.b, '\r', '\n') + w.buf = append(w.buf, ArrayReply) + w.buf = strconv.AppendUint(w.buf, uint64(len(args)), 10) + w.buf = append(w.buf, '\r', '\n') for _, arg := range args { if err := w.append(arg); err != nil { @@ -85,19 +112,19 @@ func (w *WriteBuffer) append(val interface{}) error { } func (w *WriteBuffer) AppendString(s string) { - w.b = append(w.b, StringReply) - w.b = strconv.AppendUint(w.b, uint64(len(s)), 10) - w.b = append(w.b, '\r', '\n') - w.b = append(w.b, s...) - w.b = append(w.b, '\r', '\n') + w.buf = append(w.buf, StringReply) + w.buf = strconv.AppendUint(w.buf, uint64(len(s)), 10) + w.buf = append(w.buf, '\r', '\n') + w.buf = append(w.buf, s...) + w.buf = append(w.buf, '\r', '\n') } func (w *WriteBuffer) AppendBytes(p []byte) { - w.b = append(w.b, StringReply) - w.b = strconv.AppendUint(w.b, uint64(len(p)), 10) - w.b = append(w.b, '\r', '\n') - w.b = append(w.b, p...) - w.b = append(w.b, '\r', '\n') + w.buf = append(w.buf, StringReply) + w.buf = strconv.AppendUint(w.buf, uint64(len(p)), 10) + w.buf = append(w.buf, '\r', '\n') + w.buf = append(w.buf, p...) + w.buf = append(w.buf, '\r', '\n') } func formatInt(n int64) string { diff --git a/internal/proto/write_buffer_test.go b/internal/proto/write_buffer_test.go index 84799ff3..dba8be6d 100644 --- a/internal/proto/write_buffer_test.go +++ b/internal/proto/write_buffer_test.go @@ -1,6 +1,7 @@ package proto_test import ( + "strings" "testing" "time" @@ -14,7 +15,7 @@ var _ = Describe("WriteBuffer", func() { var buf *proto.WriteBuffer BeforeEach(func() { - buf = proto.NewWriteBuffer() + buf = proto.NewWriteBuffer(proto.NewBufioReader(strings.NewReader(""))) }) It("should reset", func() { @@ -53,7 +54,7 @@ var _ = Describe("WriteBuffer", func() { }) func BenchmarkWriteBuffer_Append(b *testing.B) { - buf := proto.NewWriteBuffer() + buf := proto.NewWriteBuffer(proto.NewBufioReader(strings.NewReader(""))) args := []interface{}{"hello", "world", "foo", "bar"} for i := 0; i < b.N; i++ { diff --git a/pubsub.go b/pubsub.go index f0fcb8a2..0289265b 100644 --- a/pubsub.go +++ b/pubsub.go @@ -46,7 +46,7 @@ func (c *PubSub) conn() (*pool.Conn, error) { return cn, err } -func (c *PubSub) _conn(channels []string) (*pool.Conn, error) { +func (c *PubSub) _conn(newChannels []string) (*pool.Conn, error) { if c.closed { return nil, pool.ErrClosed } @@ -55,10 +55,14 @@ func (c *PubSub) _conn(channels []string) (*pool.Conn, error) { return c.cn, nil } + channels := mapKeys(c.channels) + channels = append(channels, newChannels...) + cn, err := c.newConn(channels) if err != nil { return nil, err } + cn.WB.AllocBuffer() if err := c.resubscribe(cn); err != nil { _ = c.closeConn(cn) @@ -69,20 +73,23 @@ func (c *PubSub) _conn(channels []string) (*pool.Conn, error) { return cn, nil } +func (c *PubSub) writeCmd(cn *pool.Conn, cmd Cmder) error { + cn.SetWriteTimeout(c.opt.WriteTimeout) + return writeCmd(cn, cmd) +} + func (c *PubSub) resubscribe(cn *pool.Conn) error { var firstErr error if len(c.channels) > 0 { - channels := mapKeys(c.channels) - err := c._subscribe(cn, "subscribe", channels...) + err := c._subscribe(cn, "subscribe", mapKeys(c.channels)) if err != nil && firstErr == nil { firstErr = err } } if len(c.patterns) > 0 { - patterns := mapKeys(c.patterns) - err := c._subscribe(cn, "psubscribe", patterns...) + err := c._subscribe(cn, "psubscribe", mapKeys(c.patterns)) if err != nil && firstErr == nil { firstErr = err } @@ -101,16 +108,16 @@ func mapKeys(m map[string]struct{}) []string { return s } -func (c *PubSub) _subscribe(cn *pool.Conn, redisCmd string, channels ...string) error { - args := make([]interface{}, 1+len(channels)) - args[0] = redisCmd - for i, channel := range channels { - args[1+i] = channel +func (c *PubSub) _subscribe( + cn *pool.Conn, redisCmd string, channels []string, +) error { + args := make([]interface{}, 0, 1+len(channels)) + args = append(args, redisCmd) + for _, channel := range channels { + args = append(args, channel) } cmd := NewSliceCmd(args...) - - cn.SetWriteTimeout(c.opt.WriteTimeout) - return writeCmd(cn, cmd) + return c.writeCmd(cn, cmd) } func (c *PubSub) releaseConn(cn *pool.Conn, err error, allowTimeout bool) { @@ -166,8 +173,8 @@ func (c *PubSub) Subscribe(channels ...string) error { if c.channels == nil { c.channels = make(map[string]struct{}) } - for _, channel := range channels { - c.channels[channel] = struct{}{} + for _, s := range channels { + c.channels[s] = struct{}{} } return err } @@ -182,8 +189,8 @@ func (c *PubSub) PSubscribe(patterns ...string) error { if c.patterns == nil { c.patterns = make(map[string]struct{}) } - for _, pattern := range patterns { - c.patterns[pattern] = struct{}{} + for _, s := range patterns { + c.patterns[s] = struct{}{} } return err } @@ -194,10 +201,10 @@ func (c *PubSub) Unsubscribe(channels ...string) error { c.mu.Lock() defer c.mu.Unlock() - err := c.subscribe("unsubscribe", channels...) for _, channel := range channels { delete(c.channels, channel) } + err := c.subscribe("unsubscribe", channels...) return err } @@ -207,10 +214,10 @@ func (c *PubSub) PUnsubscribe(patterns ...string) error { c.mu.Lock() defer c.mu.Unlock() - err := c.subscribe("punsubscribe", patterns...) for _, pattern := range patterns { delete(c.patterns, pattern) } + err := c.subscribe("punsubscribe", patterns...) return err } @@ -220,7 +227,7 @@ func (c *PubSub) subscribe(redisCmd string, channels ...string) error { return err } - err = c._subscribe(cn, redisCmd, channels...) + err = c._subscribe(cn, redisCmd, channels) c._releaseConn(cn, err, false) return err } @@ -237,8 +244,7 @@ func (c *PubSub) Ping(payload ...string) error { return err } - cn.SetWriteTimeout(c.opt.WriteTimeout) - err = writeCmd(cn, cmd) + err = c.writeCmd(cn, cmd) c.releaseConn(cn, err, false) return err } diff --git a/pubsub_test.go b/pubsub_test.go index 059b4a60..5f1fb543 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -404,4 +404,33 @@ var _ = Describe("PubSub", func() { Expect(msg.Channel).To(Equal("mychannel")) Expect(msg.Payload).To(Equal(string(bigVal))) }) + + It("supports concurrent Ping and Receive", func() { + const N = 100 + + pubsub := client.Subscribe("mychannel") + defer pubsub.Close() + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + + for i := 0; i < N; i++ { + _, err := pubsub.ReceiveTimeout(5 * time.Second) + Expect(err).NotTo(HaveOccurred()) + } + close(done) + }() + + for i := 0; i < N; i++ { + err := pubsub.Ping() + Expect(err).NotTo(HaveOccurred()) + } + + select { + case <-done: + case <-time.After(30 * time.Second): + Fail("timeout") + } + }) })