From 4a3a3006655b3e92d89cdf45c339b1b1d9fb2f66 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Thu, 25 May 2017 14:16:39 +0300 Subject: [PATCH] Add Options.OnConnect --- cluster.go | 5 +++- commands.go | 4 +-- options.go | 3 +++ redis.go | 72 +++++++++++++++++++++++++++++++++++++++++++-------- redis_test.go | 23 ++++++++++++++++ ring.go | 4 +++ sentinel.go | 4 +++ 7 files changed, 101 insertions(+), 14 deletions(-) diff --git a/cluster.go b/cluster.go index 6a4bbe8..e3c5832 100644 --- a/cluster.go +++ b/cluster.go @@ -35,6 +35,8 @@ type ClusterOptions struct { // Following options are copied from Options struct. + OnConnect func(*Conn) error + MaxRetries int Password string @@ -65,6 +67,8 @@ func (opt *ClusterOptions) clientOptions() *Options { const disableIdleCheck = -1 return &Options{ + OnConnect: opt.OnConnect, + MaxRetries: opt.MaxRetries, Password: opt.Password, ReadOnly: opt.ReadOnly, @@ -77,7 +81,6 @@ func (opt *ClusterOptions) clientOptions() *Options { PoolTimeout: opt.PoolTimeout, IdleTimeout: opt.IdleTimeout, - // IdleCheckFrequency is not copied to disable reaper IdleCheckFrequency: disableIdleCheck, } } diff --git a/commands.go b/commands.go index 51dcf94..3956cf7 100644 --- a/commands.go +++ b/commands.go @@ -42,6 +42,7 @@ type Cmdable interface { Pipeline() Pipeliner Pipelined(fn func(Pipeliner) error) ([]Cmder, error) + ClientGetName() *StringCmd Echo(message interface{}) *StringCmd Ping() *StatusCmd Quit() *StatusCmd @@ -242,7 +243,6 @@ type StatefulCmdable interface { Auth(password string) *StatusCmd Select(index int) *StatusCmd ClientSetName(name string) *BoolCmd - ClientGetName() *StringCmd ReadOnly() *StatusCmd ReadWrite() *StatusCmd } @@ -1649,7 +1649,7 @@ func (c *statefulCmdable) ClientSetName(name string) *BoolCmd { } // ClientGetName returns the name of the connection. -func (c *statefulCmdable) ClientGetName() *StringCmd { +func (c *cmdable) ClientGetName() *StringCmd { cmd := NewStringCmd("client", "getname") c.process(cmd) return cmd diff --git a/options.go b/options.go index d2aefb4..1695c0b 100644 --- a/options.go +++ b/options.go @@ -24,6 +24,9 @@ type Options struct { // Network and Addr options. Dialer func() (net.Conn, error) + // Hook that is called when new connection is established. + OnConnect func(*Conn) error + // Optional password. Must match the password specified in the // requirepass server configuration option. Password string diff --git a/redis.go b/redis.go index 89f985e..303877f 100644 --- a/redis.go +++ b/redis.go @@ -21,11 +21,6 @@ func (c *baseClient) String() string { return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB) } -// Options returns read-only Options that were used to create the client. -func (c *baseClient) Options() *Options { - return c.opt -} - func (c *baseClient) conn() (*pool.Conn, bool, error) { cn, isNew, err := c.connPool.Get() if err != nil { @@ -55,13 +50,23 @@ func (c *baseClient) putConn(cn *pool.Conn, err error) bool { func (c *baseClient) initConn(cn *pool.Conn) error { cn.Inited = true - if c.opt.Password == "" && c.opt.DB == 0 && !c.opt.ReadOnly { + if c.opt.Password == "" && + c.opt.DB == 0 && + !c.opt.ReadOnly && + c.opt.OnConnect == nil { return nil } - // Temp client for Auth and Select. - client := newClient(c.opt, pool.NewSingleConnPool(cn)) - _, err := client.Pipelined(func(pipe Pipeliner) error { + // Temp client to initialize connection. + conn := &Conn{ + baseClient: baseClient{ + opt: c.opt, + connPool: pool.NewSingleConnPool(cn), + }, + } + conn.setProcessor(conn.Process) + + _, err := conn.Pipelined(func(pipe Pipeliner) error { if c.opt.Password != "" { pipe.Auth(c.opt.Password) } @@ -76,7 +81,14 @@ func (c *baseClient) initConn(cn *pool.Conn) error { return nil }) - return err + if err != nil { + return err + } + + if c.opt.OnConnect != nil { + return c.opt.OnConnect(conn) + } + return nil } func (c *baseClient) Process(cmd Cmder) error { @@ -182,7 +194,7 @@ func (c *baseClient) pipelineExecer(p pipelineProcessor) pipelineExecer { } } -func (c *baseClient) pipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (retry bool, firstErr error) { +func (c *baseClient) pipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, error) { cn.SetWriteTimeout(c.opt.WriteTimeout) if err := writeCmd(cn, cmds...); err != nil { setCmdsErr(cmds, err) @@ -311,6 +323,11 @@ func (c *Client) copy() *Client { return c2 } +// Options returns read-only Options that were used to create the client. +func (c *Client) Options() *Options { + return c.opt +} + // PoolStats returns connection pool stats. func (c *Client) PoolStats() *PoolStats { s := c.connPool.Stats() @@ -375,3 +392,36 @@ func (c *Client) PSubscribe(channels ...string) *PubSub { } return pubsub } + +//------------------------------------------------------------------------------ + +// Conn is like Client, but its pool contains single connection. +type Conn struct { + baseClient + statefulCmdable +} + +func (c *Conn) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) { + return c.Pipeline().pipelined(fn) +} + +func (c *Conn) Pipeline() Pipeliner { + pipe := Pipeline{ + exec: c.pipelineExecer(c.pipelineProcessCmds), + } + pipe.setProcessor(pipe.Process) + return &pipe +} + +func (c *Conn) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) { + return c.TxPipeline().pipelined(fn) +} + +// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC. +func (c *Conn) TxPipeline() Pipeliner { + pipe := Pipeline{ + exec: c.pipelineExecer(c.txPipelineProcessCmds), + } + pipe.setProcessor(pipe.Process) + return &pipe +} diff --git a/redis_test.go b/redis_test.go index a27e3bc..407d378 100644 --- a/redis_test.go +++ b/redis_test.go @@ -296,3 +296,26 @@ var _ = Describe("Client timeout", func() { testTimeout() }) }) + +var _ = Describe("Client OnConnect", func() { + var client *redis.Client + + BeforeEach(func() { + opt := redisOptions() + opt.OnConnect = func(cn *redis.Conn) error { + return cn.ClientSetName("on_connect").Err() + } + + client = redis.NewClient(opt) + }) + + AfterEach(func() { + Expect(client.Close()).NotTo(HaveOccurred()) + }) + + It("calls OnConnect", func() { + name, err := client.ClientGetName().Result() + Expect(err).NotTo(HaveOccurred()) + Expect(name).To(Equal("on_connect")) + }) +}) diff --git a/ring.go b/ring.go index 9c57430..a9666bc 100644 --- a/ring.go +++ b/ring.go @@ -29,6 +29,8 @@ type RingOptions struct { // Following options are copied from Options struct. + OnConnect func(*Conn) error + DB int Password string @@ -52,6 +54,8 @@ func (opt *RingOptions) init() { func (opt *RingOptions) clientOptions() *Options { return &Options{ + OnConnect: opt.OnConnect, + DB: opt.DB, Password: opt.Password, diff --git a/sentinel.go b/sentinel.go index da3a431..b28c370 100644 --- a/sentinel.go +++ b/sentinel.go @@ -23,6 +23,8 @@ type FailoverOptions struct { // Following options are copied from Options struct. + OnConnect func(*Conn) error + Password string DB int @@ -42,6 +44,8 @@ func (opt *FailoverOptions) options() *Options { return &Options{ Addr: "FailoverClient", + OnConnect: opt.OnConnect, + DB: opt.DB, Password: opt.Password,