Small API changes.

This commit is contained in:
Vladimir Mihailenco 2013-09-29 10:17:39 +03:00
parent e90db6f494
commit 7f11168689
5 changed files with 97 additions and 125 deletions

View File

@ -55,32 +55,32 @@ func ExamplePipeline() {
var set *redis.StatusReq var set *redis.StatusReq
var get *redis.StringReq var get *redis.StringReq
reqs, err := client.Pipelined(func(c *redis.PipelineClient) { reqs, err := client.Pipelined(func(c *redis.Pipeline) {
set = c.Set("key1", "hello1") set = c.Set("key1", "hello1")
get = c.Get("key2") get = c.Get("key2")
}) })
fmt.Println(err, reqs) fmt.Println(err, reqs)
fmt.Println(set) fmt.Println(set)
fmt.Println(get) fmt.Println(get)
// Output: <nil> [SET key1 hello1: OK GET key2: (nil)] // Output: (nil) [SET key1 hello1: OK GET key2: (nil)]
// SET key1 hello1: OK // SET key1 hello1: OK
// GET key2: (nil) // GET key2: (nil)
} }
func transaction(multi *redis.MultiClient) ([]redis.Req, error) { func incr(tx *redis.Multi) ([]redis.Req, error) {
get := multi.Get("key") get := tx.Get("key")
if err := get.Err(); err != nil && err != redis.Nil { if err := get.Err(); err != nil && err != redis.Nil {
return nil, err return nil, err
} }
val, _ := strconv.ParseInt(get.Val(), 10, 64) val, _ := strconv.ParseInt(get.Val(), 10, 64)
reqs, err := multi.Exec(func() { reqs, err := tx.Exec(func() {
multi.Set("key", strconv.FormatInt(val+1, 10)) tx.Set("key", strconv.FormatInt(val+1, 10))
}) })
// Transaction failed. Repeat. // Transaction failed. Repeat.
if err == redis.Nil { if err == redis.Nil {
return transaction(multi) return incr(tx)
} }
return reqs, err return reqs, err
} }
@ -93,14 +93,13 @@ func ExampleTransaction() {
client.Del("key") client.Del("key")
multi, err := client.MultiClient() tx := client.Multi()
_ = err defer tx.Close()
defer multi.Close()
watch := multi.Watch("key") watch := tx.Watch("key")
_ = watch.Err() _ = watch.Err()
reqs, err := transaction(multi) reqs, err := incr(tx)
fmt.Println(err, reqs) fmt.Println(err, reqs)
// Output: <nil> [SET key 1: OK] // Output: <nil> [SET key 1: OK]

View File

@ -1,67 +1,62 @@
package redis package redis
import ( import (
"errors"
"fmt" "fmt"
"sync"
) )
type MultiClient struct { var errDiscard = errors.New("redis: Discard can be used only inside Exec")
// Not thread-safe.
type Multi struct {
*Client *Client
execMtx sync.Mutex
} }
func (c *Client) MultiClient() (*MultiClient, error) { func (c *Client) Multi() *Multi {
return &MultiClient{ return &Multi{
Client: &Client{ Client: &Client{
baseClient: &baseClient{ baseClient: &baseClient{
opt: c.opt, opt: c.opt,
connPool: newSingleConnPool(c.connPool, nil, true), connPool: newSingleConnPool(c.connPool, nil, true),
}, },
}, },
}, nil }
} }
func (c *MultiClient) Close() error { func (c *Multi) Close() error {
c.Unwatch() c.Unwatch()
return c.Client.Close() return c.Client.Close()
} }
func (c *MultiClient) Watch(keys ...string) *StatusReq { func (c *Multi) Watch(keys ...string) *StatusReq {
args := append([]string{"WATCH"}, keys...) args := append([]string{"WATCH"}, keys...)
req := NewStatusReq(args...) req := NewStatusReq(args...)
c.Process(req) c.Process(req)
return req return req
} }
func (c *MultiClient) Unwatch(keys ...string) *StatusReq { func (c *Multi) Unwatch(keys ...string) *StatusReq {
args := append([]string{"UNWATCH"}, keys...) args := append([]string{"UNWATCH"}, keys...)
req := NewStatusReq(args...) req := NewStatusReq(args...)
c.Process(req) c.Process(req)
return req return req
} }
func (c *MultiClient) Discard() { func (c *Multi) Discard() error {
c.reqsMtx.Lock()
if c.reqs == nil { if c.reqs == nil {
panic("Discard can be used only inside Exec") return errDiscard
} }
c.reqs = c.reqs[:1] c.reqs = c.reqs[:1]
c.reqsMtx.Unlock() return nil
} }
func (c *MultiClient) Exec(do func()) ([]Req, error) { func (c *Multi) Exec(f func()) ([]Req, error) {
c.reqsMtx.Lock()
c.reqs = []Req{NewStatusReq("MULTI")} c.reqs = []Req{NewStatusReq("MULTI")}
c.reqsMtx.Unlock() f()
c.reqs = append(c.reqs, NewIfaceSliceReq("EXEC"))
do()
c.queue(NewIfaceSliceReq("EXEC"))
c.reqsMtx.Lock()
reqs := c.reqs reqs := c.reqs
c.reqs = nil c.reqs = nil
c.reqsMtx.Unlock()
if len(reqs) == 2 { if len(reqs) == 2 {
return []Req{}, nil return []Req{}, nil
@ -73,9 +68,7 @@ func (c *MultiClient) Exec(do func()) ([]Req, error) {
} }
// Synchronize writes and reads to the connection using mutex. // Synchronize writes and reads to the connection using mutex.
c.execMtx.Lock()
err = c.execReqs(reqs, cn) err = c.execReqs(reqs, cn)
c.execMtx.Unlock()
if err != nil { if err != nil {
c.removeConn(cn) c.removeConn(cn)
return nil, err return nil, err
@ -85,7 +78,7 @@ func (c *MultiClient) Exec(do func()) ([]Req, error) {
return reqs[1 : len(reqs)-1], nil return reqs[1 : len(reqs)-1], nil
} }
func (c *MultiClient) execReqs(reqs []Req, cn *conn) error { func (c *Multi) execReqs(reqs []Req, cn *conn) error {
err := c.writeReq(cn, reqs...) err := c.writeReq(cn, reqs...)
if err != nil { if err != nil {
return err return err
@ -110,7 +103,7 @@ func (c *MultiClient) execReqs(reqs []Req, cn *conn) error {
return err return err
} }
if line[0] != '*' { if line[0] != '*' {
return fmt.Errorf("Expected '*', but got line %q", line) return fmt.Errorf("redis: expected '*', but got line %q", line)
} }
if len(line) == 3 && line[1] == '-' && line[2] == '1' { if len(line) == 3 && line[1] == '-' && line[2] == '1' {
return Nil return Nil

View File

@ -1,13 +1,12 @@
package redis package redis
type PipelineClient struct { // Not thread-safe.
type Pipeline struct {
*Client *Client
} }
// TODO: rename to Pipeline func (c *Client) Pipeline() *Pipeline {
// TODO: return just *PipelineClient return &Pipeline{
func (c *Client) PipelineClient() (*PipelineClient, error) {
return &PipelineClient{
Client: &Client{ Client: &Client{
baseClient: &baseClient{ baseClient: &baseClient{
opt: c.opt, opt: c.opt,
@ -16,38 +15,31 @@ func (c *Client) PipelineClient() (*PipelineClient, error) {
reqs: make([]Req, 0), reqs: make([]Req, 0),
}, },
}, },
}, nil
}
func (c *Client) Pipelined(do func(*PipelineClient)) ([]Req, error) {
pc, err := c.PipelineClient()
if err != nil {
return nil, err
} }
defer pc.Close()
do(pc)
return pc.RunQueued()
} }
func (c *PipelineClient) Close() error { func (c *Client) Pipelined(f func(*Pipeline)) ([]Req, error) {
pc := c.Pipeline()
f(pc)
reqs, err := pc.Exec()
pc.Close()
return reqs, err
}
func (c *Pipeline) Close() error {
return nil return nil
} }
func (c *PipelineClient) DiscardQueued() { func (c *Pipeline) Discard() error {
c.reqsMtx.Lock()
c.reqs = c.reqs[:0] c.reqs = c.reqs[:0]
c.reqsMtx.Unlock() return nil
} }
// TODO: rename to Run or ... // Always returns list of commands and error of the first failed
// TODO: should return error if one of the commands failed // command if any.
func (c *PipelineClient) RunQueued() ([]Req, error) { func (c *Pipeline) Exec() ([]Req, error) {
c.reqsMtx.Lock()
reqs := c.reqs reqs := c.reqs
c.reqs = make([]Req, 0) c.reqs = make([]Req, 0)
c.reqsMtx.Unlock()
if len(reqs) == 0 { if len(reqs) == 0 {
return []Req{}, nil return []Req{}, nil
@ -55,34 +47,39 @@ func (c *PipelineClient) RunQueued() ([]Req, error) {
cn, err := c.conn() cn, err := c.conn()
if err != nil { if err != nil {
return nil, err return reqs, err
} }
if err := c.runReqs(reqs, cn); err != nil { if err := c.execReqs(reqs, cn); err != nil {
c.removeConn(cn) c.freeConn(cn, err)
return nil, err return reqs, err
} }
c.putConn(cn) c.putConn(cn)
return reqs, nil return reqs, nil
} }
func (c *PipelineClient) runReqs(reqs []Req, cn *conn) error { func (c *Pipeline) execReqs(reqs []Req, cn *conn) error {
err := c.writeReq(cn, reqs...) err := c.writeReq(cn, reqs...)
if err != nil { if err != nil {
for _, req := range reqs {
req.SetErr(err)
}
return err return err
} }
reqsLen := len(reqs) var firstReqErr error
for i := 0; i < reqsLen; i++ { for _, req := range reqs {
req := reqs[i]
val, err := req.ParseReply(cn.Rd) val, err := req.ParseReply(cn.Rd)
if err != nil { if err != nil {
req.SetErr(err) req.SetErr(err)
if err != nil {
firstReqErr = err
}
} else { } else {
req.SetVal(val) req.SetVal(val)
} }
} }
return nil return firstReqErr
} }

View File

@ -5,7 +5,6 @@ import (
"log" "log"
"net" "net"
"os" "os"
"sync"
"time" "time"
) )
@ -20,7 +19,6 @@ type baseClient struct {
opt *Options opt *Options
reqs []Req reqs []Req
reqsMtx sync.Mutex
} }
func (c *baseClient) writeReq(cn *conn, reqs ...Req) error { func (c *baseClient) writeReq(cn *conn, reqs ...Req) error {
@ -75,6 +73,14 @@ func (c *baseClient) init(cn *conn, password string, db int64) error {
return nil return nil
} }
func (c *baseClient) freeConn(cn *conn, err error) {
if err == Nil {
c.putConn(cn)
} else {
c.removeConn(cn)
}
}
func (c *baseClient) removeConn(cn *conn) { func (c *baseClient) removeConn(cn *conn) {
if err := c.connPool.Remove(cn); err != nil { if err := c.connPool.Remove(cn); err != nil {
Logger.Printf("connPool.Remove error: %v", err) Logger.Printf("connPool.Remove error: %v", err)
@ -91,7 +97,7 @@ func (c *baseClient) Process(req Req) {
if c.reqs == nil { if c.reqs == nil {
c.run(req) c.run(req)
} else { } else {
c.queue(req) c.reqs = append(c.reqs, req)
} }
} }
@ -120,11 +126,7 @@ func (c *baseClient) run(req Req) {
val, err := req.ParseReply(cn.Rd) val, err := req.ParseReply(cn.Rd)
if err != nil { if err != nil {
if err == Nil { c.freeConn(cn, err)
c.putConn(cn)
} else {
c.removeConn(cn)
}
req.SetErr(err) req.SetErr(err)
return return
} }
@ -133,13 +135,6 @@ func (c *baseClient) run(req Req) {
req.SetVal(val) req.SetVal(val)
} }
// Queues request to be executed later.
func (c *baseClient) queue(req Req) {
c.reqsMtx.Lock()
c.reqs = append(c.reqs, req)
c.reqsMtx.Unlock()
}
func (c *baseClient) Close() error { func (c *baseClient) Close() error {
return c.connPool.Close() return c.connPool.Close()
} }

View File

@ -2393,8 +2393,7 @@ func (t *RedisTest) TestPipeline(c *C) {
c.Assert(set.Err(), IsNil) c.Assert(set.Err(), IsNil)
c.Assert(set.Val(), Equals, "OK") c.Assert(set.Val(), Equals, "OK")
pipeline, err := t.client.PipelineClient() pipeline := t.client.Pipeline()
c.Assert(err, IsNil)
defer func() { defer func() {
c.Assert(pipeline.Close(), IsNil) c.Assert(pipeline.Close(), IsNil)
}() }()
@ -2404,8 +2403,8 @@ func (t *RedisTest) TestPipeline(c *C) {
incr := pipeline.Incr("key3") incr := pipeline.Incr("key3")
getNil := pipeline.Get("key4") getNil := pipeline.Get("key4")
reqs, err := pipeline.RunQueued() reqs, err := pipeline.Exec()
c.Assert(err, IsNil) c.Assert(err, Equals, redis.Nil)
c.Assert(reqs, HasLen, 4) c.Assert(reqs, HasLen, 4)
c.Assert(set.Err(), IsNil) c.Assert(set.Err(), IsNil)
@ -2422,33 +2421,31 @@ func (t *RedisTest) TestPipeline(c *C) {
} }
func (t *RedisTest) TestPipelineDiscardQueued(c *C) { func (t *RedisTest) TestPipelineDiscardQueued(c *C) {
pipeline, err := t.client.PipelineClient() pipeline := t.client.Pipeline()
c.Assert(err, IsNil)
defer func() { defer func() {
c.Assert(pipeline.Close(), IsNil) c.Assert(pipeline.Close(), IsNil)
}() }()
pipeline.Get("key") pipeline.Get("key")
pipeline.DiscardQueued() pipeline.Discard()
reqs, err := pipeline.RunQueued() reqs, err := pipeline.Exec()
c.Assert(err, IsNil) c.Assert(err, IsNil)
c.Assert(reqs, HasLen, 0) c.Assert(reqs, HasLen, 0)
} }
func (t *RedisTest) TestPipelineFunc(c *C) { func (t *RedisTest) TestPipelineFunc(c *C) {
var get *redis.StringReq var get *redis.StringReq
reqs, err := t.client.Pipelined(func(c *redis.PipelineClient) { reqs, err := t.client.Pipelined(func(c *redis.Pipeline) {
get = c.Get("foo") get = c.Get("foo")
}) })
c.Assert(err, IsNil) c.Assert(err, Equals, redis.Nil)
c.Assert(reqs, HasLen, 1) c.Assert(reqs, HasLen, 1)
c.Assert(get.Err(), Equals, redis.Nil) c.Assert(get.Err(), Equals, redis.Nil)
c.Assert(get.Val(), Equals, "") c.Assert(get.Val(), Equals, "")
} }
func (t *RedisTest) TestPipelineErrValNotSet(c *C) { func (t *RedisTest) TestPipelineErrValNotSet(c *C) {
pipeline, err := t.client.PipelineClient() pipeline := t.client.Pipeline()
c.Assert(err, IsNil)
defer func() { defer func() {
c.Assert(pipeline.Close(), IsNil) c.Assert(pipeline.Close(), IsNil)
}() }()
@ -2458,20 +2455,18 @@ func (t *RedisTest) TestPipelineErrValNotSet(c *C) {
} }
func (t *RedisTest) TestPipelineRunQueuedOnEmptyQueue(c *C) { func (t *RedisTest) TestPipelineRunQueuedOnEmptyQueue(c *C) {
pipeline, err := t.client.PipelineClient() pipeline := t.client.Pipeline()
c.Assert(err, IsNil)
defer func() { defer func() {
c.Assert(pipeline.Close(), IsNil) c.Assert(pipeline.Close(), IsNil)
}() }()
reqs, err := pipeline.RunQueued() reqs, err := pipeline.Exec()
c.Assert(err, IsNil) c.Assert(err, IsNil)
c.Assert(reqs, HasLen, 0) c.Assert(reqs, HasLen, 0)
} }
func (t *RedisTest) TestPipelineIncrFromGoroutines(c *C) { func (t *RedisTest) TestPipelineIncrFromGoroutines(c *C) {
pipeline, err := t.client.PipelineClient() pipeline := t.client.Pipeline()
c.Assert(err, IsNil)
defer func() { defer func() {
c.Assert(pipeline.Close(), IsNil) c.Assert(pipeline.Close(), IsNil)
}() }()
@ -2486,7 +2481,7 @@ func (t *RedisTest) TestPipelineIncrFromGoroutines(c *C) {
} }
wg.Wait() wg.Wait()
reqs, err := pipeline.RunQueued() reqs, err := pipeline.Exec()
c.Assert(err, IsNil) c.Assert(err, IsNil)
c.Assert(reqs, HasLen, 20000) c.Assert(reqs, HasLen, 20000)
for _, req := range reqs { for _, req := range reqs {
@ -2501,8 +2496,7 @@ func (t *RedisTest) TestPipelineIncrFromGoroutines(c *C) {
} }
func (t *RedisTest) TestPipelineEchoFromGoroutines(c *C) { func (t *RedisTest) TestPipelineEchoFromGoroutines(c *C) {
pipeline, err := t.client.PipelineClient() pipeline := t.client.Pipeline()
c.Assert(err, IsNil)
defer func() { defer func() {
c.Assert(pipeline.Close(), IsNil) c.Assert(pipeline.Close(), IsNil)
}() }()
@ -2517,7 +2511,7 @@ func (t *RedisTest) TestPipelineEchoFromGoroutines(c *C) {
echo1 := pipeline.Echo(msg1) echo1 := pipeline.Echo(msg1)
echo2 := pipeline.Echo(msg2) echo2 := pipeline.Echo(msg2)
reqs, err := pipeline.RunQueued() reqs, err := pipeline.Exec()
c.Assert(err, IsNil) c.Assert(err, IsNil)
c.Assert(reqs, HasLen, 2) c.Assert(reqs, HasLen, 2)
@ -2536,8 +2530,7 @@ func (t *RedisTest) TestPipelineEchoFromGoroutines(c *C) {
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
func (t *RedisTest) TestMultiExec(c *C) { func (t *RedisTest) TestMultiExec(c *C) {
multi, err := t.client.MultiClient() multi := t.client.Multi()
c.Assert(err, IsNil)
defer func() { defer func() {
c.Assert(multi.Close(), IsNil) c.Assert(multi.Close(), IsNil)
}() }()
@ -2561,8 +2554,7 @@ func (t *RedisTest) TestMultiExec(c *C) {
} }
func (t *RedisTest) TestMultiExecDiscard(c *C) { func (t *RedisTest) TestMultiExecDiscard(c *C) {
multi, err := t.client.MultiClient() multi := t.client.Multi()
c.Assert(err, IsNil)
defer func() { defer func() {
c.Assert(multi.Close(), IsNil) c.Assert(multi.Close(), IsNil)
}() }()
@ -2585,8 +2577,7 @@ func (t *RedisTest) TestMultiExecDiscard(c *C) {
} }
func (t *RedisTest) TestMultiExecEmpty(c *C) { func (t *RedisTest) TestMultiExecEmpty(c *C) {
multi, err := t.client.MultiClient() multi := t.client.Multi()
c.Assert(err, IsNil)
defer func() { defer func() {
c.Assert(multi.Close(), IsNil) c.Assert(multi.Close(), IsNil)
}() }()
@ -2601,8 +2592,7 @@ func (t *RedisTest) TestMultiExecEmpty(c *C) {
} }
func (t *RedisTest) TestMultiExecOnEmptyQueue(c *C) { func (t *RedisTest) TestMultiExecOnEmptyQueue(c *C) {
multi, err := t.client.MultiClient() multi := t.client.Multi()
c.Assert(err, IsNil)
defer func() { defer func() {
c.Assert(multi.Close(), IsNil) c.Assert(multi.Close(), IsNil)
}() }()
@ -2612,16 +2602,15 @@ func (t *RedisTest) TestMultiExecOnEmptyQueue(c *C) {
c.Assert(reqs, HasLen, 0) c.Assert(reqs, HasLen, 0)
} }
func (t *RedisTest) TestMultiExecIncrTransaction(c *C) { func (t *RedisTest) TestMultiExecIncr(c *C) {
multi, err := t.client.MultiClient() multi := t.client.Multi()
c.Assert(err, IsNil)
defer func() { defer func() {
c.Assert(multi.Close(), IsNil) c.Assert(multi.Close(), IsNil)
}() }()
reqs, err := multi.Exec(func() { reqs, err := multi.Exec(func() {
for i := int64(0); i < 20000; i++ { for i := int64(0); i < 20000; i++ {
multi.Incr("TestIncrTransactionKey") multi.Incr("key")
} }
}) })
c.Assert(err, IsNil) c.Assert(err, IsNil)
@ -2632,14 +2621,13 @@ func (t *RedisTest) TestMultiExecIncrTransaction(c *C) {
} }
} }
get := t.client.Get("TestIncrTransactionKey") get := t.client.Get("key")
c.Assert(get.Err(), IsNil) c.Assert(get.Err(), IsNil)
c.Assert(get.Val(), Equals, "20000") c.Assert(get.Val(), Equals, "20000")
} }
func (t *RedisTest) transactionalIncr(c *C) ([]redis.Req, error) { func (t *RedisTest) transactionalIncr(c *C) ([]redis.Req, error) {
multi, err := t.client.MultiClient() multi := t.client.Multi()
c.Assert(err, IsNil)
defer func() { defer func() {
c.Assert(multi.Close(), IsNil) c.Assert(multi.Close(), IsNil)
}() }()