From ce4fd8b6774e97e61a1f432af3ed49b70437ebe4 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Wed, 8 Feb 2017 11:24:09 +0200 Subject: [PATCH] Fix ReceiveMessage to work without any subscriptions. --- Makefile | 1 + command.go | 2 +- internal/pool/conn.go | 62 +++++++++++----- internal/pool/pool.go | 7 +- internal/pool/pool_single.go | 8 -- internal/pool/pool_sticky.go | 23 ------ internal/pool/pool_test.go | 2 +- internal/proto/reader.go | 29 +++++--- internal/proto/reader_test.go | 14 ++-- .../proto/{writebuffer.go => write_buffer.go} | 8 +- ...itebuffer_test.go => write_buffer_test.go} | 8 +- pool_test.go | 2 +- pubsub.go | 73 +++++++++++-------- pubsub_test.go | 32 +++++++- redis.go | 4 - redis_test.go | 8 +- sentinel.go | 8 +- tx_test.go | 2 +- 18 files changed, 164 insertions(+), 129 deletions(-) rename internal/proto/{writebuffer.go => write_buffer.go} (95%) rename internal/proto/{writebuffer_test.go => write_buffer_test.go} (94%) diff --git a/Makefile b/Makefile index 4562692..50fdc55 100644 --- a/Makefile +++ b/Makefile @@ -15,4 +15,5 @@ testdata/redis: wget -qO- https://github.com/antirez/redis/archive/unstable.tar.gz | tar xvz --strip-components=1 -C $@ testdata/redis/src/redis-server: testdata/redis + sed -i 's/libjemalloc.a/libjemalloc.a -lrt/g' $ 0 && time.Since(cn.UsedAt) > timeout + return timeout > 0 && time.Since(cn.UsedAt()) > timeout } func (cn *Conn) SetReadTimeout(timeout time.Duration) error { - cn.UsedAt = time.Now() + now := time.Now() + cn.SetUsedAt(now) if timeout > 0 { - return cn.NetConn.SetReadDeadline(cn.UsedAt.Add(timeout)) + return cn.netConn.SetReadDeadline(now.Add(timeout)) } - return cn.NetConn.SetReadDeadline(noDeadline) - + return cn.netConn.SetReadDeadline(noDeadline) } func (cn *Conn) SetWriteTimeout(timeout time.Duration) error { - cn.UsedAt = time.Now() + now := time.Now() + cn.SetUsedAt(now) if timeout > 0 { - return cn.NetConn.SetWriteDeadline(cn.UsedAt.Add(timeout)) + return cn.netConn.SetWriteDeadline(now.Add(timeout)) } - return cn.NetConn.SetWriteDeadline(noDeadline) + return cn.netConn.SetWriteDeadline(noDeadline) +} + +func (cn *Conn) Write(b []byte) (int, error) { + return cn.netConn.Write(b) +} + +func (cn *Conn) RemoteAddr() net.Addr { + return cn.netConn.RemoteAddr() } func (cn *Conn) Close() error { - return cn.NetConn.Close() + return cn.netConn.Close() } diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 6a0e057..4033e58 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -41,7 +41,6 @@ type Pooler interface { FreeLen() int Stats() *Stats Close() error - Closed() bool } type dialer func() (net.Conn, error) @@ -132,7 +131,7 @@ func (p *ConnPool) popFree() *Conn { // Get returns existed connection from the pool or creates a new one. func (p *ConnPool) Get() (*Conn, bool, error) { - if p.Closed() { + if p.closed() { return nil, false, ErrClosed } @@ -241,7 +240,7 @@ func (p *ConnPool) Stats() *Stats { } } -func (p *ConnPool) Closed() bool { +func (p *ConnPool) closed() bool { return atomic.LoadInt32(&p._closed) == 1 } @@ -318,7 +317,7 @@ func (p *ConnPool) reaper(frequency time.Duration) { defer ticker.Stop() for _ = range ticker.C { - if p.Closed() { + if p.closed() { break } n, err := p.ReapStaleConns() diff --git a/internal/pool/pool_single.go b/internal/pool/pool_single.go index 18ca616..22eaba9 100644 --- a/internal/pool/pool_single.go +++ b/internal/pool/pool_single.go @@ -12,10 +12,6 @@ func NewSingleConnPool(cn *Conn) *SingleConnPool { } } -func (p *SingleConnPool) First() *Conn { - return p.cn -} - func (p *SingleConnPool) Get() (*Conn, bool, error) { return p.cn, false, nil } @@ -49,7 +45,3 @@ func (p *SingleConnPool) Stats() *Stats { func (p *SingleConnPool) Close() error { return nil } - -func (p *SingleConnPool) Closed() bool { - return false -} diff --git a/internal/pool/pool_sticky.go b/internal/pool/pool_sticky.go index 9fb9971..7426cd2 100644 --- a/internal/pool/pool_sticky.go +++ b/internal/pool/pool_sticky.go @@ -23,13 +23,6 @@ func NewStickyConnPool(pool *ConnPool, reusable bool) *StickyConnPool { } } -func (p *StickyConnPool) First() *Conn { - p.mu.Lock() - cn := p.cn - p.mu.Unlock() - return cn -} - func (p *StickyConnPool) Get() (*Conn, bool, error) { p.mu.Lock() defer p.mu.Unlock() @@ -62,9 +55,6 @@ func (p *StickyConnPool) Put(cn *Conn) error { if p.closed { return ErrClosed } - if p.cn != cn { - panic("p.cn != cn") - } return nil } @@ -81,12 +71,6 @@ func (p *StickyConnPool) Remove(cn *Conn, reason error) error { if p.closed { return nil } - if p.cn == nil { - panic("p.cn == nil") - } - if cn != nil && p.cn != cn { - panic("p.cn != cn") - } return p.removeUpstream(reason) } @@ -133,10 +117,3 @@ func (p *StickyConnPool) Close() error { } return err } - -func (p *StickyConnPool) Closed() bool { - p.mu.Lock() - closed := p.closed - p.mu.Unlock() - return closed -} diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index ece9b8f..f24f855 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -100,7 +100,7 @@ var _ = Describe("conns reaper", func() { for i := 0; i < 3; i++ { cn, _, err := connPool.Get() Expect(err).NotTo(HaveOccurred()) - cn.UsedAt = time.Now().Add(-2 * idleTimeout) + cn.SetUsedAt(time.Now().Add(-2 * idleTimeout)) conns = append(conns, cn) idleConns = append(idleConns, cn) } diff --git a/internal/proto/reader.go b/internal/proto/reader.go index ee811c8..78f3231 100644 --- a/internal/proto/reader.go +++ b/internal/proto/reader.go @@ -26,13 +26,17 @@ type Reader struct { buf []byte } -func NewReader(rd io.Reader) *Reader { +func NewReader(rd io.Reader, buf []byte) *Reader { return &Reader{ src: bufio.NewReader(rd), - buf: make([]byte, 0, bufferSize), + buf: buf, } } +func (r *Reader) Reset(rd io.Reader) { + r.src.Reset(rd) +} + func (p *Reader) PeekBuffered() []byte { if n := p.src.Buffered(); n != 0 { b, _ := p.src.Peek(n) @@ -42,7 +46,12 @@ func (p *Reader) PeekBuffered() []byte { } func (p *Reader) ReadN(n int) ([]byte, error) { - return readN(p.src, p.buf, n) + b, err := readN(p.src, p.buf, n) + if err != nil { + return nil, err + } + p.buf = b + return b, nil } func (p *Reader) ReadLine() ([]byte, error) { @@ -72,11 +81,11 @@ func (p *Reader) ReadReply(m MultiBulkParse) (interface{}, error) { case ErrorReply: return nil, ParseErrorReply(line) case StatusReply: - return parseStatusValue(line) + return parseStatusValue(line), nil case IntReply: return parseInt(line[1:], 10, 64) case StringReply: - return p.readBytesValue(line) + return p.readTmpBytesValue(line) case ArrayReply: n, err := parseArrayLen(line) if err != nil { @@ -111,9 +120,9 @@ func (p *Reader) ReadTmpBytesReply() ([]byte, error) { case ErrorReply: return nil, ParseErrorReply(line) case StringReply: - return p.readBytesValue(line) + return p.readTmpBytesValue(line) case StatusReply: - return parseStatusValue(line) + return parseStatusValue(line), nil default: return nil, fmt.Errorf("redis: can't parse string reply: %.100q", line) } @@ -210,7 +219,7 @@ func (p *Reader) ReadScanReply() ([]string, uint64, error) { return keys, cursor, err } -func (p *Reader) readBytesValue(line []byte) ([]byte, error) { +func (p *Reader) readTmpBytesValue(line []byte) ([]byte, error) { if isNilReply(line) { return nil, internal.Nil } @@ -297,8 +306,8 @@ func ParseErrorReply(line []byte) error { return internal.RedisError(string(line[1:])) } -func parseStatusValue(line []byte) ([]byte, error) { - return line[1:], nil +func parseStatusValue(line []byte) []byte { + return line[1:] } func parseArrayLen(line []byte) (int64, error) { diff --git a/internal/proto/reader_test.go b/internal/proto/reader_test.go index 421344b..4835a62 100644 --- a/internal/proto/reader_test.go +++ b/internal/proto/reader_test.go @@ -5,27 +5,27 @@ import ( "strings" "testing" + "gopkg.in/redis.v5/internal/proto" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" - - "gopkg.in/redis.v5/internal/proto" ) var _ = Describe("Reader", func() { It("should read n bytes", func() { - data, err := proto.NewReader(strings.NewReader("ABCDEFGHIJKLMNO")).ReadN(10) + data, err := proto.NewReader(strings.NewReader("ABCDEFGHIJKLMNO"), nil).ReadN(10) Expect(err).NotTo(HaveOccurred()) Expect(len(data)).To(Equal(10)) Expect(string(data)).To(Equal("ABCDEFGHIJ")) - data, err = proto.NewReader(strings.NewReader(strings.Repeat("x", 8192))).ReadN(6000) + data, err = proto.NewReader(strings.NewReader(strings.Repeat("x", 8192)), nil).ReadN(6000) Expect(err).NotTo(HaveOccurred()) Expect(len(data)).To(Equal(6000)) }) It("should read lines", func() { - p := proto.NewReader(strings.NewReader("$5\r\nhello\r\n")) + p := proto.NewReader(strings.NewReader("$5\r\nhello\r\n"), nil) data, err := p.ReadLine() Expect(err).NotTo(HaveOccurred()) @@ -59,11 +59,11 @@ func BenchmarkReader_ParseReply_Slice(b *testing.B) { } func benchmarkParseReply(b *testing.B, reply string, m proto.MultiBulkParse, wanterr bool) { - buf := &bytes.Buffer{} + buf := new(bytes.Buffer) for i := 0; i < b.N; i++ { buf.WriteString(reply) } - p := proto.NewReader(buf) + p := proto.NewReader(buf, nil) b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/internal/proto/writebuffer.go b/internal/proto/write_buffer.go similarity index 95% rename from internal/proto/writebuffer.go rename to internal/proto/write_buffer.go index 1e0f8e6..93fb367 100644 --- a/internal/proto/writebuffer.go +++ b/internal/proto/write_buffer.go @@ -8,11 +8,13 @@ import ( const bufferSize = 4096 -type WriteBuffer struct{ b []byte } +type WriteBuffer struct { + b []byte +} -func NewWriteBuffer() *WriteBuffer { +func NewWriteBuffer(b []byte) *WriteBuffer { return &WriteBuffer{ - b: make([]byte, 0, bufferSize), + b: b, } } diff --git a/internal/proto/writebuffer_test.go b/internal/proto/write_buffer_test.go similarity index 94% rename from internal/proto/writebuffer_test.go rename to internal/proto/write_buffer_test.go index 36593af..fd70cd7 100644 --- a/internal/proto/writebuffer_test.go +++ b/internal/proto/write_buffer_test.go @@ -4,17 +4,17 @@ import ( "testing" "time" + "gopkg.in/redis.v5/internal/proto" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" - - "gopkg.in/redis.v5/internal/proto" ) var _ = Describe("WriteBuffer", func() { var buf *proto.WriteBuffer BeforeEach(func() { - buf = proto.NewWriteBuffer() + buf = proto.NewWriteBuffer(nil) }) It("should reset", func() { @@ -53,7 +53,7 @@ var _ = Describe("WriteBuffer", func() { }) func BenchmarkWriteBuffer_Append(b *testing.B) { - buf := proto.NewWriteBuffer() + buf := proto.NewWriteBuffer(nil) args := []interface{}{"hello", "world", "foo", "bar"} for i := 0; i < b.N; i++ { diff --git a/pool_test.go b/pool_test.go index 0876029..683f7c8 100644 --- a/pool_test.go +++ b/pool_test.go @@ -93,7 +93,7 @@ var _ = Describe("pool", func() { It("removes broken connections", func() { cn, _, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) - cn.NetConn = &badConn{} + cn.SetNetConn(&badConn{}) Expect(client.Pool().Put(cn)).NotTo(HaveOccurred()) err = client.Ping().Err() diff --git a/pubsub.go b/pubsub.go index c9205a0..f98566e 100644 --- a/pubsub.go +++ b/pubsub.go @@ -3,6 +3,7 @@ package redis import ( "fmt" "net" + "sync" "time" "gopkg.in/redis.v5/internal" @@ -14,7 +15,9 @@ import ( // multiple goroutines. type PubSub struct { base baseClient + cmd *Cmd + mu sync.Mutex channels []string patterns []string } @@ -150,31 +153,40 @@ func (p *Pong) String() string { return "Pong" } -func (c *PubSub) newMessage(reply []interface{}) (interface{}, error) { - switch kind := reply[0].(string); kind { - case "subscribe", "unsubscribe", "psubscribe", "punsubscribe": - return &Subscription{ - Kind: kind, - Channel: reply[1].(string), - Count: int(reply[2].(int64)), - }, nil - case "message": - return &Message{ - Channel: reply[1].(string), - Payload: reply[2].(string), - }, nil - case "pmessage": - return &Message{ - Pattern: reply[1].(string), - Channel: reply[2].(string), - Payload: reply[3].(string), - }, nil - case "pong": +func (c *PubSub) newMessage(reply interface{}) (interface{}, error) { + switch reply := reply.(type) { + case string: return &Pong{ - Payload: reply[1].(string), + Payload: reply, }, nil + case []interface{}: + switch kind := reply[0].(string); kind { + case "subscribe", "unsubscribe", "psubscribe", "punsubscribe": + return &Subscription{ + Kind: kind, + Channel: reply[1].(string), + Count: int(reply[2].(int64)), + }, nil + case "message": + return &Message{ + Channel: reply[1].(string), + Payload: reply[2].(string), + }, nil + case "pmessage": + return &Message{ + Pattern: reply[1].(string), + Channel: reply[2].(string), + Payload: reply[3].(string), + }, nil + case "pong": + return &Pong{ + Payload: reply[1].(string), + }, nil + default: + return nil, fmt.Errorf("redis: unsupported pubsub message: %q", kind) + } default: - return nil, fmt.Errorf("redis: unsupported pubsub notification: %q", kind) + return nil, fmt.Errorf("redis: unsupported pubsub message: %#v", reply) } } @@ -182,7 +194,9 @@ func (c *PubSub) newMessage(reply []interface{}) (interface{}, error) { // is not received in time. This is low-level API and most clients // should use ReceiveMessage. func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { - cmd := NewSliceCmd() + if c.cmd == nil { + c.cmd = NewCmd() + } cn, _, err := c.conn() if err != nil { @@ -190,13 +204,13 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { } cn.SetReadTimeout(timeout) - err = cmd.readReply(cn) + err = c.cmd.readReply(cn) c.putConn(cn, err) if err != nil { return nil, err } - return c.newMessage(cmd.Val()) + return c.newMessage(c.cmd.Val()) } // Receive returns a message as a Subscription, Message, Pong or error. @@ -225,14 +239,14 @@ func (c *PubSub) receiveMessage(timeout time.Duration) (*Message, error) { errNum++ if errNum < 3 { if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - err := c.Ping("") + err := c.Ping("hello") if err != nil { internal.Logf("PubSub.Ping failed: %s", err) } } } else { - // 3 consequent errors - connection is bad - // and/or Redis Server is down. + // 3 consequent errors - connection is broken or + // Redis Server is down. // Sleep to not exceed max number of open connections. time.Sleep(time.Second) } @@ -256,9 +270,6 @@ func (c *PubSub) receiveMessage(timeout time.Duration) (*Message, error) { } func (c *PubSub) resubscribe() { - if c.base.closed() { - return - } if len(c.channels) > 0 { if err := c.Subscribe(c.channels...); err != nil { internal.Logf("Subscribe failed: %s", err) diff --git a/pubsub_test.go b/pubsub_test.go index f366213..9490688 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -288,12 +288,13 @@ var _ = Describe("PubSub", func() { }) expectReceiveMessageOnError := func(pubsub *redis.PubSub) { - cn1, _, err := pubsub.Pool().Get() + cn, _, err := pubsub.Pool().Get() Expect(err).NotTo(HaveOccurred()) - cn1.NetConn = &badConn{ + cn.SetNetConn(&badConn{ readErr: io.EOF, writeErr: io.EOF, - } + }) + pubsub.Pool().Put(cn) done := make(chan bool, 1) go func() { @@ -315,7 +316,7 @@ var _ = Describe("PubSub", func() { Eventually(done).Should(Receive()) stats := client.PoolStats() - Expect(stats.Requests).To(Equal(uint32(3))) + Expect(stats.Requests).To(Equal(uint32(4))) Expect(stats.Hits).To(Equal(uint32(1))) } @@ -362,4 +363,27 @@ var _ = Describe("PubSub", func() { wg.Wait() }) + It("should ReceiveMessage without a subscription", func() { + timeout := 100 * time.Millisecond + + pubsub, err := client.Subscribe() + Expect(err).NotTo(HaveOccurred()) + defer pubsub.Close() + + go func() { + defer GinkgoRecover() + + time.Sleep(2 * timeout) + err = pubsub.Subscribe("mychannel") + Expect(err).NotTo(HaveOccurred()) + + err := client.Publish("mychannel", "hello").Err() + Expect(err).NotTo(HaveOccurred()) + }() + + msg, err := pubsub.ReceiveMessageTimeout(timeout) + Expect(err).NotTo(HaveOccurred()) + Expect(msg.Channel).To(Equal("mychannel")) + Expect(msg.Payload).To(Equal("hello")) + }) }) diff --git a/redis.go b/redis.go index 32b83d5..1ddd754 100644 --- a/redis.go +++ b/redis.go @@ -126,10 +126,6 @@ func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration { } } -func (c *baseClient) closed() bool { - return c.connPool.Closed() -} - // Close closes the client, releasing any open resources. // // It is rare to Close a Client, as the Client is meant to be diff --git a/redis_test.go b/redis_test.go index 69c68df..c7ee7be 100644 --- a/redis_test.go +++ b/redis_test.go @@ -148,7 +148,7 @@ var _ = Describe("Client", func() { cn, _, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) - cn.NetConn = &badConn{} + cn.SetNetConn(&badConn{}) err = client.Pool().Put(cn) Expect(err).NotTo(HaveOccurred()) @@ -160,11 +160,11 @@ var _ = Describe("Client", func() { cn, _, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) Expect(cn.UsedAt).NotTo(BeZero()) - createdAt := cn.UsedAt + createdAt := cn.UsedAt() err = client.Pool().Put(cn) Expect(err).NotTo(HaveOccurred()) - Expect(cn.UsedAt.Equal(createdAt)).To(BeTrue()) + Expect(cn.UsedAt().Equal(createdAt)).To(BeTrue()) err = client.Ping().Err() Expect(err).NotTo(HaveOccurred()) @@ -172,7 +172,7 @@ var _ = Describe("Client", func() { cn, _, err = client.Pool().Get() Expect(err).NotTo(HaveOccurred()) Expect(cn).NotTo(BeNil()) - Expect(cn.UsedAt.After(createdAt)).To(BeTrue()) + Expect(cn.UsedAt().After(createdAt)).To(BeTrue()) }) It("should process command with special chars", func() { diff --git a/sentinel.go b/sentinel.go index 77b892d..b39d365 100644 --- a/sentinel.go +++ b/sentinel.go @@ -258,7 +258,7 @@ func (d *sentinelFailover) discoverSentinels(sentinel *sentinelClient) { // closeOldConns closes connections to the old master after failover switch. func (d *sentinelFailover) closeOldConns(newMaster string) { // Good connections that should be put back to the pool. They - // can't be put immediately, because pool.First will return them + // can't be put immediately, because pool.PopFree will return them // again on next iteration. cnsToPut := make([]*pool.Conn, 0) @@ -267,10 +267,10 @@ func (d *sentinelFailover) closeOldConns(newMaster string) { if cn == nil { break } - if cn.NetConn.RemoteAddr().String() != newMaster { + if cn.RemoteAddr().String() != newMaster { err := fmt.Errorf( "sentinel: closing connection to the old master %s", - cn.NetConn.RemoteAddr(), + cn.RemoteAddr(), ) internal.Logf(err.Error()) d.pool.Remove(cn, err) @@ -289,8 +289,10 @@ func (d *sentinelFailover) listen(sentinel *sentinelClient) { for { if pubsub == nil { pubsub = sentinel.PubSub() + if err := pubsub.Subscribe("+switch-master"); err != nil { internal.Logf("sentinel: Subscribe failed: %s", err) + pubsub.Close() d.resetSentinel() return } diff --git a/tx_test.go b/tx_test.go index 156a890..5ca22fd 100644 --- a/tx_test.go +++ b/tx_test.go @@ -127,7 +127,7 @@ var _ = Describe("Tx", func() { cn, _, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) - cn.NetConn = &badConn{} + cn.SetNetConn(&badConn{}) err = client.Pool().Put(cn) Expect(err).NotTo(HaveOccurred())