Merge pull request #571 from go-redis/fix/on-connect

Fix/on connect
This commit is contained in:
Vladimir Mihailenco 2017-05-26 14:56:40 +03:00 committed by GitHub
commit ee42c3d5d3
9 changed files with 124 additions and 37 deletions

View File

@ -35,6 +35,8 @@ type ClusterOptions struct {
// Following options are copied from Options struct. // Following options are copied from Options struct.
OnConnect func(*Conn) error
MaxRetries int MaxRetries int
Password string Password string
@ -65,6 +67,8 @@ func (opt *ClusterOptions) clientOptions() *Options {
const disableIdleCheck = -1 const disableIdleCheck = -1
return &Options{ return &Options{
OnConnect: opt.OnConnect,
MaxRetries: opt.MaxRetries, MaxRetries: opt.MaxRetries,
Password: opt.Password, Password: opt.Password,
ReadOnly: opt.ReadOnly, ReadOnly: opt.ReadOnly,
@ -77,7 +81,6 @@ func (opt *ClusterOptions) clientOptions() *Options {
PoolTimeout: opt.PoolTimeout, PoolTimeout: opt.PoolTimeout,
IdleTimeout: opt.IdleTimeout, IdleTimeout: opt.IdleTimeout,
// IdleCheckFrequency is not copied to disable reaper
IdleCheckFrequency: disableIdleCheck, IdleCheckFrequency: disableIdleCheck,
} }
} }
@ -349,7 +352,7 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient {
opt: opt, opt: opt,
nodes: newClusterNodes(opt), nodes: newClusterNodes(opt),
} }
c.cmdable.process = c.Process c.setProcessor(c.Process)
// Add initial nodes. // Add initial nodes.
for _, addr := range opt.Addrs { for _, addr := range opt.Addrs {
@ -678,8 +681,7 @@ func (c *ClusterClient) Pipeline() Pipeliner {
pipe := Pipeline{ pipe := Pipeline{
exec: c.pipelineExec, exec: c.pipelineExec,
} }
pipe.cmdable.process = pipe.Process pipe.setProcessor(pipe.Process)
pipe.statefulCmdable.process = pipe.Process
return &pipe return &pipe
} }
@ -801,8 +803,7 @@ func (c *ClusterClient) TxPipeline() Pipeliner {
pipe := Pipeline{ pipe := Pipeline{
exec: c.txPipelineExec, exec: c.txPipelineExec,
} }
pipe.cmdable.process = pipe.Process pipe.setProcessor(pipe.Process)
pipe.statefulCmdable.process = pipe.Process
return &pipe return &pipe
} }

View File

