ledisdb/cmd/vendor/github.com/siddontang/goredis/resp.go

457 lines
8.9 KiB
Go
Raw Normal View History

2015-03-23 14:50:37 +03:00
package goredis
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"strconv"
)
type Error string
func (err Error) Error() string { return string(err) }
var (
okReply interface{} = "OK"
pongReply interface{} = "PONG"
)
type RespReader struct {
br *bufio.Reader
}
func NewRespReader(br *bufio.Reader) *RespReader {
r := &RespReader{br}
return r
}
// Parse RESP
func (resp *RespReader) Parse() (interface{}, error) {
line, err := readLine(resp.br)
if err != nil {
return nil, err
}
if len(line) == 0 {
return nil, errors.New("short resp line")
}
switch line[0] {
case '+':
switch {
case len(line) == 3 && line[1] == 'O' && line[2] == 'K':
// Avoid allocation for frequent "+OK" response.
return okReply, nil
case len(line) == 5 && line[1] == 'P' && line[2] == 'O' && line[3] == 'N' && line[4] == 'G':
// Avoid allocation in PING command benchmarks :)
return pongReply, nil
default:
return string(line[1:]), nil
}
case '-':
return Error(string(line[1:])), nil
case ':':
n, err := parseInt(line[1:])
return n, err
case '$':
n, err := parseLen(line[1:])
if n < 0 || err != nil {
return nil, err
}
p := make([]byte, n)
_, err = io.ReadFull(resp.br, p)
if err != nil {
return nil, err
}
if line, err := readLine(resp.br); err != nil {
return nil, err
} else if len(line) != 0 {
return nil, errors.New("bad bulk string format")
}
return p, nil
case '*':
n, err := parseLen(line[1:])
if n < 0 || err != nil {
return nil, err
}
r := make([]interface{}, n)
for i := range r {
r[i], err = resp.Parse()
if err != nil {
return nil, err
}
}
return r, nil
}
return nil, errors.New("unexpected response line")
}
// Parse client -> server command request, must be array of bulk strings
func (resp *RespReader) ParseRequest() ([][]byte, error) {
line, err := readLine(resp.br)
if err != nil {
return nil, err
}
if len(line) == 0 {
return nil, errors.New("short resp line")
}
switch line[0] {
case '*':
n, err := parseLen(line[1:])
if n < 0 || err != nil {
return nil, err
}
r := make([][]byte, n)
for i := range r {
r[i], err = parseBulk(resp.br)
if err != nil {
return nil, err
}
}
return r, nil
default:
return nil, fmt.Errorf("not invalid array of bulk string type, but %c", line[0])
}
}
// Parse bulk string and write it with writer w
func (resp *RespReader) ParseBulkTo(w io.Writer) error {
line, err := readLine(resp.br)
if err != nil {
return err
}
if len(line) == 0 {
return errors.New("ledis: short response line")
}
switch line[0] {
case '-':
return Error(string(line[1:]))
case '$':
n, err := parseLen(line[1:])
if n < 0 || err != nil {
return err
}
var nn int64
if nn, err = io.CopyN(w, resp.br, int64(n)); err != nil {
return err
} else if nn != int64(n) {
return io.ErrShortWrite
}
if line, err := readLine(resp.br); err != nil {
return err
} else if len(line) != 0 {
return errors.New("bad bulk string format")
}
return nil
default:
return fmt.Errorf("not invalid bulk string type, but %c", line[0])
}
}
func readLine(br *bufio.Reader) ([]byte, error) {
p, err := br.ReadSlice('\n')
if err == bufio.ErrBufferFull {
return nil, errors.New("long resp line")
}
if err != nil {
return nil, err
}
i := len(p) - 2
if i < 0 || p[i] != '\r' {
return nil, errors.New("bad resp line terminator")
}
return p[:i], nil
}
// parseLen parses bulk string and array lengths.
func parseLen(p []byte) (int, error) {
if len(p) == 0 {
return -1, errors.New("malformed length")
}
if p[0] == '-' && len(p) == 2 && p[1] == '1' {
// handle $-1 and $-1 null replies.
return -1, nil
}
var n int
for _, b := range p {
n *= 10
if b < '0' || b > '9' {
return -1, errors.New("illegal bytes in length")
}
n += int(b - '0')
}
return n, nil
}
// parseInt parses an integer reply.
func parseInt(p []byte) (int64, error) {
if len(p) == 0 {
return 0, errors.New("malformed integer")
}
var negate bool
if p[0] == '-' {
negate = true
p = p[1:]
if len(p) == 0 {
return 0, errors.New("malformed integer")
}
}
var n int64
for _, b := range p {
n *= 10
if b < '0' || b > '9' {
return 0, errors.New("illegal bytes in length")
}
n += int64(b - '0')
}
if negate {
n = -n
}
return n, nil
}
func parseBulk(br *bufio.Reader) ([]byte, error) {
line, err := readLine(br)
if err != nil {
return nil, err
}
if len(line) == 0 {
return nil, errors.New("short resp line")
}
switch line[0] {
case '$':
n, err := parseLen(line[1:])
if n < 0 || err != nil {
return nil, err
}
p := make([]byte, n)
_, err = io.ReadFull(br, p)
if err != nil {
return nil, err
}
if line, err := readLine(br); err != nil {
return nil, err
} else if len(line) != 0 {
return nil, errors.New("bad bulk string format")
}
return p, nil
default:
return nil, fmt.Errorf("not invalid bulk string type, but %c", line[0])
}
}
var (
intBuffer [][]byte
respTerm = []byte("\r\n")
nullBulk = []byte("-1")
nullArray = []byte("-1")
)
func init() {
cnt := 10000
intBuffer = make([][]byte, cnt)
for i := 0; i < cnt; i++ {
intBuffer[i] = []byte(strconv.Itoa(i))
}
}
type RespWriter struct {
bw *bufio.Writer
// Scratch space for formatting integers and floats.
numScratch [40]byte
}
func NewRespWriter(bw *bufio.Writer) *RespWriter {
r := &RespWriter{bw: bw}
return r
}
func (resp *RespWriter) Flush() error {
return resp.bw.Flush()
}
func (resp *RespWriter) writeTerm() error {
_, err := resp.bw.Write(respTerm)
return err
}
func (resp *RespWriter) writeInteger(n int64) error {
var err error
if n < int64(len(intBuffer)) {
_, err = resp.bw.Write(intBuffer[n])
} else {
_, err = resp.bw.Write(strconv.AppendInt(nil, n, 10))
}
return err
}
func (resp *RespWriter) WriteInteger(n int64) error {
resp.bw.WriteByte(':')
resp.writeInteger(n)
return resp.writeTerm()
}
func (resp *RespWriter) FlushInteger(n int64) error {
resp.WriteInteger(n)
return resp.Flush()
}
func (resp *RespWriter) WriteString(s string) error {
resp.bw.WriteByte('+')
resp.bw.WriteString(s)
return resp.writeTerm()
}
func (resp *RespWriter) FlushString(s string) error {
resp.WriteString(s)
return resp.Flush()
}
func (resp *RespWriter) WriteError(e error) error {
resp.bw.WriteByte('-')
if e != nil {
resp.bw.WriteString(e.Error())
} else {
resp.bw.WriteString("error is nil, invalid")
}
return resp.writeTerm()
}
func (resp *RespWriter) FlushError(e error) error {
resp.WriteError(e)
return resp.Flush()
}
func (resp *RespWriter) WriteBulk(b []byte) error {
resp.bw.WriteByte('$')
if b == nil {
resp.bw.Write(nullBulk)
} else {
resp.writeInteger(int64(len(b)))
resp.writeTerm()
resp.bw.Write(b)
}
return resp.writeTerm()
}
func (resp *RespWriter) FlushBulk(b []byte) error {
resp.WriteBulk(b)
return resp.Flush()
}
func (resp *RespWriter) WriteArray(ay []interface{}) error {
resp.bw.WriteByte('*')
if ay == nil {
resp.bw.Write(nullArray)
return resp.writeTerm()
} else {
resp.writeInteger(int64(len(ay)))
resp.writeTerm()
var err error
for i := 0; i < len(ay); i++ {
if err != nil {
return err
}
switch v := ay[i].(type) {
case []interface{}:
err = resp.WriteArray(v)
case []byte:
err = resp.WriteBulk(v)
case nil:
err = resp.WriteBulk(nil)
case int64:
err = resp.WriteInteger(v)
case string:
err = resp.WriteString(v)
case error:
err = resp.WriteError(v)
default:
err = fmt.Errorf("invalid array type %T %v", ay[i], v)
}
}
return err
}
}
func (resp *RespWriter) FlushArray(ay []interface{}) error {
resp.WriteArray(ay)
return resp.Flush()
}
func (resp *RespWriter) writeBulkString(s string) error {
resp.bw.WriteByte('$')
resp.writeInteger(int64(len(s)))
resp.writeTerm()
resp.bw.WriteString(s)
return resp.writeTerm()
}
func (resp *RespWriter) writeBulkInt64(n int64) error {
return resp.WriteBulk(strconv.AppendInt(resp.numScratch[:0], n, 10))
}
func (resp *RespWriter) writeBulkFloat64(n float64) error {
return resp.WriteBulk(strconv.AppendFloat(resp.numScratch[:0], n, 'g', -1, 64))
}
// RESP command is array of bulk string
func (resp *RespWriter) WriteCommand(cmd string, args ...interface{}) error {
resp.bw.WriteByte('*')
resp.writeInteger(int64(1 + len(args)))
resp.writeTerm()
err := resp.writeBulkString(cmd)
for _, arg := range args {
if err != nil {
break
}
switch arg := arg.(type) {
case string:
err = resp.writeBulkString(arg)
case []byte:
err = resp.WriteBulk(arg)
case int:
err = resp.writeBulkInt64(int64(arg))
case int64:
err = resp.writeBulkInt64(arg)
case float64:
err = resp.writeBulkFloat64(arg)
case bool:
if arg {
err = resp.writeBulkString("1")
} else {
err = resp.writeBulkString("0")
}
case nil:
err = resp.writeBulkString("")
default:
var buf bytes.Buffer
fmt.Fprint(&buf, arg)
err = resp.WriteBulk(buf.Bytes())
}
}
if err != nil {
return err
}
return resp.Flush()
}