diff --git a/internal/proto/reader.go b/internal/proto/reader.go index 84f1488..bd76e3a 100644 --- a/internal/proto/reader.go +++ b/internal/proto/reader.go @@ -9,6 +9,8 @@ import ( "gopkg.in/redis.v5/internal" ) +const bytesAllocLimit = 1024 * 1024 // 1mb + const errEmptyReply = internal.RedisError("redis: reply is empty") type MultiBulkParse func(*Reader, int64) (interface{}, error) @@ -34,16 +36,7 @@ func (p *Reader) PeekBuffered() []byte { } func (p *Reader) ReadN(n int) ([]byte, error) { - // grow internal buffer, if necessary - if d := n - cap(p.buf); d > 0 { - p.buf = p.buf[:cap(p.buf)] - p.buf = append(p.buf, make([]byte, d)...) - } else { - p.buf = p.buf[:n] - } - - _, err := io.ReadFull(p.src, p.buf) - return p.buf, err + return readN(p.src, p.buf, n) } func (p *Reader) ReadLine() ([]byte, error) { @@ -225,6 +218,36 @@ func (p *Reader) readBytesValue(line []byte) ([]byte, error) { // -------------------------------------------------------------------- +func readN(r io.Reader, b []byte, n int) ([]byte, error) { + if n == 0 && b == nil { + return make([]byte, 0), nil + } + + if cap(b) >= n { + b = b[:n] + _, err := io.ReadFull(r, b) + return b, err + } + b = b[:cap(b)] + + pos := 0 + for pos < n { + diff := n - len(b) + if diff > bytesAllocLimit { + diff = bytesAllocLimit + } + b = append(b, make([]byte, diff)...) + + nn, err := io.ReadFull(r, b[pos:]) + if err != nil { + return nil, err + } + pos += nn + } + + return b, nil +} + func formatInt(n int64) string { return strconv.FormatInt(n, 10) } diff --git a/race_test.go b/race_test.go index d6503eb..1e44bf0 100644 --- a/race_test.go +++ b/race_test.go @@ -104,7 +104,9 @@ var _ = Describe("races", func() { }) It("should handle big vals in Get", func() { - bigVal := string(bytes.Repeat([]byte{'*'}, 1<<17)) // 128kb + C, N = 4, 100 + + bigVal := bytes.Repeat([]byte{'*'}, 1<<17) // 128kb err := client.Set("key", bigVal, 0).Err() Expect(err).NotTo(HaveOccurred()) @@ -115,7 +117,7 @@ var _ = Describe("races", func() { perform(C, func(id int) { for i := 0; i < N; i++ { - got, err := client.Get("key").Result() + got, err := client.Get("key").Bytes() Expect(err).NotTo(HaveOccurred()) Expect(got).To(Equal(bigVal)) } @@ -124,7 +126,8 @@ var _ = Describe("races", func() { It("should handle big vals in Set", func() { C, N = 4, 100 - bigVal := string(bytes.Repeat([]byte{'*'}, 1<<17)) // 128kb + + bigVal := bytes.Repeat([]byte{'*'}, 1<<17) // 128kb perform(C, func(id int) { for i := 0; i < N; i++ { diff --git a/redis_test.go b/redis_test.go index 8389e00..9bb4b68 100644 --- a/redis_test.go +++ b/redis_test.go @@ -185,7 +185,7 @@ var _ = Describe("Client", func() { }) It("should handle big vals", func() { - bigVal := string(bytes.Repeat([]byte{'*'}, 1<<17)) // 128kb + bigVal := bytes.Repeat([]byte{'*'}, 2e6) err := client.Set("key", bigVal, 0).Err() Expect(err).NotTo(HaveOccurred()) @@ -194,9 +194,8 @@ var _ = Describe("Client", func() { Expect(client.Close()).To(BeNil()) client = redis.NewClient(redisOptions()) - got, err := client.Get("key").Result() + got, err := client.Get("key").Bytes() Expect(err).NotTo(HaveOccurred()) - Expect(len(got)).To(Equal(len(bigVal))) Expect(got).To(Equal(bigVal)) })