@ -42,6 +42,7 @@ type Cmdable interface {
Pipeline() Pipeliner Pipeline() Pipeliner
Pipelined(fn func(Pipeliner) error) ([]Cmder, error) Pipelined(fn func(Pipeliner) error) ([]Cmder, error)
ClientGetName() *StringCmd
Echo(message interface{}) *StringCmd Echo(message interface{}) *StringCmd
Ping() *StatusCmd Ping() *StatusCmd
Quit() *StatusCmd Quit() *StatusCmd
@ -238,10 +239,10 @@ type Cmdable interface {
} }
type StatefulCmdable interface { type StatefulCmdable interface {
Cmdable
Auth(password string) *StatusCmd Auth(password string) *StatusCmd
Select(index int) *StatusCmd Select(index int) *StatusCmd
ClientSetName(name string) *BoolCmd ClientSetName(name string) *BoolCmd
ClientGetName() *StringCmd
ReadOnly() *StatusCmd ReadOnly() *StatusCmd
ReadWrite() *StatusCmd ReadWrite() *StatusCmd
} }
@ -255,10 +256,20 @@ type cmdable struct {
process func(cmd Cmder) error process func(cmd Cmder) error
} }
func (c *cmdable) setProcessor(fn func(Cmder) error) {
c.process = fn
}
type statefulCmdable struct { type statefulCmdable struct {
cmdable
process func(cmd Cmder) error process func(cmd Cmder) error
} }
func (c *statefulCmdable) setProcessor(fn func(Cmder) error) {
c.process = fn
c.cmdable.setProcessor(fn)
}
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
func (c *statefulCmdable) Auth(password string) *StatusCmd { func (c *statefulCmdable) Auth(password string) *StatusCmd {
@ -280,7 +291,6 @@ func (c *cmdable) Ping() *StatusCmd {
} }
func (c *cmdable) Wait(numSlaves int, timeout time.Duration) *IntCmd { func (c *cmdable) Wait(numSlaves int, timeout time.Duration) *IntCmd {
cmd := NewIntCmd("wait", numSlaves, int(timeout/time.Millisecond)) cmd := NewIntCmd("wait", numSlaves, int(timeout/time.Millisecond))
c.process(cmd) c.process(cmd)
return cmd return cmd
@ -1639,7 +1649,7 @@ func (c *statefulCmdable) ClientSetName(name string) *BoolCmd {
} }
// ClientGetName returns the name of the connection. // ClientGetName returns the name of the connection.
func (c *statefulCmdable) ClientGetName() *StringCmd { func (c *cmdable) ClientGetName() *StringCmd {
cmd := NewStringCmd("client", "getname") cmd := NewStringCmd("client", "getname")
c.process(cmd) c.process(cmd)
return cmd return cmd

View File

@ -24,6 +24,9 @@ type Options struct {
// Network and Addr options. // Network and Addr options.
Dialer func() (net.Conn, error) Dialer func() (net.Conn, error)
// Hook that is called when new connection is established.
OnConnect func(*Conn) error
// Optional password. Must match the password specified in the // Optional password. Must match the password specified in the
// requirepass server configuration option. // requirepass server configuration option.
Password string Password string

View File

@ -10,7 +10,6 @@ import (
type pipelineExecer func([]Cmder) error type pipelineExecer func([]Cmder) error
type Pipeliner interface { type Pipeliner interface {
Cmdable
StatefulCmdable StatefulCmdable
Process(cmd Cmder) error Process(cmd Cmder) error
Close() error Close() error
@ -26,7 +25,6 @@ var _ Pipeliner = (*Pipeline)(nil)
// http://redis.io/topics/pipelining. It's safe for concurrent use // http://redis.io/topics/pipelining. It's safe for concurrent use
// by multiple goroutines. // by multiple goroutines.
type Pipeline struct { type Pipeline struct {
cmdable
statefulCmdable statefulCmdable
exec pipelineExecer exec pipelineExecer

View File

@ -21,11 +21,6 @@ func (c *baseClient) String() string {
return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB) return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB)
} }
// Options returns read-only Options that were used to create the client.
func (c *baseClient) Options() *Options {
return c.opt
}
func (c *baseClient) conn() (*pool.Conn, bool, error) { func (c *baseClient) conn() (*pool.Conn, bool, error) {
cn, isNew, err := c.connPool.Get() cn, isNew, err := c.connPool.Get()
if err != nil { if err != nil {
@ -55,13 +50,23 @@ func (c *baseClient) putConn(cn *pool.Conn, err error) bool {
func (c *baseClient) initConn(cn *pool.Conn) error { func (c *baseClient) initConn(cn *pool.Conn) error {
cn.Inited = true cn.Inited = true
if c.opt.Password == "" && c.opt.DB == 0 && !c.opt.ReadOnly { if c.opt.Password == "" &&
c.opt.DB == 0 &&
!c.opt.ReadOnly &&
c.opt.OnConnect == nil {
return nil return nil
} }
// Temp client for Auth and Select. // Temp client to initialize connection.
client := newClient(c.opt, pool.NewSingleConnPool(cn)) conn := &Conn{
_, err := client.Pipelined(func(pipe Pipeliner) error { baseClient: baseClient{
opt: c.opt,
connPool: pool.NewSingleConnPool(cn),
},
}
conn.setProcessor(conn.Process)
_, err := conn.Pipelined(func(pipe Pipeliner) error {
if c.opt.Password != "" { if c.opt.Password != "" {
pipe.Auth(c.opt.Password) pipe.Auth(c.opt.Password)
} }
@ -76,7 +81,14 @@ func (c *baseClient) initConn(cn *pool.Conn) error {
return nil return nil
}) })
if err != nil {
return err return err
}
if c.opt.OnConnect != nil {
return c.opt.OnConnect(conn)
}
return nil
} }
func (c *baseClient) Process(cmd Cmder) error { func (c *baseClient) Process(cmd Cmder) error {
@ -189,7 +201,7 @@ func (c *baseClient) pipelineExecer(p pipelineProcessor) pipelineExecer {
} }
} }
func (c *baseClient) pipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (retry bool, firstErr error) { func (c *baseClient) pipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, error) {
cn.SetWriteTimeout(c.opt.WriteTimeout) cn.SetWriteTimeout(c.opt.WriteTimeout)
if err := writeCmd(cn, cmds...); err != nil { if err := writeCmd(cn, cmds...); err != nil {
setCmdsErr(cmds, err) setCmdsErr(cmds, err)
@ -301,7 +313,7 @@ func newClient(opt *Options, pool pool.Pooler) *Client {
connPool: pool, connPool: pool,
}, },
} }
client.cmdable.process = client.Process client.setProcessor(client.Process)
return &client return &client
} }
@ -314,10 +326,15 @@ func NewClient(opt *Options) *Client {
func (c *Client) copy() *Client { func (c *Client) copy() *Client {
c2 := new(Client) c2 := new(Client)
*c2 = *c *c2 = *c
c2.cmdable.process = c2.Process c2.setProcessor(c2.Process)
return c2 return c2
} }
// Options returns read-only Options that were used to create the client.
func (c *Client) Options() *Options {
return c.opt
}
// PoolStats returns connection pool stats. // PoolStats returns connection pool stats.
func (c *Client) PoolStats() *PoolStats { func (c *Client) PoolStats() *PoolStats {
s := c.connPool.Stats() s := c.connPool.Stats()
@ -339,8 +356,7 @@ func (c *Client) Pipeline() Pipeliner {
pipe := Pipeline{ pipe := Pipeline{
exec: c.pipelineExecer(c.pipelineProcessCmds), exec: c.pipelineExecer(c.pipelineProcessCmds),
} }
pipe.cmdable.process = pipe.Process pipe.setProcessor(pipe.Process)
pipe.statefulCmdable.process = pipe.Process
return &pipe return &pipe
} }
@ -353,8 +369,7 @@ func (c *Client) TxPipeline() Pipeliner {
pipe := Pipeline{ pipe := Pipeline{
exec: c.pipelineExecer(c.txPipelineProcessCmds), exec: c.pipelineExecer(c.txPipelineProcessCmds),
} }
pipe.cmdable.process = pipe.Process pipe.setProcessor(pipe.Process)
pipe.statefulCmdable.process = pipe.Process
return &pipe return &pipe
} }
@ -384,3 +399,36 @@ func (c *Client) PSubscribe(channels ...string) *PubSub {
} }
return pubsub return pubsub
} }
//------------------------------------------------------------------------------
// Conn is like Client, but its pool contains single connection.
type Conn struct {
baseClient
statefulCmdable
}
func (c *Conn) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().pipelined(fn)
}
func (c *Conn) Pipeline() Pipeliner {
pipe := Pipeline{
exec: c.pipelineExecer(c.pipelineProcessCmds),
}
pipe.setProcessor(pipe.Process)
return &pipe
}
func (c *Conn) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.TxPipeline().pipelined(fn)
}
// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC.
func (c *Conn) TxPipeline() Pipeliner {
pipe := Pipeline{
exec: c.pipelineExecer(c.txPipelineProcessCmds),
}
pipe.setProcessor(pipe.Process)
return &pipe
}

View File

@ -338,3 +338,26 @@ var _ = Describe("Client timeout", func() {
testTimeout() testTimeout()
}) })
}) })
var _ = Describe("Client OnConnect", func() {
var client *redis.Client
BeforeEach(func() {
opt := redisOptions()
opt.OnConnect = func(cn *redis.Conn) error {
return cn.ClientSetName("on_connect").Err()
}
client = redis.NewClient(opt)
})
AfterEach(func() {
Expect(client.Close()).NotTo(HaveOccurred())
})
It("calls OnConnect", func() {
name, err := client.ClientGetName().Result()
Expect(err).NotTo(HaveOccurred())
Expect(name).To(Equal("on_connect"))
})
})

