From 6f86c84a886a6678b32798898eb0a6047e8ed10e Mon Sep 17 00:00:00 2001 From: Jaden Weiss Date: Mon, 20 Aug 2018 08:33:17 -0400 Subject: [PATCH] refactor code --- client.go | 6 +- conn.go | 42 +++++------ server.go | 55 ++++++++------- util.go | 196 +++++++++++++++++++++++++++++---------------------- util_test.go | 20 ++++++ 5 files changed, 185 insertions(+), 134 deletions(-) diff --git a/client.go b/client.go index 41f8ed5..070e182 100644 --- a/client.go +++ b/client.go @@ -93,9 +93,7 @@ func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { hostNoPort = hostNoPort[:i] } else { switch u.Scheme { - case "wss": - hostPort += ":443" - case "https": + case "wss", "https": hostPort += ":443" default: hostPort += ":80" @@ -111,7 +109,7 @@ var DefaultDialer = &Dialer{ } // nilDialer is dialer to use when receiver is nil. -var nilDialer Dialer = *DefaultDialer +var nilDialer = *DefaultDialer // Dial creates a new client connection. Use requestHeader to specify the // origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie). diff --git a/conn.go b/conn.go index 5f46bf4..59a4baf 100644 --- a/conn.go +++ b/conn.go @@ -110,39 +110,37 @@ type CloseError struct { } func (e *CloseError) Error() string { - s := []byte("websocket: close ") - s = strconv.AppendInt(s, int64(e.Code), 10) + s := "websocket: close " + strconv.Itoa(e.Code) switch e.Code { case CloseNormalClosure: - s = append(s, " (normal)"...) + s += " (normal)" case CloseGoingAway: - s = append(s, " (going away)"...) + s += " (going away)" case CloseProtocolError: - s = append(s, " (protocol error)"...) + s += " (protocol error)" case CloseUnsupportedData: - s = append(s, " (unsupported data)"...) + s += " (unsupported data)" case CloseNoStatusReceived: - s = append(s, " (no status)"...) + s += " (no status)" case CloseAbnormalClosure: - s = append(s, " (abnormal closure)"...) + s += " (abnormal closure)" case CloseInvalidFramePayloadData: - s = append(s, " (invalid payload data)"...) + s += " (invalid payload data)" case ClosePolicyViolation: - s = append(s, " (policy violation)"...) + s += " (policy violation)" case CloseMessageTooBig: - s = append(s, " (message too big)"...) + s += " (message too big)" case CloseMandatoryExtension: - s = append(s, " (mandatory extension missing)"...) + s += " (mandatory extension missing)" case CloseInternalServerErr: - s = append(s, " (internal server error)"...) + s += " (internal server error)" case CloseTLSHandshake: - s = append(s, " (TLS handshake error)"...) + s += " (TLS handshake error)" } if e.Text != "" { - s = append(s, ": "...) - s = append(s, e.Text...) + s += ": " + e.Text } - return string(s) + return s } // IsCloseError returns boolean indicating whether the error is a *CloseError @@ -223,6 +221,9 @@ func isValidReceivedCloseCode(code int) bool { return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) } +// CloseHandler is a callback esed on closure. +type CloseHandler func(code int, text string) error + // The Conn type represents a WebSocket connection. type Conn struct { conn net.Conn @@ -255,7 +256,7 @@ type Conn struct { readMaskKey [4]byte handlePong func(string) error handlePing func(string) error - handleClose func(int, string) error + handleClose CloseHandler readErrCount int messageReader *messageReader // the current low-level reader @@ -267,6 +268,7 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil) } +// writeHook is an io.Writer that steals the buffer that it is called with. type writeHook struct { p []byte } @@ -1041,7 +1043,7 @@ func (c *Conn) SetReadLimit(limit int64) { } // CloseHandler returns the current close handler -func (c *Conn) CloseHandler() func(code int, text string) error { +func (c *Conn) CloseHandler() CloseHandler { return c.handleClose } @@ -1059,7 +1061,7 @@ func (c *Conn) CloseHandler() func(code int, text string) error { // normal error handling. Applications should only set a close handler when the // application must perform some action before sending a close message back to // the peer. -func (c *Conn) SetCloseHandler(h func(code int, text string) error) { +func (c *Conn) SetCloseHandler(h CloseHandler) { if h == nil { h = func(code int, text string) error { message := FormatCloseMessage(code, "") diff --git a/server.go b/server.go index 4834c38..ca5cb4f 100644 --- a/server.go +++ b/server.go @@ -7,7 +7,6 @@ package websocket import ( "bufio" "errors" - "net" "net/http" "net/url" "strings" @@ -159,17 +158,12 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade } } - var ( - netConn net.Conn - err error - ) - h, ok := w.(http.Hijacker) if !ok { return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker") } var brw *bufio.ReadWriter - netConn, brw, err = h.Hijack() + netConn, brw, err := h.Hijack() if err != nil { return u.returnError(w, r, http.StatusInternalServerError, err.Error()) } @@ -187,48 +181,55 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade c.newDecompressionReader = decompressNoContextTakeover } - p := c.writeBuf[:0] - p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) - p = append(p, computeAcceptKey(challengeKey)...) - p = append(p, "\r\n"...) + // workaround for haxe in newConnBRW + brw.Writer.Reset(netConn) + + // Clear deadlines set by HTTP server. + netConn.SetReadDeadline(time.Time{}) + // start handshake timeout + if u.HandshakeTimeout > 0 { + netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)) + } + + // write handshake + brw.WriteString("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ") + brw.WriteString(computeAcceptKey(challengeKey)) + brw.WriteString("\r\n") if c.subprotocol != "" { - p = append(p, "Sec-WebSocket-Protocol: "...) - p = append(p, c.subprotocol...) - p = append(p, "\r\n"...) + brw.WriteString("Sec-WebSocket-Protocol: ") + brw.WriteString(c.subprotocol) + brw.WriteString("\r\n") } if compress { - p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) + brw.WriteString("Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n") } for k, vs := range responseHeader { if k == "Sec-Websocket-Protocol" { continue } for _, v := range vs { - p = append(p, k...) - p = append(p, ": "...) + brw.WriteString(k) + brw.WriteString(": ") for i := 0; i < len(v); i++ { b := v[i] if b <= 31 { // prevent response splitting. b = ' ' } - p = append(p, b) + brw.WriteByte(b) } - p = append(p, "\r\n"...) + brw.WriteString("\r\n") } } - p = append(p, "\r\n"...) + brw.WriteString("\r\n") - // Clear deadlines set by HTTP server. - netConn.SetDeadline(time.Time{}) - - if u.HandshakeTimeout > 0 { - netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)) - } - if _, err = netConn.Write(p); err != nil { + // flush handshake + if err = brw.Writer.Flush(); err != nil { netConn.Close() return nil, err } + + // clear handshake write timeout if u.HandshakeTimeout > 0 { netConn.SetWriteDeadline(time.Time{}) } diff --git a/util.go b/util.go index 385fa01..4dce884 100644 --- a/util.go +++ b/util.go @@ -32,122 +32,152 @@ func generateChallengeKey() (string, error) { } // Octet types from RFC 2616. -var octetTypes [256]byte +// +// OCTET = +// CHAR = +// CTL = +// CR = +// LF = +// SP = +// HT = +// <"> = +// CRLF = CR LF +// LWS = [CRLF] 1*( SP | HT ) +// TEXT = +// separators = "(" | ")" | "<" | ">" | "@" | "," | ";" | ":" | "\" | <"> +// | "/" | "[" | "]" | "?" | "=" | "{" | "}" | SP | HT +// token = 1* +// qdtext = > -const ( - isTokenOctet = 1 << iota - isSpaceOctet -) - -func init() { - // From RFC 2616 - // - // OCTET = - // CHAR = - // CTL = - // CR = - // LF = - // SP = - // HT = - // <"> = - // CRLF = CR LF - // LWS = [CRLF] 1*( SP | HT ) - // TEXT = - // separators = "(" | ")" | "<" | ">" | "@" | "," | ";" | ":" | "\" | <"> - // | "/" | "[" | "]" | "?" | "=" | "{" | "}" | SP | HT - // token = 1* - // qdtext = > - - for c := 0; c < 256; c++ { - var t byte - isCtl := c <= 31 || c == 127 - isChar := 0 <= c && c <= 127 - isSeparator := strings.IndexRune(" \t\"(),/:;<=>?@[]\\{}", rune(c)) >= 0 - if strings.IndexRune(" \t\r\n", rune(c)) >= 0 { - t |= isSpaceOctet - } - if isChar && !isCtl && !isSeparator { - t |= isTokenOctet - } - octetTypes[c] = t - } -} - -func skipSpace(s string) (rest string) { - i := 0 - for ; i < len(s); i++ { - if octetTypes[s[i]]&isSpaceOctet == 0 { - break +func skipSpace(s string) string { + for i := 0; i < len(s); i++ { + switch s[i] { + case ' ', '\t', '\r', '\n': + default: + return s[i:] } } - return s[i:] + return "" } func nextToken(s string) (token, rest string) { i := 0 +loop: for ; i < len(s); i++ { - if octetTypes[s[i]]&isTokenOctet == 0 { + c := s[i] + if c <= 31 || c >= 127 { // control characters & non-ascii are not token octets break } + switch c { //separators are not token octets + case ' ', '\t', '"', '(', ')', ',', '/', ':', ';', '<', + '=', '>', '?', '@', '[', ']', '\\', '{', '}': + break loop + } } return s[:i], s[i:] } +// nextTokenOrQuoted gets the next token, unescaping and unquoting quoted tokens func nextTokenOrQuoted(s string) (value string, rest string) { + // if it isnt quoted, then regular tokenization rules apply if !strings.HasPrefix(s, "\"") { return nextToken(s) } + + // trim off opening quote s = s[1:] - for i := 0; i < len(s); i++ { + + // find closing quote while counting escapes + escapes := 0 // count escapes + escaped := false // whether the next char is escaped + i := 0 +scan: + for ; i < len(s); i++ { + // skip escaped characters + if escaped { + escaped = false + continue + } + switch s[i] { case '"': - return s[:i], s[i+1:] + // closing quote + break scan case '\\': - p := make([]byte, len(s)-1) - j := copy(p, s[:i]) - escape := true - for i = i + 1; i < len(s); i++ { - b := s[i] - switch { - case escape: - escape = false - p[j] = b - j++ - case b == '\\': - escape = true - case b == '"': - return string(p[:j]), s[i+1:] - default: - p[j] = b - j++ - } - } - return "", "" + // escape sequence + escaped = true + escapes++ } } - return "", "" + + // handle unterminated quoted token + if i == len(s) { + return "", "" + } + + // split out token + value, rest = s[:i], s[i+1:] + + // handle token without escapes + if escapes == 0 { + return value, rest + } + + // unescape token + buf := make([]byte, len(value)-escapes) + j := 0 + escaped = false + for i := 0; i < len(value); i++ { + c := value[i] + + // handle escape sequence + if c == '\\' && !escaped { + escaped = true + continue + } + escaped = false + + // copy character + buf[j] = c + j++ + } + return string(buf), rest } // equalASCIIFold returns true if s is equal to t with ASCII case folding. func equalASCIIFold(s, t string) bool { for s != "" && t != "" { - sr, size := utf8.DecodeRuneInString(s) - s = s[size:] - tr, size := utf8.DecodeRuneInString(t) - t = t[size:] - if sr == tr { - continue + // get first rune from both strings + var sr, tr rune + if s[0] < utf8.RuneSelf { + sr, s = rune(s[0]), s[1:] + } else { + r, size := utf8.DecodeRuneInString(s) + sr, s = r, s[size:] } - if 'A' <= sr && sr <= 'Z' { - sr = sr + 'a' - 'A' + if t[0] < utf8.RuneSelf { + tr, t = rune(t[0]), t[1:] + } else { + r, size := utf8.DecodeRuneInString(t) + tr, t = r, t[size:] } - if 'A' <= tr && tr <= 'Z' { - tr = tr + 'a' - 'A' - } - if sr != tr { + + // compare runes + switch { + case sr == tr: + case 'A' <= sr && sr <= 'Z': + if sr+'a'-'A' != tr { + return false + } + case 'A' <= tr && tr <= 'Z': + if tr+'a'-'A' != sr { + return false + } + default: return false } } + return s == t } @@ -178,7 +208,7 @@ headers: return false } -// parseExtensiosn parses WebSocket extensions from a header. +// parseExtensions parses WebSocket extensions from a header. func parseExtensions(header http.Header) []map[string]string { // From RFC 6455: // diff --git a/util_test.go b/util_test.go index 6e15965..bd6a75d 100644 --- a/util_test.go +++ b/util_test.go @@ -10,12 +10,32 @@ import ( "testing" ) +var nextTokenTests = []struct { + input string + token, next string +}{ + {"other,websocket,more", "other", ",websocket,more"}, + {"websocket,more", "websocket", ",more"}, + {"more", "more", ""}, +} + +func TestNextToken(t *testing.T) { + for _, tt := range nextTokenTests { + token, next := nextToken(tt.input) + if token != tt.token || next != tt.next { + t.Errorf("nextToken(%q) = %q, %q, want %q, %q", tt.input, token, next, tt.token, tt.next) + } + } +} + var equalASCIIFoldTests = []struct { t, s string eq bool }{ {"WebSocket", "websocket", true}, {"websocket", "WebSocket", true}, + {"websocket", "WebRocket", false}, + {"WebRocket", "websocket", false}, {"Öyster", "öyster", false}, }