diff --git a/export_test.go b/export_test.go index d043b0df..b88e41be 100644 --- a/export_test.go +++ b/export_test.go @@ -1,6 +1,7 @@ package redis import ( + "net" "time" "github.com/go-redis/redis/internal/pool" @@ -10,8 +11,8 @@ func (c *baseClient) Pool() pool.Pooler { return c.connPool } -func (c *PubSub) Pool() pool.Pooler { - return c.base.connPool +func (c *PubSub) SetNetConn(netConn net.Conn) { + c.cn = pool.NewConn(netConn) } func (c *PubSub) ReceiveMessageTimeout(timeout time.Duration) (*Message, error) { diff --git a/internal/pool/bench_test.go b/internal/pool/bench_test.go index 610e12c7..5c021693 100644 --- a/internal/pool/bench_test.go +++ b/internal/pool/bench_test.go @@ -1,7 +1,6 @@ package pool_test import ( - "errors" "testing" "time" @@ -40,7 +39,6 @@ func BenchmarkPoolGetPut1000Conns(b *testing.B) { func benchmarkPoolGetRemove(b *testing.B, poolSize int) { connPool := pool.NewConnPool(dummyDialer, poolSize, time.Second, time.Hour, time.Hour) - removeReason := errors.New("benchmark") b.ResetTimer() @@ -50,7 +48,7 @@ func benchmarkPoolGetRemove(b *testing.B, poolSize int) { if err != nil { b.Fatal(err) } - if err := connPool.Remove(cn, removeReason); err != nil { + if err := connPool.Remove(cn); err != nil { b.Fatal(err) } } diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 45743fed..da8337a4 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -2,7 +2,6 @@ package pool import ( "errors" - "fmt" "net" "sync" "sync/atomic" @@ -11,11 +10,8 @@ import ( "github.com/go-redis/redis/internal" ) -var ( - ErrClosed = errors.New("redis: client is closed") - ErrPoolTimeout = errors.New("redis: connection pool timeout") - errConnStale = errors.New("connection is stale") -) +var ErrClosed = errors.New("redis: client is closed") +var ErrPoolTimeout = errors.New("redis: connection pool timeout") var timers = sync.Pool{ New: func() interface{} { @@ -36,12 +32,17 @@ type Stats struct { } type Pooler interface { + NewConn() (*Conn, error) + CloseConn(*Conn) error + Get() (*Conn, bool, error) Put(*Conn) error - Remove(*Conn, error) error + Remove(*Conn) error + Len() int FreeLen() int Stats() *Stats + Close() error } @@ -87,11 +88,21 @@ func NewConnPool(dial dialer, poolSize int, poolTimeout, idleTimeout, idleCheckF } func (p *ConnPool) NewConn() (*Conn, error) { + if p.closed() { + return nil, ErrClosed + } + netConn, err := p.dial() if err != nil { return nil, err } - return NewConn(netConn), nil + + cn := NewConn(netConn) + p.connsMu.Lock() + p.conns = append(p.conns, cn) + p.connsMu.Unlock() + + return cn, nil } func (p *ConnPool) PopFree() *Conn { @@ -164,7 +175,7 @@ func (p *ConnPool) Get() (*Conn, bool, error) { } if cn.IsStale(p.idleTimeout) { - p.remove(cn, errConnStale) + p.CloseConn(cn) continue } @@ -178,18 +189,13 @@ func (p *ConnPool) Get() (*Conn, bool, error) { return nil, false, err } - p.connsMu.Lock() - p.conns = append(p.conns, newcn) - p.connsMu.Unlock() - return newcn, true, nil } func (p *ConnPool) Put(cn *Conn) error { if data := cn.Rd.PeekBuffered(); data != nil { - err := fmt.Errorf("connection has unread data: %q", data) - internal.Logf(err.Error()) - return p.Remove(cn, err) + internal.Logf("connection has unread data: %q", data) + return p.Remove(cn) } p.freeConnsMu.Lock() p.freeConns = append(p.freeConns, cn) @@ -198,15 +204,13 @@ func (p *ConnPool) Put(cn *Conn) error { return nil } -func (p *ConnPool) Remove(cn *Conn, reason error) error { - p.remove(cn, reason) +func (p *ConnPool) Remove(cn *Conn) error { + _ = p.CloseConn(cn) <-p.queue return nil } -func (p *ConnPool) remove(cn *Conn, reason error) { - _ = p.closeConn(cn, reason) - +func (p *ConnPool) CloseConn(cn *Conn) error { p.connsMu.Lock() for i, c := range p.conns { if c == cn { @@ -215,6 +219,15 @@ func (p *ConnPool) remove(cn *Conn, reason error) { } } p.connsMu.Unlock() + + return p.closeConn(cn) +} + +func (p *ConnPool) closeConn(cn *Conn) error { + if p.OnClose != nil { + _ = p.OnClose(cn) + } + return cn.Close() } // Len returns total number of connections. @@ -258,7 +271,7 @@ func (p *ConnPool) Close() error { if cn == nil { continue } - if err := p.closeConn(cn, ErrClosed); err != nil && firstErr == nil { + if err := p.closeConn(cn); err != nil && firstErr == nil { firstErr = err } } @@ -272,13 +285,6 @@ func (p *ConnPool) Close() error { return firstErr } -func (p *ConnPool) closeConn(cn *Conn, reason error) error { - if p.OnClose != nil { - _ = p.OnClose(cn) - } - return cn.Close() -} - func (p *ConnPool) reapStaleConn() bool { if len(p.freeConns) == 0 { return false @@ -289,7 +295,7 @@ func (p *ConnPool) reapStaleConn() bool { return false } - p.remove(cn, errConnStale) + p.CloseConn(cn) p.freeConns = append(p.freeConns[:0], p.freeConns[1:]...) return true diff --git a/internal/pool/pool_single.go b/internal/pool/pool_single.go index 22eaba9d..ff91279b 100644 --- a/internal/pool/pool_single.go +++ b/internal/pool/pool_single.go @@ -12,6 +12,14 @@ func NewSingleConnPool(cn *Conn) *SingleConnPool { } } +func (p *SingleConnPool) NewConn() (*Conn, error) { + panic("not implemented") +} + +func (p *SingleConnPool) CloseConn(*Conn) error { + panic("not implemented") +} + func (p *SingleConnPool) Get() (*Conn, bool, error) { return p.cn, false, nil } @@ -23,7 +31,7 @@ func (p *SingleConnPool) Put(cn *Conn) error { return nil } -func (p *SingleConnPool) Remove(cn *Conn, _ error) error { +func (p *SingleConnPool) Remove(cn *Conn) error { if p.cn != cn { panic("p.cn != cn") } diff --git a/internal/pool/pool_sticky.go b/internal/pool/pool_sticky.go index 7426cd26..17f16385 100644 --- a/internal/pool/pool_sticky.go +++ b/internal/pool/pool_sticky.go @@ -1,9 +1,6 @@ package pool -import ( - "errors" - "sync" -) +import "sync" type StickyConnPool struct { pool *ConnPool @@ -23,6 +20,14 @@ func NewStickyConnPool(pool *ConnPool, reusable bool) *StickyConnPool { } } +func (p *StickyConnPool) NewConn() (*Conn, error) { + panic("not implemented") +} + +func (p *StickyConnPool) CloseConn(*Conn) error { + panic("not implemented") +} + func (p *StickyConnPool) Get() (*Conn, bool, error) { p.mu.Lock() defer p.mu.Unlock() @@ -58,20 +63,20 @@ func (p *StickyConnPool) Put(cn *Conn) error { return nil } -func (p *StickyConnPool) removeUpstream(reason error) error { - err := p.pool.Remove(p.cn, reason) +func (p *StickyConnPool) removeUpstream() error { + err := p.pool.Remove(p.cn) p.cn = nil return err } -func (p *StickyConnPool) Remove(cn *Conn, reason error) error { +func (p *StickyConnPool) Remove(cn *Conn) error { p.mu.Lock() defer p.mu.Unlock() if p.closed { return nil } - return p.removeUpstream(reason) + return p.removeUpstream() } func (p *StickyConnPool) Len() int { @@ -111,8 +116,7 @@ func (p *StickyConnPool) Close() error { if p.reusable { err = p.putUpstream() } else { - reason := errors.New("redis: unreusable sticky connection") - err = p.removeUpstream(reason) + err = p.removeUpstream() } } return err diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index c2983dd0..c8fbeb9b 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -1,7 +1,6 @@ package pool_test import ( - "errors" "testing" "time" @@ -59,7 +58,7 @@ var _ = Describe("ConnPool", func() { // ok } - err = connPool.Remove(cn, errors.New("test")) + err = connPool.Remove(cn) Expect(err).NotTo(HaveOccurred()) // Check that Ping is unblocked. @@ -169,7 +168,7 @@ var _ = Describe("conns reaper", func() { Expect(connPool.Len()).To(Equal(4)) Expect(connPool.FreeLen()).To(Equal(0)) - err = connPool.Remove(cn, errors.New("test")) + err = connPool.Remove(cn) Expect(err).NotTo(HaveOccurred()) Expect(connPool.Len()).To(Equal(3)) @@ -219,7 +218,7 @@ var _ = Describe("race", func() { cn, _, err := connPool.Get() Expect(err).NotTo(HaveOccurred()) if err == nil { - Expect(connPool.Remove(cn, errors.New("test"))).NotTo(HaveOccurred()) + Expect(connPool.Remove(cn)).NotTo(HaveOccurred()) } } }) diff --git a/options.go b/options.go index 03c64c78..8c30a67c 100644 --- a/options.go +++ b/options.go @@ -84,7 +84,7 @@ func (opt *Options) init() { } } if opt.PoolSize == 0 { - opt.PoolSize = 100 + opt.PoolSize = 10 } if opt.DialTimeout == 0 { opt.DialTimeout = 5 * time.Second diff --git a/pool_test.go b/pool_test.go index c6731e48..5363c400 100644 --- a/pool_test.go +++ b/pool_test.go @@ -77,18 +77,6 @@ var _ = Describe("pool", func() { Expect(pool.Len()).To(Equal(pool.FreeLen())) }) - It("respects max size on pubsub", func() { - connPool := client.Pool() - - perform(1000, func(id int) { - pubsub := client.Subscribe("test") - Expect(pubsub.Close()).NotTo(HaveOccurred()) - }) - - Expect(connPool.Len()).To(Equal(connPool.FreeLen())) - Expect(connPool.Len()).To(BeNumerically("<=", 10)) - }) - It("removes broken connections", func() { cn, _, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) diff --git a/pubsub.go b/pubsub.go index 8705a137..497b27cc 100644 --- a/pubsub.go +++ b/pubsub.go @@ -3,6 +3,7 @@ package redis import ( "fmt" "net" + "sync" "time" "github.com/go-redis/redis/internal" @@ -14,25 +15,72 @@ import ( // multiple goroutines. type PubSub struct { base baseClient - cmd *Cmd + + mu sync.Mutex + cn *pool.Conn + closed bool + + cmd *Cmd channels []string patterns []string } -func (c *PubSub) conn() (*pool.Conn, bool, error) { - cn, isNew, err := c.base.conn() +func (c *PubSub) conn() (*pool.Conn, error) { + cn, isNew, err := c._conn() if err != nil { - return nil, false, err + return nil, err } + if isNew { c.resubscribe() } - return cn, isNew, nil + + return cn, nil +} + +func (c *PubSub) resubscribe() { + if len(c.channels) > 0 { + if err := c.subscribe("subscribe", c.channels...); err != nil { + internal.Logf("Subscribe failed: %s", err) + } + } + if len(c.patterns) > 0 { + if err := c.subscribe("psubscribe", c.patterns...); err != nil { + internal.Logf("PSubscribe failed: %s", err) + } + } +} + +func (c *PubSub) _conn() (*pool.Conn, bool, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.closed { + return nil, false, pool.ErrClosed + } + + if c.cn != nil { + return c.cn, false, nil + } + + cn, err := c.base.connPool.NewConn() + if err != nil { + return nil, false, err + } + c.cn = cn + + return cn, true, nil } func (c *PubSub) putConn(cn *pool.Conn, err error) { - c.base.putConn(cn, err, true) + if internal.IsBadConn(err, true) { + c.mu.Lock() + if c.cn == cn { + _ = c.closeConn() + } + c.mu.Unlock() + } } func (c *PubSub) subscribe(redisCmd string, channels ...string) error { @@ -43,7 +91,7 @@ func (c *PubSub) subscribe(redisCmd string, channels ...string) error { } cmd := NewSliceCmd(args...) - cn, _, err := c.conn() + cn, err := c.conn() if err != nil { return err } @@ -56,14 +104,14 @@ func (c *PubSub) subscribe(redisCmd string, channels ...string) error { // Subscribes the client to the specified channels. func (c *PubSub) Subscribe(channels ...string) error { - err := c.subscribe("SUBSCRIBE", channels...) + err := c.subscribe("subscribe", channels...) c.channels = appendIfNotExists(c.channels, channels...) return err } // Subscribes the client to the given patterns. func (c *PubSub) PSubscribe(patterns ...string) error { - err := c.subscribe("PSUBSCRIBE", patterns...) + err := c.subscribe("psubscribe", patterns...) c.patterns = appendIfNotExists(c.patterns, patterns...) return err } @@ -71,7 +119,7 @@ func (c *PubSub) PSubscribe(patterns ...string) error { // Unsubscribes the client from the given channels, or from all of // them if none is given. func (c *PubSub) Unsubscribe(channels ...string) error { - err := c.subscribe("UNSUBSCRIBE", channels...) + err := c.subscribe("unsubscribe", channels...) c.channels = remove(c.channels, channels...) return err } @@ -79,23 +127,41 @@ func (c *PubSub) Unsubscribe(channels ...string) error { // Unsubscribes the client from the given patterns, or from all of // them if none is given. func (c *PubSub) PUnsubscribe(patterns ...string) error { - err := c.subscribe("PUNSUBSCRIBE", patterns...) + err := c.subscribe("punsubscribe", patterns...) c.patterns = remove(c.patterns, patterns...) return err } func (c *PubSub) Close() error { - return c.base.Close() + c.mu.Lock() + defer c.mu.Unlock() + + if c.closed { + return pool.ErrClosed + } + c.closed = true + + if c.cn != nil { + _ = c.closeConn() + } + + return nil +} + +func (c *PubSub) closeConn() error { + err := c.base.connPool.CloseConn(c.cn) + c.cn = nil + return err } func (c *PubSub) Ping(payload ...string) error { - args := []interface{}{"PING"} + args := []interface{}{"ping"} if len(payload) == 1 { args = append(args, payload[0]) } cmd := NewCmd(args...) - cn, _, err := c.conn() + cn, err := c.conn() if err != nil { return err } @@ -188,7 +254,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { c.cmd = NewCmd() } - cn, _, err := c.conn() + cn, err := c.conn() if err != nil { return nil, err } @@ -259,19 +325,6 @@ func (c *PubSub) receiveMessage(timeout time.Duration) (*Message, error) { } } -func (c *PubSub) resubscribe() { - if len(c.channels) > 0 { - if err := c.Subscribe(c.channels...); err != nil { - internal.Logf("Subscribe failed: %s", err) - } - } - if len(c.patterns) > 0 { - if err := c.PSubscribe(c.patterns...); err != nil { - internal.Logf("PSubscribe failed: %s", err) - } - } -} - // Channel returns a channel for concurrently receiving messages. // The channel is closed with PubSub. func (c *PubSub) Channel() <-chan *Message { @@ -292,6 +345,19 @@ func (c *PubSub) Channel() <-chan *Message { return ch } +func appendIfNotExists(ss []string, es ...string) []string { +loop: + for _, e := range es { + for _, s := range ss { + if s == e { + continue loop + } + } + ss = append(ss, e) + } + return ss +} + func remove(ss []string, es ...string) []string { if len(es) == 0 { return ss[:0] @@ -306,16 +372,3 @@ func remove(ss []string, es ...string) []string { } return ss } - -func appendIfNotExists(ss []string, es ...string) []string { -loop: - for _, e := range es { - for _, s := range ss { - if s == e { - continue loop - } - } - ss = append(ss, e) - } - return ss -} diff --git a/pubsub_test.go b/pubsub_test.go index 0164805a..b17ca7ad 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -274,18 +274,15 @@ var _ = Describe("PubSub", func() { Eventually(done).Should(Receive()) stats := client.PoolStats() - Expect(stats.Requests).To(Equal(uint32(3))) + Expect(stats.Requests).To(Equal(uint32(2))) Expect(stats.Hits).To(Equal(uint32(1))) }) expectReceiveMessageOnError := func(pubsub *redis.PubSub) { - cn, _, err := pubsub.Pool().Get() - Expect(err).NotTo(HaveOccurred()) - cn.SetNetConn(&badConn{ + pubsub.SetNetConn(&badConn{ readErr: io.EOF, writeErr: io.EOF, }) - pubsub.Pool().Put(cn) done := make(chan bool, 1) go func() { @@ -305,10 +302,6 @@ var _ = Describe("PubSub", func() { Expect(msg.Payload).To(Equal("hello")) Eventually(done).Should(Receive()) - - stats := client.PoolStats() - Expect(stats.Requests).To(Equal(uint32(4))) - Expect(stats.Hits).To(Equal(uint32(1))) } It("Subscribe should reconnect on ReceiveMessage error", func() { diff --git a/race_test.go b/race_test.go index 3af7226f..0ec6a140 100644 --- a/race_test.go +++ b/race_test.go @@ -136,35 +136,6 @@ var _ = Describe("races", func() { }) }) - It("should PubSub", func() { - connPool := client.Pool() - - perform(C, func(id int) { - for i := 0; i < N; i++ { - pubsub := client.Subscribe(fmt.Sprintf("mychannel%d", id)) - - go func() { - defer GinkgoRecover() - - time.Sleep(time.Millisecond) - err := pubsub.Close() - Expect(err).NotTo(HaveOccurred()) - }() - - _, err := pubsub.ReceiveMessage() - Expect(err.Error()).To(ContainSubstring("closed")) - - val := "echo" + strconv.Itoa(i) - echo, err := client.Echo(val).Result() - Expect(err).NotTo(HaveOccurred()) - Expect(echo).To(Equal(val)) - } - }) - - Expect(connPool.Len()).To(Equal(connPool.FreeLen())) - Expect(connPool.Len()).To(BeNumerically("<=", 10)) - }) - It("should select db", func() { err := client.Set("db", 1, 0).Err() Expect(err).NotTo(HaveOccurred()) diff --git a/redis.go b/redis.go index 7fbb7fd3..873d5805 100644 --- a/redis.go +++ b/redis.go @@ -31,9 +31,10 @@ func (c *baseClient) conn() (*pool.Conn, bool, error) { if err != nil { return nil, false, err } + if !cn.Inited { if err := c.initConn(cn); err != nil { - _ = c.connPool.Remove(cn, err) + _ = c.connPool.Remove(cn) return nil, false, err } } @@ -42,7 +43,7 @@ func (c *baseClient) conn() (*pool.Conn, bool, error) { func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool { if internal.IsBadConn(err, allowTimeout) { - _ = c.connPool.Remove(cn, err) + _ = c.connPool.Remove(cn) return false } @@ -353,7 +354,7 @@ func (c *Client) pubSub() *PubSub { return &PubSub{ base: baseClient{ opt: c.opt, - connPool: pool.NewStickyConnPool(c.connPool.(*pool.ConnPool), false), + connPool: c.connPool, }, } } diff --git a/redis_test.go b/redis_test.go index f0447222..2847963f 100644 --- a/redis_test.go +++ b/redis_test.go @@ -95,7 +95,7 @@ var _ = Describe("Client", func() { Expect(client.Close()).NotTo(HaveOccurred()) _, err := pubsub.Receive() - Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError("redis: client is closed")) Expect(pubsub.Close()).NotTo(HaveOccurred()) }) @@ -217,6 +217,7 @@ var _ = Describe("Client", func() { }) var _ = Describe("Client timeout", func() { + var opt *redis.Options var client *redis.Client AfterEach(func() { @@ -240,7 +241,13 @@ var _ = Describe("Client timeout", func() { }) It("Subscribe timeouts", func() { + if opt.WriteTimeout == 0 { + return + } + pubsub := client.Subscribe() + defer pubsub.Close() + err := pubsub.Subscribe("_") Expect(err).To(HaveOccurred()) Expect(err.(net.Error).Timeout()).To(BeTrue()) @@ -269,7 +276,7 @@ var _ = Describe("Client timeout", func() { Context("read timeout", func() { BeforeEach(func() { - opt := redisOptions() + opt = redisOptions() opt.ReadTimeout = time.Nanosecond opt.WriteTimeout = -1 client = redis.NewClient(opt) @@ -280,7 +287,7 @@ var _ = Describe("Client timeout", func() { Context("write timeout", func() { BeforeEach(func() { - opt := redisOptions() + opt = redisOptions() opt.ReadTimeout = -1 opt.WriteTimeout = time.Nanosecond client = redis.NewClient(opt) diff --git a/sentinel.go b/sentinel.go index 8070b464..799f530f 100644 --- a/sentinel.go +++ b/sentinel.go @@ -2,7 +2,6 @@ package redis import ( "errors" - "fmt" "net" "strings" "sync" @@ -111,7 +110,7 @@ func (c *sentinelClient) PubSub() *PubSub { return &PubSub{ base: baseClient{ opt: c.opt, - connPool: pool.NewStickyConnPool(c.connPool.(*pool.ConnPool), false), + connPool: c.connPool, }, } } @@ -268,12 +267,11 @@ func (d *sentinelFailover) closeOldConns(newMaster string) { break } if cn.RemoteAddr().String() != newMaster { - err := fmt.Errorf( + internal.Logf( "sentinel: closing connection to the old master %s", cn.RemoteAddr(), ) - internal.Logf(err.Error()) - d.pool.Remove(cn, err) + d.pool.Remove(cn) } else { cnsToPut = append(cnsToPut, cn) }