diff --git a/redis.go b/redis.go index 3d8f0c5..ebf1a27 100644 --- a/redis.go +++ b/redis.go @@ -6,6 +6,7 @@ import ( "io" "net" "sync" + "time" "github.com/vmihailenco/bufreader" ) @@ -81,6 +82,13 @@ func (c *Client) WriteReq(buf []byte, conn *Conn) error { } func (c *Client) ReadReply(conn *Conn) error { + if false { + err := conn.RW.(*net.TCPConn).SetReadDeadline(time.Now().Add(time.Second)) + if err != nil { + return err + } + } + _, err := conn.Rd.ReadFrom(conn.RW) if err != nil { return err @@ -114,7 +122,6 @@ func (c *Client) Queue(req Req) { func (c *Client) Run(req Req) { conn, _, err := c.ConnPool.Get() if err != nil { - c.ConnPool.Remove(conn) req.SetErr(err) return } @@ -128,7 +135,7 @@ func (c *Client) Run(req Req) { val, err := req.ParseReply(conn.Rd) if err != nil { - c.ConnPool.Remove(conn) + c.ConnPool.Add(conn) req.SetErr(err) return } @@ -147,33 +154,47 @@ func (c *Client) RunQueued() ([]Req, error) { c.reqs = make([]Req, 0) c.mtx.Unlock() - return c.RunReqs(reqs) + conn, _, err := c.ConnPool.Get() + if err != nil { + return nil, err + } + + err = c.RunReqs(reqs, conn) + if err != nil { + c.ConnPool.Remove(conn) + return nil, err + } + + // c.ConnPool.Add(conn) + return reqs, nil } -func (c *Client) RunReqs(reqs []Req) ([]Req, error) { +func (c *Client) RunReqs(reqs []Req, conn *Conn) error { var multiReq []byte if len(reqs) == 1 { multiReq = reqs[0].Req() } else { + // TODO: split req to chunks multiReq = make([]byte, 0, 1024) for _, req := range reqs { multiReq = append(multiReq, req.Req()...) } } - conn, _, err := c.ConnPool.Get() + err := c.WriteRead(multiReq, conn) if err != nil { - return nil, err - } - - err = c.WriteRead(multiReq, conn) - if err != nil { - return nil, err + return err } for i := 0; i < len(reqs); i++ { - req := reqs[i] + if !conn.Rd.HasUnread() { + _, err := conn.Rd.ReadFrom(conn.RW) + if err != err { + return err + } + } + req := reqs[i] val, err := req.ParseReply(conn.Rd) if err != nil { req.SetErr(err) @@ -182,7 +203,7 @@ func (c *Client) RunReqs(reqs []Req) ([]Req, error) { } } - return reqs, nil + return nil } //------------------------------------------------------------------------------ @@ -203,6 +224,23 @@ func (c *Client) Exec() ([]Req, error) { c.reqs = make([]Req, 0) c.mtx.Unlock() + conn, _, err := c.ConnPool.Get() + if err != nil { + return nil, err + } + + err = c.ExecReqs(reqs, conn) + if err != nil { + c.ConnPool.Remove(conn) + return nil, err + } + + c.ConnPool.Add(conn) + return reqs, nil +} + +func (c *Client) ExecReqs(reqs []Req, conn *Conn) error { + // TODO: split req to chunks multiReq := make([]byte, 0, 1024) multiReq = append(multiReq, PackReq([]string{"MULTI"})...) for _, req := range reqs { @@ -210,14 +248,9 @@ func (c *Client) Exec() ([]Req, error) { } multiReq = append(multiReq, PackReq([]string{"EXEC"})...) - conn, _, err := c.ConnPool.Get() + err := c.WriteRead(multiReq, conn) if err != nil { - return nil, err - } - - err = c.WriteRead(multiReq, conn) - if err != nil { - return nil, err + return err } statusReq := NewStatusReq() @@ -225,24 +258,31 @@ func (c *Client) Exec() ([]Req, error) { // Parse MULTI command reply. _, err = statusReq.ParseReply(conn.Rd) if err != nil { - return nil, err + return err } // Parse queued replies. for _ = range reqs { + if !conn.Rd.HasUnread() { + _, err := conn.Rd.ReadFrom(conn.RW) + if err != err { + return err + } + } + _, err = statusReq.ParseReply(conn.Rd) if err != nil { - return nil, err + return err } } // Parse number of replies. line, err := conn.Rd.ReadLine('\n') if err != nil { - return nil, err + return err } if line[0] != '*' { - return nil, fmt.Errorf("Expected '*', but got line %q of %q.", line, conn.Rd.Bytes()) + return fmt.Errorf("Expected '*', but got line %q of %q.", line, conn.Rd.Bytes()) } // Parse replies. @@ -256,5 +296,5 @@ func (c *Client) Exec() ([]Req, error) { } } - return reqs, nil + return nil } diff --git a/redis_test.go b/redis_test.go index 495526c..9695272 100644 --- a/redis_test.go +++ b/redis_test.go @@ -1539,63 +1539,63 @@ func (t *RedisTest) TestPipeliningFromGoroutines(c *C) { func (t *RedisTest) TestIncrFromGoroutines(c *C) { wg := &sync.WaitGroup{} - for i := int64(0); i < 1000; i++ { + for i := int64(0); i < 20000; i++ { wg.Add(1) go func() { - _, err := t.client.Incr("key").Reply() + _, err := t.client.Incr("TestIncrFromGoroutinesKey").Reply() c.Check(err, IsNil) wg.Done() }() } wg.Wait() - n, err := t.client.Get("key").Reply() + n, err := t.client.Get("TestIncrFromGoroutinesKey").Reply() c.Check(err, IsNil) - c.Check(n, Equals, "1000") + c.Check(n, Equals, "20000") } func (t *RedisTest) TestIncrPipeliningFromGoroutines(c *C) { - c.Skip("conn pool required") + multiClient := t.client.Multi() wg := &sync.WaitGroup{} - for i := int64(0); i < 10000; i++ { + for i := int64(0); i < 20000; i++ { wg.Add(1) go func() { - t.client.Incr("key") + multiClient.Incr("TestIncrPipeliningFromGoroutinesKey") wg.Done() }() } wg.Wait() - reqs, err := t.client.RunQueued() + reqs, err := multiClient.RunQueued() c.Check(err, IsNil) - c.Check(reqs, HasLen, 10000) + c.Check(reqs, HasLen, 20000) - n, err := t.client.Get("key").Reply() + n, err := t.client.Get("TestIncrPipeliningFromGoroutinesKey").Reply() c.Check(err, IsNil) - c.Check(n, Equals, "10000") + c.Check(n, Equals, "20000") } func (t *RedisTest) TestIncrTransaction(c *C) { - c.Skip("conn pool required") + multiClient := t.client.Multi() wg := &sync.WaitGroup{} - for i := int64(0); i < 10000; i++ { + for i := int64(0); i < 20000; i++ { wg.Add(1) go func() { - t.client.Incr("key") + multiClient.Incr("TestIncrTransactionKey") wg.Done() }() } wg.Wait() - reqs, err := t.client.Exec() + reqs, err := multiClient.Exec() c.Check(err, IsNil) - c.Check(reqs, HasLen, 10000) + c.Check(reqs, HasLen, 20000) - n, err := t.client.Get("key").Reply() + n, err := t.client.Get("TestIncrTransactionKey").Reply() c.Check(err, IsNil) - c.Check(n, Equals, "10000") + c.Check(n, Equals, "20000") } //------------------------------------------------------------------------------ @@ -1632,7 +1632,7 @@ func (t *RedisTest) BenchmarkRedisSet(c *C) { } } -func (t *RedisTest) BenchmarkRedisGet(c *C) { +func (t *RedisTest) BenchmarkRedisGetNil(c *C) { c.StopTimer() for i := 0; i < 10; i++ { @@ -1648,6 +1648,25 @@ func (t *RedisTest) BenchmarkRedisGet(c *C) { } } +func (t *RedisTest) BenchmarkRedisGet(c *C) { + c.StopTimer() + + _, err := t.client.Set("foo", "bar").Reply() + c.Check(err, IsNil) + + for i := 0; i < 10; i++ { + v, err := t.client.Get("foo").Reply() + c.Check(err, IsNil) + c.Check(v, Equals, "bar") + } + + c.StartTimer() + + for i := 0; i < c.N; i++ { + t.client.Get("foo").Reply() + } +} + func (t *RedisTest) BenchmarkRedisWriteRead(c *C) { c.StopTimer() @@ -1666,4 +1685,8 @@ func (t *RedisTest) BenchmarkRedisWriteRead(c *C) { for i := 0; i < c.N; i++ { t.client.WriteRead(req, conn) } + + c.StopTimer() + t.client.ConnPool.Add(conn) + c.StartTimer() } diff --git a/request.go b/request.go index ca9b5e5..5224611 100644 --- a/request.go +++ b/request.go @@ -15,12 +15,16 @@ var errResultMissing = errors.New("Request was not run properly.") //------------------------------------------------------------------------------ -func isNil(buf []byte) bool { - return len(buf) == 3 && buf[0] == '$' && buf[1] == '-' && buf[2] == '1' +func isNil(line []byte) bool { + return len(line) == 3 && line[0] == '$' && line[1] == '-' && line[2] == '1' } -func isEmpty(buf []byte) bool { - return len(buf) == 2 && buf[0] == '$' && buf[1] == '0' +func isEmpty(line []byte) bool { + return len(line) == 2 && line[0] == '$' && line[1] == '0' +} + +func isNoReplies(line []byte) bool { + return len(line) >= 2 && line[1] == '*' && line[1] == '0' } //------------------------------------------------------------------------------ @@ -384,21 +388,25 @@ func (r *MultiBulkReq) ParseReply(rd *bufreader.Reader) (interface{}, error) { return nil, errors.New(string(line[1:])) } else if line[0] != '*' { return nil, fmt.Errorf("Expected '*', but got line %q of %q.", line, rd.Bytes()) - } - - val := make([]interface{}, 0) - - if len(line) >= 2 && line[1] == '0' { - return val, nil } else if isNil(line) { return nil, Nil } - line, err = rd.ReadLine('\n') + val := make([]interface{}, 0) + if isNoReplies(line) { + return val, nil + } + numReplies, err := strconv.ParseInt(string(line[1:]), 10, 64) if err != nil { return nil, err } - for { + + for i := int64(0); i < numReplies; i++ { + line, err = rd.ReadLine('\n') + if err != nil { + return nil, err + } + if line[0] == ':' { var n int64 n, err = strconv.ParseInt(string(line[1:]), 10, 64) @@ -421,20 +429,6 @@ func (r *MultiBulkReq) ParseReply(rd *bufreader.Reader) (interface{}, error) { } else { return nil, fmt.Errorf("Expected '$', but got line %q of %q.", line, rd.Bytes()) } - - line, err = rd.ReadLine('\n') - if err != nil { - if err == io.EOF { - break - } - return nil, err - } - - // Check for the header of another reply. - if line[0] == '*' { - rd.UnreadLine('\n') - break - } } return val, nil