redis/internal/proto/reader.go

335 lines
6.3 KiB
Go
Raw Normal View History

2016-07-02 15:52:10 +03:00
package proto
import (
"bufio"
"fmt"
"io"
"strconv"
2017-02-18 17:42:34 +03:00
"github.com/go-redis/redis/internal"
2016-07-02 15:52:10 +03:00
)
2016-11-09 11:04:37 +03:00
const bytesAllocLimit = 1024 * 1024 // 1mb
2016-11-20 10:50:49 +03:00
const (
ErrorReply = '-'
StatusReply = '+'
IntReply = ':'
StringReply = '$'
ArrayReply = '*'
)
2016-07-02 15:52:10 +03:00
type MultiBulkParse func(*Reader, int64) (interface{}, error)
2016-07-02 15:52:10 +03:00
type Reader struct {
src *bufio.Reader
buf []byte
}
func NewReader(rd io.Reader) *Reader {
2016-07-02 15:52:10 +03:00
return &Reader{
src: bufio.NewReader(rd),
buf: make([]byte, 4096),
2016-07-02 15:52:10 +03:00
}
}
func (r *Reader) Reset(rd io.Reader) {
r.src.Reset(rd)
}
2016-07-02 15:52:10 +03:00
func (p *Reader) PeekBuffered() []byte {
if n := p.src.Buffered(); n != 0 {
b, _ := p.src.Peek(n)
return b
}
return nil
}
func (p *Reader) ReadN(n int) ([]byte, error) {
b, err := readN(p.src, p.buf, n)
if err != nil {
return nil, err
}
p.buf = b
return b, nil
2016-07-02 15:52:10 +03:00
}
func (p *Reader) ReadLine() ([]byte, error) {
line, isPrefix, err := p.src.ReadLine()
if err != nil {
return nil, err
}
if isPrefix {
return nil, bufio.ErrBufferFull
}
if len(line) == 0 {
2016-11-20 10:50:49 +03:00
return nil, internal.RedisError("redis: reply is empty")
2016-07-02 15:52:10 +03:00
}
if isNilReply(line) {
return nil, internal.Nil
2016-07-02 15:52:10 +03:00
}
return line, nil
}
func (p *Reader) ReadReply(m MultiBulkParse) (interface{}, error) {
line, err := p.ReadLine()
if err != nil {
return nil, err
}
switch line[0] {
case ErrorReply:
2016-12-13 18:28:39 +03:00
return nil, ParseErrorReply(line)
2016-07-02 15:52:10 +03:00
case StatusReply:
return parseStatusValue(line), nil
2016-07-02 15:52:10 +03:00
case IntReply:
2017-01-13 14:39:59 +03:00
return parseInt(line[1:], 10, 64)
2016-07-02 15:52:10 +03:00
case StringReply:
return p.readTmpBytesValue(line)
2016-07-02 15:52:10 +03:00
case ArrayReply:
n, err := parseArrayLen(line)
if err != nil {
return nil, err
}
return m(p, n)
}
return nil, fmt.Errorf("redis: can't parse %.100q", line)
}
func (p *Reader) ReadIntReply() (int64, error) {
line, err := p.ReadLine()
if err != nil {
return 0, err
}
switch line[0] {
case ErrorReply:
2016-12-13 18:28:39 +03:00
return 0, ParseErrorReply(line)
2016-07-02 15:52:10 +03:00
case IntReply:
2017-01-13 14:39:59 +03:00
return parseInt(line[1:], 10, 64)
2016-07-02 15:52:10 +03:00
default:
return 0, fmt.Errorf("redis: can't parse int reply: %.100q", line)
}
}
2017-01-13 14:39:59 +03:00
func (p *Reader) ReadTmpBytesReply() ([]byte, error) {
2016-07-02 15:52:10 +03:00
line, err := p.ReadLine()
if err != nil {
return nil, err
}
switch line[0] {
case ErrorReply:
2016-12-13 18:28:39 +03:00
return nil, ParseErrorReply(line)
2016-07-02 15:52:10 +03:00
case StringReply:
return p.readTmpBytesValue(line)
2016-07-02 15:52:10 +03:00
case StatusReply:
return parseStatusValue(line), nil
2016-07-02 15:52:10 +03:00
default:
return nil, fmt.Errorf("redis: can't parse string reply: %.100q", line)
}
}
2017-01-13 14:39:59 +03:00
func (r *Reader) ReadBytesReply() ([]byte, error) {
b, err := r.ReadTmpBytesReply()
if err != nil {
return nil, err
}
cp := make([]byte, len(b))
copy(cp, b)
return cp, nil
}
2016-07-02 15:52:10 +03:00
func (p *Reader) ReadStringReply() (string, error) {
2017-01-13 14:39:59 +03:00
b, err := p.ReadTmpBytesReply()
2016-07-02 15:52:10 +03:00
if err != nil {
return "", err
}
return string(b), nil
}
func (p *Reader) ReadFloatReply() (float64, error) {
2017-01-13 14:39:59 +03:00
b, err := p.ReadTmpBytesReply()
2016-07-02 15:52:10 +03:00
if err != nil {
return 0, err
}
2017-01-13 14:39:59 +03:00
return parseFloat(b, 64)
2016-07-02 15:52:10 +03:00
}
func (p *Reader) ReadArrayReply(m MultiBulkParse) (interface{}, error) {
line, err := p.ReadLine()
if err != nil {
return nil, err
}
switch line[0] {
case ErrorReply:
2016-12-13 18:28:39 +03:00
return nil, ParseErrorReply(line)
2016-07-02 15:52:10 +03:00
case ArrayReply:
n, err := parseArrayLen(line)
if err != nil {
return nil, err
}
return m(p, n)
default:
return nil, fmt.Errorf("redis: can't parse array reply: %.100q", line)
}
}
func (p *Reader) ReadArrayLen() (int64, error) {
line, err := p.ReadLine()
if err != nil {
return 0, err
}
switch line[0] {
case ErrorReply:
2016-12-13 18:28:39 +03:00
return 0, ParseErrorReply(line)
2016-07-02 15:52:10 +03:00
case ArrayReply:
return parseArrayLen(line)
default:
return 0, fmt.Errorf("redis: can't parse array reply: %.100q", line)
}
}
func (p *Reader) ReadScanReply() ([]string, uint64, error) {
n, err := p.ReadArrayLen()
if err != nil {
return nil, 0, err
}
if n != 2 {
return nil, 0, fmt.Errorf("redis: got %d elements in scan reply, expected 2", n)
}
2017-01-13 14:39:59 +03:00
cursor, err := p.ReadUint()
2016-07-02 15:52:10 +03:00
if err != nil {
return nil, 0, err
}
n, err = p.ReadArrayLen()
if err != nil {
return nil, 0, err
}
keys := make([]string, n)
for i := int64(0); i < n; i++ {
key, err := p.ReadStringReply()
if err != nil {
return nil, 0, err
}
keys[i] = key
}
return keys, cursor, err
}
func (p *Reader) readTmpBytesValue(line []byte) ([]byte, error) {
2016-07-02 15:52:10 +03:00
if isNilReply(line) {
return nil, internal.Nil
2016-07-02 15:52:10 +03:00
}
replyLen, err := strconv.Atoi(string(line[1:]))
if err != nil {
return nil, err
}
b, err := p.ReadN(replyLen + 2)
if err != nil {
return nil, err
}
return b[:replyLen], nil
}
2017-01-13 14:39:59 +03:00
func (r *Reader) ReadInt() (int64, error) {
b, err := r.ReadTmpBytesReply()
if err != nil {
return 0, err
}
return parseInt(b, 10, 64)
}
func (r *Reader) ReadUint() (uint64, error) {
b, err := r.ReadTmpBytesReply()
if err != nil {
return 0, err
}
return parseUint(b, 10, 64)
}
2016-07-02 15:52:10 +03:00
// --------------------------------------------------------------------
2016-11-09 11:04:37 +03:00
func readN(r io.Reader, b []byte, n int) ([]byte, error) {
if n == 0 && b == nil {
return make([]byte, 0), nil
}
if cap(b) >= n {
b = b[:n]
_, err := io.ReadFull(r, b)
return b, err
}
b = b[:cap(b)]
pos := 0
for pos < n {
diff := n - len(b)
if diff > bytesAllocLimit {
diff = bytesAllocLimit
}
b = append(b, make([]byte, diff)...)
nn, err := io.ReadFull(r, b[pos:])
if err != nil {
return nil, err
}
pos += nn
}
return b, nil
}
2016-07-02 15:52:10 +03:00
func formatInt(n int64) string {
return strconv.FormatInt(n, 10)
}
func formatUint(u uint64) string {
return strconv.FormatUint(u, 10)
}
func formatFloat(f float64) string {
return strconv.FormatFloat(f, 'f', -1, 64)
}
func isNilReply(b []byte) bool {
return len(b) == 3 &&
(b[0] == StringReply || b[0] == ArrayReply) &&
b[1] == '-' && b[2] == '1'
}
2016-12-13 18:28:39 +03:00
func ParseErrorReply(line []byte) error {
return internal.RedisError(string(line[1:]))
2016-07-02 15:52:10 +03:00
}
func parseStatusValue(line []byte) []byte {
return line[1:]
2016-07-02 15:52:10 +03:00
}
func parseArrayLen(line []byte) (int64, error) {
if isNilReply(line) {
return 0, internal.Nil
2016-07-02 15:52:10 +03:00
}
2017-01-13 14:39:59 +03:00
return parseInt(line[1:], 10, 64)
}
func atoi(b []byte) (int, error) {
return strconv.Atoi(internal.BytesToString(b))
}
func parseInt(b []byte, base int, bitSize int) (int64, error) {
return strconv.ParseInt(internal.BytesToString(b), base, bitSize)
}
func parseUint(b []byte, base int, bitSize int) (uint64, error) {
return strconv.ParseUint(internal.BytesToString(b), base, bitSize)
}
func parseFloat(b []byte, bitSize int) (float64, error) {
return strconv.ParseFloat(internal.BytesToString(b), bitSize)
2016-07-02 15:52:10 +03:00
}