diff --git a/command.go b/command.go index 49218572..f20d5ea0 100644 --- a/command.go +++ b/command.go @@ -1,7 +1,6 @@ package redis import ( - "bytes" "fmt" "net" "strconv" @@ -10,7 +9,6 @@ import ( "github.com/go-redis/redis/internal" "github.com/go-redis/redis/internal/proto" - "github.com/go-redis/redis/internal/util" ) type Cmder interface { @@ -237,14 +235,7 @@ func (cmd *Cmd) Bool() (bool, error) { func (cmd *Cmd) readReply(rd proto.Reader) error { cmd.val, cmd.err = rd.ReadReply(sliceParser) - if cmd.err != nil { - return cmd.err - } - if b, ok := cmd.val.([]byte); ok { - // Bytes must be copied, because underlying memory is reused. - cmd.val = string(b) - } - return nil + return cmd.err } // Implements proto.MultiBulkParse @@ -265,8 +256,8 @@ func sliceParser(rd proto.Reader, n int64) (interface{}, error) { } switch v := v.(type) { - case []byte: - vals = append(vals, string(v)) + case string: + vals = append(vals, v) default: vals = append(vals, v) } @@ -341,7 +332,7 @@ func (cmd *StatusCmd) String() string { } func (cmd *StatusCmd) readReply(rd proto.Reader) error { - cmd.val, cmd.err = rd.ReadStringReply() + cmd.val, cmd.err = rd.ReadString() return cmd.err } @@ -503,8 +494,6 @@ func (cmd *BoolCmd) String() string { return cmdString(cmd, cmd.val) } -var ok = []byte("OK") - func (cmd *BoolCmd) readReply(rd proto.Reader) error { var v interface{} v, cmd.err = rd.ReadReply(nil) @@ -523,8 +512,8 @@ func (cmd *BoolCmd) readReply(rd proto.Reader) error { case int64: cmd.val = v == 1 return nil - case []byte: - cmd.val = bytes.Equal(v, ok) + case string: + cmd.val = v == "OK" return nil default: cmd.err = fmt.Errorf("got %T, wanted int64 or string", v) @@ -537,7 +526,7 @@ func (cmd *BoolCmd) readReply(rd proto.Reader) error { type StringCmd struct { baseCmd - val []byte + val string } var _ Cmder = (*StringCmd)(nil) @@ -549,7 +538,7 @@ func NewStringCmd(args ...interface{}) *StringCmd { } func (cmd *StringCmd) Val() string { - return util.BytesToString(cmd.val) + return cmd.val } func (cmd *StringCmd) Result() (string, error) { @@ -557,7 +546,7 @@ func (cmd *StringCmd) Result() (string, error) { } func (cmd *StringCmd) Bytes() ([]byte, error) { - return cmd.val, cmd.err + return []byte(cmd.val), cmd.err } func (cmd *StringCmd) Int64() (int64, error) { @@ -585,7 +574,7 @@ func (cmd *StringCmd) Scan(val interface{}) error { if cmd.err != nil { return cmd.err } - return proto.Scan(cmd.val, val) + return proto.Scan([]byte(cmd.val), val) } func (cmd *StringCmd) String() string { @@ -593,7 +582,7 @@ func (cmd *StringCmd) String() string { } func (cmd *StringCmd) readReply(rd proto.Reader) error { - cmd.val, cmd.err = rd.ReadBytesReply() + cmd.val, cmd.err = rd.ReadString() return cmd.err } @@ -676,7 +665,7 @@ func (cmd *StringSliceCmd) readReply(rd proto.Reader) error { func stringSliceParser(rd proto.Reader, n int64) (interface{}, error) { ss := make([]string, 0, n) for i := int64(0); i < n; i++ { - s, err := rd.ReadStringReply() + s, err := rd.ReadString() if err == Nil { ss = append(ss, "") } else if err != nil { @@ -781,12 +770,12 @@ func (cmd *StringStringMapCmd) readReply(rd proto.Reader) error { func stringStringMapParser(rd proto.Reader, n int64) (interface{}, error) { m := make(map[string]string, n/2) for i := int64(0); i < n; i += 2 { - key, err := rd.ReadStringReply() + key, err := rd.ReadString() if err != nil { return nil, err } - value, err := rd.ReadStringReply() + value, err := rd.ReadString() if err != nil { return nil, err } @@ -838,7 +827,7 @@ func (cmd *StringIntMapCmd) readReply(rd proto.Reader) error { func stringIntMapParser(rd proto.Reader, n int64) (interface{}, error) { m := make(map[string]int64, n/2) for i := int64(0); i < n; i += 2 { - key, err := rd.ReadStringReply() + key, err := rd.ReadString() if err != nil { return nil, err } @@ -895,7 +884,7 @@ func (cmd *StringStructMapCmd) readReply(rd proto.Reader) error { func stringStructMapParser(rd proto.Reader, n int64) (interface{}, error) { m := make(map[string]struct{}, n) for i := int64(0); i < n; i++ { - key, err := rd.ReadStringReply() + key, err := rd.ReadString() if err != nil { return nil, err } @@ -953,7 +942,7 @@ func xMessageSliceParser(rd proto.Reader, n int64) (interface{}, error) { msgs := make([]XMessage, 0, n) for i := int64(0); i < n; i++ { _, err := rd.ReadArrayReply(func(rd proto.Reader, n int64) (interface{}, error) { - id, err := rd.ReadStringReply() + id, err := rd.ReadString() if err != nil { return nil, err } @@ -980,12 +969,12 @@ func xMessageSliceParser(rd proto.Reader, n int64) (interface{}, error) { func stringInterfaceMapParser(rd proto.Reader, n int64) (interface{}, error) { m := make(map[string]interface{}, n/2) for i := int64(0); i < n; i += 2 { - key, err := rd.ReadStringReply() + key, err := rd.ReadString() if err != nil { return nil, err } - value, err := rd.ReadStringReply() + value, err := rd.ReadString() if err != nil { return nil, err } @@ -1047,7 +1036,7 @@ func xStreamSliceParser(rd proto.Reader, n int64) (interface{}, error) { return nil, fmt.Errorf("got %d, wanted 2", n) } - stream, err := rd.ReadStringReply() + stream, err := rd.ReadString() if err != nil { return nil, err } @@ -1124,12 +1113,12 @@ func xPendingParser(rd proto.Reader, n int64) (interface{}, error) { return nil, err } - lower, err := rd.ReadStringReply() + lower, err := rd.ReadString() if err != nil && err != Nil { return nil, err } - higher, err := rd.ReadStringReply() + higher, err := rd.ReadString() if err != nil && err != Nil { return nil, err } @@ -1146,17 +1135,12 @@ func xPendingParser(rd proto.Reader, n int64) (interface{}, error) { return nil, fmt.Errorf("got %d, wanted 2", n) } - consumerName, err := rd.ReadStringReply() + consumerName, err := rd.ReadString() if err != nil { return nil, err } - consumerPendingStr, err := rd.ReadStringReply() - if err != nil { - return nil, err - } - - consumerPending, err := strconv.ParseInt(consumerPendingStr, 10, 64) + consumerPending, err := rd.ReadInt() if err != nil { return nil, err } @@ -1233,12 +1217,12 @@ func xPendingExtSliceParser(rd proto.Reader, n int64) (interface{}, error) { return nil, fmt.Errorf("got %d, wanted 4", n) } - id, err := rd.ReadStringReply() + id, err := rd.ReadString() if err != nil { return nil, err } - consumer, err := rd.ReadStringReply() + consumer, err := rd.ReadString() if err != nil && err != Nil { return nil, err } @@ -1316,7 +1300,7 @@ func zSliceParser(rd proto.Reader, n int64) (interface{}, error) { z := &zz[i/2] - z.Member, err = rd.ReadStringReply() + z.Member, err = rd.ReadString() if err != nil { return nil, err } @@ -1456,19 +1440,20 @@ func clusterSlotsParser(rd proto.Reader, n int64) (interface{}, error) { return nil, err } - ip, err := rd.ReadStringReply() + ip, err := rd.ReadString() if err != nil { return nil, err } - port, err := rd.ReadIntReply() + port, err := rd.ReadString() if err != nil { return nil, err } - nodes[j].Addr = net.JoinHostPort(ip, strconv.FormatInt(port, 10)) + + nodes[j].Addr = net.JoinHostPort(ip, port) if n == 3 { - id, err := rd.ReadStringReply() + id, err := rd.ReadString() if err != nil { return nil, err } @@ -1581,7 +1566,7 @@ func newGeoLocationParser(q *GeoRadiusQuery) proto.MultiBulkParse { var loc GeoLocation var err error - loc.Name, err = rd.ReadStringReply() + loc.Name, err = rd.ReadString() if err != nil { return nil, err } @@ -1629,9 +1614,9 @@ func newGeoLocationSliceParser(q *GeoRadiusQuery) proto.MultiBulkParse { return nil, err } switch vv := v.(type) { - case []byte: + case string: locs = append(locs, GeoLocation{ - Name: string(vv), + Name: vv, }) case *GeoLocation: locs = append(locs, *vv) @@ -1794,7 +1779,7 @@ func commandInfoParser(rd proto.Reader, n int64) (interface{}, error) { return nil, fmt.Errorf("redis: got %d elements in COMMAND reply, wanted 6", n) } - cmd.Name, err = rd.ReadStringReply() + cmd.Name, err = rd.ReadString() if err != nil { return nil, err } diff --git a/internal/proto/elastic_reader.go b/internal/proto/elastic_reader.go index c89a7ebe..f075e86f 100644 --- a/internal/proto/elastic_reader.go +++ b/internal/proto/elastic_reader.go @@ -183,3 +183,23 @@ func (b *ElasticBufReader) grow(n int) { b.buf = append(b.buf, make([]byte, d)...) } } + +func (b *ElasticBufReader) Read(p []byte) (n int, err error) { + if len(p) == 0 { + return 0, b.readErr() + } + + if b.r != b.w { + // copy as much as we can + n = copy(p, b.buf[b.r:b.w]) + b.r += n + return n, nil + } + + if b.err != nil { + return 0, b.readErr() + } + + n, b.err = b.rd.Read(p) + return n, b.readErr() +} diff --git a/internal/proto/reader.go b/internal/proto/reader.go index 3e9c38a7..43abed50 100644 --- a/internal/proto/reader.go +++ b/internal/proto/reader.go @@ -54,10 +54,6 @@ func (r Reader) Bytes() []byte { return r.src.Bytes() } -func (r Reader) ReadN(n int) ([]byte, error) { - return r.src.ReadN(n) -} - func (r Reader) ReadLine() ([]byte, error) { line, err := r.src.ReadLine() if err != nil { @@ -82,11 +78,11 @@ func (r Reader) ReadReply(m MultiBulkParse) (interface{}, error) { case ErrorReply: return nil, ParseErrorReply(line) case StatusReply: - return parseTmpStatusReply(line), nil + return string(line[1:]), nil case IntReply: return util.ParseInt(line[1:], 10, 64) case StringReply: - return r.readTmpBytesReply(line) + return r.readStringReply(line) case ArrayReply: n, err := parseArrayLen(line) if err != nil { @@ -112,47 +108,42 @@ func (r Reader) ReadIntReply() (int64, error) { } } -func (r Reader) ReadTmpBytesReply() ([]byte, error) { +func (r Reader) ReadString() (string, error) { line, err := r.ReadLine() - if err != nil { - return nil, err - } - switch line[0] { - case ErrorReply: - return nil, ParseErrorReply(line) - case StringReply: - return r.readTmpBytesReply(line) - case StatusReply: - return parseTmpStatusReply(line), nil - default: - return nil, fmt.Errorf("redis: can't parse string reply: %.100q", line) - } -} - -func (r Reader) ReadBytesReply() ([]byte, error) { - b, err := r.ReadTmpBytesReply() - if err != nil { - return nil, err - } - cp := make([]byte, len(b)) - copy(cp, b) - return cp, nil -} - -func (r Reader) ReadStringReply() (string, error) { - b, err := r.ReadTmpBytesReply() if err != nil { return "", err } - return string(b), nil + switch line[0] { + case ErrorReply: + return "", ParseErrorReply(line) + case StringReply: + return r.readStringReply(line) + case StatusReply: + return string(line[1:]), nil + case IntReply: + return string(line[1:]), nil + default: + return "", fmt.Errorf("redis: can't parse reply=%.100q reading string", line) + } } -func (r Reader) ReadFloatReply() (float64, error) { - b, err := r.ReadTmpBytesReply() - if err != nil { - return 0, err +func (r Reader) readStringReply(line []byte) (string, error) { + if isNilReply(line) { + return "", Nil } - return util.ParseFloat(b, 64) + + replyLen, err := strconv.Atoi(string(line[1:])) + if err != nil { + return "", err + } + + b := make([]byte, replyLen+2) + _, err = io.ReadFull(r.src, b) + if err != nil { + return "", err + } + + return util.BytesToString(b[:replyLen]), nil } func (r Reader) ReadArrayReply(m MultiBulkParse) (interface{}, error) { @@ -210,7 +201,7 @@ func (r Reader) ReadScanReply() ([]string, uint64, error) { keys := make([]string, n) for i := int64(0); i < n; i++ { - key, err := r.ReadStringReply() + key, err := r.ReadString() if err != nil { return nil, 0, err } @@ -220,7 +211,48 @@ func (r Reader) ReadScanReply() ([]string, uint64, error) { return keys, cursor, err } -func (r Reader) readTmpBytesReply(line []byte) ([]byte, error) { +func (r Reader) ReadInt() (int64, error) { + b, err := r.readTmpBytesReply() + if err != nil { + return 0, err + } + return util.ParseInt(b, 10, 64) +} + +func (r Reader) ReadUint() (uint64, error) { + b, err := r.readTmpBytesReply() + if err != nil { + return 0, err + } + return util.ParseUint(b, 10, 64) +} + +func (r Reader) ReadFloatReply() (float64, error) { + b, err := r.readTmpBytesReply() + if err != nil { + return 0, err + } + return util.ParseFloat(b, 64) +} + +func (r Reader) readTmpBytesReply() ([]byte, error) { + line, err := r.ReadLine() + if err != nil { + return nil, err + } + switch line[0] { + case ErrorReply: + return nil, ParseErrorReply(line) + case StringReply: + return r._readTmpBytesReply(line) + case StatusReply: + return line[1:], nil + default: + return nil, fmt.Errorf("redis: can't parse string reply: %.100q", line) + } +} + +func (r Reader) _readTmpBytesReply(line []byte) ([]byte, error) { if isNilReply(line) { return nil, Nil } @@ -230,29 +262,13 @@ func (r Reader) readTmpBytesReply(line []byte) ([]byte, error) { return nil, err } - b, err := r.ReadN(replyLen + 2) + b, err := r.src.ReadN(replyLen + 2) if err != nil { return nil, err } return b[:replyLen], nil } -func (r Reader) ReadInt() (int64, error) { - b, err := r.ReadTmpBytesReply() - if err != nil { - return 0, err - } - return util.ParseInt(b, 10, 64) -} - -func (r Reader) ReadUint() (uint64, error) { - b, err := r.ReadTmpBytesReply() - if err != nil { - return 0, err - } - return util.ParseUint(b, 10, 64) -} - func isNilReply(b []byte) bool { return len(b) == 3 && (b[0] == StringReply || b[0] == ArrayReply) && @@ -263,10 +279,6 @@ func ParseErrorReply(line []byte) error { return RedisError(string(line[1:])) } -func parseTmpStatusReply(line []byte) []byte { - return line[1:] -} - func parseArrayLen(line []byte) (int64, error) { if isNilReply(line) { return 0, Nil diff --git a/internal/proto/reader_test.go b/internal/proto/reader_test.go index 20252489..687b20f5 100644 --- a/internal/proto/reader_test.go +++ b/internal/proto/reader_test.go @@ -6,42 +6,12 @@ import ( "testing" "github.com/go-redis/redis/internal/proto" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" ) func newReader(s string) proto.Reader { return proto.NewReader(proto.NewElasticBufReader(strings.NewReader(s))) } -var _ = Describe("Reader", func() { - - It("should read n bytes", func() { - data, err := newReader("ABCDEFGHIJKLMNO").ReadN(10) - Expect(err).NotTo(HaveOccurred()) - Expect(len(data)).To(Equal(10)) - Expect(string(data)).To(Equal("ABCDEFGHIJ")) - - data, err = newReader(strings.Repeat("x", 8192)).ReadN(6000) - Expect(err).NotTo(HaveOccurred()) - Expect(len(data)).To(Equal(6000)) - }) - - It("should read lines", func() { - r := newReader("$5\r\nhello\r\n") - - data, err := r.ReadLine() - Expect(err).NotTo(HaveOccurred()) - Expect(string(data)).To(Equal("$5")) - - data, err = r.ReadLine() - Expect(err).NotTo(HaveOccurred()) - Expect(string(data)).To(Equal("hello")) - }) - -}) - func BenchmarkReader_ParseReply_Status(b *testing.B) { benchmarkParseReply(b, "+OK\r\n", nil, false) } diff --git a/main_test.go b/main_test.go index e49d954b..e6a372b2 100644 --- a/main_test.go +++ b/main_test.go @@ -168,7 +168,7 @@ func perform(n int, cbs ...func(int)) { } func eventually(fn func() error, timeout time.Duration) error { - errCh := make(chan error) + errCh := make(chan error, 1) done := make(chan struct{}) exit := make(chan struct{}) @@ -202,7 +202,7 @@ func eventually(fn func() error, timeout time.Duration) error { case err := <-errCh: return err default: - return fmt.Errorf("timeout after %s", timeout) + return fmt.Errorf("timeout after %s without an error", timeout) } } } diff --git a/result.go b/result.go index e086e8e3..e438f260 100644 --- a/result.go +++ b/result.go @@ -53,7 +53,7 @@ func NewBoolResult(val bool, err error) *BoolCmd { // NewStringResult returns a StringCmd initialised with val and err for testing func NewStringResult(val string, err error) *StringCmd { var cmd StringCmd - cmd.val = []byte(val) + cmd.val = val cmd.setErr(err) return &cmd }