redis/redis.go

248 lines
4.2 KiB
Go
Raw Normal View History

2012-07-25 17:00:50 +04:00
package redis
import (
2012-07-26 19:16:17 +04:00
"fmt"
2012-07-25 17:00:50 +04:00
"io"
"sync"
2012-07-27 15:43:30 +04:00
"github.com/vmihailenco/bufreader"
2012-07-25 17:00:50 +04:00
)
type connectFunc func() (io.ReadWriter, error)
type disconnectFunc func(io.ReadWriter)
2012-07-26 22:43:21 +04:00
func createReader() (*bufreader.Reader, error) {
return bufreader.NewSizedReader(8192), nil
}
2012-07-25 17:00:50 +04:00
type Client struct {
mtx sync.Mutex
connect connectFunc
disconnect disconnectFunc
currConn io.ReadWriter
2012-07-26 22:43:21 +04:00
readerPool *bufreader.ReaderPool
2012-07-26 19:16:17 +04:00
reqs []Req
2012-07-25 17:00:50 +04:00
}
func NewClient(connect connectFunc, disconnect disconnectFunc) *Client {
return &Client{
2012-07-27 15:43:30 +04:00
readerPool: bufreader.NewReaderPool(100, createReader),
2012-07-25 17:00:50 +04:00
connect: connect,
disconnect: disconnect,
2012-07-29 13:42:00 +04:00
reqs: make([]Req, 0),
2012-07-25 17:00:50 +04:00
}
}
2012-07-26 19:16:17 +04:00
func NewMultiClient(connect connectFunc, disconnect disconnectFunc) *Client {
return &Client{
2012-07-27 15:43:30 +04:00
readerPool: bufreader.NewReaderPool(100, createReader),
2012-07-26 19:16:17 +04:00
connect: connect,
disconnect: disconnect,
reqs: make([]Req, 0),
}
}
2012-07-25 17:00:50 +04:00
func (c *Client) Close() error {
if c.disconnect != nil {
c.disconnect(c.currConn)
}
c.currConn = nil
return nil
}
func (c *Client) conn() (io.ReadWriter, error) {
if c.currConn == nil {
currConn, err := c.connect()
if err != nil {
return nil, err
}
c.currConn = currConn
}
return c.currConn, nil
}
func (c *Client) WriteReq(buf []byte) error {
conn, err := c.conn()
if err != nil {
return err
}
2012-07-26 22:43:21 +04:00
2012-07-25 17:00:50 +04:00
_, err = conn.Write(buf)
if err != nil {
c.Close()
}
return err
}
2012-07-26 22:43:21 +04:00
func (c *Client) ReadReply(rd *bufreader.Reader) error {
2012-07-25 17:00:50 +04:00
conn, err := c.conn()
if err != nil {
2012-07-26 22:43:21 +04:00
return err
2012-07-25 17:00:50 +04:00
}
2012-07-26 22:43:21 +04:00
_, err = rd.ReadFrom(conn)
2012-07-25 17:00:50 +04:00
if err != nil {
c.Close()
2012-07-26 22:43:21 +04:00
return err
2012-07-25 17:00:50 +04:00
}
2012-07-26 22:43:21 +04:00
return nil
2012-07-25 17:00:50 +04:00
}
2012-07-26 22:43:21 +04:00
func (c *Client) WriteRead(buf []byte, rd *bufreader.Reader) error {
2012-07-26 19:16:17 +04:00
c.mtx.Lock()
defer c.mtx.Unlock()
2012-07-25 17:00:50 +04:00
if err := c.WriteReq(buf); err != nil {
2012-07-26 22:43:21 +04:00
return err
2012-07-25 17:00:50 +04:00
}
2012-07-26 22:43:21 +04:00
return c.ReadReply(rd)
2012-07-25 17:00:50 +04:00
}
2012-07-29 13:42:00 +04:00
func (c *Client) Queue(req Req) {
req.SetClient(c)
c.mtx.Lock()
c.reqs = append(c.reqs, req)
c.mtx.Unlock()
}
2012-07-26 19:16:17 +04:00
2012-07-29 13:42:00 +04:00
func (c *Client) Run(req Req) {
2012-07-26 22:43:21 +04:00
rd, err := c.readerPool.Get()
if err != nil {
req.SetErr(err)
return
}
defer c.readerPool.Add(rd)
err = c.WriteRead(req.Req(), rd)
if err != nil {
req.SetErr(err)
return
}
val, err := req.ParseReply(rd)
2012-07-26 19:16:17 +04:00
if err != nil {
req.SetErr(err)
return
}
2012-07-26 22:43:21 +04:00
req.SetVal(val)
2012-07-26 19:16:17 +04:00
}
2012-07-29 13:42:00 +04:00
func (c *Client) RunQueued() ([]Req, error) {
2012-07-29 13:51:29 +04:00
if len(c.reqs) == 0 {
return c.reqs, nil
}
2012-07-29 13:42:00 +04:00
c.mtx.Lock()
reqs := c.reqs
c.reqs = make([]Req, 0)
c.mtx.Unlock()
2012-07-26 19:16:17 +04:00
2012-07-29 13:42:00 +04:00
var multiReq []byte
if len(reqs) == 1 {
multiReq = reqs[0].Req()
} else {
multiReq = make([]byte, 0, 1024)
for _, req := range reqs {
multiReq = append(multiReq, req.Req()...)
}
2012-07-26 19:16:17 +04:00
}
2012-07-29 13:42:00 +04:00
rd, err := c.readerPool.Get()
if err != nil {
return nil, err
}
defer c.readerPool.Add(rd)
err = c.WriteRead(multiReq, rd)
if err != nil {
return nil, err
}
for _, req := range reqs {
val, err := req.ParseReply(rd)
if err != nil {
req.SetErr(err)
} else {
req.SetVal(val)
}
}
return reqs, nil
}
//------------------------------------------------------------------------------
func (c *Client) Discard() {
2012-07-25 17:00:50 +04:00
c.mtx.Lock()
2012-07-26 19:16:17 +04:00
c.reqs = c.reqs[:0]
2012-07-25 17:00:50 +04:00
c.mtx.Unlock()
}
2012-07-26 19:16:17 +04:00
func (c *Client) Exec() ([]Req, error) {
2012-07-29 13:51:29 +04:00
if len(c.reqs) == 0 {
return c.reqs, nil
}
2012-07-26 19:16:17 +04:00
c.mtx.Lock()
reqs := c.reqs
c.reqs = make([]Req, 0)
c.mtx.Unlock()
multiReq := make([]byte, 0, 1024)
multiReq = append(multiReq, PackReq([]string{"MULTI"})...)
for _, req := range reqs {
multiReq = append(multiReq, req.Req()...)
}
multiReq = append(multiReq, PackReq([]string{"EXEC"})...)
2012-07-26 22:43:21 +04:00
rd, err := c.readerPool.Get()
if err != nil {
return nil, err
}
defer c.readerPool.Add(rd)
err = c.WriteRead(multiReq, rd)
2012-07-25 17:00:50 +04:00
if err != nil {
2012-07-26 19:16:17 +04:00
return nil, err
2012-07-25 17:00:50 +04:00
}
2012-07-26 19:16:17 +04:00
statusReq := NewStatusReq()
2012-07-29 13:42:00 +04:00
// Parse MULTI command reply.
2012-07-26 22:43:21 +04:00
_, err = statusReq.ParseReply(rd)
2012-07-26 19:16:17 +04:00
if err != nil {
return nil, err
}
2012-07-29 13:42:00 +04:00
// Parse queued replies.
2012-07-26 19:16:17 +04:00
for _ = range reqs {
2012-07-26 22:43:21 +04:00
_, err = statusReq.ParseReply(rd)
2012-07-26 19:16:17 +04:00
if err != nil {
return nil, err
}
}
2012-07-29 13:42:00 +04:00
// Parse number of replies.
2012-07-26 19:16:17 +04:00
line, err := rd.ReadLine('\n')
if err != nil {
return nil, err
}
if line[0] != '*' {
return nil, fmt.Errorf("Expected '*', but got line %q of %q.", line, rd.Bytes())
}
2012-07-29 13:42:00 +04:00
// Parse replies.
2012-07-26 19:16:17 +04:00
for _, req := range reqs {
2012-07-26 22:43:21 +04:00
val, err := req.ParseReply(rd)
if err != nil {
req.SetErr(err)
2012-07-29 13:42:00 +04:00
} else {
req.SetVal(val)
2012-07-26 22:43:21 +04:00
}
2012-07-26 19:16:17 +04:00
}
return reqs, nil
2012-07-25 17:00:50 +04:00
}