diff --git a/README.md b/README.md index 5b7f48d..730c444 100644 --- a/README.md +++ b/README.md @@ -93,14 +93,18 @@ Running commands Pipelining ---------- -Client has ability to run several commands with one read/write: +Client has ability to run commands in batches: - multiClient := redisClient.Multi() + pipeline, err := redisClient.PipelineClient() + if err != nil { + panic(err) + } + defer pipeline.Close() - setReq := multiClient.Set("foo1", "bar1") // queue command SET - getReq := multiClient.Get("foo2") // queue command GET + setReq := pipeline.Set("foo1", "bar1") // queue command SET + getReq := pipeline.Get("foo2") // queue command GET - reqs, err := multiClient.RunQueued() // run queued commands + reqs, err := pipeline.RunQueued() // run queued commands if err != nil { panic(err) } @@ -121,28 +125,33 @@ Multi/Exec Example: - multiClient := redisClient.Multi() - - get1 := multiClient.Get("foo1") - get2 := multiClient.Get("foo2") - reqs, err := multiClient.Exec() + multiClient, err := redisClient.MultiClient() if err != nil { panic(err) } + defer multiClient.Close() + + watch := mutliClient.Watch("foo") + if watch.Err() != nil { + panic(watch.Err()) + } + + // Start transaction. + multiClient.Multi() + + set := multiClient.Set("foo", watch.Val() + "1") + + reqs, err := multiClient.Exec() + if err == redis.Nil { + // Repeat transaction. + } else if err != nil { + panic(err) + } for _, req := range reqs { // ... } - - if get1.Err() != nil && get1.Err() != redis.Nil { - panic(get1.Err()) - } - val1 := get1.Val() - - if get2.Err() != nil && get2.Err() != redis.Nil { - panic(get2.Err()) - } - val2 := get2.Val() + ok := set.Val() Pub/sub ------- @@ -160,6 +169,7 @@ Subscribe: if err != nil { panic(err) } + defer pubsub.Close() ch, err := pubsub.Subscribe("mychannel") if err != nil { diff --git a/commands.go b/commands.go index 9002600..a179d7d 100644 --- a/commands.go +++ b/commands.go @@ -1,6 +1,7 @@ package redis import ( + "fmt" "strconv" ) @@ -857,11 +858,123 @@ func (c *Client) Publish(channel, message string) *IntReq { //------------------------------------------------------------------------------ -func (c *Client) Multi() *Client { +func (c *Client) PipelineClient() (*Client, error) { return &Client{ ConnPool: c.ConnPool, InitConn: c.InitConn, - - reqs: make([]Req, 0), - } + reqs: make([]Req, 0), + }, nil +} + +//------------------------------------------------------------------------------ + +func (c *Client) MultiClient() (*Client, error) { + return &Client{ + ConnPool: NewSingleConnPool(c.ConnPool), + InitConn: c.InitConn, + }, nil +} + +func (c *Client) Multi() { + c.reqs = make([]Req, 0) +} + +func (c *Client) Watch(keys ...string) *StatusReq { + args := append([]string{"WATCH"}, keys...) + req := NewStatusReq(args...) + c.Process(req) + return req +} + +func (c *Client) Unwatch(keys ...string) *StatusReq { + args := append([]string{"UNWATCH"}, keys...) + req := NewStatusReq(args...) + c.Process(req) + return req +} + +func (c *Client) Discard() { + c.mtx.Lock() + c.reqs = c.reqs[:0] + c.mtx.Unlock() +} + +func (c *Client) Exec() ([]Req, error) { + c.mtx.Lock() + if len(c.reqs) == 0 { + c.mtx.Unlock() + return c.reqs, nil + } + reqs := c.reqs + c.reqs = nil + c.mtx.Unlock() + + conn, err := c.conn() + if err != nil { + return nil, err + } + + err = c.ExecReqs(reqs, conn) + if err != nil { + c.ConnPool.Remove(conn) + return nil, err + } + + c.ConnPool.Add(conn) + return reqs, nil +} + +func (c *Client) ExecReqs(reqs []Req, conn *Conn) error { + multiReq := make([]byte, 0, 1024) + multiReq = append(multiReq, PackReq([]string{"MULTI"})...) + for _, req := range reqs { + multiReq = append(multiReq, req.Req()...) + } + multiReq = append(multiReq, PackReq([]string{"EXEC"})...) + + err := c.WriteReq(multiReq, conn) + if err != nil { + return err + } + + statusReq := NewStatusReq() + + // Parse MULTI command reply. + _, err = statusReq.ParseReply(conn.Rd) + if err != nil { + return err + } + + // Parse queued replies. + for _ = range reqs { + _, err = statusReq.ParseReply(conn.Rd) + if err != nil { + return err + } + } + + // Parse number of replies. + line, err := readLine(conn.Rd) + if err != nil { + return err + } + if line[0] != '*' { + return fmt.Errorf("Expected '*', but got line %q", line) + } + if isNilReplies(line) { + return Nil + } + + // Parse replies. + for i := 0; i < len(reqs); i++ { + req := reqs[i] + val, err := req.ParseReply(conn.Rd) + if err != nil { + req.SetErr(err) + } else { + req.SetVal(val) + } + } + + return nil } diff --git a/connpool.go b/connpool.go index 9ec8fa5..27cb320 100644 --- a/connpool.go +++ b/connpool.go @@ -25,6 +25,7 @@ type ConnPool interface { Add(*Conn) Remove(*Conn) Len() int + Close() } //------------------------------------------------------------------------------ @@ -102,18 +103,39 @@ func (p *MultiConnPool) Len() int { return len(p.conns) } +func (p *MultiConnPool) Close() {} + //------------------------------------------------------------------------------ type SingleConnPool struct { + mtx sync.Mutex + pool ConnPool conn *Conn } -func NewSingleConnPool(conn *Conn) *SingleConnPool { - return &SingleConnPool{conn: conn} +func NewSingleConnPoolConn(pool ConnPool, conn *Conn) *SingleConnPool { + return &SingleConnPool{ + pool: pool, + conn: conn, + } +} + +func NewSingleConnPool(pool ConnPool) *SingleConnPool { + return NewSingleConnPoolConn(pool, nil) } func (p *SingleConnPool) Get() (*Conn, bool, error) { - return p.conn, false, nil + p.mtx.Lock() + defer p.mtx.Unlock() + if p.conn != nil { + return p.conn, false, nil + } + conn, isNew, err := p.pool.Get() + if err != nil { + return nil, false, err + } + p.conn = conn + return p.conn, isNew, nil } func (p *SingleConnPool) Add(conn *Conn) {} @@ -123,3 +145,10 @@ func (p *SingleConnPool) Remove(conn *Conn) {} func (p *SingleConnPool) Len() int { return 1 } + +func (p *SingleConnPool) Close() { + p.mtx.Lock() + defer p.mtx.Unlock() + p.pool.Add(p.conn) + p.conn = nil +} diff --git a/pubsub.go b/pubsub.go index 93d5b38..8f738b9 100644 --- a/pubsub.go +++ b/pubsub.go @@ -12,15 +12,9 @@ type PubSubClient struct { } func newPubSubClient(client *Client) (*PubSubClient, error) { - pubSubConn, _, err := client.ConnPool.Get() - if err != nil { - return nil, err - } - client.ConnPool.Remove(pubSubConn) - c := &PubSubClient{ Client: &Client{ - ConnPool: NewSingleConnPool(pubSubConn), + ConnPool: NewSingleConnPool(client.ConnPool), }, ch: make(chan *Message), } diff --git a/redis.go b/redis.go index f23147d..a198af6 100644 --- a/redis.go +++ b/redis.go @@ -3,7 +3,6 @@ package redis import ( "crypto/tls" "errors" - "fmt" "io" "net" "sync" @@ -82,6 +81,10 @@ func NewTLSClient(addr string, tlsConfig *tls.Config, password string, db int64) ) } +func (c *Client) Close() { + c.ConnPool.Close() +} + func (c *Client) conn() (*Conn, error) { conn, isNew, err := c.ConnPool.Get() if err != nil { @@ -89,7 +92,7 @@ func (c *Client) conn() (*Conn, error) { } if isNew && c.InitConn != nil { client := &Client{ - ConnPool: NewSingleConnPool(conn), + ConnPool: NewSingleConnPoolConn(c.ConnPool, conn), } err = c.InitConn(client) if err != nil { @@ -196,89 +199,3 @@ func (c *Client) RunReqs(reqs []Req, conn *Conn) error { return nil } - -//------------------------------------------------------------------------------ - -func (c *Client) Discard() { - c.mtx.Lock() - c.reqs = c.reqs[:0] - c.mtx.Unlock() -} - -func (c *Client) Exec() ([]Req, error) { - c.mtx.Lock() - if len(c.reqs) == 0 { - c.mtx.Unlock() - return c.reqs, nil - } - reqs := c.reqs - c.reqs = make([]Req, 0) - c.mtx.Unlock() - - conn, err := c.conn() - if err != nil { - return nil, err - } - - err = c.ExecReqs(reqs, conn) - if err != nil { - c.ConnPool.Remove(conn) - return nil, err - } - - c.ConnPool.Add(conn) - return reqs, nil -} - -func (c *Client) ExecReqs(reqs []Req, conn *Conn) error { - multiReq := make([]byte, 0, 1024) - multiReq = append(multiReq, PackReq([]string{"MULTI"})...) - for _, req := range reqs { - multiReq = append(multiReq, req.Req()...) - } - multiReq = append(multiReq, PackReq([]string{"EXEC"})...) - - err := c.WriteReq(multiReq, conn) - if err != nil { - return err - } - - statusReq := NewStatusReq() - - // Parse MULTI command reply. - _, err = statusReq.ParseReply(conn.Rd) - if err != nil { - return err - } - - // Parse queued replies. - for _ = range reqs { - _, err = statusReq.ParseReply(conn.Rd) - if err != nil { - return err - } - } - - // Parse number of replies. - line, err := readLine(conn.Rd) - if err != nil { - return err - } - if line[0] != '*' { - buf, _ := conn.Rd.Peek(conn.Rd.Buffered()) - return fmt.Errorf("Expected '*', but got line %q of %q.", line, buf) - } - - // Parse replies. - for i := 0; i < len(reqs); i++ { - req := reqs[i] - val, err := req.ParseReply(conn.Rd) - if err != nil { - req.SetErr(err) - } else { - req.SetVal(val) - } - } - - return nil -} diff --git a/redis_test.go b/redis_test.go index 90abd96..7637e27 100644 --- a/redis_test.go +++ b/redis_test.go @@ -18,7 +18,7 @@ const redisAddr = ":8888" //------------------------------------------------------------------------------ type RedisTest struct { - client, multiClient *redis.Client + client *redis.Client } var _ = Suite(&RedisTest{}) @@ -30,8 +30,6 @@ func Test(t *testing.T) { TestingT(t) } func (t *RedisTest) SetUpTest(c *C) { t.client = redis.NewTCPClient(redisAddr, "", -1) c.Check(t.client.Flushdb().Err(), IsNil) - - t.multiClient = t.client.Multi() } func (t *RedisTest) TearDownTest(c *C) { @@ -1611,6 +1609,7 @@ func (t *RedisTest) TestZUnionStore(c *C) { func (t *RedisTest) TestPatternPubSub(c *C) { pubsub, err := t.client.PubSubClient() c.Check(err, IsNil) + defer pubsub.Close() ch, err := pubsub.PSubscribe("mychannel*") c.Check(err, IsNil) @@ -1658,6 +1657,7 @@ func (t *RedisTest) TestPatternPubSub(c *C) { func (t *RedisTest) TestPubSub(c *C) { pubsub, err := t.client.PubSubClient() c.Check(err, IsNil) + defer pubsub.Close() ch, err := pubsub.Subscribe("mychannel") c.Check(err, IsNil) @@ -1749,10 +1749,15 @@ func (t *RedisTest) TestPipelining(c *C) { c.Check(set.Err(), IsNil) c.Check(set.Val(), Equals, "OK") - setReq := t.multiClient.Set("foo1", "bar1") - getReq := t.multiClient.Get("foo2") + multi, err := t.client.MultiClient() + c.Check(err, IsNil) - reqs, err := t.multiClient.RunQueued() + multi.Multi() + + setReq := multi.Set("foo1", "bar1") + getReq := multi.Get("foo2") + + reqs, err := multi.RunQueued() c.Check(err, IsNil) c.Check(reqs, HasLen, 2) @@ -1769,16 +1774,74 @@ func (t *RedisTest) TestRunQueuedOnEmptyQueue(c *C) { c.Check(reqs, HasLen, 0) } +func (t *RedisTest) TestIncrPipeliningFromGoroutines(c *C) { + multi, err := t.client.PipelineClient() + c.Check(err, IsNil) + defer multi.Close() + + wg := &sync.WaitGroup{} + for i := int64(0); i < 20000; i++ { + wg.Add(1) + go func() { + multi.Incr("TestIncrPipeliningFromGoroutinesKey") + wg.Done() + }() + } + wg.Wait() + + reqs, err := multi.RunQueued() + c.Check(err, IsNil) + c.Check(reqs, HasLen, 20000) + for _, req := range reqs { + if req.Err() != nil { + c.Errorf("got %v, expected nil", req.Err()) + } + } + + get := t.client.Get("TestIncrPipeliningFromGoroutinesKey") + c.Check(get.Err(), IsNil) + c.Check(get.Val(), Equals, "20000") +} + +func (t *RedisTest) TestPipeliningFromGoroutines(c *C) { + multi, err := t.client.PipelineClient() + c.Check(err, IsNil) + defer multi.Close() + + for i := int64(0); i < 1000; i += 2 { + go func() { + msg1 := "echo" + strconv.FormatInt(i, 10) + msg2 := "echo" + strconv.FormatInt(i+1, 10) + + echo1Req := multi.Echo(msg1) + echo2Req := multi.Echo(msg2) + + reqs, err := multi.RunQueued() + c.Check(reqs, HasLen, 2) + c.Check(err, IsNil) + + c.Check(echo1Req.Err(), IsNil) + c.Check(echo1Req.Val(), Equals, msg1) + + c.Check(echo2Req.Err(), IsNil) + c.Check(echo2Req.Val(), Equals, msg2) + }() + } +} + //------------------------------------------------------------------------------ func (t *RedisTest) TestDiscard(c *C) { - multiC := t.client.Multi() + multi, err := t.client.MultiClient() + c.Check(err, IsNil) - multiC.Set("foo1", "bar1") - multiC.Discard() - multiC.Set("foo2", "bar2") + multi.Multi() - reqs, err := multiC.Exec() + multi.Set("foo1", "bar1") + multi.Discard() + multi.Set("foo2", "bar2") + + reqs, err := multi.Exec() c.Check(err, IsNil) c.Check(reqs, HasLen, 1) @@ -1792,12 +1855,15 @@ func (t *RedisTest) TestDiscard(c *C) { } func (t *RedisTest) TestMultiExec(c *C) { - multiC := t.client.Multi() + multi, err := t.client.MultiClient() + c.Check(err, IsNil) - setR := multiC.Set("foo", "bar") - getR := multiC.Get("foo") + multi.Multi() - reqs, err := multiC.Exec() + setR := multi.Set("foo", "bar") + getR := multi.Get("foo") + + reqs, err := multi.Exec() c.Check(err, IsNil) c.Check(reqs, HasLen, 2) @@ -1827,30 +1893,6 @@ func (t *RedisTest) TestEchoFromGoroutines(c *C) { } } -func (t *RedisTest) TestPipeliningFromGoroutines(c *C) { - multiClient := t.client.Multi() - - for i := int64(0); i < 1000; i += 2 { - go func() { - msg1 := "echo" + strconv.FormatInt(i, 10) - msg2 := "echo" + strconv.FormatInt(i+1, 10) - - echo1Req := multiClient.Echo(msg1) - echo2Req := multiClient.Echo(msg2) - - reqs, err := multiClient.RunQueued() - c.Check(reqs, HasLen, 2) - c.Check(err, IsNil) - - c.Check(echo1Req.Err(), IsNil) - c.Check(echo1Req.Val(), Equals, msg1) - - c.Check(echo2Req.Err(), IsNil) - c.Check(echo2Req.Val(), Equals, msg2) - }() - } -} - func (t *RedisTest) TestIncrFromGoroutines(c *C) { wg := &sync.WaitGroup{} for i := int64(0); i < 20000; i++ { @@ -1868,47 +1910,23 @@ func (t *RedisTest) TestIncrFromGoroutines(c *C) { c.Check(get.Val(), Equals, "20000") } -func (t *RedisTest) TestIncrPipeliningFromGoroutines(c *C) { - multiClient := t.client.Multi() - - wg := &sync.WaitGroup{} - for i := int64(0); i < 20000; i++ { - wg.Add(1) - go func() { - multiClient.Incr("TestIncrPipeliningFromGoroutinesKey") - wg.Done() - }() - } - wg.Wait() - - reqs, err := multiClient.RunQueued() - c.Check(err, IsNil) - c.Check(reqs, HasLen, 20000) - for _, req := range reqs { - if req.Err() != nil { - c.Errorf("got %v, expected nil", req.Err()) - } - } - - get := t.client.Get("TestIncrPipeliningFromGoroutinesKey") - c.Check(get.Err(), IsNil) - c.Check(get.Val(), Equals, "20000") -} - func (t *RedisTest) TestIncrTransaction(c *C) { - multiClient := t.client.Multi() + multi, err := t.client.MultiClient() + c.Check(err, IsNil) + + multi.Multi() wg := &sync.WaitGroup{} for i := int64(0); i < 20000; i++ { wg.Add(1) go func() { - multiClient.Incr("TestIncrTransactionKey") + multi.Incr("TestIncrTransactionKey") wg.Done() }() } wg.Wait() - reqs, err := multiClient.Exec() + reqs, err := multi.Exec() c.Check(err, IsNil) c.Check(reqs, HasLen, 20000) for _, req := range reqs { @@ -1922,6 +1940,55 @@ func (t *RedisTest) TestIncrTransaction(c *C) { c.Check(get.Val(), Equals, "20000") } +func (t *RedisTest) transactionalIncr(c *C, wg *sync.WaitGroup) { + multi, err := t.client.MultiClient() + c.Check(err, IsNil) + defer multi.Close() + + watch := multi.Watch("foo") + c.Check(watch.Err(), IsNil) + c.Check(watch.Val(), Equals, "OK") + + get := multi.Get("foo") + c.Check(get.Err(), IsNil) + c.Check(get.Val(), Not(Equals), redis.Nil) + + v, err := strconv.ParseInt(get.Val(), 10, 64) + c.Check(err, IsNil) + + multi.Multi() + set := multi.Set("foo", strconv.FormatInt(v+1, 10)) + reqs, err := multi.Exec() + if err == redis.Nil { + t.transactionalIncr(c, wg) + return + } + c.Check(reqs, HasLen, 1) + c.Check(err, IsNil) + + c.Check(set.Err(), IsNil) + c.Check(set.Val(), Equals, "OK") + + wg.Done() +} + +func (t *RedisTest) TestWatchUnwatch(c *C) { + set := t.client.Set("foo", "0") + c.Check(set.Err(), IsNil) + c.Check(set.Val(), Equals, "OK") + + wg := &sync.WaitGroup{} + for i := 0; i < 1000; i++ { + wg.Add(1) + go t.transactionalIncr(c, wg) + } + wg.Wait() + + get := t.client.Get("foo") + c.Check(get.Err(), IsNil) + c.Check(get.Val(), Equals, "1000") +} + //------------------------------------------------------------------------------ func (t *RedisTest) BenchmarkRedisPing(c *C) { diff --git a/request.go b/request.go index 6130317..4e78e74 100644 --- a/request.go +++ b/request.go @@ -20,8 +20,12 @@ func isEmpty(line []byte) bool { return len(line) == 2 && line[0] == '$' && line[1] == '0' } +func isNilReplies(line []byte) bool { + return len(line) == 3 && line[0] == '*' && line[1] == '-' && line[2] == '1' +} + func isNoReplies(line []byte) bool { - return len(line) >= 2 && line[1] == '*' && line[1] == '0' + return len(line) == 2 && line[1] == '*' && line[1] == '0' } //------------------------------------------------------------------------------ @@ -401,7 +405,7 @@ func (r *MultiBulkReq) ParseReply(rd ReadLiner) (interface{}, error) { return nil, errors.New(string(line[1:])) } else if line[0] != '*' { return nil, fmt.Errorf("Expected '*', but got line %q", line) - } else if isNil(line) { + } else if isNilReplies(line) { return nil, Nil } @@ -420,14 +424,15 @@ func (r *MultiBulkReq) ParseReply(rd ReadLiner) (interface{}, error) { return nil, err } - if line[0] == ':' { + switch line[0] { + case ':': var n int64 n, err = strconv.ParseInt(string(line[1:]), 10, 64) if err != nil { return nil, err } val = append(val, n) - } else if line[0] == '$' { + case '$': if isEmpty(line) { val = append(val, "") } else if isNil(line) { @@ -439,7 +444,7 @@ func (r *MultiBulkReq) ParseReply(rd ReadLiner) (interface{}, error) { } val = append(val, string(line)) } - } else { + default: return nil, fmt.Errorf("Expected '$', but got line %q", line) } }