diff --git a/README.md b/README.md index 86821ace..5b7f48d9 100644 --- a/README.md +++ b/README.md @@ -16,9 +16,8 @@ Supports: Installation ------------ -Run: +Install: - go get github.com/vmihailenco/bufreader go get github.com/vmihailenco/redis Run tests: diff --git a/connpool.go b/connpool.go index 361187a8..9ec8fa52 100644 --- a/connpool.go +++ b/connpool.go @@ -1,23 +1,22 @@ package redis import ( + "bufio" "io" "log" "os" "sync" - - "github.com/vmihailenco/bufreader" ) type Conn struct { RW io.ReadWriter - Rd *bufreader.Reader + Rd *bufio.Reader } func NewConn(rw io.ReadWriter) *Conn { return &Conn{ RW: rw, - Rd: bufreader.NewSizedReader(8024), + Rd: bufio.NewReaderSize(rw, 1024), } } diff --git a/pubsub.go b/pubsub.go index e9f95d14..23f66616 100644 --- a/pubsub.go +++ b/pubsub.go @@ -42,17 +42,6 @@ func (c *PubSubClient) consumeMessages() { req := NewMultiBulkReq() for { - // Replies can arrive in batches. - // Read whole reply and parse messages one by one. - - err := c.ReadReply(conn) - if err != nil { - msg := &Message{} - msg.Err = err - c.ch <- msg - return - } - for { msg := &Message{} @@ -79,7 +68,7 @@ func (c *PubSubClient) consumeMessages() { } c.ch <- msg - if !conn.Rd.HasUnread() { + if conn.Rd.Buffered() <= 0 { break } } diff --git a/redis.go b/redis.go index 213367f4..f23147d9 100644 --- a/redis.go +++ b/redis.go @@ -2,12 +2,15 @@ package redis import ( "crypto/tls" + "errors" "fmt" "io" "net" "sync" +) - "github.com/vmihailenco/bufreader" +var ( + ErrReaderTooSmall = errors.New("redis: Reader is too small") ) type OpenConnFunc func() (io.ReadWriter, error) @@ -50,9 +53,7 @@ func AuthSelectFunc(password string, db int64) InitConnFunc { } } -func createReader() (*bufreader.Reader, error) { - return bufreader.NewSizedReader(8192), nil -} +//------------------------------------------------------------------------------ type Client struct { mtx sync.Mutex @@ -103,21 +104,6 @@ func (c *Client) WriteReq(buf []byte, conn *Conn) error { return err } -func (c *Client) ReadReply(conn *Conn) error { - _, err := conn.Rd.ReadFrom(conn.RW) - if err != nil { - return err - } - return nil -} - -func (c *Client) WriteRead(buf []byte, conn *Conn) error { - if err := c.WriteReq(buf, conn); err != nil { - return err - } - return c.ReadReply(conn) -} - func (c *Client) Process(req Req) { if c.reqs == nil { c.Run(req) @@ -139,7 +125,7 @@ func (c *Client) Run(req Req) { return } - err = c.WriteRead(req.Req(), conn) + err = c.WriteReq(req.Req(), conn) if err != nil { c.ConnPool.Remove(conn) req.SetErr(err) @@ -193,19 +179,12 @@ func (c *Client) RunReqs(reqs []Req, conn *Conn) error { } } - err := c.WriteRead(multiReq, conn) + err := c.WriteReq(multiReq, conn) if err != nil { return err } for i := 0; i < len(reqs); i++ { - if !conn.Rd.HasUnread() { - _, err := conn.Rd.ReadFrom(conn.RW) - if err != err { - return err - } - } - req := reqs[i] val, err := req.ParseReply(conn.Rd) if err != nil { @@ -259,7 +238,7 @@ func (c *Client) ExecReqs(reqs []Req, conn *Conn) error { } multiReq = append(multiReq, PackReq([]string{"EXEC"})...) - err := c.WriteRead(multiReq, conn) + err := c.WriteReq(multiReq, conn) if err != nil { return err } @@ -274,13 +253,6 @@ func (c *Client) ExecReqs(reqs []Req, conn *Conn) error { // Parse queued replies. for _ = range reqs { - if !conn.Rd.HasUnread() { - _, err := conn.Rd.ReadFrom(conn.RW) - if err != err { - return err - } - } - _, err = statusReq.ParseReply(conn.Rd) if err != nil { return err @@ -288,12 +260,13 @@ func (c *Client) ExecReqs(reqs []Req, conn *Conn) error { } // Parse number of replies. - line, err := conn.Rd.ReadLine('\n') + line, err := readLine(conn.Rd) if err != nil { return err } if line[0] != '*' { - return fmt.Errorf("Expected '*', but got line %q of %q.", line, conn.Rd.Bytes()) + buf, _ := conn.Rd.Peek(conn.Rd.Buffered()) + return fmt.Errorf("Expected '*', but got line %q of %q.", line, buf) } // Parse replies. diff --git a/redis_test.go b/redis_test.go index c7a0d6f2..24246aa0 100644 --- a/redis_test.go +++ b/redis_test.go @@ -1964,20 +1964,23 @@ func (t *RedisTest) BenchmarkRedisMGet(c *C) { func (t *RedisTest) BenchmarkRedisWriteRead(c *C) { c.StopTimer() - req := []byte("PING\r\n") conn, _, err := t.client.ConnPool.Get() c.Check(err, IsNil) for i := 0; i < 10; i++ { - err := t.client.WriteRead(req, conn) + err := t.client.WriteReq([]byte("PING\r\n"), conn) c.Check(err, IsNil) - c.Check(conn.Rd.Bytes(), DeepEquals, []byte("+PONG\r\n")) + + line, _, err := conn.Rd.ReadLine() + c.Check(err, IsNil) + c.Check(line, DeepEquals, []byte("+PONG")) } c.StartTimer() for i := 0; i < c.N; i++ { - t.client.WriteRead(req, conn) + t.client.WriteReq([]byte("PING\r\n"), conn) + conn.Rd.ReadLine() } c.StopTimer() diff --git a/request.go b/request.go index f44b1002..9e6cdf7f 100644 --- a/request.go +++ b/request.go @@ -1,11 +1,10 @@ package redis import ( + "bufio" "errors" "fmt" "strconv" - - "github.com/vmihailenco/bufreader" ) var Nil = errors.New("(nil)") @@ -28,8 +27,8 @@ func isNoReplies(line []byte) bool { //------------------------------------------------------------------------------ -func ParseReq(rd *bufreader.Reader) ([]string, error) { - line, err := rd.ReadLine('\n') +func ParseReq(rd *bufio.Reader) ([]string, error) { + line, err := readLine(rd) if err != nil { return nil, err } @@ -44,15 +43,19 @@ func ParseReq(rd *bufreader.Reader) ([]string, error) { args := make([]string, 0) for i := int64(0); i < numReplies; i++ { - line, err = rd.ReadLine('\n') + line, err = readLine(rd) if err != nil { return nil, err } if line[0] != '$' { - return nil, fmt.Errorf("Expected '$', but got %q of %q.", line, rd.Bytes()) + buf, _ := rd.Peek(rd.Buffered()) + return nil, fmt.Errorf("Expected '$', but got %q of %q.", line, buf) } - line, err = rd.ReadLine('\n') + line, err = readLine(rd) + if err != nil { + return nil, err + } args = append(args, string(line)) } return args, nil @@ -79,7 +82,7 @@ func PackReq(args []string) []byte { type Req interface { Req() []byte - ParseReply(*bufreader.Reader) (interface{}, error) + ParseReply(*bufio.Reader) (interface{}, error) SetErr(error) Err() error SetVal(interface{}) @@ -133,7 +136,7 @@ func (r *BaseReq) InterfaceVal() interface{} { return r.val } -func (r *BaseReq) ParseReply(rd *bufreader.Reader) (interface{}, error) { +func (r *BaseReq) ParseReply(rd *bufio.Reader) (interface{}, error) { panic("abstract") } @@ -149,8 +152,8 @@ func NewStatusReq(args ...string) *StatusReq { } } -func (r *StatusReq) ParseReply(rd *bufreader.Reader) (interface{}, error) { - line, err := rd.ReadLine('\n') +func (r *StatusReq) ParseReply(rd *bufio.Reader) (interface{}, error) { + line, err := readLine(rd) if err != nil { return nil, err } @@ -158,7 +161,8 @@ func (r *StatusReq) ParseReply(rd *bufreader.Reader) (interface{}, error) { if line[0] == '-' { return nil, errors.New(string(line[1:])) } else if line[0] != '+' { - return nil, fmt.Errorf("Expected '+', but got %q of %q.", line, rd.Bytes()) + buf, _ := rd.Peek(rd.Buffered()) + return nil, fmt.Errorf("Expected '+', but got %q of %q.", line, buf) } return string(line[1:]), nil @@ -183,8 +187,8 @@ func NewIntReq(args ...string) *IntReq { } } -func (r *IntReq) ParseReply(rd *bufreader.Reader) (interface{}, error) { - line, err := rd.ReadLine('\n') +func (r *IntReq) ParseReply(rd *bufio.Reader) (interface{}, error) { + line, err := readLine(rd) if err != nil { return nil, err } @@ -192,7 +196,8 @@ func (r *IntReq) ParseReply(rd *bufreader.Reader) (interface{}, error) { if line[0] == '-' { return nil, errors.New(string(line[1:])) } else if line[0] != ':' { - return nil, fmt.Errorf("Expected ':', but got %q of %q.", line, rd.Bytes()) + buf, _ := rd.Peek(rd.Buffered()) + return nil, fmt.Errorf("Expected ':', but got %q of %q.", line, buf) } return strconv.ParseInt(string(line[1:]), 10, 64) @@ -217,8 +222,8 @@ func NewIntNilReq(args ...string) *IntNilReq { } } -func (r *IntNilReq) ParseReply(rd *bufreader.Reader) (interface{}, error) { - line, err := rd.ReadLine('\n') +func (r *IntNilReq) ParseReply(rd *bufio.Reader) (interface{}, error) { + line, err := readLine(rd) if err != nil { return nil, err } @@ -231,7 +236,8 @@ func (r *IntNilReq) ParseReply(rd *bufreader.Reader) (interface{}, error) { return nil, Nil } - return nil, fmt.Errorf("Expected ':', but got %q of %q.", line, rd.Bytes()) + buf, _ := rd.Peek(rd.Buffered()) + return nil, fmt.Errorf("Expected ':', but got %q of %q.", line, buf) } func (r *IntNilReq) Val() int64 { @@ -253,8 +259,8 @@ func NewBoolReq(args ...string) *BoolReq { } } -func (r *BoolReq) ParseReply(rd *bufreader.Reader) (interface{}, error) { - line, err := rd.ReadLine('\n') +func (r *BoolReq) ParseReply(rd *bufio.Reader) (interface{}, error) { + line, err := readLine(rd) if err != nil { return nil, err } @@ -262,7 +268,8 @@ func (r *BoolReq) ParseReply(rd *bufreader.Reader) (interface{}, error) { if line[0] == '-' { return nil, errors.New(string(line[1:])) } else if line[0] != ':' { - return nil, fmt.Errorf("Expected ':', but got %q of %q.", line, rd.Bytes()) + buf, _ := rd.Peek(rd.Buffered()) + return nil, fmt.Errorf("Expected ':', but got %q of %q.", line, buf) } return line[1] == '1', nil @@ -287,8 +294,8 @@ func NewBulkReq(args ...string) *BulkReq { } } -func (r *BulkReq) ParseReply(rd *bufreader.Reader) (interface{}, error) { - line, err := rd.ReadLine('\n') +func (r *BulkReq) ParseReply(rd *bufio.Reader) (interface{}, error) { + line, err := readLine(rd) if err != nil { return nil, err } @@ -296,14 +303,15 @@ func (r *BulkReq) ParseReply(rd *bufreader.Reader) (interface{}, error) { if line[0] == '-' { return nil, errors.New(string(line[1:])) } else if line[0] != '$' { - return nil, fmt.Errorf("Expected '$', but got %q of %q.", line, rd.Bytes()) + buf, _ := rd.Peek(rd.Buffered()) + return nil, fmt.Errorf("Expected '$', but got %q of %q.", line, buf) } if isNil(line) { return nil, Nil } - line, err = rd.ReadLine('\n') + line, err = readLine(rd) if err != nil { return nil, err } @@ -330,8 +338,8 @@ func NewFloatReq(args ...string) *FloatReq { } } -func (r *FloatReq) ParseReply(rd *bufreader.Reader) (interface{}, error) { - line, err := rd.ReadLine('\n') +func (r *FloatReq) ParseReply(rd *bufio.Reader) (interface{}, error) { + line, err := readLine(rd) if err != nil { return nil, err } @@ -339,14 +347,15 @@ func (r *FloatReq) ParseReply(rd *bufreader.Reader) (interface{}, error) { if line[0] == '-' { return nil, errors.New(string(line[1:])) } else if line[0] != '$' { - return nil, fmt.Errorf("Expected '$', but got %q of %q.", line, rd.Bytes()) + buf, _ := rd.Peek(rd.Buffered()) + return nil, fmt.Errorf("Expected '$', but got %q of %q.", line, buf) } if isNil(line) { return nil, Nil } - line, err = rd.ReadLine('\n') + line, err = readLine(rd) if err != nil { return nil, err } @@ -373,8 +382,8 @@ func NewMultiBulkReq(args ...string) *MultiBulkReq { } } -func (r *MultiBulkReq) ParseReply(rd *bufreader.Reader) (interface{}, error) { - line, err := rd.ReadLine('\n') +func (r *MultiBulkReq) ParseReply(rd *bufio.Reader) (interface{}, error) { + line, err := readLine(rd) if err != nil { return nil, err } @@ -382,7 +391,8 @@ func (r *MultiBulkReq) ParseReply(rd *bufreader.Reader) (interface{}, error) { if line[0] == '-' { return nil, errors.New(string(line[1:])) } else if line[0] != '*' { - return nil, fmt.Errorf("Expected '*', but got line %q of %q.", line, rd.Bytes()) + buf, _ := rd.Peek(rd.Buffered()) + return nil, fmt.Errorf("Expected '*', but got line %q of %q.", line, buf) } else if isNil(line) { return nil, Nil } @@ -397,7 +407,7 @@ func (r *MultiBulkReq) ParseReply(rd *bufreader.Reader) (interface{}, error) { } for i := int64(0); i < numReplies; i++ { - line, err = rd.ReadLine('\n') + line, err = readLine(rd) if err != nil { return nil, err } @@ -415,14 +425,15 @@ func (r *MultiBulkReq) ParseReply(rd *bufreader.Reader) (interface{}, error) { } else if isNil(line) { val = append(val, nil) } else { - line, err = rd.ReadLine('\n') + line, err = readLine(rd) if err != nil { return nil, err } val = append(val, string(line)) } } else { - return nil, fmt.Errorf("Expected '$', but got line %q of %q.", line, rd.Bytes()) + buf, _ := rd.Peek(rd.Buffered()) + return nil, fmt.Errorf("Expected '$', but got line %q of %q.", line, buf) } } diff --git a/request_test.go b/request_test.go index 7fe27b3c..6f7938b0 100644 --- a/request_test.go +++ b/request_test.go @@ -1,7 +1,8 @@ package redis_test import ( - "github.com/vmihailenco/bufreader" + "bufio" + . "launchpad.net/gocheck" "github.com/vmihailenco/redis" @@ -9,6 +10,20 @@ import ( //------------------------------------------------------------------------------ +type LineReader struct { + line []byte +} + +func NewLineReader(line []byte) *LineReader { + return &LineReader{line: line} +} + +func (r *LineReader) Read(buf []byte) (int, error) { + return copy(buf, r.line), nil +} + +//------------------------------------------------------------------------------ + type RequestTest struct{} var _ = Suite(&RequestTest{}) @@ -24,12 +39,11 @@ func (t *RequestTest) TearDownTest(c *C) {} func (t *RequestTest) BenchmarkStatusReq(c *C) { c.StopTimer() - rd := bufreader.NewSizedReader(1024) - rd.Set([]byte("+OK\r\n")) + lineReader := NewLineReader([]byte("+OK\r\n")) + rd := bufio.NewReaderSize(lineReader, 1024) req := redis.NewStatusReq() for i := 0; i < 10; i++ { - rd.ResetPos() vI, err := req.ParseReply(rd) c.Check(err, IsNil) c.Check(vI, Equals, "OK") @@ -42,7 +56,6 @@ func (t *RequestTest) BenchmarkStatusReq(c *C) { c.StartTimer() for i := 0; i < c.N; i++ { - rd.ResetPos() v, _ := req.ParseReply(rd) req.SetVal(v) req.Err() diff --git a/utils.go b/utils.go new file mode 100644 index 00000000..a2459342 --- /dev/null +++ b/utils.go @@ -0,0 +1,16 @@ +package redis + +import ( + "bufio" +) + +func readLine(rd *bufio.Reader) ([]byte, error) { + line, isPrefix, err := rd.ReadLine() + if err != nil { + return line, err + } + if isPrefix { + return line, ErrReaderTooSmall + } + return line, nil +}