From d8057d9c4c715b904441b2cec5ca26ca7d56127c Mon Sep 17 00:00:00 2001 From: Josh Baker Date: Sun, 25 Jun 2017 21:53:51 -0700 Subject: [PATCH] added byte appending code --- append.go | 306 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ redcon.go | 73 ++++++------- 2 files changed, 337 insertions(+), 42 deletions(-) create mode 100644 append.go diff --git a/append.go b/append.go new file mode 100644 index 0000000..5a27cf9 --- /dev/null +++ b/append.go @@ -0,0 +1,306 @@ +package redcon + +import ( + "strconv" + "strings" +) + +// Kind is the kind of command +type Kind int + +const ( + // Redis is returned for Redis protocol commands + Redis Kind = iota + // Tile38 is returnd for Tile38 native protocol commands + Tile38 + // Telnet is returnd for plain telnet commands + Telnet +) + +var errInvalidMessage = &errProtocol{"invalid message"} + +// ReadNextCommand reads the next command from the provided packet. It's +// possibel that the packet contains multiple commands, or zero commands +// when the packet is incomplete. +// 'args' is an optional reusable buffer and it can be nil. +// 'argsout' are the output arguments for the command. 'kind' is the type of +// command that was read. +// 'stop' indicates that there are no more possible commands in the packet. +// 'err' is returned when a protocol error was encountered. +// 'leftover' is any remaining unused bytes which belong to the next command. +func ReadNextCommand(packet []byte, args [][]byte) ( + leftover []byte, argsout [][]byte, kind Kind, stop bool, err error, +) { + if len(packet) > 0 { + if packet[0] != '*' { + if packet[0] == '$' { + return readTile38Command(packet, args) + } + return readTelnetCommand(packet, args) + } + for i := 1; i < len(packet); i++ { + if packet[i] == '\n' { + if packet[i-1] != '\r' { + return packet, args, Redis, true, errInvalidMultiBulkLength + } + count, ok := parseInt(packet[1 : i-1]) + if !ok || count <= 0 { + return packet, args, Redis, true, errInvalidMultiBulkLength + } + i++ + nextArg: + for j := 0; j < count; j++ { + if i == len(packet) { + break + } + if packet[i] != '$' { + return packet, args, Redis, true, + &errProtocol{"expected '$', got '" + + string(packet[i]) + "'"} + } + for s := i + 1; i < len(packet); i++ { + if packet[i] == '\n' { + if packet[i-1] != '\r' { + return packet, args, Redis, true, errInvalidBulkLength + } + n, ok := parseInt(packet[s : i-1]) + if !ok || count <= 0 { + return packet, args, Redis, true, errInvalidBulkLength + } + i++ + if len(packet)-i >= n+2 { + if packet[i+n] != '\r' || packet[i+n+1] != '\n' { + return packet, args, Redis, true, errInvalidBulkLength + } + args = append(args, packet[i:i+n]) + i += n + 2 + if j == count-1 { + return packet[i:], args, Redis, i == len(packet), nil + } + continue nextArg + } + break + } + } + break + } + break + } + } + } + return packet, args, Redis, true, nil +} + +func readTile38Command(b []byte, argsbuf [][]byte) ( + leftover []byte, args [][]byte, kind Kind, stop bool, err error, +) { + for i := 1; i < len(b); i++ { + if b[i] == ' ' { + n, ok := parseInt(b[1:i]) + if !ok || n < 0 { + return b, args, Tile38, true, errInvalidMessage + } + i++ + if len(b) >= i+n+2 { + if b[i+n] != '\r' || b[i+n+1] != '\n' { + return b, args, Tile38, true, errInvalidMessage + } + line := b[i : i+n] + reading: + for len(line) != 0 { + if line[0] == '{' { + // The native protocol cannot understand json boundaries so it assumes that + // a json element must be at the end of the line. + args = append(args, line) + break + } + if line[0] == '"' && line[len(line)-1] == '"' { + if len(args) > 0 && + strings.ToLower(string(args[0])) == "set" && + strings.ToLower(string(args[len(args)-1])) == "string" { + // Setting a string value that is contained inside double quotes. + // This is only because of the boundary issues of the native protocol. + args = append(args, line[1:len(line)-1]) + break + } + } + i := 0 + for ; i < len(line); i++ { + if line[i] == ' ' { + value := line[:i] + if len(value) > 0 { + args = append(args, value) + } + line = line[i+1:] + continue reading + } + } + args = append(args, line) + break + } + return b[i+n+2:], args, Tile38, i == len(b), nil + } + break + } + } + return b, args, Tile38, true, nil +} +func readTelnetCommand(b []byte, argsbuf [][]byte) ( + leftover []byte, args [][]byte, kind Kind, stop bool, err error, +) { + // just a plain text command + for i := 0; i < len(b); i++ { + if b[i] == '\n' { + var line []byte + if i > 0 && b[i-1] == '\r' { + line = b[:i-1] + } else { + line = b[:i] + } + var quote bool + var quotech byte + var escape bool + outer: + for { + nline := make([]byte, 0, len(line)) + for i := 0; i < len(line); i++ { + c := line[i] + if !quote { + if c == ' ' { + if len(nline) > 0 { + args = append(args, nline) + } + line = line[i+1:] + continue outer + } + if c == '"' || c == '\'' { + if i != 0 { + return b, args, Telnet, true, errUnbalancedQuotes + } + quotech = c + quote = true + line = line[i+1:] + continue outer + } + } else { + if escape { + escape = false + switch c { + case 'n': + c = '\n' + case 'r': + c = '\r' + case 't': + c = '\t' + } + } else if c == quotech { + quote = false + quotech = 0 + args = append(args, nline) + line = line[i+1:] + if len(line) > 0 && line[0] != ' ' { + return b, args, Telnet, true, errUnbalancedQuotes + } + continue outer + } else if c == '\\' { + escape = true + continue + } + } + nline = append(nline, c) + } + if quote { + return b, args, Telnet, true, errUnbalancedQuotes + } + if len(line) > 0 { + args = append(args, line) + } + break + } + return b[i+1:], args, Telnet, i == len(b), nil + } + } + return b, args, Telnet, true, nil +} + +// AppendUint appends a Redis protocol uint64 to the input bytes. +func AppendUint(b []byte, n uint64) []byte { + b = append(b, ':') + b = strconv.AppendUint(b, n, 10) + return append(b, '\r', '\n') +} + +// AppendInt appends a Redis protocol int64 to the input bytes. +func AppendInt(b []byte, n int64) []byte { + b = append(b, ':') + b = strconv.AppendInt(b, n, 10) + return append(b, '\r', '\n') +} + +// AppendArray appends a Redis protocol array to the input bytes. +func AppendArray(b []byte, n int) []byte { + b = append(b, '*') + b = strconv.AppendInt(b, int64(n), 10) + return append(b, '\r', '\n') +} + +// AppendBulk appends a Redis protocol bulk byte slice to the input bytes. +func AppendBulk(b []byte, bulk []byte) []byte { + b = append(b, '$') + b = strconv.AppendInt(b, int64(len(bulk)), 10) + b = append(b, '\r', '\n') + b = append(b, bulk...) + return append(b, '\r', '\n') +} + +// AppendBulkString appends a Redis protocol bulk string to the input bytes. +func AppendBulkString(b []byte, bulk string) []byte { + b = append(b, '$') + b = strconv.AppendInt(b, int64(len(bulk)), 10) + b = append(b, '\r', '\n') + b = append(b, bulk...) + return append(b, '\r', '\n') +} + +// AppendString appends a Redis protocol string to the input bytes. +func AppendString(b []byte, s string) []byte { + b = append(b, '+') + b = append(b, stripNewlines(s)...) + return append(b, '\r', '\n') +} + +// AppendError appends a Redis protocol error to the input bytes. +func AppendError(b []byte, s string) []byte { + b = append(b, '-') + b = append(b, stripNewlines(s)...) + return append(b, '\r', '\n') +} + +// AppendOK appends a Redis protocol OK to the input bytes. +func AppendOK(b []byte) []byte { + return append(b, '+', 'O', 'K', '\r', '\n') +} +func stripNewlines(s string) string { + for i := 0; i < len(s); i++ { + if s[i] == '\r' || s[i] == '\n' { + s = strings.Replace(s, "\r", " ", -1) + s = strings.Replace(s, "\n", " ", -1) + break + } + } + return s +} + +// AppendTile38 appends a Tile38 message to the input bytes. +func AppendTile38(b []byte, data []byte) []byte { + b = append(b, '$') + b = strconv.AppendInt(b, int64(len(data)), 10) + b = append(b, ' ') + b = append(b, data...) + return append(b, '\r', '\n') +} + +// AppendNull appends a Redis protocol null to the input bytes. +func AppendNull(b []byte) []byte { + return append(b, '$', '-', '1', '\r', '\n') +} diff --git a/redcon.go b/redcon.go index ca777b4..26ba003 100644 --- a/redcon.go +++ b/redcon.go @@ -6,7 +6,6 @@ import ( "errors" "io" "net" - "strconv" "sync" ) @@ -413,7 +412,7 @@ func NewWriter(wr io.Writer) *Writer { // WriteNull writes a null to the client func (w *Writer) WriteNull() { - w.b = append(w.b, '$', '-', '1', '\r', '\n') + w.b = AppendNull(w.b) } // WriteArray writes an array header. You must then write addtional @@ -424,27 +423,17 @@ func (w *Writer) WriteNull() { // c.WriteBulk("item 1") // c.WriteBulk("item 2") func (w *Writer) WriteArray(count int) { - w.b = append(w.b, '*') - w.b = strconv.AppendInt(w.b, int64(count), 10) - w.b = append(w.b, '\r', '\n') + w.b = AppendArray(w.b, count) } // WriteBulk writes bulk bytes to the client. func (w *Writer) WriteBulk(bulk []byte) { - w.b = append(w.b, '$') - w.b = strconv.AppendInt(w.b, int64(len(bulk)), 10) - w.b = append(w.b, '\r', '\n') - w.b = append(w.b, bulk...) - w.b = append(w.b, '\r', '\n') + w.b = AppendBulk(w.b, bulk) } // WriteBulkString writes a bulk string to the client. func (w *Writer) WriteBulkString(bulk string) { - w.b = append(w.b, '$') - w.b = strconv.AppendInt(w.b, int64(len(bulk)), 10) - w.b = append(w.b, '\r', '\n') - w.b = append(w.b, bulk...) - w.b = append(w.b, '\r', '\n') + w.b = AppendBulkString(w.b, bulk) } // Buffer returns the unflushed buffer. This is a copy so changes @@ -470,16 +459,12 @@ func (w *Writer) Flush() error { // WriteError writes an error to the client. func (w *Writer) WriteError(msg string) { - w.b = append(w.b, '-') - w.b = append(w.b, msg...) - w.b = append(w.b, '\r', '\n') + w.b = AppendError(w.b, msg) } // WriteString writes a string to the client. func (w *Writer) WriteString(msg string) { - w.b = append(w.b, '+') - w.b = append(w.b, msg...) - w.b = append(w.b, '\r', '\n') + w.b = AppendString(w.b, msg) } // WriteInt writes an integer to the client. @@ -489,9 +474,7 @@ func (w *Writer) WriteInt(num int) { // WriteInt64 writes a 64-bit signed integer to the client. func (w *Writer) WriteInt64(num int64) { - w.b = append(w.b, ':') - w.b = strconv.AppendInt(w.b, num, 10) - w.b = append(w.b, '\r', '\n') + w.b = AppendInt(w.b, num) } // WriteRaw writes raw data to the client. @@ -516,21 +499,27 @@ func NewReader(rd io.Reader) *Reader { } } -func parseInt(b []byte) (int, error) { - // shortcut atoi for 0-99. fails for negative numbers. - switch len(b) { - case 1: - if b[0] >= '0' && b[0] <= '9' { - return int(b[0] - '0'), nil - } - case 2: - if b[0] >= '0' && b[0] <= '9' && b[1] >= '0' && b[1] <= '9' { - return int(b[0]-'0')*10 + int(b[1]-'0'), nil - } +func parseInt(b []byte) (int, bool) { + if len(b) == 1 && b[0] >= '0' && b[0] <= '9' { + return int(b[0] - '0'), true } - // fallback to standard library - n, err := strconv.ParseUint(string(b), 10, 64) - return int(n), err + var n int + var sign bool + var i int + if len(b) > 0 && b[0] == '-' { + sign = true + i++ + } + for ; i < len(b); i++ { + if b[i] < '0' || b[i] > '9' { + return 0, false + } + n = n*10 + int(b[i]-'0') + } + if sign { + n *= -1 + } + return n, true } func (rd *Reader) readCommands(leftover *int) ([]Command, error) { @@ -645,8 +634,8 @@ func (rd *Reader) readCommands(leftover *int) ([]Command, error) { if b[i-1] != '\r' { return nil, errInvalidMultiBulkLength } - count, err := parseInt(b[1 : i-1]) - if err != nil || count <= 0 { + count, ok := parseInt(b[1 : i-1]) + if !ok || count <= 0 { return nil, errInvalidMultiBulkLength } marks = marks[:0] @@ -664,8 +653,8 @@ func (rd *Reader) readCommands(leftover *int) ([]Command, error) { if b[i-1] != '\r' { return nil, errInvalidBulkLength } - size, err := parseInt(b[si+1 : i-1]) - if err != nil || size < 0 { + size, ok := parseInt(b[si+1 : i-1]) + if !ok || size < 0 { return nil, errInvalidBulkLength } if i+size+2 >= len(b) {