diff --git a/README.md b/README.md index f1edbea..ec3bc1d 100644 --- a/README.md +++ b/README.md @@ -131,6 +131,22 @@ Multi/Exec Example: + func transaction(multi *redis.MultiClient) ([]redis.Req, error) { + get := multiClient.Get("foo") + if get.Err() != nil { + panic(get.Err()) + } + + reqs, err = multiClient.Exec(func() { + multi.Set("foo", get.Val() + "1") + }) + if err == redis.Nil { + return transaction() + } + + return reqs, err + } + multiClient, err := redisClient.MultiClient() if err != nil { panic(err) @@ -142,29 +158,14 @@ Example: panic(watch.Err()) } - get := multiClient.Get("foo") - if get.Err() != nil { - panic(get.Err()) - } - - // Start transaction. - multiClient.Multi() - - set := multiClient.Set("foo", get.Val() + "1") - - // Commit transaction. - reqs, err := multiClient.Exec() - if err == redis.Nil { - // Repeat transaction. - } else if err != nil { + reqs, err := transaction(multiClient) + if err != nil { panic(err) } for _, req := range reqs { // ... } - ok := set.Val() - Pub/sub ------- diff --git a/multi.go b/multi.go index 4bec71a..ab57f3d 100644 --- a/multi.go +++ b/multi.go @@ -19,10 +19,6 @@ func (c *Client) MultiClient() (*MultiClient, error) { }, nil } -func (c *MultiClient) Multi() { - c.reqs = make([]Req, 0) -} - func (c *MultiClient) Watch(keys ...string) *StatusReq { args := append([]string{"WATCH"}, keys...) req := NewStatusReq(args...) @@ -43,7 +39,13 @@ func (c *MultiClient) Discard() { c.mtx.Unlock() } -func (c *MultiClient) Exec() ([]Req, error) { +func (c *MultiClient) Exec(do func()) ([]Req, error) { + c.mtx.Lock() + c.reqs = make([]Req, 0) + c.mtx.Unlock() + + do() + c.mtx.Lock() if len(c.reqs) == 0 { c.mtx.Unlock() diff --git a/parser.go b/parser.go index acf4c95..dd9a2d2 100644 --- a/parser.go +++ b/parser.go @@ -40,7 +40,7 @@ func PackReq(args []string) []byte { type ReadLiner interface { ReadLine() ([]byte, bool, error) - Peek(n int) ([]byte, error) + Read([]byte) (int, error) ReadN(n int) ([]byte, error) } @@ -121,13 +121,14 @@ func ParseReply(rd ReadLiner) (interface{}, error) { line, err = rd.ReadN(replyLen) if err == bufio.ErrBufferFull { buf := make([]byte, replyLen) - r := 0 + r := copy(buf, line) - r += copy(buf, line) - - for err == bufio.ErrBufferFull { - line, err = rd.ReadN(replyLen - r) - r += copy(buf[r:], line) + for r < replyLen { + n, err := rd.Read(buf[r:]) + if err != nil { + return "", err + } + r += n } line = buf diff --git a/redis_test.go b/redis_test.go index d590ade..80b59a0 100644 --- a/redis_test.go +++ b/redis_test.go @@ -4,6 +4,7 @@ import ( "bytes" "io" "net" + "runtime" "strconv" "sync" "testing" @@ -41,6 +42,7 @@ func (t *RedisTest) SetUpTest(c *C) { return nil } t.client = redis.NewClient(openConn, closeConn, nil) + t.client.ConnPool.(*redis.MultiConnPool).MaxCap = 10 c.Assert(t.client.FlushDb().Err(), IsNil) } @@ -164,10 +166,10 @@ func (t *RedisTest) TestConnPoolMaxCapOnMultiClient(c *C) { multi, err := t.client.MultiClient() c.Assert(err, IsNil) - multi.Multi() - - ping := multi.Ping() - reqs, err := multi.Exec() + var ping *redis.StatusReq + reqs, err := multi.Exec(func() { + ping = multi.Ping() + }) c.Assert(err, IsNil) c.Assert(reqs, HasLen, 1) c.Assert(ping.Err(), IsNil) @@ -2025,20 +2027,43 @@ func (t *RedisTest) TestPipelineEchoFromGoroutines(c *C) { //------------------------------------------------------------------------------ -func (t *RedisTest) TestMultiDiscard(c *C) { +func (t *RedisTest) TestMultiExec(c *C) { multi, err := t.client.MultiClient() c.Assert(err, IsNil) defer func() { c.Assert(multi.Close(), IsNil) }() - multi.Multi() + var ( + set *redis.StatusReq + get *redis.BulkReq + ) + reqs, err := multi.Exec(func() { + set = multi.Set("foo", "bar") + get = multi.Get("foo") + }) + c.Assert(err, IsNil) + c.Assert(reqs, HasLen, 2) - multi.Set("foo1", "bar1") - multi.Discard() - multi.Set("foo2", "bar2") + c.Assert(set.Err(), IsNil) + c.Assert(set.Val(), Equals, "OK") - reqs, err := multi.Exec() + c.Assert(get.Err(), IsNil) + c.Assert(get.Val(), Equals, "bar") +} + +func (t *RedisTest) TestMultiExecDiscard(c *C) { + multi, err := t.client.MultiClient() + c.Assert(err, IsNil) + defer func() { + c.Assert(multi.Close(), IsNil) + }() + + reqs, err := multi.Exec(func() { + multi.Set("foo1", "bar1") + multi.Discard() + multi.Set("foo2", "bar2") + }) c.Assert(err, IsNil) c.Assert(reqs, HasLen, 1) @@ -2051,29 +2076,6 @@ func (t *RedisTest) TestMultiDiscard(c *C) { c.Assert(get.Val(), Equals, "bar2") } -func (t *RedisTest) TestMultiExec(c *C) { - multi, err := t.client.MultiClient() - c.Assert(err, IsNil) - defer func() { - c.Assert(multi.Close(), IsNil) - }() - - multi.Multi() - - setR := multi.Set("foo", "bar") - getR := multi.Get("foo") - - reqs, err := multi.Exec() - c.Assert(err, IsNil) - c.Assert(reqs, HasLen, 2) - - c.Assert(setR.Err(), IsNil) - c.Assert(setR.Val(), Equals, "OK") - - c.Assert(getR.Err(), IsNil) - c.Assert(getR.Val(), Equals, "bar") -} - func (t *RedisTest) TestMultiExecOnEmptyQueue(c *C) { multi, err := t.client.MultiClient() c.Assert(err, IsNil) @@ -2081,11 +2083,86 @@ func (t *RedisTest) TestMultiExecOnEmptyQueue(c *C) { c.Assert(multi.Close(), IsNil) }() - reqs, err := multi.Exec() + reqs, err := multi.Exec(func() {}) c.Assert(err, IsNil) c.Assert(reqs, HasLen, 0) } +func (t *RedisTest) TestMultiExecIncrTransaction(c *C) { + multi, err := t.client.MultiClient() + c.Assert(err, IsNil) + defer func() { + c.Assert(multi.Close(), IsNil) + }() + + reqs, err := multi.Exec(func() { + for i := int64(0); i < 20000; i++ { + multi.Incr("TestIncrTransactionKey") + } + }) + c.Assert(err, IsNil) + c.Assert(reqs, HasLen, 20000) + for _, req := range reqs { + if req.Err() != nil { + c.Errorf("got %v, expected nil", req.Err()) + } + } + + get := t.client.Get("TestIncrTransactionKey") + c.Assert(get.Err(), IsNil) + c.Assert(get.Val(), Equals, "20000") +} + +func (t *RedisTest) transactionalIncr(c *C) ([]redis.Req, error) { + multi, err := t.client.MultiClient() + c.Assert(err, IsNil) + defer func() { + c.Assert(multi.Close(), IsNil) + }() + + watch := multi.Watch("foo") + c.Assert(watch.Err(), IsNil) + c.Assert(watch.Val(), Equals, "OK") + + get := multi.Get("foo") + c.Assert(get.Err(), IsNil) + c.Assert(get.Val(), Not(Equals), redis.Nil) + + v, err := strconv.ParseInt(get.Val(), 10, 64) + c.Assert(err, IsNil) + + reqs, err := multi.Exec(func() { + multi.Set("foo", strconv.FormatInt(v+1, 10)) + }) + if err == redis.Nil { + return t.transactionalIncr(c) + } + return reqs, err +} + +func (t *RedisTest) TestWatchUnwatch(c *C) { + set := t.client.Set("foo", "0") + c.Assert(set.Err(), IsNil) + c.Assert(set.Val(), Equals, "OK") + + wg := &sync.WaitGroup{} + for i := 0; i < 1000; i++ { + wg.Add(1) + go func() { + reqs, err := t.transactionalIncr(c) + c.Assert(reqs, HasLen, 1) + c.Assert(err, IsNil) + c.Assert(reqs[0].Err(), IsNil) + wg.Done() + }() + } + wg.Wait() + + get := t.client.Get("foo") + c.Assert(get.Err(), IsNil) + c.Assert(get.Val(), Equals, "1000") +} + //------------------------------------------------------------------------------ func (t *RedisTest) TestSyncEchoFromGoroutines(c *C) { @@ -2120,90 +2197,6 @@ func (t *RedisTest) TestIncrFromGoroutines(c *C) { c.Assert(get.Val(), Equals, "20000") } -func (t *RedisTest) TestIncrTransaction(c *C) { - multi, err := t.client.MultiClient() - c.Assert(err, IsNil) - defer func() { - c.Assert(multi.Close(), IsNil) - }() - - multi.Multi() - - wg := &sync.WaitGroup{} - for i := int64(0); i < 20000; i++ { - wg.Add(1) - go func() { - multi.Incr("TestIncrTransactionKey") - wg.Done() - }() - } - wg.Wait() - - reqs, err := multi.Exec() - c.Assert(err, IsNil) - c.Assert(reqs, HasLen, 20000) - for _, req := range reqs { - if req.Err() != nil { - c.Errorf("got %v, expected nil", req.Err()) - } - } - - get := t.client.Get("TestIncrTransactionKey") - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, "20000") -} - -func (t *RedisTest) transactionalIncr(c *C, wg *sync.WaitGroup) { - multi, err := t.client.MultiClient() - c.Assert(err, IsNil) - defer func() { - c.Assert(multi.Close(), IsNil) - }() - - watch := multi.Watch("foo") - c.Assert(watch.Err(), IsNil) - c.Assert(watch.Val(), Equals, "OK") - - get := multi.Get("foo") - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Not(Equals), redis.Nil) - - v, err := strconv.ParseInt(get.Val(), 10, 64) - c.Assert(err, IsNil) - - multi.Multi() - set := multi.Set("foo", strconv.FormatInt(v+1, 10)) - reqs, err := multi.Exec() - if err == redis.Nil { - t.transactionalIncr(c, wg) - return - } - c.Assert(reqs, HasLen, 1) - c.Assert(err, IsNil) - - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - wg.Done() -} - -func (t *RedisTest) TestWatchUnwatch(c *C) { - set := t.client.Set("foo", "0") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - wg := &sync.WaitGroup{} - for i := 0; i < 1000; i++ { - wg.Add(1) - go t.transactionalIncr(c, wg) - } - wg.Wait() - - get := t.client.Get("foo") - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, "1000") -} - //------------------------------------------------------------------------------ func (t *RedisTest) TestCmdBgRewriteAOF(c *C) { @@ -2312,6 +2305,9 @@ func (t *RedisTest) TestTime(c *C) { func (t *RedisTest) BenchmarkRedisPing(c *C) { c.StopTimer() + runtime.LockOSThread() + t.client.ConnPool.(*redis.MultiConnPool).MaxCap = 1 + for i := 0; i < 10; i++ { ping := t.client.Ping() c.Assert(ping.Err(), IsNil) @@ -2328,6 +2324,9 @@ func (t *RedisTest) BenchmarkRedisPing(c *C) { func (t *RedisTest) BenchmarkRedisSet(c *C) { c.StopTimer() + runtime.LockOSThread() + t.client.ConnPool.(*redis.MultiConnPool).MaxCap = 1 + for i := 0; i < 10; i++ { set := t.client.Set("foo", "bar") c.Assert(set.Err(), IsNil) @@ -2344,6 +2343,9 @@ func (t *RedisTest) BenchmarkRedisSet(c *C) { func (t *RedisTest) BenchmarkRedisGetNil(c *C) { c.StopTimer() + runtime.LockOSThread() + t.client.ConnPool.(*redis.MultiConnPool).MaxCap = 1 + for i := 0; i < 10; i++ { get := t.client.Get("foo") c.Assert(get.Err(), Equals, redis.Nil) @@ -2360,6 +2362,9 @@ func (t *RedisTest) BenchmarkRedisGetNil(c *C) { func (t *RedisTest) BenchmarkRedisGet(c *C) { c.StopTimer() + runtime.LockOSThread() + t.client.ConnPool.(*redis.MultiConnPool).MaxCap = 1 + set := t.client.Set("foo", "bar") c.Assert(set.Err(), IsNil) @@ -2379,6 +2384,9 @@ func (t *RedisTest) BenchmarkRedisGet(c *C) { func (t *RedisTest) BenchmarkRedisMGet(c *C) { c.StopTimer() + runtime.LockOSThread() + t.client.ConnPool.(*redis.MultiConnPool).MaxCap = 1 + mSet := t.client.MSet("foo1", "bar1", "foo2", "bar2") c.Assert(mSet.Err(), IsNil) @@ -2398,6 +2406,8 @@ func (t *RedisTest) BenchmarkRedisMGet(c *C) { func (t *RedisTest) BenchmarkRedisWriteRead(c *C) { c.StopTimer() + runtime.LockOSThread() + conn, _, err := t.client.ConnPool.Get() c.Assert(err, IsNil)