View File

@ -29,6 +29,8 @@ type RingOptions struct {
// Following options are copied from Options struct. // Following options are copied from Options struct.
OnConnect func(*Conn) error
DB int DB int
Password string Password string
@ -52,6 +54,8 @@ func (opt *RingOptions) init() {
func (opt *RingOptions) clientOptions() *Options { func (opt *RingOptions) clientOptions() *Options {
return &Options{ return &Options{
OnConnect: opt.OnConnect,
DB: opt.DB, DB: opt.DB,
Password: opt.Password, Password: opt.Password,
@ -148,7 +152,7 @@ func NewRing(opt *RingOptions) *Ring {
cmdsInfoOnce: new(sync.Once), cmdsInfoOnce: new(sync.Once),
} }
ring.cmdable.process = ring.Process ring.setProcessor(ring.Process)
for name, addr := range opt.Addrs { for name, addr := range opt.Addrs {
clopt := opt.clientOptions() clopt := opt.clientOptions()
clopt.Addr = addr clopt.Addr = addr
@ -385,8 +389,7 @@ func (c *Ring) Pipeline() Pipeliner {
pipe := Pipeline{ pipe := Pipeline{
exec: c.pipelineExec, exec: c.pipelineExec,
} }
pipe.cmdable.process = pipe.Process pipe.setProcessor(pipe.Process)
pipe.statefulCmdable.process = pipe.Process
return &pipe return &pipe
} }

View File

@ -23,6 +23,8 @@ type FailoverOptions struct {
// Following options are copied from Options struct. // Following options are copied from Options struct.
OnConnect func(*Conn) error
Password string Password string
DB int DB int
@ -42,6 +44,8 @@ func (opt *FailoverOptions) options() *Options {
return &Options{ return &Options{
Addr: "FailoverClient", Addr: "FailoverClient",
OnConnect: opt.OnConnect,
DB: opt.DB, DB: opt.DB,
Password: opt.Password, Password: opt.Password,
@ -82,7 +86,7 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
}, },
}, },
} }
client.cmdable.process = client.Process client.setProcessor(client.Process)
return &client return &client
} }

7
tx.go
View File

@ -13,7 +13,6 @@ const TxFailedErr = internal.RedisError("redis: transaction failed")
// by multiple goroutines, because Exec resets list of watched keys. // by multiple goroutines, because Exec resets list of watched keys.
// If you don't need WATCH it is better to use Pipeline. // If you don't need WATCH it is better to use Pipeline.
type Tx struct { type Tx struct {
cmdable
statefulCmdable statefulCmdable
baseClient baseClient
} }
@ -25,8 +24,7 @@ func (c *Client) newTx() *Tx {
connPool: pool.NewStickyConnPool(c.connPool.(*pool.ConnPool), true), connPool: pool.NewStickyConnPool(c.connPool.(*pool.ConnPool), true),
}, },
} }
tx.cmdable.process = tx.Process tx.setProcessor(tx.Process)
tx.statefulCmdable.process = tx.Process
return &tx return &tx
} }
@ -80,8 +78,7 @@ func (c *Tx) Pipeline() Pipeliner {
pipe := Pipeline{ pipe := Pipeline{
exec: c.pipelineExecer(c.txPipelineProcessCmds), exec: c.pipelineExecer(c.txPipelineProcessCmds),
} }
pipe.cmdable.process = pipe.Process pipe.setProcessor(pipe.Process)
pipe.statefulCmdable.process = pipe.Process
return &pipe return &pipe
} }