Extract pipeline and multi/exec support to separate files.

This commit is contained in:
Vladimir Mihailenco 2012-08-11 17:42:10 +03:00
parent 83664bb3a8
commit 2f4156dd04
8 changed files with 1273 additions and 1021 deletions

View File

@ -20,6 +20,13 @@ Install:
go get github.com/vmihailenco/redis go get github.com/vmihailenco/redis
Contributing
------------
Configure Redis to allow maximum 10 clients:
maxclients 10
Run tests: Run tests:
go test -gocheck.v go test -gocheck.v

View File

@ -1,7 +1,6 @@
package redis package redis
import ( import (
"fmt"
"strconv" "strconv"
) )
@ -843,138 +842,3 @@ func (c *Client) ZUnionStore(
c.Process(req) c.Process(req)
return req return req
} }
//------------------------------------------------------------------------------
func (c *Client) PubSubClient() (*PubSubClient, error) {
return newPubSubClient(c)
}
func (c *Client) Publish(channel, message string) *IntReq {
req := NewIntReq("PUBLISH", channel, message)
c.Process(req)
return req
}
//------------------------------------------------------------------------------
func (c *Client) PipelineClient() (*Client, error) {
return &Client{
ConnPool: c.ConnPool,
InitConn: c.InitConn,
reqs: make([]Req, 0),
}, nil
}
//------------------------------------------------------------------------------
func (c *Client) MultiClient() (*Client, error) {
return &Client{
ConnPool: NewSingleConnPool(c.ConnPool),
InitConn: c.InitConn,
}, nil
}
func (c *Client) Multi() {
c.reqs = make([]Req, 0)
}
func (c *Client) Watch(keys ...string) *StatusReq {
args := append([]string{"WATCH"}, keys...)
req := NewStatusReq(args...)
c.Process(req)
return req
}
func (c *Client) Unwatch(keys ...string) *StatusReq {
args := append([]string{"UNWATCH"}, keys...)
req := NewStatusReq(args...)
c.Process(req)
return req
}
func (c *Client) Discard() {
c.mtx.Lock()
c.reqs = c.reqs[:0]
c.mtx.Unlock()
}
func (c *Client) Exec() ([]Req, error) {
c.mtx.Lock()
if len(c.reqs) == 0 {
c.mtx.Unlock()
return c.reqs, nil
}
reqs := c.reqs
c.reqs = nil
c.mtx.Unlock()
conn, err := c.conn()
if err != nil {
return nil, err
}
err = c.ExecReqs(reqs, conn)
if err != nil {
c.ConnPool.Remove(conn)
return nil, err
}
c.ConnPool.Add(conn)
return reqs, nil
}
func (c *Client) ExecReqs(reqs []Req, conn *Conn) error {
multiReq := make([]byte, 0, 1024)
multiReq = append(multiReq, PackReq([]string{"MULTI"})...)
for _, req := range reqs {
multiReq = append(multiReq, req.Req()...)
}
multiReq = append(multiReq, PackReq([]string{"EXEC"})...)
err := c.WriteReq(multiReq, conn)
if err != nil {
return err
}
statusReq := NewStatusReq()
// Parse MULTI command reply.
_, err = statusReq.ParseReply(conn.Rd)
if err != nil {
return err
}
// Parse queued replies.
for _ = range reqs {
_, err = statusReq.ParseReply(conn.Rd)
if err != nil {
return err
}
}
// Parse number of replies.
line, err := readLine(conn.Rd)
if err != nil {
return err
}
if line[0] != '*' {
return fmt.Errorf("Expected '*', but got line %q", line)
}
if isNilReplies(line) {
return Nil
}
// Parse replies.
for i := 0; i < len(reqs); i++ {
req := reqs[i]
val, err := req.ParseReply(conn.Rd)
if err != nil {
req.SetErr(err)
} else {
req.SetVal(val)
}
}
return nil
}

View File

