diff --git a/example_test.go b/example_test.go index ca0c0b1..b75163c 100644 --- a/example_test.go +++ b/example_test.go @@ -118,8 +118,8 @@ func ExamplePubSub() { msg, err = pubsub.Receive() fmt.Println(msg, err) - // Output: &{subscribe mychannel 1} - // &{mychannel hello} + // Output: subscribe: mychannel + // Message } func ExampleScript() { diff --git a/pool.go b/pool.go index 452562d..a7ae796 100644 --- a/pool.go +++ b/pool.go @@ -27,15 +27,16 @@ type pool interface { Len() int Size() int Close() error + Filter(func(*conn) bool) } //------------------------------------------------------------------------------ type conn struct { - cn net.Conn + netcn net.Conn rd reader - inUse bool + inUse bool usedAt time.Time readTimeout time.Duration @@ -50,9 +51,8 @@ func newConnFunc(dial func() (net.Conn, error)) func() (*conn, error) { if err != nil { return nil, err } - cn := &conn{ - cn: netcn, + netcn: netcn, } cn.rd = bufio.NewReader(cn) return cn, nil @@ -61,24 +61,28 @@ func newConnFunc(dial func() (net.Conn, error)) func() (*conn, error) { func (cn *conn) Read(b []byte) (int, error) { if cn.readTimeout != 0 { - cn.cn.SetReadDeadline(time.Now().Add(cn.readTimeout)) + cn.netcn.SetReadDeadline(time.Now().Add(cn.readTimeout)) } else { - cn.cn.SetReadDeadline(zeroTime) + cn.netcn.SetReadDeadline(zeroTime) } - return cn.cn.Read(b) + return cn.netcn.Read(b) } func (cn *conn) Write(b []byte) (int, error) { if cn.writeTimeout != 0 { - cn.cn.SetWriteDeadline(time.Now().Add(cn.writeTimeout)) + cn.netcn.SetWriteDeadline(time.Now().Add(cn.writeTimeout)) } else { - cn.cn.SetWriteDeadline(zeroTime) + cn.netcn.SetWriteDeadline(zeroTime) } - return cn.cn.Write(b) + return cn.netcn.Write(b) +} + +func (cn *conn) RemoteAddr() net.Addr { + return cn.netcn.RemoteAddr() } func (cn *conn) Close() error { - return cn.cn.Close() + return cn.netcn.Close() } //------------------------------------------------------------------------------ @@ -87,30 +91,24 @@ type connPool struct { dial func() (*conn, error) rl *rateLimiter + opt *options + cond *sync.Cond conns *list.List - idleNum int - maxSize int - idleTimeout time.Duration - - closed bool + idleNum int + closed bool } -func newConnPool( - dial func() (*conn, error), - maxSize int, - idleTimeout time.Duration, -) *connPool { +func newConnPool(dial func() (*conn, error), opt *options) *connPool { return &connPool{ dial: dial, - rl: newRateLimiter(time.Second, 2*maxSize), + rl: newRateLimiter(time.Second, 2*opt.PoolSize), + + opt: opt, cond: sync.NewCond(&sync.Mutex{}), conns: list.New(), - - maxSize: maxSize, - idleTimeout: idleTimeout, } } @@ -131,13 +129,13 @@ func (p *connPool) Get() (*conn, bool, error) { return nil, false, errClosed } - if p.idleTimeout > 0 { + if p.opt.IdleTimeout > 0 { for el := p.conns.Front(); el != nil; el = el.Next() { cn := el.Value.(*conn) if cn.inUse { break } - if time.Since(cn.usedAt) > p.idleTimeout { + if time.Since(cn.usedAt) > p.opt.IdleTimeout { if err := p.remove(cn); err != nil { glog.Errorf("remove failed: %s", err) } @@ -145,7 +143,7 @@ func (p *connPool) Get() (*conn, bool, error) { } } - for p.conns.Len() >= p.maxSize && p.idleNum == 0 { + for p.conns.Len() >= p.opt.PoolSize && p.idleNum == 0 { p.cond.Wait() } @@ -163,8 +161,8 @@ func (p *connPool) Get() (*conn, bool, error) { return cn, false, nil } - if p.conns.Len() < p.maxSize { - cn, err := p.new() + if p.conns.Len() < p.opt.PoolSize { + cn, err := p.dial() if err != nil { p.cond.L.Unlock() return nil, false, err @@ -187,7 +185,7 @@ func (p *connPool) Put(cn *conn) error { return p.Remove(cn) } - if p.idleTimeout > 0 { + if p.opt.IdleTimeout > 0 { cn.usedAt = time.Now() } @@ -241,6 +239,18 @@ func (p *connPool) Size() int { return p.conns.Len() } +func (p *connPool) Filter(f func(*conn) bool) { + p.cond.L.Lock() + for el, next := p.conns.Front(), p.conns.Front(); el != nil; el = next { + next = el.Next() + cn := el.Value.(*conn) + if !f(cn) { + p.remove(cn) + } + } + p.cond.L.Unlock() +} + func (p *connPool) Close() error { defer p.cond.L.Unlock() p.cond.L.Lock() @@ -249,7 +259,11 @@ func (p *connPool) Close() error { } p.closed = true var retErr error - for e := p.conns.Front(); e != nil; e = e.Next() { + for { + e := p.conns.Front() + if e == nil { + break + } if err := p.remove(e.Value.(*conn)); err != nil { glog.Errorf("cn.Close failed: %s", err) retErr = err @@ -315,17 +329,24 @@ func (p *singleConnPool) Put(cn *conn) error { } func (p *singleConnPool) Remove(cn *conn) error { + defer p.l.Unlock() p.l.Lock() + if p.cn == nil { + panic("p.cn == nil") + } if p.cn != cn { panic("p.cn != cn") } if p.closed { - p.l.Unlock() return errClosed } + return p.remove() +} + +func (p *singleConnPool) remove() error { + err := p.pool.Remove(p.cn) p.cn = nil - p.l.Unlock() - return nil + return err } func (p *singleConnPool) Len() int { @@ -346,15 +367,23 @@ func (p *singleConnPool) Size() int { return 1 } +func (p *singleConnPool) Filter(f func(*conn) bool) { + p.l.Lock() + if p.cn != nil { + if !f(p.cn) { + p.remove() + } + } + p.l.Unlock() +} + func (p *singleConnPool) Close() error { defer p.l.Unlock() p.l.Lock() - if p.closed { return nil } p.closed = true - var err error if p.cn != nil { if p.reusable { @@ -364,6 +393,5 @@ func (p *singleConnPool) Close() error { } } p.cn = nil - return err } diff --git a/pubsub.go b/pubsub.go index cc8d7e0..bc69fb1 100644 --- a/pubsub.go +++ b/pubsub.go @@ -30,18 +30,30 @@ type Message struct { Payload string } +func (m *Message) String() string { + return fmt.Sprintf("Message<%s: %s>", m.Channel, m.Payload) +} + type PMessage struct { Channel string Pattern string Payload string } +func (m *PMessage) String() string { + return fmt.Sprintf("PMessage<%s: %s>", m.Channel, m.Payload) +} + type Subscription struct { Kind string Channel string Count int } +func (m *Subscription) String() string { + return fmt.Sprintf("%s: %s", m.Kind, m.Channel) +} + func (c *PubSub) Receive() (interface{}, error) { return c.ReceiveTimeout(0) } diff --git a/redis.go b/redis.go index 6e945da..7e96a26 100644 --- a/redis.go +++ b/redis.go @@ -9,10 +9,8 @@ import ( type baseClient struct { connPool pool - - opt *Options - - cmds []Cmder + opt *options + cmds []Cmder } func (c *baseClient) writeCmd(cn *conn, cmds ...Cmder) error { @@ -133,17 +131,29 @@ func (c *baseClient) Close() error { //------------------------------------------------------------------------------ +type options struct { + Password string + DB int64 + + DialTimeout time.Duration + ReadTimeout time.Duration + WriteTimeout time.Duration + + PoolSize int + IdleTimeout time.Duration +} + type Options struct { Addr string Password string DB int64 - PoolSize int - DialTimeout time.Duration ReadTimeout time.Duration WriteTimeout time.Duration - IdleTimeout time.Duration + + PoolSize int + IdleTimeout time.Duration } func (opt *Options) getPoolSize() int { @@ -160,32 +170,41 @@ func (opt *Options) getDialTimeout() time.Duration { return opt.DialTimeout } -//------------------------------------------------------------------------------ +func (opt *Options) options() *options { + return &options{ + DB: opt.DB, + Password: opt.Password, + + DialTimeout: opt.getDialTimeout(), + ReadTimeout: opt.ReadTimeout, + WriteTimeout: opt.WriteTimeout, + + PoolSize: opt.getPoolSize(), + IdleTimeout: opt.IdleTimeout, + } +} type Client struct { *baseClient } -func newClient(opt *Options, dial func() (net.Conn, error)) *Client { +func newClient(clOpt *Options, network string) *Client { + opt := clOpt.options() + dialer := func() (net.Conn, error) { + return net.DialTimeout(network, clOpt.Addr, opt.DialTimeout) + } return &Client{ baseClient: &baseClient{ - opt: opt, - - connPool: newConnPool(newConnFunc(dial), opt.getPoolSize(), opt.IdleTimeout), + opt: opt, + connPool: newConnPool(newConnFunc(dialer), opt), }, } } func NewTCPClient(opt *Options) *Client { - dial := func() (net.Conn, error) { - return net.DialTimeout("tcp", opt.Addr, opt.getDialTimeout()) - } - return newClient(opt, dial) + return newClient(opt, "tcp") } func NewUnixClient(opt *Options) *Client { - dial := func() (net.Conn, error) { - return net.DialTimeout("unix", opt.Addr, opt.getDialTimeout()) - } - return newClient(opt, dial) + return newClient(opt, "unix") } diff --git a/sentinel.go b/sentinel.go new file mode 100644 index 0000000..ca0499f --- /dev/null +++ b/sentinel.go @@ -0,0 +1,288 @@ +package redis + +import ( + "errors" + "net" + "strings" + "sync" + "time" + + "github.com/golang/glog" +) + +//------------------------------------------------------------------------------ + +type FailoverOptions struct { + MasterName string + SentinelAddrs []string + + Password string + DB int64 + + PoolSize int + + DialTimeout time.Duration + ReadTimeout time.Duration + WriteTimeout time.Duration + IdleTimeout time.Duration +} + +func (opt *FailoverOptions) getPoolSize() int { + if opt.PoolSize == 0 { + return 10 + } + return opt.PoolSize +} + +func (opt *FailoverOptions) getDialTimeout() time.Duration { + if opt.DialTimeout == 0 { + return 5 * time.Second + } + return opt.DialTimeout +} + +func (opt *FailoverOptions) options() *options { + return &options{ + DB: opt.DB, + Password: opt.Password, + + DialTimeout: opt.getDialTimeout(), + ReadTimeout: opt.ReadTimeout, + WriteTimeout: opt.WriteTimeout, + + PoolSize: opt.getPoolSize(), + IdleTimeout: opt.IdleTimeout, + } +} + +func NewFailoverClient(failoverOpt *FailoverOptions) *Client { + opt := failoverOpt.options() + failover := &sentinelFailover{ + masterName: failoverOpt.MasterName, + sentinelAddrs: failoverOpt.SentinelAddrs, + + opt: opt, + } + return &Client{ + baseClient: &baseClient{ + opt: opt, + connPool: failover.Pool(), + }, + } +} + +//------------------------------------------------------------------------------ + +type sentinelClient struct { + *baseClient +} + +func newSentinel(clOpt *Options) *sentinelClient { + opt := clOpt.options() + dialer := func() (net.Conn, error) { + return net.DialTimeout("tcp", clOpt.Addr, opt.DialTimeout) + } + return &sentinelClient{ + baseClient: &baseClient{ + opt: opt, + connPool: newConnPool(newConnFunc(dialer), opt), + }, + } +} + +func (c *sentinelClient) PubSub() *PubSub { + return &PubSub{ + baseClient: &baseClient{ + opt: c.opt, + connPool: newSingleConnPool(c.connPool, nil, false), + }, + } +} + +func (c *sentinelClient) GetMasterAddrByName(name string) *StringSliceCmd { + cmd := NewStringSliceCmd("SENTINEL", "get-master-addr-by-name", name) + c.Process(cmd) + return cmd +} + +func (c *sentinelClient) Sentinels(name string) *SliceCmd { + cmd := NewSliceCmd("SENTINEL", "sentinels", name) + c.Process(cmd) + return cmd +} + +type sentinelFailover struct { + masterName string + sentinelAddrs []string + + opt *options + + pool pool + poolOnce sync.Once + + lock sync.RWMutex + _sentinel *sentinelClient +} + +func (d *sentinelFailover) dial() (net.Conn, error) { + addr, err := d.MasterAddr() + if err != nil { + return nil, err + } + return net.DialTimeout("tcp", addr, d.opt.DialTimeout) +} + +func (d *sentinelFailover) Pool() pool { + d.poolOnce.Do(func() { + d.pool = newConnPool(newConnFunc(d.dial), d.opt) + }) + return d.pool +} + +func (d *sentinelFailover) MasterAddr() (string, error) { + defer d.lock.Unlock() + d.lock.Lock() + + // Try last working sentinel. + if d._sentinel != nil { + addr, err := d._sentinel.GetMasterAddrByName(d.masterName).Result() + if err != nil { + glog.Errorf("redis-sentinel: GetMasterAddrByName %s failed: %s", d.masterName, err) + d.resetSentinel() + } else { + addr := net.JoinHostPort(addr[0], addr[1]) + glog.Infof("redis-sentinel: %s addr is %s", d.masterName, addr) + return addr, nil + } + } + + for i, addr := range d.sentinelAddrs { + sentinel := newSentinel(&Options{ + Addr: addr, + + DB: d.opt.DB, + Password: d.opt.Password, + + DialTimeout: d.opt.DialTimeout, + ReadTimeout: d.opt.ReadTimeout, + WriteTimeout: d.opt.WriteTimeout, + + PoolSize: d.opt.PoolSize, + IdleTimeout: d.opt.IdleTimeout, + }) + addr, err := sentinel.GetMasterAddrByName(d.masterName).Result() + if err != nil { + glog.Errorf("redis-sentinel: GetMasterAddrByName %s failed: %s", d.masterName, err) + } else { + // Push working sentinel to the top. + d.sentinelAddrs[0], d.sentinelAddrs[i] = d.sentinelAddrs[i], d.sentinelAddrs[0] + + d.setSentinel(sentinel) + addr := net.JoinHostPort(addr[0], addr[1]) + glog.Infof("redis-sentinel: %s addr is %s", d.masterName, addr) + return addr, nil + } + } + + return "", errors.New("redis: all sentinels are unreachable") +} + +func (d *sentinelFailover) setSentinel(sentinel *sentinelClient) { + d.discoverSentinels(sentinel) + d._sentinel = sentinel + go d.listen() +} + +func (d *sentinelFailover) discoverSentinels(sentinel *sentinelClient) { + sentinels, err := sentinel.Sentinels(d.masterName).Result() + if err != nil { + glog.Errorf("redis-sentinel: Sentinels %s failed: %s", d.masterName, err) + return + } + for _, sentinel := range sentinels { + vals := sentinel.([]interface{}) + for i := 0; i < len(vals); i += 2 { + key := vals[i].(string) + if key == "name" { + sentinelAddr := vals[i+1].(string) + if !contains(d.sentinelAddrs, sentinelAddr) { + glog.Infof( + "redis-sentinel: discovered new sentinel for %s: %s", + d.masterName, sentinelAddr, + ) + d.sentinelAddrs = append(d.sentinelAddrs, sentinelAddr) + } + } + } + } +} + +func (d *sentinelFailover) listen() { + var pubsub *PubSub + for { + if pubsub == nil { + pubsub = d._sentinel.PubSub() + if err := pubsub.Subscribe("+switch-master"); err != nil { + glog.Errorf("redis-sentinel: Subscribe failed: %s", err) + d.lock.Lock() + d.resetSentinel() + d.lock.Unlock() + return + } + } + + msgIface, err := pubsub.Receive() + if err != nil { + glog.Errorf("redis-sentinel: Receive failed: %s", err) + pubsub = nil + return + } + + switch msg := msgIface.(type) { + case *Message: + switch msg.Channel { + case "+switch-master": + parts := strings.Split(msg.Payload, " ") + if parts[0] != d.masterName { + glog.Errorf("redis-sentinel: ignore new %s addr", parts[0]) + continue + } + addr := net.JoinHostPort(parts[3], parts[4]) + glog.Infof( + "redis-sentinel: new %s addr is %s", + d.masterName, addr, + ) + d.pool.Filter(func(cn *conn) bool { + if cn.RemoteAddr().String() != addr { + glog.Infof( + "redis-sentinel: closing connection to old master %s", + cn.RemoteAddr(), + ) + return false + } + return true + }) + default: + glog.Errorf("redis-sentinel: unsupported message: %s", msg) + } + case *Subscription: + // Ignore. + default: + glog.Errorf("redis-sentinel: unsupported message: %s", msgIface) + } + } +} + +func (d *sentinelFailover) resetSentinel() { + d._sentinel.Close() + d._sentinel = nil +} + +func contains(slice []string, str string) bool { + for _, s := range slice { + if s == str { + return true + } + } + return false +}