diff --git a/cluster.go b/cluster.go index f758b01b..2641b16e 100644 --- a/cluster.go +++ b/cluster.go @@ -327,7 +327,7 @@ func (c *clusterState) slotClosestNode(slot int) (*clusterNode, error) { } func (c *clusterState) slotNodes(slot int) []*clusterNode { - if slot < len(c.slots) { + if slot >= 0 && slot < len(c.slots) { return c.slots[slot] } return nil @@ -720,14 +720,14 @@ func (c *ClusterClient) pipelineExec(cmds []Cmder) error { failedCmds := make(map[*clusterNode][]Cmder) for node, cmds := range cmdsMap { - cn, _, err := node.Client.conn() + cn, _, err := node.Client.getConn() if err != nil { setCmdsErr(cmds, err) continue } err = c.pipelineProcessCmds(cn, cmds, failedCmds) - node.Client.putConn(cn, err) + node.Client.releaseConn(cn, err) } if len(failedCmds) == 0 { @@ -855,14 +855,14 @@ func (c *ClusterClient) txPipelineExec(cmds []Cmder) error { failedCmds := make(map[*clusterNode][]Cmder) for node, cmds := range cmdsMap { - cn, _, err := node.Client.conn() + cn, _, err := node.Client.getConn() if err != nil { setCmdsErr(cmds, err) continue } err = c.txPipelineProcessCmds(node, cn, cmds, failedCmds) - node.Client.putConn(cn, err) + node.Client.releaseConn(cn, err) } if len(failedCmds) == 0 { @@ -966,6 +966,56 @@ func (c *ClusterClient) txPipelineReadQueued( return firstErr } +func (c *ClusterClient) pubSub(channels []string) *PubSub { + opt := c.opt.clientOptions() + + var node *clusterNode + return &PubSub{ + opt: opt, + + newConn: func(channels []string) (*pool.Conn, error) { + if node == nil { + var slot int + if len(channels) > 0 { + slot = hashtag.Slot(channels[0]) + } else { + slot = -1 + } + + masterNode, err := c.state().slotMasterNode(slot) + if err != nil { + return nil, err + } + node = masterNode + } + return node.Client.newConn() + }, + closeConn: func(cn *pool.Conn) error { + return node.Client.connPool.CloseConn(cn) + }, + } +} + +// Subscribe subscribes the client to the specified channels. +// Channels can be omitted to create empty subscription. +func (c *ClusterClient) Subscribe(channels ...string) *PubSub { + pubsub := c.pubSub(channels) + if len(channels) > 0 { + _ = pubsub.Subscribe(channels...) + } + return pubsub +} + +// PSubscribe subscribes the client to the given patterns. +// Patterns can be omitted to create empty subscription. +func (c *ClusterClient) PSubscribe(channels ...string) *PubSub { + pubsub := c.pubSub(channels) + if len(channels) > 0 { + _ = pubsub.PSubscribe(channels...) + } + return pubsub +} + func isLoopbackAddr(addr string) bool { host, _, err := net.SplitHostPort(addr) if err != nil { diff --git a/cluster_test.go b/cluster_test.go index 3a69255a..1dc72295 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -472,6 +472,28 @@ var _ = Describe("ClusterClient", func() { }) }) + It("supports PubSub", func() { + pubsub := client.Subscribe("mychannel") + defer pubsub.Close() + + msgi, err := pubsub.ReceiveTimeout(time.Second) + Expect(err).NotTo(HaveOccurred()) + subscr := msgi.(*redis.Subscription) + Expect(subscr.Kind).To(Equal("subscribe")) + Expect(subscr.Channel).To(Equal("mychannel")) + Expect(subscr.Count).To(Equal(1)) + + n, err := client.Publish("mychannel", "hello").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(int64(1))) + + msgi, err = pubsub.ReceiveTimeout(time.Second) + Expect(err).NotTo(HaveOccurred()) + msg := msgi.(*redis.Message) + Expect(msg.Channel).To(Equal("mychannel")) + Expect(msg.Payload).To(Equal("hello")) + }) + It("calls fn for every master node", func() { for i := 0; i < 10; i++ { Expect(client.Set(strconv.Itoa(i), "", 0).Err()).NotTo(HaveOccurred()) diff --git a/pubsub.go b/pubsub.go index 4872b4e8..74ac51c1 100644 --- a/pubsub.go +++ b/pubsub.go @@ -17,7 +17,10 @@ import ( // PubSub automatically resubscribes to the channels and patterns // when Redis becomes unavailable. type PubSub struct { - base baseClient + opt *Options + + newConn func([]string) (*pool.Conn, error) + closeConn func(*pool.Conn) error mu sync.Mutex cn *pool.Conn @@ -30,12 +33,12 @@ type PubSub struct { func (c *PubSub) conn() (*pool.Conn, error) { c.mu.Lock() - cn, err := c._conn() + cn, err := c._conn(nil) c.mu.Unlock() return cn, err } -func (c *PubSub) _conn() (*pool.Conn, error) { +func (c *PubSub) _conn(channels []string) (*pool.Conn, error) { if c.closed { return nil, pool.ErrClosed } @@ -44,20 +47,13 @@ func (c *PubSub) _conn() (*pool.Conn, error) { return c.cn, nil } - cn, err := c.base.connPool.NewConn() + cn, err := c.newConn(channels) if err != nil { return nil, err } - if !cn.Inited { - if err := c.base.initConn(cn); err != nil { - _ = c.base.connPool.CloseConn(cn) - return nil, err - } - } - if err := c.resubscribe(cn); err != nil { - _ = c.base.connPool.CloseConn(cn) + _ = c.closeConn(cn) return nil, err } @@ -88,7 +84,7 @@ func (c *PubSub) _subscribe(cn *pool.Conn, redisCmd string, channels ...string) } cmd := NewSliceCmd(args...) - cn.SetWriteTimeout(c.base.opt.WriteTimeout) + cn.SetWriteTimeout(c.opt.WriteTimeout) return writeCmd(cn, cmd) } @@ -99,13 +95,13 @@ func (c *PubSub) putConn(cn *pool.Conn, err error) { c.mu.Lock() if c.cn == cn { - _ = c.closeConn() + _ = c.releaseConn() } c.mu.Unlock() } -func (c *PubSub) closeConn() error { - err := c.base.connPool.CloseConn(c.cn) +func (c *PubSub) releaseConn() error { + err := c.closeConn(c.cn) c.cn = nil return err } @@ -120,7 +116,7 @@ func (c *PubSub) Close() error { c.closed = true if c.cn != nil { - return c.closeConn() + return c.releaseConn() } return nil } @@ -166,7 +162,7 @@ func (c *PubSub) PUnsubscribe(patterns ...string) error { } func (c *PubSub) subscribe(redisCmd string, channels ...string) error { - cn, err := c._conn() + cn, err := c._conn(channels) if err != nil { return err } @@ -188,7 +184,7 @@ func (c *PubSub) Ping(payload ...string) error { return err } - cn.SetWriteTimeout(c.base.opt.WriteTimeout) + cn.SetWriteTimeout(c.opt.WriteTimeout) err = writeCmd(cn, cmd) c.putConn(cn, err) return err diff --git a/pubsub_test.go b/pubsub_test.go index e8589f46..3cb9627b 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -159,9 +159,9 @@ var _ = Describe("PubSub", func() { { msgi, err := pubsub.ReceiveTimeout(time.Second) Expect(err).NotTo(HaveOccurred()) - subscr := msgi.(*redis.Message) - Expect(subscr.Channel).To(Equal("mychannel")) - Expect(subscr.Payload).To(Equal("hello")) + msg := msgi.(*redis.Message) + Expect(msg.Channel).To(Equal("mychannel")) + Expect(msg.Payload).To(Equal("hello")) } { diff --git a/redis.go b/redis.go index 9812daf6..1a2ecb0b 100644 --- a/redis.go +++ b/redis.go @@ -21,7 +21,23 @@ func (c *baseClient) String() string { return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB) } -func (c *baseClient) conn() (*pool.Conn, bool, error) { +func (c *baseClient) newConn() (*pool.Conn, error) { + cn, err := c.connPool.NewConn() + if err != nil { + return nil, err + } + + if !cn.Inited { + if err := c.initConn(cn); err != nil { + _ = c.connPool.CloseConn(cn) + return nil, err + } + } + + return cn, nil +} + +func (c *baseClient) getConn() (*pool.Conn, bool, error) { cn, isNew, err := c.connPool.Get() if err != nil { return nil, false, err @@ -37,7 +53,7 @@ func (c *baseClient) conn() (*pool.Conn, bool, error) { return cn, isNew, nil } -func (c *baseClient) putConn(cn *pool.Conn, err error) bool { +func (c *baseClient) releaseConn(cn *pool.Conn, err error) bool { if internal.IsBadConn(err, false) { _ = c.connPool.Remove(cn) return false @@ -112,7 +128,7 @@ func (c *baseClient) defaultProcess(cmd Cmder) error { time.Sleep(internal.RetryBackoff(i, c.opt.MaxRetryBackoff)) } - cn, _, err := c.conn() + cn, _, err := c.getConn() if err != nil { cmd.setErr(err) if internal.IsRetryableError(err) { @@ -123,7 +139,7 @@ func (c *baseClient) defaultProcess(cmd Cmder) error { cn.SetWriteTimeout(c.opt.WriteTimeout) if err := writeCmd(cn, cmd); err != nil { - c.putConn(cn, err) + c.releaseConn(cn, err) cmd.setErr(err) if internal.IsRetryableError(err) { continue @@ -133,7 +149,7 @@ func (c *baseClient) defaultProcess(cmd Cmder) error { cn.SetReadTimeout(c.cmdTimeout(cmd)) err = cmd.readReply(cn) - c.putConn(cn, err) + c.releaseConn(cn, err) if err != nil && internal.IsRetryableError(err) { continue } @@ -179,14 +195,14 @@ func (c *baseClient) pipelineExecer(p pipelineProcessor) pipelineExecer { return func(cmds []Cmder) error { var firstErr error for i := 0; i <= c.opt.MaxRetries; i++ { - cn, _, err := c.conn() + cn, _, err := c.getConn() if err != nil { setCmdsErr(cmds, err) return err } canRetry, err := p(cn, cmds) - c.putConn(cn, err) + c.releaseConn(cn, err) if err == nil { return nil } @@ -375,10 +391,12 @@ func (c *Client) TxPipeline() Pipeliner { func (c *Client) pubSub() *PubSub { return &PubSub{ - base: baseClient{ - opt: c.opt, - connPool: c.connPool, + opt: c.opt, + + newConn: func(channels []string) (*pool.Conn, error) { + return c.newConn() }, + closeConn: c.connPool.CloseConn, } } diff --git a/ring.go b/ring.go index be925109..72d52bf7 100644 --- a/ring.go +++ b/ring.go @@ -423,7 +423,7 @@ func (c *Ring) pipelineExec(cmds []Cmder) (firstErr error) { continue } - cn, _, err := shard.Client.conn() + cn, _, err := shard.Client.getConn() if err != nil { setCmdsErr(cmds, err) if firstErr == nil { @@ -433,7 +433,7 @@ func (c *Ring) pipelineExec(cmds []Cmder) (firstErr error) { } canRetry, err := shard.Client.pipelineProcessCmds(cn, cmds) - shard.Client.putConn(cn, err) + shard.Client.releaseConn(cn, err) if err == nil { continue } diff --git a/sentinel.go b/sentinel.go index ed6e7ffb..3bfdb4a3 100644 --- a/sentinel.go +++ b/sentinel.go @@ -112,10 +112,12 @@ func newSentinel(opt *Options) *sentinelClient { func (c *sentinelClient) PubSub() *PubSub { return &PubSub{ - base: baseClient{ - opt: c.opt, - connPool: c.connPool, + opt: c.opt, + + newConn: func(channels []string) (*pool.Conn, error) { + return c.newConn() }, + closeConn: c.connPool.CloseConn, } } @@ -149,14 +151,6 @@ func (d *sentinelFailover) Close() error { return d.resetSentinel() } -func (d *sentinelFailover) dial() (net.Conn, error) { - addr, err := d.MasterAddr() - if err != nil { - return nil, err - } - return net.DialTimeout("tcp", addr, d.opt.DialTimeout) -} - func (d *sentinelFailover) Pool() *pool.ConnPool { d.poolOnce.Do(func() { d.opt.Dialer = d.dial @@ -165,6 +159,14 @@ func (d *sentinelFailover) Pool() *pool.ConnPool { return d.pool } +func (d *sentinelFailover) dial() (net.Conn, error) { + addr, err := d.MasterAddr() + if err != nil { + return nil, err + } + return net.DialTimeout("tcp", addr, d.opt.DialTimeout) +} + func (d *sentinelFailover) MasterAddr() (string, error) { d.mu.Lock() defer d.mu.Unlock()