Add support for watch in transactions.

This commit is contained in:
Vladimir Mihailenco 2012-08-09 17:06:26 +03:00
parent 052cef49d2
commit 83664bb3a8
7 changed files with 331 additions and 196 deletions

View File

@ -93,14 +93,18 @@ Running commands
Pipelining Pipelining
---------- ----------
Client has ability to run several commands with one read/write: Client has ability to run commands in batches:
multiClient := redisClient.Multi() pipeline, err := redisClient.PipelineClient()
if err != nil {
panic(err)
}
defer pipeline.Close()
setReq := multiClient.Set("foo1", "bar1") // queue command SET setReq := pipeline.Set("foo1", "bar1") // queue command SET
getReq := multiClient.Get("foo2") // queue command GET getReq := pipeline.Get("foo2") // queue command GET
reqs, err := multiClient.RunQueued() // run queued commands reqs, err := pipeline.RunQueued() // run queued commands
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -121,28 +125,33 @@ Multi/Exec
Example: Example:
multiClient := redisClient.Multi() multiClient, err := redisClient.MultiClient()
get1 := multiClient.Get("foo1")
get2 := multiClient.Get("foo2")
reqs, err := multiClient.Exec()
if err != nil { if err != nil {
panic(err) panic(err)
} }
defer multiClient.Close()
watch := mutliClient.Watch("foo")
if watch.Err() != nil {
panic(watch.Err())
}
// Start transaction.
multiClient.Multi()
set := multiClient.Set("foo", watch.Val() + "1")
reqs, err := multiClient.Exec()
if err == redis.Nil {
// Repeat transaction.
} else if err != nil {
panic(err)
}
for _, req := range reqs { for _, req := range reqs {
// ... // ...
} }
ok := set.Val()
if get1.Err() != nil && get1.Err() != redis.Nil {
panic(get1.Err())
}
val1 := get1.Val()
if get2.Err() != nil && get2.Err() != redis.Nil {
panic(get2.Err())
}
val2 := get2.Val()
Pub/sub Pub/sub
------- -------
@ -160,6 +169,7 @@ Subscribe:
if err != nil { if err != nil {
panic(err) panic(err)
} }
defer pubsub.Close()
ch, err := pubsub.Subscribe("mychannel") ch, err := pubsub.Subscribe("mychannel")
if err != nil { if err != nil {

View File

@ -1,6 +1,7 @@
package redis package redis
import ( import (
"fmt"
"strconv" "strconv"
) )
@ -857,11 +858,123 @@ func (c *Client) Publish(channel, message string) *IntReq {
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
func (c *Client) Multi() *Client { func (c *Client) PipelineClient() (*Client, error) {
return &Client{ return &Client{
ConnPool: c.ConnPool, ConnPool: c.ConnPool,
InitConn: c.InitConn, InitConn: c.InitConn,
reqs: make([]Req, 0),
reqs: make([]Req, 0), }, nil
} }
//------------------------------------------------------------------------------
func (c *Client) MultiClient() (*Client, error) {
return &Client{
ConnPool: NewSingleConnPool(c.ConnPool),
InitConn: c.InitConn,
}, nil
}
func (c *Client) Multi() {
c.reqs = make([]Req, 0)
}
func (c *Client) Watch(keys ...string) *StatusReq {
args := append([]string{"WATCH"}, keys...)
req := NewStatusReq(args...)
c.Process(req)
return req
}
func (c *Client) Unwatch(keys ...string) *StatusReq {
args := append([]string{"UNWATCH"}, keys...)
req := NewStatusReq(args...)
c.Process(req)
return req
}
func (c *Client) Discard() {
c.mtx.Lock()
c.reqs = c.reqs[:0]
c.mtx.Unlock()
}
func (c *Client) Exec() ([]Req, error) {
c.mtx.Lock()
if len(c.reqs) == 0 {
c.mtx.Unlock()
return c.reqs, nil
}
reqs := c.reqs
c.reqs = nil
c.mtx.Unlock()
conn, err := c.conn()
if err != nil {
return nil, err
}
err = c.ExecReqs(reqs, conn)
if err != nil {
c.ConnPool.Remove(conn)
return nil, err
}
c.ConnPool.Add(conn)
return reqs, nil
}
func (c *Client) ExecReqs(reqs []Req, conn *Conn) error {
multiReq := make([]byte, 0, 1024)
multiReq = append(multiReq, PackReq([]string{"MULTI"})...)
for _, req := range reqs {
multiReq = append(multiReq, req.Req()...)
}
multiReq = append(multiReq, PackReq([]string{"EXEC"})...)
err := c.WriteReq(multiReq, conn)
if err != nil {
return err
}
statusReq := NewStatusReq()
// Parse MULTI command reply.
_, err = statusReq.ParseReply(conn.Rd)
if err != nil {
return err
}
// Parse queued replies.
for _ = range reqs {
_, err = statusReq.ParseReply(conn.Rd)
if err != nil {
return err
}
}
// Parse number of replies.
line, err := readLine(conn.Rd)
if err != nil {
return err
}
if line[0] != '*' {
return fmt.Errorf("Expected '*', but got line %q", line)
}
if isNilReplies(line) {
return Nil
}
// Parse replies.
for i := 0; i < len(reqs); i++ {
req := reqs[i]
val, err := req.ParseReply(conn.Rd)
if err != nil {
req.SetErr(err)
} else {
req.SetVal(val)
}
}
return nil
} }

View File

@ -25,6 +25,7 @@ type ConnPool interface {
Add(*Conn) Add(*Conn)
Remove(*Conn) Remove(*Conn)
Len() int Len() int
Close()
} }
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
@ -102,18 +103,39 @@ func (p *MultiConnPool) Len() int {
return len(p.conns) return len(p.conns)
} }
func (p *MultiConnPool) Close() {}
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
type SingleConnPool struct { type SingleConnPool struct {
mtx sync.Mutex
pool ConnPool
conn *Conn conn *Conn
} }
func NewSingleConnPool(conn *Conn) *SingleConnPool { func NewSingleConnPoolConn(pool ConnPool, conn *Conn) *SingleConnPool {
return &SingleConnPool{conn: conn} return &SingleConnPool{
pool: pool,
conn: conn,
}
}
func NewSingleConnPool(pool ConnPool) *SingleConnPool {
return NewSingleConnPoolConn(pool, nil)
} }
func (p *SingleConnPool) Get() (*Conn, bool, error) { func (p *SingleConnPool) Get() (*Conn, bool, error) {
return p.conn, false, nil p.mtx.Lock()
defer p.mtx.Unlock()
if p.conn != nil {
return p.conn, false, nil
}
conn, isNew, err := p.pool.Get()
if err != nil {
return nil, false, err
}
p.conn = conn
return p.conn, isNew, nil
} }
func (p *SingleConnPool) Add(conn *Conn) {} func (p *SingleConnPool) Add(conn *Conn) {}
@ -123,3 +145,10 @@ func (p *SingleConnPool) Remove(conn *Conn) {}
func (p *SingleConnPool) Len() int { func (p *SingleConnPool) Len() int {
return 1 return 1
} }
func (p *SingleConnPool) Close() {
p.mtx.Lock()
defer p.mtx.Unlock()
p.pool.Add(p.conn)
p.conn = nil
}

View File

@ -12,15 +12,9 @@ type PubSubClient struct {
} }
func newPubSubClient(client *Client) (*PubSubClient, error) { func newPubSubClient(client *Client) (*PubSubClient, error) {
pubSubConn, _, err := client.ConnPool.Get()
if err != nil {
return nil, err
}
client.ConnPool.Remove(pubSubConn)
c := &PubSubClient{ c := &PubSubClient{
Client: &Client{ Client: &Client{
ConnPool: NewSingleConnPool(pubSubConn), ConnPool: NewSingleConnPool(client.ConnPool),
}, },
ch: make(chan *Message), ch: make(chan *Message),
} }

View File

@ -3,7 +3,6 @@ package redis
import ( import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt"
"io" "io"
"net" "net"
"sync" "sync"
@ -82,6 +81,10 @@ func NewTLSClient(addr string, tlsConfig *tls.Config, password string, db int64)
) )
} }
func (c *Client) Close() {
c.ConnPool.Close()
}
func (c *Client) conn() (*Conn, error) { func (c *Client) conn() (*Conn, error) {
conn, isNew, err := c.ConnPool.Get() conn, isNew, err := c.ConnPool.Get()
if err != nil { if err != nil {
@ -89,7 +92,7 @@ func (c *Client) conn() (*Conn, error) {
} }
if isNew && c.InitConn != nil { if isNew && c.InitConn != nil {
client := &Client{ client := &Client{
ConnPool: NewSingleConnPool(conn), ConnPool: NewSingleConnPoolConn(c.ConnPool, conn),
} }
err = c.InitConn(client) err = c.InitConn(client)
if err != nil { if err != nil {
@ -196,89 +199,3 @@ func (c *Client) RunReqs(reqs []Req, conn *Conn) error {
return nil return nil
} }
//------------------------------------------------------------------------------
func (c *Client) Discard() {
c.mtx.Lock()
c.reqs = c.reqs[:0]
c.mtx.Unlock()
}
func (c *Client) Exec() ([]Req, error) {
c.mtx.Lock()
if len(c.reqs) == 0 {
c.mtx.Unlock()
return c.reqs, nil
}
reqs := c.reqs
c.reqs = make([]Req, 0)
c.mtx.Unlock()
conn, err := c.conn()
if err != nil {
return nil, err
}
err = c.ExecReqs(reqs, conn)
if err != nil {
c.ConnPool.Remove(conn)
return nil, err
}
c.ConnPool.Add(conn)
return reqs, nil
}
func (c *Client) ExecReqs(reqs []Req, conn *Conn) error {
multiReq := make([]byte, 0, 1024)
multiReq = append(multiReq, PackReq([]string{"MULTI"})...)
for _, req := range reqs {
multiReq = append(multiReq, req.Req()...)
}
multiReq = append(multiReq, PackReq([]string{"EXEC"})...)
err := c.WriteReq(multiReq, conn)
if err != nil {
return err
}
statusReq := NewStatusReq()
// Parse MULTI command reply.
_, err = statusReq.ParseReply(conn.Rd)
if err != nil {
return err
}
// Parse queued replies.
for _ = range reqs {
_, err = statusReq.ParseReply(conn.Rd)
if err != nil {
return err
}
}
// Parse number of replies.
line, err := readLine(conn.Rd)
if err != nil {
return err
}
if line[0] != '*' {
buf, _ := conn.Rd.Peek(conn.Rd.Buffered())
return fmt.Errorf("Expected '*', but got line %q of %q.", line, buf)
}
// Parse replies.
for i := 0; i < len(reqs); i++ {
req := reqs[i]
val, err := req.ParseReply(conn.Rd)
if err != nil {
req.SetErr(err)
} else {
req.SetVal(val)
}
}
return nil
}

View File

@ -18,7 +18,7 @@ const redisAddr = ":8888"
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
type RedisTest struct { type RedisTest struct {
client, multiClient *redis.Client client *redis.Client
} }
var _ = Suite(&RedisTest{}) var _ = Suite(&RedisTest{})
@ -30,8 +30,6 @@ func Test(t *testing.T) { TestingT(t) }
func (t *RedisTest) SetUpTest(c *C) { func (t *RedisTest) SetUpTest(c *C) {
t.client = redis.NewTCPClient(redisAddr, "", -1) t.client = redis.NewTCPClient(redisAddr, "", -1)
c.Check(t.client.Flushdb().Err(), IsNil) c.Check(t.client.Flushdb().Err(), IsNil)
t.multiClient = t.client.Multi()
} }
func (t *RedisTest) TearDownTest(c *C) { func (t *RedisTest) TearDownTest(c *C) {
@ -1611,6 +1609,7 @@ func (t *RedisTest) TestZUnionStore(c *C) {
func (t *RedisTest) TestPatternPubSub(c *C) { func (t *RedisTest) TestPatternPubSub(c *C) {
pubsub, err := t.client.PubSubClient() pubsub, err := t.client.PubSubClient()
c.Check(err, IsNil) c.Check(err, IsNil)
defer pubsub.Close()
ch, err := pubsub.PSubscribe("mychannel*") ch, err := pubsub.PSubscribe("mychannel*")
c.Check(err, IsNil) c.Check(err, IsNil)
@ -1658,6 +1657,7 @@ func (t *RedisTest) TestPatternPubSub(c *C) {
func (t *RedisTest) TestPubSub(c *C) { func (t *RedisTest) TestPubSub(c *C) {
pubsub, err := t.client.PubSubClient() pubsub, err := t.client.PubSubClient()
c.Check(err, IsNil) c.Check(err, IsNil)
defer pubsub.Close()
ch, err := pubsub.Subscribe("mychannel") ch, err := pubsub.Subscribe("mychannel")
c.Check(err, IsNil) c.Check(err, IsNil)
@ -1749,10 +1749,15 @@ func (t *RedisTest) TestPipelining(c *C) {
c.Check(set.Err(), IsNil) c.Check(set.Err(), IsNil)
c.Check(set.Val(), Equals, "OK") c.Check(set.Val(), Equals, "OK")
setReq := t.multiClient.Set("foo1", "bar1") multi, err := t.client.MultiClient()
getReq := t.multiClient.Get("foo2") c.Check(err, IsNil)
reqs, err := t.multiClient.RunQueued() multi.Multi()
setReq := multi.Set("foo1", "bar1")
getReq := multi.Get("foo2")
reqs, err := multi.RunQueued()
c.Check(err, IsNil) c.Check(err, IsNil)
c.Check(reqs, HasLen, 2) c.Check(reqs, HasLen, 2)
@ -1769,16 +1774,74 @@ func (t *RedisTest) TestRunQueuedOnEmptyQueue(c *C) {
c.Check(reqs, HasLen, 0) c.Check(reqs, HasLen, 0)
} }
func (t *RedisTest) TestIncrPipeliningFromGoroutines(c *C) {
multi, err := t.client.PipelineClient()
c.Check(err, IsNil)
defer multi.Close()
wg := &sync.WaitGroup{}
for i := int64(0); i < 20000; i++ {
wg.Add(1)
go func() {
multi.Incr("TestIncrPipeliningFromGoroutinesKey")
wg.Done()
}()
}
wg.Wait()
reqs, err := multi.RunQueued()
c.Check(err, IsNil)
c.Check(reqs, HasLen, 20000)
for _, req := range reqs {
if req.Err() != nil {
c.Errorf("got %v, expected nil", req.Err())
}
}
get := t.client.Get("TestIncrPipeliningFromGoroutinesKey")
c.Check(get.Err(), IsNil)
c.Check(get.Val(), Equals, "20000")
}
func (t *RedisTest) TestPipeliningFromGoroutines(c *C) {
multi, err := t.client.PipelineClient()
c.Check(err, IsNil)
defer multi.Close()
for i := int64(0); i < 1000; i += 2 {
go func() {
msg1 := "echo" + strconv.FormatInt(i, 10)
msg2 := "echo" + strconv.FormatInt(i+1, 10)
echo1Req := multi.Echo(msg1)
echo2Req := multi.Echo(msg2)
reqs, err := multi.RunQueued()
c.Check(reqs, HasLen, 2)
c.Check(err, IsNil)
c.Check(echo1Req.Err(), IsNil)
c.Check(echo1Req.Val(), Equals, msg1)
c.Check(echo2Req.Err(), IsNil)
c.Check(echo2Req.Val(), Equals, msg2)
}()
}
}
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
func (t *RedisTest) TestDiscard(c *C) { func (t *RedisTest) TestDiscard(c *C) {
multiC := t.client.Multi() multi, err := t.client.MultiClient()
c.Check(err, IsNil)
multiC.Set("foo1", "bar1") multi.Multi()
multiC.Discard()
multiC.Set("foo2", "bar2")
reqs, err := multiC.Exec() multi.Set("foo1", "bar1")
multi.Discard()
multi.Set("foo2", "bar2")
reqs, err := multi.Exec()
c.Check(err, IsNil) c.Check(err, IsNil)
c.Check(reqs, HasLen, 1) c.Check(reqs, HasLen, 1)
@ -1792,12 +1855,15 @@ func (t *RedisTest) TestDiscard(c *C) {
} }
func (t *RedisTest) TestMultiExec(c *C) { func (t *RedisTest) TestMultiExec(c *C) {
multiC := t.client.Multi() multi, err := t.client.MultiClient()
c.Check(err, IsNil)
setR := multiC.Set("foo", "bar") multi.Multi()
getR := multiC.Get("foo")
reqs, err := multiC.Exec() setR := multi.Set("foo", "bar")
getR := multi.Get("foo")
reqs, err := multi.Exec()
c.Check(err, IsNil) c.Check(err, IsNil)
c.Check(reqs, HasLen, 2) c.Check(reqs, HasLen, 2)
@ -1827,30 +1893,6 @@ func (t *RedisTest) TestEchoFromGoroutines(c *C) {
} }
} }
func (t *RedisTest) TestPipeliningFromGoroutines(c *C) {
multiClient := t.client.Multi()
for i := int64(0); i < 1000; i += 2 {
go func() {
msg1 := "echo" + strconv.FormatInt(i, 10)
msg2 := "echo" + strconv.FormatInt(i+1, 10)
echo1Req := multiClient.Echo(msg1)
echo2Req := multiClient.Echo(msg2)
reqs, err := multiClient.RunQueued()
c.Check(reqs, HasLen, 2)
c.Check(err, IsNil)
c.Check(echo1Req.Err(), IsNil)
c.Check(echo1Req.Val(), Equals, msg1)
c.Check(echo2Req.Err(), IsNil)
c.Check(echo2Req.Val(), Equals, msg2)
}()
}
}
func (t *RedisTest) TestIncrFromGoroutines(c *C) { func (t *RedisTest) TestIncrFromGoroutines(c *C) {
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
for i := int64(0); i < 20000; i++ { for i := int64(0); i < 20000; i++ {
@ -1868,47 +1910,23 @@ func (t *RedisTest) TestIncrFromGoroutines(c *C) {
c.Check(get.Val(), Equals, "20000") c.Check(get.Val(), Equals, "20000")
} }
func (t *RedisTest) TestIncrPipeliningFromGoroutines(c *C) {
multiClient := t.client.Multi()
wg := &sync.WaitGroup{}
for i := int64(0); i < 20000; i++ {
wg.Add(1)
go func() {
multiClient.Incr("TestIncrPipeliningFromGoroutinesKey")
wg.Done()
}()
}
wg.Wait()
reqs, err := multiClient.RunQueued()
c.Check(err, IsNil)
c.Check(reqs, HasLen, 20000)
for _, req := range reqs {
if req.Err() != nil {
c.Errorf("got %v, expected nil", req.Err())
}
}
get := t.client.Get("TestIncrPipeliningFromGoroutinesKey")
c.Check(get.Err(), IsNil)
c.Check(get.Val(), Equals, "20000")
}
func (t *RedisTest) TestIncrTransaction(c *C) { func (t *RedisTest) TestIncrTransaction(c *C) {
multiClient := t.client.Multi() multi, err := t.client.MultiClient()
c.Check(err, IsNil)
multi.Multi()
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
for i := int64(0); i < 20000; i++ { for i := int64(0); i < 20000; i++ {
wg.Add(1) wg.Add(1)
go func() { go func() {
multiClient.Incr("TestIncrTransactionKey") multi.Incr("TestIncrTransactionKey")
wg.Done() wg.Done()
}() }()
} }
wg.Wait() wg.Wait()
reqs, err := multiClient.Exec() reqs, err := multi.Exec()
c.Check(err, IsNil) c.Check(err, IsNil)
c.Check(reqs, HasLen, 20000) c.Check(reqs, HasLen, 20000)
for _, req := range reqs { for _, req := range reqs {
@ -1922,6 +1940,55 @@ func (t *RedisTest) TestIncrTransaction(c *C) {
c.Check(get.Val(), Equals, "20000") c.Check(get.Val(), Equals, "20000")
} }
func (t *RedisTest) transactionalIncr(c *C, wg *sync.WaitGroup) {
multi, err := t.client.MultiClient()
c.Check(err, IsNil)
defer multi.Close()
watch := multi.Watch("foo")
c.Check(watch.Err(), IsNil)
c.Check(watch.Val(), Equals, "OK")
get := multi.Get("foo")
c.Check(get.Err(), IsNil)
c.Check(get.Val(), Not(Equals), redis.Nil)
v, err := strconv.ParseInt(get.Val(), 10, 64)
c.Check(err, IsNil)
multi.Multi()
set := multi.Set("foo", strconv.FormatInt(v+1, 10))
reqs, err := multi.Exec()
if err == redis.Nil {
t.transactionalIncr(c, wg)
return
}
c.Check(reqs, HasLen, 1)
c.Check(err, IsNil)
c.Check(set.Err(), IsNil)
c.Check(set.Val(), Equals, "OK")
wg.Done()
}
func (t *RedisTest) TestWatchUnwatch(c *C) {
set := t.client.Set("foo", "0")
c.Check(set.Err(), IsNil)
c.Check(set.Val(), Equals, "OK")
wg := &sync.WaitGroup{}
for i := 0; i < 1000; i++ {
wg.Add(1)
go t.transactionalIncr(c, wg)
}
wg.Wait()
get := t.client.Get("foo")
c.Check(get.Err(), IsNil)
c.Check(get.Val(), Equals, "1000")
}
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
func (t *RedisTest) BenchmarkRedisPing(c *C) { func (t *RedisTest) BenchmarkRedisPing(c *C) {

View File

@ -20,8 +20,12 @@ func isEmpty(line []byte) bool {
return len(line) == 2 && line[0] == '$' && line[1] == '0' return len(line) == 2 && line[0] == '$' && line[1] == '0'
} }
func isNilReplies(line []byte) bool {
return len(line) == 3 && line[0] == '*' && line[1] == '-' && line[2] == '1'
}
func isNoReplies(line []byte) bool { func isNoReplies(line []byte) bool {
return len(line) >= 2 && line[1] == '*' && line[1] == '0' return len(line) == 2 && line[1] == '*' && line[1] == '0'
} }
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
@ -401,7 +405,7 @@ func (r *MultiBulkReq) ParseReply(rd ReadLiner) (interface{}, error) {
return nil, errors.New(string(line[1:])) return nil, errors.New(string(line[1:]))
} else if line[0] != '*' { } else if line[0] != '*' {
return nil, fmt.Errorf("Expected '*', but got line %q", line) return nil, fmt.Errorf("Expected '*', but got line %q", line)
} else if isNil(line) { } else if isNilReplies(line) {
return nil, Nil return nil, Nil
} }
@ -420,14 +424,15 @@ func (r *MultiBulkReq) ParseReply(rd ReadLiner) (interface{}, error) {
return nil, err return nil, err
} }
if line[0] == ':' { switch line[0] {
case ':':
var n int64 var n int64
n, err = strconv.ParseInt(string(line[1:]), 10, 64) n, err = strconv.ParseInt(string(line[1:]), 10, 64)
if err != nil { if err != nil {
return nil, err return nil, err
} }
val = append(val, n) val = append(val, n)
} else if line[0] == '$' { case '$':
if isEmpty(line) { if isEmpty(line) {
val = append(val, "") val = append(val, "")
} else if isNil(line) { } else if isNil(line) {
@ -439,7 +444,7 @@ func (r *MultiBulkReq) ParseReply(rd ReadLiner) (interface{}, error) {
} }
val = append(val, string(line)) val = append(val, string(line))
} }
} else { default:
return nil, fmt.Errorf("Expected '$', but got line %q", line) return nil, fmt.Errorf("Expected '$', but got line %q", line)
} }
} }