@ -9,11 +9,11 @@ import (
) )
type Conn struct { type Conn struct {
RW io.ReadWriter RW io.ReadWriteCloser
Rd *bufio.Reader Rd *bufio.Reader
} }
func NewConn(rw io.ReadWriter) *Conn { func NewConn(rw io.ReadWriteCloser) *Conn {
return &Conn{ return &Conn{
RW: rw, RW: rw,
Rd: bufio.NewReaderSize(rw, 1024), Rd: bufio.NewReaderSize(rw, 1024),
@ -22,10 +22,10 @@ func NewConn(rw io.ReadWriter) *Conn {
type ConnPool interface { type ConnPool interface {
Get() (*Conn, bool, error) Get() (*Conn, bool, error)
Add(*Conn) Add(*Conn) error
Remove(*Conn) Remove(*Conn) error
Len() int Len() int
Close() Close() error
} }
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
@ -36,10 +36,10 @@ type MultiConnPool struct {
conns []*Conn conns []*Conn
OpenConn OpenConnFunc OpenConn OpenConnFunc
CloseConn CloseConnFunc CloseConn CloseConnFunc
cap, MaxCap int64 cap, MaxCap int
} }
func NewMultiConnPool(openConn OpenConnFunc, closeConn CloseConnFunc, maxCap int64) *MultiConnPool { func NewMultiConnPool(openConn OpenConnFunc, closeConn CloseConnFunc, maxCap int) *MultiConnPool {
logger := log.New( logger := log.New(
os.Stdout, os.Stdout,
"redis.connpool: ", "redis.connpool: ",
@ -81,29 +81,54 @@ func (p *MultiConnPool) Get() (*Conn, bool, error) {
return conn, false, nil return conn, false, nil
} }
func (p *MultiConnPool) Add(conn *Conn) { func (p *MultiConnPool) Add(conn *Conn) error {
p.cond.L.Lock() p.cond.L.Lock()
defer p.cond.L.Unlock() defer p.cond.L.Unlock()
p.conns = append(p.conns, conn) p.conns = append(p.conns, conn)
p.cond.Signal() p.cond.Signal()
return nil
} }
func (p *MultiConnPool) Remove(conn *Conn) { func (p *MultiConnPool) Remove(conn *Conn) error {
defer func() {
p.cond.L.Lock() p.cond.L.Lock()
p.cap-- p.cap--
p.cond.Signal() p.cond.Signal()
p.cond.L.Unlock() p.cond.L.Unlock()
}()
if p.CloseConn != nil && conn != nil { if conn == nil {
p.CloseConn(conn.RW) return nil
} }
return p.closeConn(conn)
} }
func (p *MultiConnPool) Len() int { func (p *MultiConnPool) Len() int {
return len(p.conns) return len(p.conns)
} }
func (p *MultiConnPool) Close() {} func (p *MultiConnPool) Close() error {
p.cond.L.Lock()
defer p.cond.L.Unlock()
for _, conn := range p.conns {
err := p.closeConn(conn)
if err != nil {
return err
}
}
p.conns = make([]*Conn, 0)
p.cap = 0
return nil
}
func (p *MultiConnPool) closeConn(conn *Conn) error {
if p.CloseConn != nil {
err := p.CloseConn(conn.RW)
if err != nil {
return err
}
}
return conn.RW.Close()
}
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
@ -111,17 +136,19 @@ type SingleConnPool struct {
mtx sync.Mutex mtx sync.Mutex
pool ConnPool pool ConnPool
conn *Conn conn *Conn
isReusable bool
} }
func NewSingleConnPoolConn(pool ConnPool, conn *Conn) *SingleConnPool { func NewSingleConnPoolConn(pool ConnPool, conn *Conn, isReusable bool) *SingleConnPool {
return &SingleConnPool{ return &SingleConnPool{
pool: pool, pool: pool,
conn: conn, conn: conn,
isReusable: isReusable,
} }
} }
func NewSingleConnPool(pool ConnPool) *SingleConnPool { func NewSingleConnPool(pool ConnPool, isReusable bool) *SingleConnPool {
return NewSingleConnPoolConn(pool, nil) return NewSingleConnPoolConn(pool, nil, isReusable)
} }
func (p *SingleConnPool) Get() (*Conn, bool, error) { func (p *SingleConnPool) Get() (*Conn, bool, error) {
@ -138,17 +165,32 @@ func (p *SingleConnPool) Get() (*Conn, bool, error) {
return p.conn, isNew, nil return p.conn, isNew, nil
} }
func (p *SingleConnPool) Add(conn *Conn) {} func (p *SingleConnPool) Add(conn *Conn) error {
return nil
}
func (p *SingleConnPool) Remove(conn *Conn) {} func (p *SingleConnPool) Remove(conn *Conn) error {
return nil
}
func (p *SingleConnPool) Len() int { func (p *SingleConnPool) Len() int {
return 1 return 1
} }
func (p *SingleConnPool) Close() { func (p *SingleConnPool) Close() error {
p.mtx.Lock() p.mtx.Lock()
defer p.mtx.Unlock() defer p.mtx.Unlock()
p.pool.Add(p.conn)
p.conn = nil if p.conn == nil {
return nil
}
var err error
if p.isReusable {
err = p.pool.Add(p.conn)
} else {
err = p.pool.Remove(p.conn)
}
p.conn = nil
return err
} }

124
multi.go Normal file
View File

@ -0,0 +1,124 @@
package redis
import (
"fmt"
)
type MultiClient struct {
*Client
}
func (c *Client) MultiClient() (*MultiClient, error) {
return &MultiClient{
Client: &Client{
BaseClient: &BaseClient{
ConnPool: NewSingleConnPool(c.ConnPool, true),
InitConn: c.InitConn,
},
},
}, nil
}
func (c *MultiClient) Multi() {
c.reqs = make([]Req, 0)
}
func (c *MultiClient) Watch(keys ...string) *StatusReq {
args := append([]string{"WATCH"}, keys...)
req := NewStatusReq(args...)
c.Process(req)
return req
}
func (c *MultiClient) Unwatch(keys ...string) *StatusReq {
args := append([]string{"UNWATCH"}, keys...)
req := NewStatusReq(args...)
c.Process(req)
return req
}
func (c *MultiClient) Discard() {
c.mtx.Lock()
c.reqs = c.reqs[:0]
c.mtx.Unlock()
}
func (c *MultiClient) Exec() ([]Req, error) {
c.mtx.Lock()
if len(c.reqs) == 0 {
c.mtx.Unlock()
return c.reqs, nil
}
reqs := c.reqs
c.reqs = nil
c.mtx.Unlock()
conn, err := c.conn()
if err != nil {
return nil, err
}
err = c.ExecReqs(reqs, conn)
if err != nil {
c.ConnPool.Remove(conn)
return nil, err
}
c.ConnPool.Add(conn)
return reqs, nil
}
func (c *MultiClient) ExecReqs(reqs []Req, conn *Conn) error {
multiReq := make([]byte, 0, 1024)
multiReq = append(multiReq, PackReq([]string{"MULTI"})...)
for _, req := range reqs {
multiReq = append(multiReq, req.Req()...)
}
multiReq = append(multiReq, PackReq([]string{"EXEC"})...)
err := c.WriteReq(multiReq, conn)
if err != nil {
return err
}
statusReq := NewStatusReq()
// Parse MULTI command reply.
_, err = statusReq.ParseReply(conn.Rd)
if err != nil {
return err
}
// Parse queued replies.
for _ = range reqs {
_, err = statusReq.ParseReply(conn.Rd)
if err != nil {
return err
}
}
// Parse number of replies.
line, err := readLine(conn.Rd)
if err != nil {
return err
}
if line[0] != '*' {
return fmt.Errorf("Expected '*', but got line %q", line)
}
if isNilReplies(line) {
return Nil
}
// Parse replies.
for i := 0; i < len(reqs); i++ {
req := reqs[i]
val, err := req.ParseReply(conn.Rd)
if err != nil {
req.SetErr(err)
} else {
req.SetVal(val)
}
}
return nil
}

75
pipeline.go Normal file
View File

@ -0,0 +1,75 @@
package redis
type PipelineClient struct {
*Client
}
func (c *Client) PipelineClient() (*PipelineClient, error) {
return &PipelineClient{
Client: &Client{
BaseClient: &BaseClient{
ConnPool: c.ConnPool,
InitConn: c.InitConn,
reqs: make([]Req, 0),
},
},
}, nil
}
func (c *PipelineClient) Close() error {
return nil
}
func (c *PipelineClient) RunQueued() ([]Req, error) {
c.mtx.Lock()
if len(c.reqs) == 0 {
c.mtx.Unlock()
return c.reqs, nil
}
reqs := c.reqs
c.reqs = make([]Req, 0)
c.mtx.Unlock()
conn, err := c.conn()
if err != nil {
return nil, err
}
err = c.RunReqs(reqs, conn)
if err != nil {
c.ConnPool.Remove(conn)
return nil, err
}
c.ConnPool.Add(conn)
return reqs, nil
}
func (c *PipelineClient) RunReqs(reqs []Req, conn *Conn) error {
var multiReq []byte
if len(reqs) == 1 {
multiReq = reqs[0].Req()
} else {
multiReq = make([]byte, 0, 1024)
for _, req := range reqs {
multiReq = append(multiReq, req.Req()...)
}
}
err := c.WriteReq(multiReq, conn)
if err != nil {
return err
}
for i := 0; i < len(reqs); i++ {
req := reqs[i]
val, err := req.ParseReply(conn.Rd)
if err != nil {
req.SetErr(err)
} else {
req.SetVal(val)
}
}
return nil
}

View File

@ -6,19 +6,29 @@ import (
) )
type PubSubClient struct { type PubSubClient struct {
*Client *BaseClient
ch chan *Message ch chan *Message
once sync.Once once sync.Once
} }
func newPubSubClient(client *Client) (*PubSubClient, error) { func newPubSubClient(client *Client) (*PubSubClient, error) {
c := &PubSubClient{ return &PubSubClient{
Client: &Client{ BaseClient: &BaseClient{
ConnPool: NewSingleConnPool(client.ConnPool), ConnPool: NewSingleConnPool(client.ConnPool, false),
InitConn: client.InitConn,
}, },
ch: make(chan *Message), ch: make(chan *Message),
}, nil
} }
return c, nil
func (c *Client) PubSubClient() (*PubSubClient, error) {
return newPubSubClient(c)
}
func (c *Client) Publish(channel, message string) *IntReq {
req := NewIntReq("PUBLISH", channel, message)
c.Process(req)
return req
} }
type Message struct { type Message struct {
@ -28,12 +38,7 @@ type Message struct {
Err error Err error
} }
func (c *PubSubClient) consumeMessages() { func (c *PubSubClient) consumeMessages(conn *Conn) {
conn, err := c.conn()
// SignleConnPool never returns error.
if err != nil {
panic(err)
}
req := NewMultiBulkReq() req := NewMultiBulkReq()
for { for {
@ -89,7 +94,7 @@ func (c *PubSubClient) subscribe(cmd string, channels ...string) (chan *Message,
} }
c.once.Do(func() { c.once.Do(func() {
go c.consumeMessages() go c.consumeMessages(conn)
}) })
return c.ch, nil return c.ch, nil

118
redis.go
View File

@ -12,18 +12,18 @@ var (
ErrReaderTooSmall = errors.New("redis: Reader is too small") ErrReaderTooSmall = errors.New("redis: Reader is too small")
) )
type OpenConnFunc func() (io.ReadWriter, error) type OpenConnFunc func() (io.ReadWriteCloser, error)
type CloseConnFunc func(io.ReadWriter) type CloseConnFunc func(io.ReadWriteCloser) error
type InitConnFunc func(*Client) error type InitConnFunc func(*Client) error
func TCPConnector(addr string) OpenConnFunc { func TCPConnector(addr string) OpenConnFunc {
return func() (io.ReadWriter, error) { return func() (io.ReadWriteCloser, error) {
return net.Dial("tcp", addr) return net.Dial("tcp", addr)
} }
} }
func TLSConnector(addr string, tlsConfig *tls.Config) OpenConnFunc { func TLSConnector(addr string, tlsConfig *tls.Config) OpenConnFunc {
return func() (io.ReadWriter, error) { return func() (io.ReadWriteCloser, error) {
return tls.Dial("tcp", addr, tlsConfig) return tls.Dial("tcp", addr, tlsConfig)
} }
} }
@ -54,45 +54,29 @@ func AuthSelectFunc(password string, db int64) InitConnFunc {
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
type Client struct { type BaseClient struct {
mtx sync.Mutex mtx sync.Mutex
ConnPool ConnPool ConnPool ConnPool
InitConn InitConnFunc InitConn InitConnFunc
reqs []Req reqs []Req
} }
func NewClient(openConn OpenConnFunc, closeConn CloseConnFunc, initConn InitConnFunc) *Client { func (c *BaseClient) WriteReq(buf []byte, conn *Conn) error {
return &Client{ _, err := conn.RW.Write(buf)
ConnPool: NewMultiConnPool(openConn, closeConn, 10), return err
InitConn: initConn,
}
} }
func NewTCPClient(addr string, password string, db int64) *Client { func (c *BaseClient) conn() (*Conn, error) {
return NewClient(TCPConnector(addr), nil, AuthSelectFunc(password, db))
}
func NewTLSClient(addr string, tlsConfig *tls.Config, password string, db int64) *Client {
return NewClient(
TLSConnector(addr, tlsConfig),
nil,
AuthSelectFunc(password, db),
)
}
func (c *Client) Close() {
c.ConnPool.Close()
}
func (c *Client) conn() (*Conn, error) {
conn, isNew, err := c.ConnPool.Get() conn, isNew, err := c.ConnPool.Get()
if err != nil { if err != nil {
return nil, err return nil, err
} }
if isNew && c.InitConn != nil { if isNew && c.InitConn != nil {
client := &Client{ client := &Client{
ConnPool: NewSingleConnPoolConn(c.ConnPool, conn), BaseClient: &BaseClient{
ConnPool: NewSingleConnPoolConn(c.ConnPool, conn, true),
},
} }
err = c.InitConn(client) err = c.InitConn(client)
if err != nil { if err != nil {
@ -102,12 +86,7 @@ func (c *Client) conn() (*Conn, error) {
return conn, nil return conn, nil
} }
func (c *Client) WriteReq(buf []byte, conn *Conn) error { func (c *BaseClient) Process(req Req) {
_, err := conn.RW.Write(buf)
return err
}
func (c *Client) Process(req Req) {
if c.reqs == nil { if c.reqs == nil {
c.Run(req) c.Run(req)
} else { } else {
@ -115,13 +94,7 @@ func (c *Client) Process(req Req) {
} }
} }
func (c *Client) Queue(req Req) { func (c *BaseClient) Run(req Req) {
c.mtx.Lock()
c.reqs = append(c.reqs, req)
c.mtx.Unlock()
}
func (c *Client) Run(req Req) {
conn, err := c.conn() conn, err := c.conn()
if err != nil { if err != nil {
req.SetErr(err) req.SetErr(err)
@ -146,56 +119,39 @@ func (c *Client) Run(req Req) {
req.SetVal(val) req.SetVal(val)
} }
func (c *Client) RunQueued() ([]Req, error) { func (c *BaseClient) Queue(req Req) {
c.mtx.Lock() c.mtx.Lock()
if len(c.reqs) == 0 { c.reqs = append(c.reqs, req)
c.mtx.Unlock() c.mtx.Unlock()
return c.reqs, nil
}
reqs := c.reqs
c.reqs = make([]Req, 0)
c.mtx.Unlock()
conn, err := c.conn()
if err != nil {
return nil, err
} }
err = c.RunReqs(reqs, conn) func (c *BaseClient) Close() error {
if err != nil { return c.ConnPool.Close()
c.ConnPool.Remove(conn)
return nil, err
} }
c.ConnPool.Add(conn) //------------------------------------------------------------------------------
return reqs, nil
type Client struct {
*BaseClient
} }
func (c *Client) RunReqs(reqs []Req, conn *Conn) error { func NewClient(openConn OpenConnFunc, closeConn CloseConnFunc, initConn InitConnFunc) *Client {
var multiReq []byte return &Client{
if len(reqs) == 1 { BaseClient: &BaseClient{
multiReq = reqs[0].Req() ConnPool: NewMultiConnPool(openConn, closeConn, 10),
} else { InitConn: initConn,
multiReq = make([]byte, 0, 1024) },
for _, req := range reqs {
multiReq = append(multiReq, req.Req()...)
} }
} }
err := c.WriteReq(multiReq, conn) func NewTCPClient(addr string, password string, db int64) *Client {
if err != nil { return NewClient(TCPConnector(addr), nil, AuthSelectFunc(password, db))
return err
} }
for i := 0; i < len(reqs); i++ { func NewTLSClient(addr string, tlsConfig *tls.Config, password string, db int64) *Client {
req := reqs[i] return NewClient(
val, err := req.ParseReply(conn.Rd) TLSConnector(addr, tlsConfig),
if err != nil { nil,
req.SetErr(err) AuthSelectFunc(password, db),
} else { )
req.SetVal(val)
}
}
return nil
} }

File diff suppressed because it is too large Load Diff