tile38/controller/server/anyreader.go

296 lines
6.9 KiB
Go

package server
import (
"bufio"
"bytes"
"crypto/sha1"
"encoding/base64"
"errors"
"io"
"net/url"
"strconv"
"strings"
"github.com/tidwall/resp"
)
const TelnetIsJSON = false
type Type int
const (
Null Type = iota
RESP
Telnet
Native
HTTP
WebSocket
JSON
)
type errRESPProtocolError struct {
msg string
}
func (err errRESPProtocolError) Error() string {
return "Protocol error: " + err.msg
}
type Message struct {
Command string
Values []resp.Value
ConnType Type
OutputType Type
Auth string
}
type AnyReaderWriter struct {
rd *bufio.Reader
wr io.Writer
ws bool
}
func NewAnyReaderWriter(rd io.Reader) *AnyReaderWriter {
ar := &AnyReaderWriter{}
if rd2, ok := rd.(*bufio.Reader); ok {
ar.rd = rd2
} else {
ar.rd = bufio.NewReader(rd)
}
if wr, ok := rd.(io.Writer); ok {
ar.wr = wr
}
return ar
}
func (ar *AnyReaderWriter) peekcrlfline() (string, error) {
// this is slow operation.
for i := 0; ; i++ {
bb, err := ar.rd.Peek(i)
if err != nil {
return "", err
}
if len(bb) > 2 && bb[len(bb)-2] == '\r' && bb[len(bb)-1] == '\n' {
return string(bb[:len(bb)-2]), nil
}
}
}
func (ar *AnyReaderWriter) readcrlfline() (string, error) {
var line []byte
for {
bb, err := ar.rd.ReadBytes('\r')
if err != nil {
return "", err
}
if line == nil {
line = bb
} else {
line = append(line, bb...)
}
b, err := ar.rd.ReadByte()
if err != nil {
return "", err
}
if b == '\n' {
return string(line[:len(line)-1]), nil
}
line = append(line, b)
}
}
func (ar *AnyReaderWriter) ReadMessage() (*Message, error) {
b, err := ar.rd.ReadByte()
if err != nil {
return nil, err
}
if err := ar.rd.UnreadByte(); err != nil {
return nil, err
}
switch b {
case 'G', 'P':
line, err := ar.peekcrlfline()
if err != nil {
return nil, err
}
if strings.HasSuffix(line, " HTTP/1.1") {
return ar.readHTTPMessage()
}
case '$':
return ar.readNativeMessage()
}
// MultiBulk also reads telnet
return ar.readMultiBulkMessage()
}
func (ar *AnyReaderWriter) readNativeMessage() (*Message, error) {
b, err := ar.rd.ReadBytes(' ')
if err != nil {
return nil, err
}
if len(b) > 0 && b[0] != '$' {
return nil, errors.New("invalid message")
}
n, err := strconv.ParseUint(string(b[1:len(b)-1]), 10, 32)
if err != nil {
return nil, errors.New("invalid size")
}
if n > 0x1FFFFFFF { // 536,870,911 bytes
return nil, errors.New("message too big")
}
b = make([]byte, int(n)+2)
if _, err := io.ReadFull(ar.rd, b); err != nil {
return nil, err
}
if b[len(b)-2] != '\r' || b[len(b)-1] != '\n' {
return nil, errors.New("expecting crlf")
}
values := make([]resp.Value, 0, 16)
line := b[:len(b)-2]
reading:
for len(line) != 0 {
if line[0] == '{' {
// The native protocol cannot understand json boundaries so it assumes that
// a json element must be at the end of the line.
values = append(values, resp.StringValue(string(line)))
break
}
i := 0
for ; i < len(line); i++ {
if line[i] == ' ' {
values = append(values, resp.StringValue(string(line[:i])))
line = line[i+1:]
continue reading
}
}
values = append(values, resp.StringValue(string(line)))
break
}
return &Message{Command: commandValues(values), Values: values, ConnType: Native, OutputType: JSON}, nil
}
func commandValues(values []resp.Value) string {
if len(values) == 0 {
return ""
}
return strings.ToLower(values[0].String())
}
func (ar *AnyReaderWriter) readMultiBulkMessage() (*Message, error) {
rd := resp.NewReader(ar.rd)
v, telnet, _, err := rd.ReadMultiBulk()
if err != nil {
return nil, err
}
values := v.Array()
if len(values) == 0 {
return nil, nil
}
if telnet && TelnetIsJSON {
return &Message{Command: commandValues(values), Values: values, ConnType: Telnet, OutputType: JSON}, nil
}
return &Message{Command: commandValues(values), Values: values, ConnType: RESP, OutputType: RESP}, nil
}
func (ar *AnyReaderWriter) readHTTPMessage() (*Message, error) {
msg := &Message{ConnType: HTTP, OutputType: JSON}
line, err := ar.readcrlfline()
if err != nil {
return nil, err
}
parts := strings.Split(line, " ")
if len(parts) != 3 {
return nil, errors.New("invalid HTTP request")
}
method := parts[0]
path := parts[1]
if len(path) == 0 || path[0] != '/' {
return nil, errors.New("invalid HTTP request")
}
path, err = url.QueryUnescape(path[1:])
if err != nil {
return nil, errors.New("invalid HTTP request")
}
if method != "GET" && method != "POST" {
return nil, errors.New("invalid HTTP method")
}
contentLength := 0
websocket := false
websocketVersion := 0
websocketKey := ""
for {
header, err := ar.readcrlfline()
if err != nil {
return nil, err
}
if header == "" {
break // end of headers
}
if header[0] == 'a' || header[0] == 'A' {
if strings.HasPrefix(strings.ToLower(header), "authorization:") {
msg.Auth = strings.TrimSpace(header[len("authorization:"):])
}
} else if header[0] == 'u' || header[0] == 'U' {
if strings.HasPrefix(strings.ToLower(header), "upgrade:") && strings.ToLower(strings.TrimSpace(header[len("upgrade:"):])) == "websocket" {
websocket = true
}
} else if header[0] == 's' || header[0] == 'S' {
if strings.HasPrefix(strings.ToLower(header), "sec-websocket-version:") {
var n uint64
n, err = strconv.ParseUint(strings.TrimSpace(header[len("sec-websocket-version:"):]), 10, 64)
if err != nil {
return nil, err
}
websocketVersion = int(n)
} else if strings.HasPrefix(strings.ToLower(header), "sec-websocket-key:") {
websocketKey = strings.TrimSpace(header[len("sec-websocket-key:"):])
}
} else if header[0] == 'c' || header[0] == 'C' {
if strings.HasPrefix(strings.ToLower(header), "content-length:") {
var n uint64
n, err = strconv.ParseUint(strings.TrimSpace(header[len("content-length:"):]), 10, 64)
if err != nil {
return nil, err
}
contentLength = int(n)
}
}
}
if websocket && websocketVersion >= 13 && websocketKey != "" {
msg.ConnType = WebSocket
if ar.wr == nil {
return nil, errors.New("connection is nil")
}
sum := sha1.Sum([]byte(websocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
accept := base64.StdEncoding.EncodeToString(sum[:])
wshead := "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: " + accept + "\r\n\r\n"
if _, err = ar.wr.Write([]byte(wshead)); err != nil {
return nil, err
}
ar.ws = true
} else if contentLength > 0 {
msg.ConnType = HTTP
buf := make([]byte, contentLength)
if _, err = io.ReadFull(ar.rd, buf); err != nil {
return nil, err
}
path += string(buf)
}
if path == "" {
return msg, nil
}
if !strings.HasSuffix(path, "\r\n") {
path += "\r\n"
}
rd := NewAnyReaderWriter(bytes.NewBufferString(path))
nmsg, err := rd.ReadMessage()
if err != nil {
return nil, err
}
msg.OutputType = nmsg.OutputType
msg.Values = nmsg.Values
msg.Command = commandValues(nmsg.Values)
return msg, nil
}