diff --git a/v2/example_test.go b/v2/example_test.go index 190a6e5..e2a3658 100644 --- a/v2/example_test.go +++ b/v2/example_test.go @@ -55,32 +55,32 @@ func ExamplePipeline() { var set *redis.StatusReq var get *redis.StringReq - reqs, err := client.Pipelined(func(c *redis.PipelineClient) { + reqs, err := client.Pipelined(func(c *redis.Pipeline) { set = c.Set("key1", "hello1") get = c.Get("key2") }) fmt.Println(err, reqs) fmt.Println(set) fmt.Println(get) - // Output: [SET key1 hello1: OK GET key2: (nil)] + // Output: (nil) [SET key1 hello1: OK GET key2: (nil)] // SET key1 hello1: OK // GET key2: (nil) } -func transaction(multi *redis.MultiClient) ([]redis.Req, error) { - get := multi.Get("key") +func incr(tx *redis.Multi) ([]redis.Req, error) { + get := tx.Get("key") if err := get.Err(); err != nil && err != redis.Nil { return nil, err } val, _ := strconv.ParseInt(get.Val(), 10, 64) - reqs, err := multi.Exec(func() { - multi.Set("key", strconv.FormatInt(val+1, 10)) + reqs, err := tx.Exec(func() { + tx.Set("key", strconv.FormatInt(val+1, 10)) }) // Transaction failed. Repeat. if err == redis.Nil { - return transaction(multi) + return incr(tx) } return reqs, err } @@ -93,14 +93,13 @@ func ExampleTransaction() { client.Del("key") - multi, err := client.MultiClient() - _ = err - defer multi.Close() + tx := client.Multi() + defer tx.Close() - watch := multi.Watch("key") + watch := tx.Watch("key") _ = watch.Err() - reqs, err := transaction(multi) + reqs, err := incr(tx) fmt.Println(err, reqs) // Output: [SET key 1: OK] diff --git a/v2/multi.go b/v2/multi.go index c2bf9b8..737a3a9 100644 --- a/v2/multi.go +++ b/v2/multi.go @@ -1,67 +1,62 @@ package redis import ( + "errors" "fmt" - "sync" ) -type MultiClient struct { +var errDiscard = errors.New("redis: Discard can be used only inside Exec") + +// Not thread-safe. +type Multi struct { *Client - execMtx sync.Mutex } -func (c *Client) MultiClient() (*MultiClient, error) { - return &MultiClient{ +func (c *Client) Multi() *Multi { + return &Multi{ Client: &Client{ baseClient: &baseClient{ opt: c.opt, connPool: newSingleConnPool(c.connPool, nil, true), }, }, - }, nil + } } -func (c *MultiClient) Close() error { +func (c *Multi) Close() error { c.Unwatch() return c.Client.Close() } -func (c *MultiClient) Watch(keys ...string) *StatusReq { +func (c *Multi) Watch(keys ...string) *StatusReq { args := append([]string{"WATCH"}, keys...) req := NewStatusReq(args...) c.Process(req) return req } -func (c *MultiClient) Unwatch(keys ...string) *StatusReq { +func (c *Multi) Unwatch(keys ...string) *StatusReq { args := append([]string{"UNWATCH"}, keys...) req := NewStatusReq(args...) c.Process(req) return req } -func (c *MultiClient) Discard() { - c.reqsMtx.Lock() +func (c *Multi) Discard() error { if c.reqs == nil { - panic("Discard can be used only inside Exec") + return errDiscard } c.reqs = c.reqs[:1] - c.reqsMtx.Unlock() + return nil } -func (c *MultiClient) Exec(do func()) ([]Req, error) { - c.reqsMtx.Lock() +func (c *Multi) Exec(f func()) ([]Req, error) { c.reqs = []Req{NewStatusReq("MULTI")} - c.reqsMtx.Unlock() + f() + c.reqs = append(c.reqs, NewIfaceSliceReq("EXEC")) - do() - - c.queue(NewIfaceSliceReq("EXEC")) - - c.reqsMtx.Lock() reqs := c.reqs c.reqs = nil - c.reqsMtx.Unlock() if len(reqs) == 2 { return []Req{}, nil @@ -73,9 +68,7 @@ func (c *MultiClient) Exec(do func()) ([]Req, error) { } // Synchronize writes and reads to the connection using mutex. - c.execMtx.Lock() err = c.execReqs(reqs, cn) - c.execMtx.Unlock() if err != nil { c.removeConn(cn) return nil, err @@ -85,7 +78,7 @@ func (c *MultiClient) Exec(do func()) ([]Req, error) { return reqs[1 : len(reqs)-1], nil } -func (c *MultiClient) execReqs(reqs []Req, cn *conn) error { +func (c *Multi) execReqs(reqs []Req, cn *conn) error { err := c.writeReq(cn, reqs...) if err != nil { return err @@ -110,7 +103,7 @@ func (c *MultiClient) execReqs(reqs []Req, cn *conn) error { return err } if line[0] != '*' { - return fmt.Errorf("Expected '*', but got line %q", line) + return fmt.Errorf("redis: expected '*', but got line %q", line) } if len(line) == 3 && line[1] == '-' && line[2] == '1' { return Nil diff --git a/v2/pipeline.go b/v2/pipeline.go index 6a9bafe..6acec64 100644 --- a/v2/pipeline.go +++ b/v2/pipeline.go @@ -1,13 +1,12 @@ package redis -type PipelineClient struct { +// Not thread-safe. +type Pipeline struct { *Client } -// TODO: rename to Pipeline -// TODO: return just *PipelineClient -func (c *Client) PipelineClient() (*PipelineClient, error) { - return &PipelineClient{ +func (c *Client) Pipeline() *Pipeline { + return &Pipeline{ Client: &Client{ baseClient: &baseClient{ opt: c.opt, @@ -16,38 +15,31 @@ func (c *Client) PipelineClient() (*PipelineClient, error) { reqs: make([]Req, 0), }, }, - }, nil -} - -func (c *Client) Pipelined(do func(*PipelineClient)) ([]Req, error) { - pc, err := c.PipelineClient() - if err != nil { - return nil, err } - defer pc.Close() - - do(pc) - - return pc.RunQueued() } -func (c *PipelineClient) Close() error { +func (c *Client) Pipelined(f func(*Pipeline)) ([]Req, error) { + pc := c.Pipeline() + f(pc) + reqs, err := pc.Exec() + pc.Close() + return reqs, err +} + +func (c *Pipeline) Close() error { return nil } -func (c *PipelineClient) DiscardQueued() { - c.reqsMtx.Lock() +func (c *Pipeline) Discard() error { c.reqs = c.reqs[:0] - c.reqsMtx.Unlock() + return nil } -// TODO: rename to Run or ... -// TODO: should return error if one of the commands failed -func (c *PipelineClient) RunQueued() ([]Req, error) { - c.reqsMtx.Lock() +// Always returns list of commands and error of the first failed +// command if any. +func (c *Pipeline) Exec() ([]Req, error) { reqs := c.reqs c.reqs = make([]Req, 0) - c.reqsMtx.Unlock() if len(reqs) == 0 { return []Req{}, nil @@ -55,34 +47,39 @@ func (c *PipelineClient) RunQueued() ([]Req, error) { cn, err := c.conn() if err != nil { - return nil, err + return reqs, err } - if err := c.runReqs(reqs, cn); err != nil { - c.removeConn(cn) - return nil, err + if err := c.execReqs(reqs, cn); err != nil { + c.freeConn(cn, err) + return reqs, err } c.putConn(cn) return reqs, nil } -func (c *PipelineClient) runReqs(reqs []Req, cn *conn) error { +func (c *Pipeline) execReqs(reqs []Req, cn *conn) error { err := c.writeReq(cn, reqs...) if err != nil { + for _, req := range reqs { + req.SetErr(err) + } return err } - reqsLen := len(reqs) - for i := 0; i < reqsLen; i++ { - req := reqs[i] + var firstReqErr error + for _, req := range reqs { val, err := req.ParseReply(cn.Rd) if err != nil { req.SetErr(err) + if err != nil { + firstReqErr = err + } } else { req.SetVal(val) } } - return nil + return firstReqErr } diff --git a/v2/redis.go b/v2/redis.go index 6230f1f..f259b1c 100644 --- a/v2/redis.go +++ b/v2/redis.go @@ -5,7 +5,6 @@ import ( "log" "net" "os" - "sync" "time" ) @@ -19,8 +18,7 @@ type baseClient struct { opt *Options - reqs []Req - reqsMtx sync.Mutex + reqs []Req } func (c *baseClient) writeReq(cn *conn, reqs ...Req) error { @@ -75,6 +73,14 @@ func (c *baseClient) init(cn *conn, password string, db int64) error { return nil } +func (c *baseClient) freeConn(cn *conn, err error) { + if err == Nil { + c.putConn(cn) + } else { + c.removeConn(cn) + } +} + func (c *baseClient) removeConn(cn *conn) { if err := c.connPool.Remove(cn); err != nil { Logger.Printf("connPool.Remove error: %v", err) @@ -91,7 +97,7 @@ func (c *baseClient) Process(req Req) { if c.reqs == nil { c.run(req) } else { - c.queue(req) + c.reqs = append(c.reqs, req) } } @@ -120,11 +126,7 @@ func (c *baseClient) run(req Req) { val, err := req.ParseReply(cn.Rd) if err != nil { - if err == Nil { - c.putConn(cn) - } else { - c.removeConn(cn) - } + c.freeConn(cn, err) req.SetErr(err) return } @@ -133,13 +135,6 @@ func (c *baseClient) run(req Req) { req.SetVal(val) } -// Queues request to be executed later. -func (c *baseClient) queue(req Req) { - c.reqsMtx.Lock() - c.reqs = append(c.reqs, req) - c.reqsMtx.Unlock() -} - func (c *baseClient) Close() error { return c.connPool.Close() } diff --git a/v2/redis_test.go b/v2/redis_test.go index 88a2f78..d70e70d 100644 --- a/v2/redis_test.go +++ b/v2/redis_test.go @@ -2393,8 +2393,7 @@ func (t *RedisTest) TestPipeline(c *C) { c.Assert(set.Err(), IsNil) c.Assert(set.Val(), Equals, "OK") - pipeline, err := t.client.PipelineClient() - c.Assert(err, IsNil) + pipeline := t.client.Pipeline() defer func() { c.Assert(pipeline.Close(), IsNil) }() @@ -2404,8 +2403,8 @@ func (t *RedisTest) TestPipeline(c *C) { incr := pipeline.Incr("key3") getNil := pipeline.Get("key4") - reqs, err := pipeline.RunQueued() - c.Assert(err, IsNil) + reqs, err := pipeline.Exec() + c.Assert(err, Equals, redis.Nil) c.Assert(reqs, HasLen, 4) c.Assert(set.Err(), IsNil) @@ -2422,33 +2421,31 @@ func (t *RedisTest) TestPipeline(c *C) { } func (t *RedisTest) TestPipelineDiscardQueued(c *C) { - pipeline, err := t.client.PipelineClient() - c.Assert(err, IsNil) + pipeline := t.client.Pipeline() defer func() { c.Assert(pipeline.Close(), IsNil) }() pipeline.Get("key") - pipeline.DiscardQueued() - reqs, err := pipeline.RunQueued() + pipeline.Discard() + reqs, err := pipeline.Exec() c.Assert(err, IsNil) c.Assert(reqs, HasLen, 0) } func (t *RedisTest) TestPipelineFunc(c *C) { var get *redis.StringReq - reqs, err := t.client.Pipelined(func(c *redis.PipelineClient) { + reqs, err := t.client.Pipelined(func(c *redis.Pipeline) { get = c.Get("foo") }) - c.Assert(err, IsNil) + c.Assert(err, Equals, redis.Nil) c.Assert(reqs, HasLen, 1) c.Assert(get.Err(), Equals, redis.Nil) c.Assert(get.Val(), Equals, "") } func (t *RedisTest) TestPipelineErrValNotSet(c *C) { - pipeline, err := t.client.PipelineClient() - c.Assert(err, IsNil) + pipeline := t.client.Pipeline() defer func() { c.Assert(pipeline.Close(), IsNil) }() @@ -2458,20 +2455,18 @@ func (t *RedisTest) TestPipelineErrValNotSet(c *C) { } func (t *RedisTest) TestPipelineRunQueuedOnEmptyQueue(c *C) { - pipeline, err := t.client.PipelineClient() - c.Assert(err, IsNil) + pipeline := t.client.Pipeline() defer func() { c.Assert(pipeline.Close(), IsNil) }() - reqs, err := pipeline.RunQueued() + reqs, err := pipeline.Exec() c.Assert(err, IsNil) c.Assert(reqs, HasLen, 0) } func (t *RedisTest) TestPipelineIncrFromGoroutines(c *C) { - pipeline, err := t.client.PipelineClient() - c.Assert(err, IsNil) + pipeline := t.client.Pipeline() defer func() { c.Assert(pipeline.Close(), IsNil) }() @@ -2486,7 +2481,7 @@ func (t *RedisTest) TestPipelineIncrFromGoroutines(c *C) { } wg.Wait() - reqs, err := pipeline.RunQueued() + reqs, err := pipeline.Exec() c.Assert(err, IsNil) c.Assert(reqs, HasLen, 20000) for _, req := range reqs { @@ -2501,8 +2496,7 @@ func (t *RedisTest) TestPipelineIncrFromGoroutines(c *C) { } func (t *RedisTest) TestPipelineEchoFromGoroutines(c *C) { - pipeline, err := t.client.PipelineClient() - c.Assert(err, IsNil) + pipeline := t.client.Pipeline() defer func() { c.Assert(pipeline.Close(), IsNil) }() @@ -2517,7 +2511,7 @@ func (t *RedisTest) TestPipelineEchoFromGoroutines(c *C) { echo1 := pipeline.Echo(msg1) echo2 := pipeline.Echo(msg2) - reqs, err := pipeline.RunQueued() + reqs, err := pipeline.Exec() c.Assert(err, IsNil) c.Assert(reqs, HasLen, 2) @@ -2536,8 +2530,7 @@ func (t *RedisTest) TestPipelineEchoFromGoroutines(c *C) { //------------------------------------------------------------------------------ func (t *RedisTest) TestMultiExec(c *C) { - multi, err := t.client.MultiClient() - c.Assert(err, IsNil) + multi := t.client.Multi() defer func() { c.Assert(multi.Close(), IsNil) }() @@ -2561,8 +2554,7 @@ func (t *RedisTest) TestMultiExec(c *C) { } func (t *RedisTest) TestMultiExecDiscard(c *C) { - multi, err := t.client.MultiClient() - c.Assert(err, IsNil) + multi := t.client.Multi() defer func() { c.Assert(multi.Close(), IsNil) }() @@ -2585,8 +2577,7 @@ func (t *RedisTest) TestMultiExecDiscard(c *C) { } func (t *RedisTest) TestMultiExecEmpty(c *C) { - multi, err := t.client.MultiClient() - c.Assert(err, IsNil) + multi := t.client.Multi() defer func() { c.Assert(multi.Close(), IsNil) }() @@ -2601,8 +2592,7 @@ func (t *RedisTest) TestMultiExecEmpty(c *C) { } func (t *RedisTest) TestMultiExecOnEmptyQueue(c *C) { - multi, err := t.client.MultiClient() - c.Assert(err, IsNil) + multi := t.client.Multi() defer func() { c.Assert(multi.Close(), IsNil) }() @@ -2612,16 +2602,15 @@ func (t *RedisTest) TestMultiExecOnEmptyQueue(c *C) { c.Assert(reqs, HasLen, 0) } -func (t *RedisTest) TestMultiExecIncrTransaction(c *C) { - multi, err := t.client.MultiClient() - c.Assert(err, IsNil) +func (t *RedisTest) TestMultiExecIncr(c *C) { + multi := t.client.Multi() defer func() { c.Assert(multi.Close(), IsNil) }() reqs, err := multi.Exec(func() { for i := int64(0); i < 20000; i++ { - multi.Incr("TestIncrTransactionKey") + multi.Incr("key") } }) c.Assert(err, IsNil) @@ -2632,14 +2621,13 @@ func (t *RedisTest) TestMultiExecIncrTransaction(c *C) { } } - get := t.client.Get("TestIncrTransactionKey") + get := t.client.Get("key") c.Assert(get.Err(), IsNil) c.Assert(get.Val(), Equals, "20000") } func (t *RedisTest) transactionalIncr(c *C) ([]redis.Req, error) { - multi, err := t.client.MultiClient() - c.Assert(err, IsNil) + multi := t.client.Multi() defer func() { c.Assert(multi.Close(), IsNil) }()