diff --git a/cluster.go b/cluster.go index 1ae6082..217cadd 100644 --- a/cluster.go +++ b/cluster.go @@ -7,6 +7,7 @@ import ( "time" "gopkg.in/redis.v4/internal" + "gopkg.in/redis.v4/internal/errors" "gopkg.in/redis.v4/internal/hashtag" "gopkg.in/redis.v4/internal/pool" ) @@ -291,14 +292,14 @@ func (c *ClusterClient) Process(cmd Cmder) error { } // On network errors try random node. - if shouldRetry(err) { + if errors.IsRetryable(err) { node, err = c.randomNode() continue } var moved bool var addr string - moved, ask, addr = isMovedError(err) + moved, ask, addr = errors.IsMoved(err) if moved || ask { master, _ := c.slotMasterNode(slot) if moved && (master == nil || master.Addr != addr) { @@ -549,11 +550,11 @@ func (c *ClusterClient) execClusterCmds( if err == nil { continue } - if isNetworkError(err) { + if errors.IsNetwork(err) { cmd.reset() failedCmds[nil] = append(failedCmds[nil], cmds[i:]...) break - } else if moved, ask, addr := isMovedError(err); moved { + } else if moved, ask, addr := errors.IsMoved(err); moved { c.lazyReloadSlots() cmd.reset() node, err := c.nodeByAddr(addr) diff --git a/cluster_test.go b/cluster_test.go index 6889d4e..17d5113 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -393,6 +393,7 @@ var _ = Describe("ClusterClient", func() { for i := 0; i < 100; i++ { wg.Add(1) go func() { + defer GinkgoRecover() defer wg.Done() err := incr("key") diff --git a/command.go b/command.go index c562582..9cb7058 100644 --- a/command.go +++ b/command.go @@ -8,6 +8,7 @@ import ( "time" "gopkg.in/redis.v4/internal/pool" + "gopkg.in/redis.v4/internal/proto" ) var ( @@ -55,16 +56,14 @@ func resetCmds(cmds []Cmder) { } func writeCmd(cn *pool.Conn, cmds ...Cmder) error { - cn.Buf = cn.Buf[:0] + cn.Wb.Reset() for _, cmd := range cmds { - var err error - cn.Buf, err = appendArgs(cn.Buf, cmd.args()) - if err != nil { + if err := cn.Wb.Append(cmd.args()); err != nil { return err } } - _, err := cn.Write(cn.Buf) + _, err := cn.Write(cn.Wb.Bytes()) return err } @@ -166,7 +165,7 @@ func (cmd *Cmd) String() string { } func (cmd *Cmd) readReply(cn *pool.Conn) error { - val, err := readReply(cn, sliceParser) + val, err := cn.Rd.ReadReply(sliceParser) if err != nil { cmd.err = err return cmd.err @@ -211,7 +210,7 @@ func (cmd *SliceCmd) String() string { } func (cmd *SliceCmd) readReply(cn *pool.Conn) error { - v, err := readArrayReply(cn, sliceParser) + v, err := cn.Rd.ReadArrayReply(sliceParser) if err != nil { cmd.err = err return err @@ -251,7 +250,7 @@ func (cmd *StatusCmd) String() string { } func (cmd *StatusCmd) readReply(cn *pool.Conn) error { - cmd.val, cmd.err = readStringReply(cn) + cmd.val, cmd.err = cn.Rd.ReadStringReply() return cmd.err } @@ -286,7 +285,7 @@ func (cmd *IntCmd) String() string { } func (cmd *IntCmd) readReply(cn *pool.Conn) error { - cmd.val, cmd.err = readIntReply(cn) + cmd.val, cmd.err = cn.Rd.ReadIntReply() return cmd.err } @@ -325,7 +324,7 @@ func (cmd *DurationCmd) String() string { } func (cmd *DurationCmd) readReply(cn *pool.Conn) error { - n, err := readIntReply(cn) + n, err := cn.Rd.ReadIntReply() if err != nil { cmd.err = err return err @@ -367,7 +366,7 @@ func (cmd *BoolCmd) String() string { var ok = []byte("OK") func (cmd *BoolCmd) readReply(cn *pool.Conn) error { - v, err := readReply(cn, nil) + v, err := cn.Rd.ReadReply(nil) // `SET key value NX` returns nil when key already exists. But // `SETNX key value` returns bool (0/1). So convert nil to bool. // TODO: is this okay? @@ -410,7 +409,7 @@ func (cmd *StringCmd) reset() { } func (cmd *StringCmd) Val() string { - return bytesToString(cmd.val) + return string(cmd.val) } func (cmd *StringCmd) Result() (string, error) { @@ -446,7 +445,7 @@ func (cmd *StringCmd) Scan(val interface{}) error { if cmd.err != nil { return cmd.err } - return scan(cmd.val, val) + return proto.Scan(cmd.val, val) } func (cmd *StringCmd) String() string { @@ -454,7 +453,7 @@ func (cmd *StringCmd) String() string { } func (cmd *StringCmd) readReply(cn *pool.Conn) error { - b, err := readBytesReply(cn) + b, err := cn.Rd.ReadBytesReply() if err != nil { cmd.err = err return err @@ -498,7 +497,7 @@ func (cmd *FloatCmd) String() string { } func (cmd *FloatCmd) readReply(cn *pool.Conn) error { - cmd.val, cmd.err = readFloatReply(cn) + cmd.val, cmd.err = cn.Rd.ReadFloatReply() return cmd.err } @@ -533,7 +532,7 @@ func (cmd *StringSliceCmd) String() string { } func (cmd *StringSliceCmd) readReply(cn *pool.Conn) error { - v, err := readArrayReply(cn, stringSliceParser) + v, err := cn.Rd.ReadArrayReply(stringSliceParser) if err != nil { cmd.err = err return err @@ -573,7 +572,7 @@ func (cmd *BoolSliceCmd) String() string { } func (cmd *BoolSliceCmd) readReply(cn *pool.Conn) error { - v, err := readArrayReply(cn, boolSliceParser) + v, err := cn.Rd.ReadArrayReply(boolSliceParser) if err != nil { cmd.err = err return err @@ -613,7 +612,7 @@ func (cmd *StringStringMapCmd) String() string { } func (cmd *StringStringMapCmd) readReply(cn *pool.Conn) error { - v, err := readArrayReply(cn, stringStringMapParser) + v, err := cn.Rd.ReadArrayReply(stringStringMapParser) if err != nil { cmd.err = err return err @@ -653,7 +652,7 @@ func (cmd *StringIntMapCmd) reset() { } func (cmd *StringIntMapCmd) readReply(cn *pool.Conn) error { - v, err := readArrayReply(cn, stringIntMapParser) + v, err := cn.Rd.ReadArrayReply(stringIntMapParser) if err != nil { cmd.err = err return err @@ -693,7 +692,7 @@ func (cmd *ZSliceCmd) String() string { } func (cmd *ZSliceCmd) readReply(cn *pool.Conn) error { - v, err := readArrayReply(cn, zSliceParser) + v, err := cn.Rd.ReadArrayReply(zSliceParser) if err != nil { cmd.err = err return err @@ -737,7 +736,7 @@ func (cmd *ScanCmd) String() string { } func (cmd *ScanCmd) readReply(cn *pool.Conn) error { - page, cursor, err := readScanReply(cn) + page, cursor, err := cn.Rd.ReadScanReply() if err != nil { cmd.err = err return cmd.err @@ -789,7 +788,7 @@ func (cmd *ClusterSlotsCmd) reset() { } func (cmd *ClusterSlotsCmd) readReply(cn *pool.Conn) error { - v, err := readArrayReply(cn, clusterSlotsParser) + v, err := cn.Rd.ReadArrayReply(clusterSlotsParser) if err != nil { cmd.err = err return err @@ -874,7 +873,7 @@ func (cmd *GeoLocationCmd) String() string { } func (cmd *GeoLocationCmd) readReply(cn *pool.Conn) error { - reply, err := readArrayReply(cn, newGeoLocationSliceParser(cmd.q)) + reply, err := cn.Rd.ReadArrayReply(newGeoLocationSliceParser(cmd.q)) if err != nil { cmd.err = err return err @@ -924,7 +923,7 @@ func (cmd *CommandsInfoCmd) reset() { } func (cmd *CommandsInfoCmd) readReply(cn *pool.Conn) error { - v, err := readArrayReply(cn, commandInfoSliceParser) + v, err := cn.Rd.ReadArrayReply(commandInfoSliceParser) if err != nil { cmd.err = err return err diff --git a/commands.go b/commands.go index f84311c..412adae 100644 --- a/commands.go +++ b/commands.go @@ -6,20 +6,9 @@ import ( "time" "gopkg.in/redis.v4/internal" + "gopkg.in/redis.v4/internal/errors" ) -func formatInt(i int64) string { - return strconv.FormatInt(i, 10) -} - -func formatUint(i uint64) string { - return strconv.FormatUint(i, 10) -} - -func formatFloat(f float64) string { - return strconv.FormatFloat(f, 'f', -1, 64) -} - func readTimeout(timeout time.Duration) time.Duration { if timeout == 0 { return 0 @@ -38,7 +27,7 @@ func formatMs(dur time.Duration) string { dur, time.Millisecond, ) } - return formatInt(int64(dur / time.Millisecond)) + return strconv.FormatInt(int64(dur/time.Millisecond), 10) } func formatSec(dur time.Duration) string { @@ -48,7 +37,7 @@ func formatSec(dur time.Duration) string { dur, time.Second, ) } - return formatInt(int64(dur / time.Second)) + return strconv.FormatInt(int64(dur/time.Second), 10) } type cmdable struct { @@ -1515,7 +1504,7 @@ func (c *cmdable) shutdown(modifier string) *StatusCmd { } } else { // Server did not quit. String reply contains the reason. - cmd.err = errorf(cmd.val) + cmd.err = errors.RedisError(cmd.val) cmd.val = "" } return cmd diff --git a/error.go b/error.go deleted file mode 100644 index f0b27a8..0000000 --- a/error.go +++ /dev/null @@ -1,81 +0,0 @@ -package redis - -import ( - "fmt" - "io" - "net" - "strings" -) - -// Redis nil reply, .e.g. when key does not exist. -var Nil = errorf("redis: nil") - -// Redis transaction failed. -var TxFailedErr = errorf("redis: transaction failed") - -type redisError struct { - s string -} - -func errorf(s string, args ...interface{}) redisError { - return redisError{s: fmt.Sprintf(s, args...)} -} - -func (err redisError) Error() string { - return err.s -} - -func isInternalError(err error) bool { - _, ok := err.(redisError) - return ok -} - -func isNetworkError(err error) bool { - if err == io.EOF { - return true - } - _, ok := err.(net.Error) - return ok -} - -func isBadConn(err error, allowTimeout bool) bool { - if err == nil { - return false - } - if isInternalError(err) { - return false - } - if allowTimeout { - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - return false - } - } - return true -} - -func isMovedError(err error) (moved bool, ask bool, addr string) { - if _, ok := err.(redisError); !ok { - return - } - - s := err.Error() - if strings.HasPrefix(s, "MOVED ") { - moved = true - } else if strings.HasPrefix(s, "ASK ") { - ask = true - } else { - return - } - - ind := strings.LastIndexByte(s, ' ') - if ind == -1 { - return false, false, "" - } - addr = s[ind+1:] - return -} - -// shouldRetry reports whether failed command should be retried. -func shouldRetry(err error) bool { - return isNetworkError(err) -} diff --git a/internal/errors/errors.go b/internal/errors/errors.go new file mode 100644 index 0000000..6d664dd --- /dev/null +++ b/internal/errors/errors.go @@ -0,0 +1,67 @@ +package errors + +import ( + "io" + "net" + "strings" +) + +const Nil = RedisError("redis: nil") + +type RedisError string + +func (e RedisError) Error() string { return string(e) } + +func IsRetryable(err error) bool { + return IsNetwork(err) +} + +func IsInternal(err error) bool { + _, ok := err.(RedisError) + return ok +} + +func IsNetwork(err error) bool { + if err == io.EOF { + return true + } + _, ok := err.(net.Error) + return ok +} + +func IsBadConn(err error, allowTimeout bool) bool { + if err == nil { + return false + } + if IsInternal(err) { + return false + } + if allowTimeout { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return false + } + } + return true +} + +func IsMoved(err error) (moved bool, ask bool, addr string) { + if !IsInternal(err) { + return + } + + s := err.Error() + if strings.HasPrefix(s, "MOVED ") { + moved = true + } else if strings.HasPrefix(s, "ASK ") { + ask = true + } else { + return + } + + ind := strings.LastIndexByte(s, ' ') + if ind == -1 { + return false, false, "" + } + addr = s[ind+1:] + return +} diff --git a/internal/pool/conn.go b/internal/pool/conn.go index 497fd4e..bb0922f 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -1,10 +1,10 @@ package pool import ( - "bufio" - "io" "net" "time" + + "gopkg.in/redis.v4/internal/proto" ) const defaultBufSize = 4096 @@ -13,8 +13,8 @@ var noDeadline = time.Time{} type Conn struct { NetConn net.Conn - Rd *bufio.Reader - Buf []byte + Rd *proto.Reader + Wb *proto.WriteBuffer Inited bool UsedAt time.Time @@ -26,11 +26,11 @@ type Conn struct { func NewConn(netConn net.Conn) *Conn { cn := &Conn{ NetConn: netConn, - Buf: make([]byte, defaultBufSize), + Wb: proto.NewWriteBuffer(), UsedAt: time.Now(), } - cn.Rd = bufio.NewReader(cn) + cn.Rd = proto.NewReader(cn) return cn } @@ -62,17 +62,6 @@ func (cn *Conn) RemoteAddr() net.Addr { return cn.NetConn.RemoteAddr() } -func (cn *Conn) ReadN(n int) ([]byte, error) { - if d := n - cap(cn.Buf); d > 0 { - cn.Buf = cn.Buf[:cap(cn.Buf)] - cn.Buf = append(cn.Buf, make([]byte, d)...) - } else { - cn.Buf = cn.Buf[:n] - } - _, err := io.ReadFull(cn.Rd, cn.Buf) - return cn.Buf, err -} - func (cn *Conn) Close() error { return cn.NetConn.Close() } diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 330767c..da1e381 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -205,9 +205,8 @@ func (p *ConnPool) Get() (*Conn, error) { } func (p *ConnPool) Put(cn *Conn) error { - if cn.Rd.Buffered() != 0 { - b, _ := cn.Rd.Peek(cn.Rd.Buffered()) - err := fmt.Errorf("connection has unread data: %q", b) + if data := cn.Rd.PeekBuffered(); data != nil { + err := fmt.Errorf("connection has unread data: %q", data) internal.Logf(err.Error()) return p.Remove(cn, err) } diff --git a/internal/proto/proto.go b/internal/proto/proto.go new file mode 100644 index 0000000..c63caaa --- /dev/null +++ b/internal/proto/proto.go @@ -0,0 +1,122 @@ +package proto + +import ( + "encoding" + "fmt" + "strconv" + + "gopkg.in/redis.v4/internal/errors" +) + +const ( + ErrorReply = '-' + StatusReply = '+' + IntReply = ':' + StringReply = '$' + ArrayReply = '*' +) + +const defaultBufSize = 4096 + +var errScanNil = errors.RedisError("redis: Scan(nil)") + +func Scan(b []byte, val interface{}) error { + switch v := val.(type) { + case nil: + return errScanNil + case *string: + *v = string(b) + return nil + case *[]byte: + *v = b + return nil + case *int: + var err error + *v, err = strconv.Atoi(string(b)) + return err + case *int8: + n, err := strconv.ParseInt(string(b), 10, 8) + if err != nil { + return err + } + *v = int8(n) + return nil + case *int16: + n, err := strconv.ParseInt(string(b), 10, 16) + if err != nil { + return err + } + *v = int16(n) + return nil + case *int32: + n, err := strconv.ParseInt(string(b), 10, 32) + if err != nil { + return err + } + *v = int32(n) + return nil + case *int64: + n, err := strconv.ParseInt(string(b), 10, 64) + if err != nil { + return err + } + *v = n + return nil + case *uint: + n, err := strconv.ParseUint(string(b), 10, 64) + if err != nil { + return err + } + *v = uint(n) + return nil + case *uint8: + n, err := strconv.ParseUint(string(b), 10, 8) + if err != nil { + return err + } + *v = uint8(n) + return nil + case *uint16: + n, err := strconv.ParseUint(string(b), 10, 16) + if err != nil { + return err + } + *v = uint16(n) + return nil + case *uint32: + n, err := strconv.ParseUint(string(b), 10, 32) + if err != nil { + return err + } + *v = uint32(n) + return nil + case *uint64: + n, err := strconv.ParseUint(string(b), 10, 64) + if err != nil { + return err + } + *v = n + return nil + case *float32: + n, err := strconv.ParseFloat(string(b), 32) + if err != nil { + return err + } + *v = float32(n) + return err + case *float64: + var err error + *v, err = strconv.ParseFloat(string(b), 64) + return err + case *bool: + *v = len(b) == 1 && b[0] == '1' + return nil + default: + if bu, ok := val.(encoding.BinaryUnmarshaler); ok { + return bu.UnmarshalBinary(b) + } + err := fmt.Errorf( + "redis: can't unmarshal %T (consider implementing BinaryUnmarshaler)", val) + return err + } +} diff --git a/internal/proto/proto_test.go b/internal/proto/proto_test.go new file mode 100644 index 0000000..c9a820e --- /dev/null +++ b/internal/proto/proto_test.go @@ -0,0 +1,13 @@ +package proto_test + +import ( + "testing" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestGinkgoSuite(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "proto") +} diff --git a/internal/proto/reader.go b/internal/proto/reader.go new file mode 100644 index 0000000..9d2f51a --- /dev/null +++ b/internal/proto/reader.go @@ -0,0 +1,264 @@ +package proto + +import ( + "bufio" + "errors" + "fmt" + "io" + "strconv" + + ierrors "gopkg.in/redis.v4/internal/errors" +) + +type MultiBulkParse func(*Reader, int64) (interface{}, error) + +var errEmptyReply = errors.New("redis: reply is empty") + +type Reader struct { + src *bufio.Reader + buf []byte +} + +func NewReader(rd io.Reader) *Reader { + return &Reader{ + src: bufio.NewReader(rd), + buf: make([]byte, 0, defaultBufSize), + } +} + +func (p *Reader) PeekBuffered() []byte { + if n := p.src.Buffered(); n != 0 { + b, _ := p.src.Peek(n) + return b + } + return nil +} + +func (p *Reader) ReadN(n int) ([]byte, error) { + // grow internal buffer, if necessary + if d := n - cap(p.buf); d > 0 { + p.buf = p.buf[:cap(p.buf)] + p.buf = append(p.buf, make([]byte, d)...) + } else { + p.buf = p.buf[:n] + } + + _, err := io.ReadFull(p.src, p.buf) + return p.buf, err +} + +func (p *Reader) ReadLine() ([]byte, error) { + line, isPrefix, err := p.src.ReadLine() + if err != nil { + return nil, err + } + if isPrefix { + return nil, bufio.ErrBufferFull + } + if len(line) == 0 { + return nil, errEmptyReply + } + if isNilReply(line) { + return nil, ierrors.Nil + } + return line, nil +} + +func (p *Reader) ReadReply(m MultiBulkParse) (interface{}, error) { + line, err := p.ReadLine() + if err != nil { + return nil, err + } + + switch line[0] { + case ErrorReply: + return nil, parseErrorValue(line) + case StatusReply: + return parseStatusValue(line) + case IntReply: + return parseIntValue(line) + case StringReply: + return p.parseBytesValue(line) + case ArrayReply: + n, err := parseArrayLen(line) + if err != nil { + return nil, err + } + return m(p, n) + } + return nil, fmt.Errorf("redis: can't parse %.100q", line) +} + +func (p *Reader) ReadIntReply() (int64, error) { + line, err := p.ReadLine() + if err != nil { + return 0, err + } + switch line[0] { + case ErrorReply: + return 0, parseErrorValue(line) + case IntReply: + return parseIntValue(line) + default: + return 0, fmt.Errorf("redis: can't parse int reply: %.100q", line) + } +} + +func (p *Reader) ReadBytesReply() ([]byte, error) { + line, err := p.ReadLine() + if err != nil { + return nil, err + } + switch line[0] { + case ErrorReply: + return nil, parseErrorValue(line) + case StringReply: + return p.parseBytesValue(line) + case StatusReply: + return parseStatusValue(line) + default: + return nil, fmt.Errorf("redis: can't parse string reply: %.100q", line) + } +} + +func (p *Reader) ReadStringReply() (string, error) { + b, err := p.ReadBytesReply() + if err != nil { + return "", err + } + return string(b), nil +} + +func (p *Reader) ReadFloatReply() (float64, error) { + s, err := p.ReadStringReply() + if err != nil { + return 0, err + } + return strconv.ParseFloat(s, 64) +} + +func (p *Reader) ReadArrayReply(m MultiBulkParse) (interface{}, error) { + line, err := p.ReadLine() + if err != nil { + return nil, err + } + switch line[0] { + case ErrorReply: + return nil, parseErrorValue(line) + case ArrayReply: + n, err := parseArrayLen(line) + if err != nil { + return nil, err + } + return m(p, n) + default: + return nil, fmt.Errorf("redis: can't parse array reply: %.100q", line) + } +} + +func (p *Reader) ReadArrayLen() (int64, error) { + line, err := p.ReadLine() + if err != nil { + return 0, err + } + switch line[0] { + case ErrorReply: + return 0, parseErrorValue(line) + case ArrayReply: + return parseArrayLen(line) + default: + return 0, fmt.Errorf("redis: can't parse array reply: %.100q", line) + } +} + +func (p *Reader) ReadScanReply() ([]string, uint64, error) { + n, err := p.ReadArrayLen() + if err != nil { + return nil, 0, err + } + if n != 2 { + return nil, 0, fmt.Errorf("redis: got %d elements in scan reply, expected 2", n) + } + + s, err := p.ReadStringReply() + if err != nil { + return nil, 0, err + } + + cursor, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return nil, 0, err + } + + n, err = p.ReadArrayLen() + if err != nil { + return nil, 0, err + } + + keys := make([]string, n) + for i := int64(0); i < n; i++ { + key, err := p.ReadStringReply() + if err != nil { + return nil, 0, err + } + keys[i] = key + } + + return keys, cursor, err +} + +func (p *Reader) parseBytesValue(line []byte) ([]byte, error) { + if isNilReply(line) { + return nil, ierrors.Nil + } + + replyLen, err := strconv.Atoi(string(line[1:])) + if err != nil { + return nil, err + } + + b, err := p.ReadN(replyLen + 2) + if err != nil { + return nil, err + } + return b[:replyLen], nil +} + +// -------------------------------------------------------------------- + +func formatInt(n int64) string { + return strconv.FormatInt(n, 10) +} + +func formatUint(u uint64) string { + return strconv.FormatUint(u, 10) +} + +func formatFloat(f float64) string { + return strconv.FormatFloat(f, 'f', -1, 64) +} + +func isNilReply(b []byte) bool { + return len(b) == 3 && + (b[0] == StringReply || b[0] == ArrayReply) && + b[1] == '-' && b[2] == '1' +} + +func parseErrorValue(line []byte) error { + return ierrors.RedisError(string(line[1:])) +} + +func parseStatusValue(line []byte) ([]byte, error) { + return line[1:], nil +} + +func parseIntValue(line []byte) (int64, error) { + return strconv.ParseInt(string(line[1:]), 10, 64) +} + +func parseArrayLen(line []byte) (int64, error) { + if isNilReply(line) { + return 0, ierrors.Nil + } + return parseIntValue(line) +} diff --git a/internal/proto/reader_test.go b/internal/proto/reader_test.go new file mode 100644 index 0000000..5169262 --- /dev/null +++ b/internal/proto/reader_test.go @@ -0,0 +1,86 @@ +package proto_test + +import ( + "bytes" + "strings" + "testing" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "gopkg.in/redis.v4/internal/proto" +) + +var _ = Describe("Reader", func() { + + It("should read n bytes", func() { + data, err := proto.NewReader(strings.NewReader("ABCDEFGHIJKLMNO")).ReadN(10) + Expect(err).NotTo(HaveOccurred()) + Expect(len(data)).To(Equal(10)) + Expect(string(data)).To(Equal("ABCDEFGHIJ")) + + data, err = proto.NewReader(strings.NewReader(strings.Repeat("x", 8192))).ReadN(6000) + Expect(err).NotTo(HaveOccurred()) + Expect(len(data)).To(Equal(6000)) + }) + + It("should read lines", func() { + p := proto.NewReader(strings.NewReader("$5\r\nhello\r\n")) + + data, err := p.ReadLine() + Expect(err).NotTo(HaveOccurred()) + Expect(string(data)).To(Equal("$5")) + + data, err = p.ReadLine() + Expect(err).NotTo(HaveOccurred()) + Expect(string(data)).To(Equal("hello")) + }) + +}) + +func BenchmarkReader_ParseReply_Status(b *testing.B) { + benchmarkParseReply(b, "+OK\r\n", nil, false) +} + +func BenchmarkReader_ParseReply_Int(b *testing.B) { + benchmarkParseReply(b, ":1\r\n", nil, false) +} + +func BenchmarkReader_ParseReply_Error(b *testing.B) { + benchmarkParseReply(b, "-Error message\r\n", nil, true) +} + +func BenchmarkReader_ParseReply_String(b *testing.B) { + benchmarkParseReply(b, "$5\r\nhello\r\n", nil, false) +} + +func BenchmarkReader_ParseReply_Slice(b *testing.B) { + benchmarkParseReply(b, "*2\r\n$5\r\nhello\r\n$5\r\nworld\r\n", multiBulkParse, false) +} + +func benchmarkParseReply(b *testing.B, reply string, m proto.MultiBulkParse, wanterr bool) { + buf := &bytes.Buffer{} + for i := 0; i < b.N; i++ { + buf.WriteString(reply) + } + p := proto.NewReader(buf) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, err := p.ReadReply(m) + if !wanterr && err != nil { + b.Fatal(err) + } + } +} + +func multiBulkParse(p *proto.Reader, n int64) (interface{}, error) { + vv := make([]interface{}, 0, n) + for i := int64(0); i < n; i++ { + v, err := p.ReadReply(multiBulkParse) + if err != nil { + return nil, err + } + vv = append(vv, v) + } + return vv, nil +} diff --git a/internal/proto/writebuffer.go b/internal/proto/writebuffer.go new file mode 100644 index 0000000..8164f7f --- /dev/null +++ b/internal/proto/writebuffer.go @@ -0,0 +1,101 @@ +package proto + +import ( + "encoding" + "fmt" + "strconv" +) + +type WriteBuffer struct{ b []byte } + +func NewWriteBuffer() *WriteBuffer { + return &WriteBuffer{ + b: make([]byte, 0, defaultBufSize), + } +} + +func (w *WriteBuffer) Len() int { return len(w.b) } +func (w *WriteBuffer) Bytes() []byte { return w.b } +func (w *WriteBuffer) Reset() { w.b = w.b[:0] } + +func (w *WriteBuffer) Append(args []interface{}) error { + w.b = append(w.b, ArrayReply) + w.b = strconv.AppendUint(w.b, uint64(len(args)), 10) + w.b = append(w.b, '\r', '\n') + + for _, arg := range args { + if err := w.append(arg); err != nil { + return err + } + } + return nil +} + +func (w *WriteBuffer) append(val interface{}) error { + switch v := val.(type) { + case nil: + w.AppendString("") + case string: + w.AppendString(v) + case []byte: + w.AppendBytes(v) + case int: + w.AppendString(formatInt(int64(v))) + case int8: + w.AppendString(formatInt(int64(v))) + case int16: + w.AppendString(formatInt(int64(v))) + case int32: + w.AppendString(formatInt(int64(v))) + case int64: + w.AppendString(formatInt(v)) + case uint: + w.AppendString(formatUint(uint64(v))) + case uint8: + w.AppendString(formatUint(uint64(v))) + case uint16: + w.AppendString(formatUint(uint64(v))) + case uint32: + w.AppendString(formatUint(uint64(v))) + case uint64: + w.AppendString(formatUint(v)) + case float32: + w.AppendString(formatFloat(float64(v))) + case float64: + w.AppendString(formatFloat(v)) + case bool: + if v { + w.AppendString("1") + } else { + w.AppendString("0") + } + default: + if bm, ok := val.(encoding.BinaryMarshaler); ok { + bb, err := bm.MarshalBinary() + if err != nil { + return err + } + w.AppendBytes(bb) + } else { + return fmt.Errorf( + "redis: can't marshal %T (consider implementing encoding.BinaryMarshaler)", val) + } + } + return nil +} + +func (w *WriteBuffer) AppendString(s string) { + w.b = append(w.b, StringReply) + w.b = strconv.AppendUint(w.b, uint64(len(s)), 10) + w.b = append(w.b, '\r', '\n') + w.b = append(w.b, s...) + w.b = append(w.b, '\r', '\n') +} + +func (w *WriteBuffer) AppendBytes(p []byte) { + w.b = append(w.b, StringReply) + w.b = strconv.AppendUint(w.b, uint64(len(p)), 10) + w.b = append(w.b, '\r', '\n') + w.b = append(w.b, p...) + w.b = append(w.b, '\r', '\n') +} diff --git a/internal/proto/writebuffer_test.go b/internal/proto/writebuffer_test.go new file mode 100644 index 0000000..6316ded --- /dev/null +++ b/internal/proto/writebuffer_test.go @@ -0,0 +1,62 @@ +package proto_test + +import ( + "testing" + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "gopkg.in/redis.v4/internal/proto" +) + +var _ = Describe("WriteBuffer", func() { + var buf *proto.WriteBuffer + + BeforeEach(func() { + buf = proto.NewWriteBuffer() + }) + + It("should reset", func() { + buf.AppendString("string") + Expect(buf.Len()).To(Equal(12)) + buf.Reset() + Expect(buf.Len()).To(Equal(0)) + }) + + It("should append args", func() { + err := buf.Append([]interface{}{ + "string", + 12, + 34.56, + []byte{'b', 'y', 't', 'e', 's'}, + true, + nil, + }) + Expect(err).NotTo(HaveOccurred()) + Expect(buf.Bytes()).To(Equal([]byte("*6\r\n" + + "$6\r\nstring\r\n" + + "$2\r\n12\r\n" + + "$5\r\n34.56\r\n" + + "$5\r\nbytes\r\n" + + "$1\r\n1\r\n" + + "$0\r\n" + + "\r\n"))) + }) + + It("should append marshalable args", func() { + err := buf.Append([]interface{}{time.Unix(1414141414, 0)}) + Expect(err).NotTo(HaveOccurred()) + Expect(buf.Len()).To(Equal(26)) + }) + +}) + +func BenchmarkWriteBuffer_Append(b *testing.B) { + buf := proto.NewWriteBuffer() + args := []interface{}{"hello", "world", "foo", "bar"} + + for i := 0; i < b.N; i++ { + buf.Append(args) + buf.Reset() + } +} diff --git a/main_test.go b/main_test.go index b0277e8..5208903 100644 --- a/main_test.go +++ b/main_test.go @@ -136,6 +136,8 @@ func eventually(fn func() error, timeout time.Duration) error { done := make(chan struct{}) go func() { + defer GinkgoRecover() + for atomic.LoadInt32(&exit) == 0 { err := fn() if err == nil { diff --git a/parser.go b/parser.go index f6e5ca6..95a4297 100644 --- a/parser.go +++ b/parser.go @@ -1,446 +1,18 @@ package redis import ( - "bufio" - "errors" "fmt" "net" "strconv" - "gopkg.in/redis.v4/internal/pool" + "gopkg.in/redis.v4/internal/proto" ) -const ( - errorReply = '-' - statusReply = '+' - intReply = ':' - stringReply = '$' - arrayReply = '*' -) - -type multiBulkParser func(cn *pool.Conn, n int64) (interface{}, error) - -var errEmptyReply = errors.New("redis: reply is empty") - -//------------------------------------------------------------------------------ - -// Copy of encoding.BinaryMarshaler. -type binaryMarshaler interface { - MarshalBinary() (data []byte, err error) -} - -// Copy of encoding.BinaryUnmarshaler. -type binaryUnmarshaler interface { - UnmarshalBinary(data []byte) error -} - -func appendString(b []byte, s string) []byte { - b = append(b, '$') - b = strconv.AppendUint(b, uint64(len(s)), 10) - b = append(b, '\r', '\n') - b = append(b, s...) - b = append(b, '\r', '\n') - return b -} - -func appendBytes(b, bb []byte) []byte { - b = append(b, '$') - b = strconv.AppendUint(b, uint64(len(bb)), 10) - b = append(b, '\r', '\n') - b = append(b, bb...) - b = append(b, '\r', '\n') - return b -} - -func appendArg(b []byte, val interface{}) ([]byte, error) { - switch v := val.(type) { - case nil: - b = appendString(b, "") - case string: - b = appendString(b, v) - case []byte: - b = appendBytes(b, v) - case int: - b = appendString(b, formatInt(int64(v))) - case int8: - b = appendString(b, formatInt(int64(v))) - case int16: - b = appendString(b, formatInt(int64(v))) - case int32: - b = appendString(b, formatInt(int64(v))) - case int64: - b = appendString(b, formatInt(v)) - case uint: - b = appendString(b, formatUint(uint64(v))) - case uint8: - b = appendString(b, formatUint(uint64(v))) - case uint16: - b = appendString(b, formatUint(uint64(v))) - case uint32: - b = appendString(b, formatUint(uint64(v))) - case uint64: - b = appendString(b, formatUint(v)) - case float32: - b = appendString(b, formatFloat(float64(v))) - case float64: - b = appendString(b, formatFloat(v)) - case bool: - if v { - b = appendString(b, "1") - } else { - b = appendString(b, "0") - } - default: - if bm, ok := val.(binaryMarshaler); ok { - bb, err := bm.MarshalBinary() - if err != nil { - return nil, err - } - b = appendBytes(b, bb) - } else { - err := fmt.Errorf( - "redis: can't marshal %T (consider implementing BinaryMarshaler)", val) - return nil, err - } - } - return b, nil -} - -func appendArgs(b []byte, args []interface{}) ([]byte, error) { - b = append(b, arrayReply) - b = strconv.AppendUint(b, uint64(len(args)), 10) - b = append(b, '\r', '\n') - for _, arg := range args { - var err error - b, err = appendArg(b, arg) - if err != nil { - return nil, err - } - } - return b, nil -} - -func scan(b []byte, val interface{}) error { - switch v := val.(type) { - case nil: - return errorf("redis: Scan(nil)") - case *string: - *v = bytesToString(b) - return nil - case *[]byte: - *v = b - return nil - case *int: - var err error - *v, err = strconv.Atoi(bytesToString(b)) - return err - case *int8: - n, err := strconv.ParseInt(bytesToString(b), 10, 8) - if err != nil { - return err - } - *v = int8(n) - return nil - case *int16: - n, err := strconv.ParseInt(bytesToString(b), 10, 16) - if err != nil { - return err - } - *v = int16(n) - return nil - case *int32: - n, err := strconv.ParseInt(bytesToString(b), 10, 16) - if err != nil { - return err - } - *v = int32(n) - return nil - case *int64: - n, err := strconv.ParseInt(bytesToString(b), 10, 64) - if err != nil { - return err - } - *v = n - return nil - case *uint: - n, err := strconv.ParseUint(bytesToString(b), 10, 64) - if err != nil { - return err - } - *v = uint(n) - return nil - case *uint8: - n, err := strconv.ParseUint(bytesToString(b), 10, 8) - if err != nil { - return err - } - *v = uint8(n) - return nil - case *uint16: - n, err := strconv.ParseUint(bytesToString(b), 10, 16) - if err != nil { - return err - } - *v = uint16(n) - return nil - case *uint32: - n, err := strconv.ParseUint(bytesToString(b), 10, 32) - if err != nil { - return err - } - *v = uint32(n) - return nil - case *uint64: - n, err := strconv.ParseUint(bytesToString(b), 10, 64) - if err != nil { - return err - } - *v = n - return nil - case *float32: - n, err := strconv.ParseFloat(bytesToString(b), 32) - if err != nil { - return err - } - *v = float32(n) - return err - case *float64: - var err error - *v, err = strconv.ParseFloat(bytesToString(b), 64) - return err - case *bool: - *v = len(b) == 1 && b[0] == '1' - return nil - default: - if bu, ok := val.(binaryUnmarshaler); ok { - return bu.UnmarshalBinary(b) - } - err := fmt.Errorf( - "redis: can't unmarshal %T (consider implementing BinaryUnmarshaler)", val) - return err - } -} - -//------------------------------------------------------------------------------ - -func readLine(cn *pool.Conn) ([]byte, error) { - line, isPrefix, err := cn.Rd.ReadLine() - if err != nil { - return nil, err - } - if isPrefix { - return nil, bufio.ErrBufferFull - } - if len(line) == 0 { - return nil, errEmptyReply - } - if isNilReply(line) { - return nil, Nil - } - return line, nil -} - -func isNilReply(b []byte) bool { - return len(b) == 3 && - (b[0] == stringReply || b[0] == arrayReply) && - b[1] == '-' && b[2] == '1' -} - -//------------------------------------------------------------------------------ - -func parseErrorReply(cn *pool.Conn, line []byte) error { - return errorf(string(line[1:])) -} - -func parseStatusReply(cn *pool.Conn, line []byte) ([]byte, error) { - return line[1:], nil -} - -func parseIntReply(cn *pool.Conn, line []byte) (int64, error) { - n, err := strconv.ParseInt(bytesToString(line[1:]), 10, 64) - if err != nil { - return 0, err - } - return n, nil -} - -func readIntReply(cn *pool.Conn) (int64, error) { - line, err := readLine(cn) - if err != nil { - return 0, err - } - switch line[0] { - case errorReply: - return 0, parseErrorReply(cn, line) - case intReply: - return parseIntReply(cn, line) - default: - return 0, fmt.Errorf("redis: can't parse int reply: %.100q", line) - } -} - -func parseBytesReply(cn *pool.Conn, line []byte) ([]byte, error) { - if isNilReply(line) { - return nil, Nil - } - - replyLen, err := strconv.Atoi(bytesToString(line[1:])) - if err != nil { - return nil, err - } - - b, err := cn.ReadN(replyLen + 2) - if err != nil { - return nil, err - } - - return b[:replyLen], nil -} - -func readBytesReply(cn *pool.Conn) ([]byte, error) { - line, err := readLine(cn) - if err != nil { - return nil, err - } - switch line[0] { - case errorReply: - return nil, parseErrorReply(cn, line) - case stringReply: - return parseBytesReply(cn, line) - case statusReply: - return parseStatusReply(cn, line) - default: - return nil, fmt.Errorf("redis: can't parse string reply: %.100q", line) - } -} - -func readStringReply(cn *pool.Conn) (string, error) { - b, err := readBytesReply(cn) - if err != nil { - return "", err - } - return string(b), nil -} - -func readFloatReply(cn *pool.Conn) (float64, error) { - b, err := readBytesReply(cn) - if err != nil { - return 0, err - } - return strconv.ParseFloat(bytesToString(b), 64) -} - -func parseArrayHeader(cn *pool.Conn, line []byte) (int64, error) { - if isNilReply(line) { - return 0, Nil - } - - n, err := strconv.ParseInt(bytesToString(line[1:]), 10, 64) - if err != nil { - return 0, err - } - return n, nil -} - -func parseArrayReply(cn *pool.Conn, p multiBulkParser, line []byte) (interface{}, error) { - n, err := parseArrayHeader(cn, line) - if err != nil { - return nil, err - } - return p(cn, n) -} - -func readArrayHeader(cn *pool.Conn) (int64, error) { - line, err := readLine(cn) - if err != nil { - return 0, err - } - switch line[0] { - case errorReply: - return 0, parseErrorReply(cn, line) - case arrayReply: - return parseArrayHeader(cn, line) - default: - return 0, fmt.Errorf("redis: can't parse array reply: %.100q", line) - } -} - -func readArrayReply(cn *pool.Conn, p multiBulkParser) (interface{}, error) { - line, err := readLine(cn) - if err != nil { - return nil, err - } - switch line[0] { - case errorReply: - return nil, parseErrorReply(cn, line) - case arrayReply: - return parseArrayReply(cn, p, line) - default: - return nil, fmt.Errorf("redis: can't parse array reply: %.100q", line) - } -} - -func readReply(cn *pool.Conn, p multiBulkParser) (interface{}, error) { - line, err := readLine(cn) - if err != nil { - return nil, err - } - - switch line[0] { - case errorReply: - return nil, parseErrorReply(cn, line) - case statusReply: - return parseStatusReply(cn, line) - case intReply: - return parseIntReply(cn, line) - case stringReply: - return parseBytesReply(cn, line) - case arrayReply: - return parseArrayReply(cn, p, line) - } - return nil, fmt.Errorf("redis: can't parse %.100q", line) -} - -func readScanReply(cn *pool.Conn) ([]string, uint64, error) { - n, err := readArrayHeader(cn) - if err != nil { - return nil, 0, err - } - if n != 2 { - return nil, 0, fmt.Errorf("redis: got %d elements in scan reply, expected 2", n) - } - - b, err := readBytesReply(cn) - if err != nil { - return nil, 0, err - } - - cursor, err := strconv.ParseUint(bytesToString(b), 10, 64) - if err != nil { - return nil, 0, err - } - - n, err = readArrayHeader(cn) - if err != nil { - return nil, 0, err - } - - keys := make([]string, n) - for i := int64(0); i < n; i++ { - key, err := readStringReply(cn) - if err != nil { - return nil, 0, err - } - keys[i] = key - } - - return keys, cursor, err -} - -func sliceParser(cn *pool.Conn, n int64) (interface{}, error) { +// Implements proto.MultiBulkParse +func sliceParser(rd *proto.Reader, n int64) (interface{}, error) { vals := make([]interface{}, 0, n) for i := int64(0); i < n; i++ { - v, err := readReply(cn, sliceParser) + v, err := rd.ReadReply(sliceParser) if err == Nil { vals = append(vals, nil) } else if err != nil { @@ -457,10 +29,11 @@ func sliceParser(cn *pool.Conn, n int64) (interface{}, error) { return vals, nil } -func intSliceParser(cn *pool.Conn, n int64) (interface{}, error) { +// Implements proto.MultiBulkParse +func intSliceParser(rd *proto.Reader, n int64) (interface{}, error) { ints := make([]int64, 0, n) for i := int64(0); i < n; i++ { - n, err := readIntReply(cn) + n, err := rd.ReadIntReply() if err != nil { return nil, err } @@ -469,10 +42,11 @@ func intSliceParser(cn *pool.Conn, n int64) (interface{}, error) { return ints, nil } -func boolSliceParser(cn *pool.Conn, n int64) (interface{}, error) { +// Implements proto.MultiBulkParse +func boolSliceParser(rd *proto.Reader, n int64) (interface{}, error) { bools := make([]bool, 0, n) for i := int64(0); i < n; i++ { - n, err := readIntReply(cn) + n, err := rd.ReadIntReply() if err != nil { return nil, err } @@ -481,10 +55,11 @@ func boolSliceParser(cn *pool.Conn, n int64) (interface{}, error) { return bools, nil } -func stringSliceParser(cn *pool.Conn, n int64) (interface{}, error) { +// Implements proto.MultiBulkParse +func stringSliceParser(rd *proto.Reader, n int64) (interface{}, error) { ss := make([]string, 0, n) for i := int64(0); i < n; i++ { - s, err := readStringReply(cn) + s, err := rd.ReadStringReply() if err == Nil { ss = append(ss, "") } else if err != nil { @@ -496,10 +71,11 @@ func stringSliceParser(cn *pool.Conn, n int64) (interface{}, error) { return ss, nil } -func floatSliceParser(cn *pool.Conn, n int64) (interface{}, error) { +// Implements proto.MultiBulkParse +func floatSliceParser(rd *proto.Reader, n int64) (interface{}, error) { nn := make([]float64, 0, n) for i := int64(0); i < n; i++ { - n, err := readFloatReply(cn) + n, err := rd.ReadFloatReply() if err != nil { return nil, err } @@ -508,15 +84,16 @@ func floatSliceParser(cn *pool.Conn, n int64) (interface{}, error) { return nn, nil } -func stringStringMapParser(cn *pool.Conn, n int64) (interface{}, error) { +// Implements proto.MultiBulkParse +func stringStringMapParser(rd *proto.Reader, n int64) (interface{}, error) { m := make(map[string]string, n/2) for i := int64(0); i < n; i += 2 { - key, err := readStringReply(cn) + key, err := rd.ReadStringReply() if err != nil { return nil, err } - value, err := readStringReply(cn) + value, err := rd.ReadStringReply() if err != nil { return nil, err } @@ -526,15 +103,16 @@ func stringStringMapParser(cn *pool.Conn, n int64) (interface{}, error) { return m, nil } -func stringIntMapParser(cn *pool.Conn, n int64) (interface{}, error) { +// Implements proto.MultiBulkParse +func stringIntMapParser(rd *proto.Reader, n int64) (interface{}, error) { m := make(map[string]int64, n/2) for i := int64(0); i < n; i += 2 { - key, err := readStringReply(cn) + key, err := rd.ReadStringReply() if err != nil { return nil, err } - n, err := readIntReply(cn) + n, err := rd.ReadIntReply() if err != nil { return nil, err } @@ -544,19 +122,20 @@ func stringIntMapParser(cn *pool.Conn, n int64) (interface{}, error) { return m, nil } -func zSliceParser(cn *pool.Conn, n int64) (interface{}, error) { +// Implements proto.MultiBulkParse +func zSliceParser(rd *proto.Reader, n int64) (interface{}, error) { zz := make([]Z, n/2) for i := int64(0); i < n; i += 2 { var err error z := &zz[i/2] - z.Member, err = readStringReply(cn) + z.Member, err = rd.ReadStringReply() if err != nil { return nil, err } - z.Score, err = readFloatReply(cn) + z.Score, err = rd.ReadFloatReply() if err != nil { return nil, err } @@ -564,10 +143,11 @@ func zSliceParser(cn *pool.Conn, n int64) (interface{}, error) { return zz, nil } -func clusterSlotsParser(cn *pool.Conn, slotNum int64) (interface{}, error) { - slots := make([]ClusterSlot, slotNum) - for slotInd := 0; slotInd < len(slots); slotInd++ { - n, err := readArrayHeader(cn) +// Implements proto.MultiBulkParse +func clusterSlotsParser(rd *proto.Reader, n int64) (interface{}, error) { + slots := make([]ClusterSlot, n) + for i := 0; i < len(slots); i++ { + n, err := rd.ReadArrayLen() if err != nil { return nil, err } @@ -576,19 +156,19 @@ func clusterSlotsParser(cn *pool.Conn, slotNum int64) (interface{}, error) { return nil, err } - start, err := readIntReply(cn) + start, err := rd.ReadIntReply() if err != nil { return nil, err } - end, err := readIntReply(cn) + end, err := rd.ReadIntReply() if err != nil { return nil, err } nodes := make([]ClusterNode, n-2) - for nodeInd := 0; nodeInd < len(nodes); nodeInd++ { - n, err := readArrayHeader(cn) + for j := 0; j < len(nodes); j++ { + n, err := rd.ReadArrayLen() if err != nil { return nil, err } @@ -597,27 +177,27 @@ func clusterSlotsParser(cn *pool.Conn, slotNum int64) (interface{}, error) { return nil, err } - ip, err := readStringReply(cn) + ip, err := rd.ReadStringReply() if err != nil { return nil, err } - port, err := readIntReply(cn) + port, err := rd.ReadIntReply() if err != nil { return nil, err } - nodes[nodeInd].Addr = net.JoinHostPort(ip, strconv.FormatInt(port, 10)) + nodes[j].Addr = net.JoinHostPort(ip, strconv.FormatInt(port, 10)) if n == 3 { - id, err := readStringReply(cn) + id, err := rd.ReadStringReply() if err != nil { return nil, err } - nodes[nodeInd].Id = id + nodes[j].Id = id } } - slots[slotInd] = ClusterSlot{ + slots[i] = ClusterSlot{ Start: int(start), End: int(end), Nodes: nodes, @@ -626,29 +206,29 @@ func clusterSlotsParser(cn *pool.Conn, slotNum int64) (interface{}, error) { return slots, nil } -func newGeoLocationParser(q *GeoRadiusQuery) multiBulkParser { - return func(cn *pool.Conn, n int64) (interface{}, error) { +func newGeoLocationParser(q *GeoRadiusQuery) proto.MultiBulkParse { + return func(rd *proto.Reader, n int64) (interface{}, error) { var loc GeoLocation var err error - loc.Name, err = readStringReply(cn) + loc.Name, err = rd.ReadStringReply() if err != nil { return nil, err } if q.WithDist { - loc.Dist, err = readFloatReply(cn) + loc.Dist, err = rd.ReadFloatReply() if err != nil { return nil, err } } if q.WithGeoHash { - loc.GeoHash, err = readIntReply(cn) + loc.GeoHash, err = rd.ReadIntReply() if err != nil { return nil, err } } if q.WithCoord { - n, err := readArrayHeader(cn) + n, err := rd.ReadArrayLen() if err != nil { return nil, err } @@ -656,11 +236,11 @@ func newGeoLocationParser(q *GeoRadiusQuery) multiBulkParser { return nil, fmt.Errorf("got %d coordinates, expected 2", n) } - loc.Longitude, err = readFloatReply(cn) + loc.Longitude, err = rd.ReadFloatReply() if err != nil { return nil, err } - loc.Latitude, err = readFloatReply(cn) + loc.Latitude, err = rd.ReadFloatReply() if err != nil { return nil, err } @@ -670,11 +250,11 @@ func newGeoLocationParser(q *GeoRadiusQuery) multiBulkParser { } } -func newGeoLocationSliceParser(q *GeoRadiusQuery) multiBulkParser { - return func(cn *pool.Conn, n int64) (interface{}, error) { +func newGeoLocationSliceParser(q *GeoRadiusQuery) proto.MultiBulkParse { + return func(rd *proto.Reader, n int64) (interface{}, error) { locs := make([]GeoLocation, 0, n) for i := int64(0); i < n; i++ { - v, err := readReply(cn, newGeoLocationParser(q)) + v, err := rd.ReadReply(newGeoLocationParser(q)) if err != nil { return nil, err } @@ -693,7 +273,7 @@ func newGeoLocationSliceParser(q *GeoRadiusQuery) multiBulkParser { } } -func commandInfoParser(cn *pool.Conn, n int64) (interface{}, error) { +func commandInfoParser(rd *proto.Reader, n int64) (interface{}, error) { var cmd CommandInfo var err error @@ -701,36 +281,36 @@ func commandInfoParser(cn *pool.Conn, n int64) (interface{}, error) { return nil, fmt.Errorf("redis: got %d elements in COMMAND reply, wanted 6") } - cmd.Name, err = readStringReply(cn) + cmd.Name, err = rd.ReadStringReply() if err != nil { return nil, err } - arity, err := readIntReply(cn) + arity, err := rd.ReadIntReply() if err != nil { return nil, err } cmd.Arity = int8(arity) - flags, err := readReply(cn, stringSliceParser) + flags, err := rd.ReadReply(stringSliceParser) if err != nil { return nil, err } cmd.Flags = flags.([]string) - firstKeyPos, err := readIntReply(cn) + firstKeyPos, err := rd.ReadIntReply() if err != nil { return nil, err } cmd.FirstKeyPos = int8(firstKeyPos) - lastKeyPos, err := readIntReply(cn) + lastKeyPos, err := rd.ReadIntReply() if err != nil { return nil, err } cmd.LastKeyPos = int8(lastKeyPos) - stepCount, err := readIntReply(cn) + stepCount, err := rd.ReadIntReply() if err != nil { return nil, err } @@ -746,10 +326,10 @@ func commandInfoParser(cn *pool.Conn, n int64) (interface{}, error) { return &cmd, nil } -func commandInfoSliceParser(cn *pool.Conn, n int64) (interface{}, error) { +func commandInfoSliceParser(rd *proto.Reader, n int64) (interface{}, error) { m := make(map[string]*CommandInfo, n) for i := int64(0); i < n; i++ { - v, err := readReply(cn, commandInfoParser) + v, err := rd.ReadReply(commandInfoParser) if err != nil { return nil, err } diff --git a/parser_test.go b/parser_test.go deleted file mode 100644 index d8fa360..0000000 --- a/parser_test.go +++ /dev/null @@ -1,57 +0,0 @@ -package redis - -import ( - "bufio" - "bytes" - "testing" - - "gopkg.in/redis.v4/internal/pool" -) - -func BenchmarkParseReplyStatus(b *testing.B) { - benchmarkParseReply(b, "+OK\r\n", nil, false) -} - -func BenchmarkParseReplyInt(b *testing.B) { - benchmarkParseReply(b, ":1\r\n", nil, false) -} - -func BenchmarkParseReplyError(b *testing.B) { - benchmarkParseReply(b, "-Error message\r\n", nil, true) -} - -func BenchmarkParseReplyString(b *testing.B) { - benchmarkParseReply(b, "$5\r\nhello\r\n", nil, false) -} - -func BenchmarkParseReplySlice(b *testing.B) { - benchmarkParseReply(b, "*2\r\n$5\r\nhello\r\n$5\r\nworld\r\n", sliceParser, false) -} - -func benchmarkParseReply(b *testing.B, reply string, p multiBulkParser, wanterr bool) { - buf := &bytes.Buffer{} - for i := 0; i < b.N; i++ { - buf.WriteString(reply) - } - cn := &pool.Conn{ - Rd: bufio.NewReader(buf), - Buf: make([]byte, 4096), - } - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - _, err := readReply(cn, p) - if !wanterr && err != nil { - b.Fatal(err) - } - } -} - -func BenchmarkAppendArgs(b *testing.B) { - buf := make([]byte, 0, 64) - args := []interface{}{"hello", "world", "foo", "bar"} - for i := 0; i < b.N; i++ { - appendArgs(buf, args) - } -} diff --git a/pipeline.go b/pipeline.go index fed6e0c..bf80d40 100644 --- a/pipeline.go +++ b/pipeline.go @@ -4,6 +4,7 @@ import ( "sync" "sync/atomic" + "gopkg.in/redis.v4/internal/errors" "gopkg.in/redis.v4/internal/pool" ) @@ -99,7 +100,7 @@ func execCmds(cn *pool.Conn, cmds []Cmder) ([]Cmder, error) { if firstCmdErr == nil { firstCmdErr = err } - if shouldRetry(err) { + if errors.IsRetryable(err) { failedCmds = append(failedCmds, cmd) } } diff --git a/pipeline_test.go b/pipeline_test.go index 939c8b8..c3a37f6 100644 --- a/pipeline_test.go +++ b/pipeline_test.go @@ -156,6 +156,8 @@ var _ = Describe("Pipelining", func() { wg.Add(N) for i := 0; i < N; i++ { go func() { + defer GinkgoRecover() + pipeline.Ping() wg.Done() }() diff --git a/pubsub.go b/pubsub.go index 5bbda59..6c72b8f 100644 --- a/pubsub.go +++ b/pubsub.go @@ -6,6 +6,7 @@ import ( "time" "gopkg.in/redis.v4/internal" + "gopkg.in/redis.v4/internal/errors" "gopkg.in/redis.v4/internal/pool" ) @@ -248,7 +249,7 @@ func (c *PubSub) receiveMessage(timeout time.Duration) (*Message, error) { for { msgi, err := c.ReceiveTimeout(timeout) if err != nil { - if !isNetworkError(err) { + if !errors.IsNetwork(err) { return nil, err } diff --git a/redis.go b/redis.go index 5c93db0..b67d422 100644 --- a/redis.go +++ b/redis.go @@ -5,9 +5,13 @@ import ( "log" "gopkg.in/redis.v4/internal" + "gopkg.in/redis.v4/internal/errors" "gopkg.in/redis.v4/internal/pool" ) +// Redis nil reply, .e.g. when key does not exist. +const Nil = errors.Nil + func SetLogger(logger *log.Logger) { internal.Logger = logger } @@ -38,7 +42,7 @@ func (c *baseClient) conn() (*pool.Conn, error) { } func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool { - if isBadConn(err, allowTimeout) { + if errors.IsBadConn(err, allowTimeout) { _ = c.connPool.Remove(cn, err) return false } @@ -97,7 +101,7 @@ func (c *baseClient) Process(cmd Cmder) error { if err := writeCmd(cn, cmd); err != nil { c.putConn(cn, err, false) cmd.setErr(err) - if err != nil && shouldRetry(err) { + if err != nil && errors.IsRetryable(err) { continue } return err @@ -105,7 +109,7 @@ func (c *baseClient) Process(cmd Cmder) error { err = cmd.readReply(cn) c.putConn(cn, err, readTimeout != nil) - if err != nil && shouldRetry(err) { + if err != nil && errors.IsRetryable(err) { continue } diff --git a/safe.go b/safe.go deleted file mode 100644 index d66dc56..0000000 --- a/safe.go +++ /dev/null @@ -1,7 +0,0 @@ -// +build appengine - -package redis - -func bytesToString(b []byte) string { - return string(b) -} diff --git a/tx.go b/tx.go index 5ed223c..ca425eb 100644 --- a/tx.go +++ b/tx.go @@ -5,9 +5,14 @@ import ( "fmt" "gopkg.in/redis.v4/internal" + ierrors "gopkg.in/redis.v4/internal/errors" "gopkg.in/redis.v4/internal/pool" + "gopkg.in/redis.v4/internal/proto" ) +// Redis transaction failed. +const TxFailedErr = ierrors.RedisError("redis: transaction failed") + var errDiscard = errors.New("redis: Discard can be used only inside Exec") // Tx implements Redis transactions as described in @@ -166,7 +171,7 @@ func (c *Tx) execCmds(cn *pool.Conn, cmds []Cmder) error { } // Parse number of replies. - line, err := readLine(cn) + line, err := cn.Rd.ReadLine() if err != nil { if err == Nil { err = TxFailedErr @@ -174,7 +179,7 @@ func (c *Tx) execCmds(cn *pool.Conn, cmds []Cmder) error { setCmdsErr(cmds[1:len(cmds)-1], err) return err } - if line[0] != '*' { + if line[0] != proto.ArrayReply { err := fmt.Errorf("redis: expected '*', but got line %q", line) setCmdsErr(cmds[1:len(cmds)-1], err) return err diff --git a/unsafe.go b/unsafe.go deleted file mode 100644 index 3cd8d1c..0000000 --- a/unsafe.go +++ /dev/null @@ -1,14 +0,0 @@ -// +build !appengine - -package redis - -import ( - "reflect" - "unsafe" -) - -func bytesToString(b []byte) string { - bytesHeader := (*reflect.SliceHeader)(unsafe.Pointer(&b)) - strHeader := reflect.StringHeader{bytesHeader.Data, bytesHeader.Len} - return *(*string)(unsafe.Pointer(&strHeader)) -}