diff --git a/connpool.go b/connpool.go index 4a28808..d4bdc07 100644 --- a/connpool.go +++ b/connpool.go @@ -10,14 +10,16 @@ import ( ) type Conn struct { - RW io.ReadWriteCloser - Rd *bufio.Reader + RW io.ReadWriteCloser + Rd *bufio.Reader + ReqBuf []byte } func NewConn(rw io.ReadWriteCloser) *Conn { return &Conn{ - RW: rw, - Rd: bufio.NewReaderSize(rw, 1024), + RW: rw, + Rd: bufio.NewReaderSize(rw, 1024), + ReqBuf: make([]byte, 0, 1024), } } diff --git a/multi.go b/multi.go index ab57f3d..8e28b8a 100644 --- a/multi.go +++ b/multi.go @@ -35,21 +35,20 @@ func (c *MultiClient) Unwatch(keys ...string) *StatusReq { func (c *MultiClient) Discard() { c.mtx.Lock() - c.reqs = c.reqs[:0] + c.reqs = []Req{NewStatusReq("MULTI")} c.mtx.Unlock() } func (c *MultiClient) Exec(do func()) ([]Req, error) { - c.mtx.Lock() - c.reqs = make([]Req, 0) - c.mtx.Unlock() + c.Discard() do() c.mtx.Lock() - if len(c.reqs) == 0 { + c.reqs = append(c.reqs, NewMultiBulkReq("EXEC")) + if len(c.reqs) == 2 { c.mtx.Unlock() - return c.reqs, nil + return []Req{}, nil } reqs := c.reqs c.reqs = nil @@ -67,18 +66,11 @@ func (c *MultiClient) Exec(do func()) ([]Req, error) { } c.ConnPool.Add(conn) - return reqs, nil + return reqs[1 : len(reqs)-1], nil } func (c *MultiClient) ExecReqs(reqs []Req, conn *Conn) error { - multiReq := make([]byte, 0, 1024) - multiReq = append(multiReq, PackReq([]string{"MULTI"})...) - for _, req := range reqs { - multiReq = append(multiReq, req.Req()...) - } - multiReq = append(multiReq, PackReq([]string{"EXEC"})...) - - err := c.WriteReq(multiReq, conn) + err := c.WriteReq(conn, reqs...) if err != nil { return err } @@ -92,7 +84,7 @@ func (c *MultiClient) ExecReqs(reqs []Req, conn *Conn) error { } // Parse queued replies. - for _ = range reqs { + for i := 1; i < len(reqs)-1; i++ { _, err = statusReq.ParseReply(conn.Rd) if err != nil { return err @@ -112,7 +104,7 @@ func (c *MultiClient) ExecReqs(reqs []Req, conn *Conn) error { } // Parse replies. - for i := 0; i < len(reqs); i++ { + for i := 1; i < len(reqs)-1; i++ { req := reqs[i] val, err := req.ParseReply(conn.Rd) if err != nil { diff --git a/parser.go b/parser.go index dd9a2d2..b65512a 100644 --- a/parser.go +++ b/parser.go @@ -21,8 +21,7 @@ var ( //------------------------------------------------------------------------------ -func PackReq(args []string) []byte { - buf := make([]byte, 0, 1024) +func AppendReq(buf []byte, args []string) []byte { buf = append(buf, '*') buf = strconv.AppendUint(buf, uint64(len(args)), 10) buf = append(buf, '\r', '\n') diff --git a/pipeline.go b/pipeline.go index 89ed0f2..9271bb7 100644 --- a/pipeline.go +++ b/pipeline.go @@ -46,17 +46,7 @@ func (c *PipelineClient) RunQueued() ([]Req, error) { } func (c *PipelineClient) RunReqs(reqs []Req, conn *Conn) error { - var multiReq []byte - if len(reqs) == 1 { - multiReq = reqs[0].Req() - } else { - multiReq = make([]byte, 0, 1024) - for _, req := range reqs { - multiReq = append(multiReq, req.Req()...) - } - } - - err := c.WriteReq(multiReq, conn) + err := c.WriteReq(conn, reqs...) if err != nil { return err } diff --git a/pubsub.go b/pubsub.go index 11d2d5f..884f3f4 100644 --- a/pubsub.go +++ b/pubsub.go @@ -89,7 +89,7 @@ func (c *PubSubClient) subscribe(cmd string, channels ...string) (chan *Message, return nil, err } - if err := c.WriteReq(req.Req(), conn); err != nil { + if err := c.WriteReq(conn, req); err != nil { return nil, err } @@ -117,7 +117,7 @@ func (c *PubSubClient) unsubscribe(cmd string, channels ...string) error { return err } - return c.WriteReq(req.Req(), conn) + return c.WriteReq(conn, req) } func (c *PubSubClient) Unsubscribe(channels ...string) error { diff --git a/redis.go b/redis.go index 3643fa5..85ad0f0 100644 --- a/redis.go +++ b/redis.go @@ -56,8 +56,13 @@ type BaseClient struct { reqs []Req } -func (c *BaseClient) WriteReq(buf []byte, conn *Conn) error { - _, err := conn.RW.Write(buf) +func (c *BaseClient) WriteReq(conn *Conn, reqs ...Req) error { + conn.ReqBuf = conn.ReqBuf[:0] + for _, req := range reqs { + conn.ReqBuf = AppendReq(conn.ReqBuf, req.Args()) + } + + _, err := conn.RW.Write(conn.ReqBuf) return err } @@ -96,7 +101,7 @@ func (c *BaseClient) Run(req Req) { return } - err = c.WriteReq(req.Req(), conn) + err = c.WriteReq(conn, req) if err != nil { c.ConnPool.Remove(conn) req.SetErr(err) diff --git a/redis_test.go b/redis_test.go index 80b59a0..cfacef1 100644 --- a/redis_test.go +++ b/redis_test.go @@ -2412,7 +2412,7 @@ func (t *RedisTest) BenchmarkRedisWriteRead(c *C) { c.Assert(err, IsNil) for i := 0; i < 10; i++ { - err := t.client.WriteReq([]byte("PING\r\n"), conn) + err := t.client.WriteReq(conn, redis.NewStatusReq("PING")) c.Assert(err, IsNil) line, _, err := conn.Rd.ReadLine() @@ -2423,7 +2423,7 @@ func (t *RedisTest) BenchmarkRedisWriteRead(c *C) { c.StartTimer() for i := 0; i < c.N; i++ { - t.client.WriteReq([]byte("PING\r\n"), conn) + t.client.WriteReq(conn, redis.NewStatusReq("PING")) conn.Rd.ReadLine() } diff --git a/request.go b/request.go index 43ec169..74fa50c 100644 --- a/request.go +++ b/request.go @@ -5,7 +5,7 @@ import ( ) type Req interface { - Req() []byte + Args() []string ParseReply(ReadLiner) (interface{}, error) SetErr(error) Err() error @@ -28,8 +28,8 @@ func NewBaseReq(args ...string) *BaseReq { } } -func (r *BaseReq) Req() []byte { - return PackReq(r.args) +func (r *BaseReq) Args() []string { + return r.args } func (r *BaseReq) SetErr(err error) {