refactor code

This commit is contained in:
Jaden Weiss 2018-08-20 08:33:17 -04:00
parent 3ff3320c2a
commit 6f86c84a88
No known key found for this signature in database
GPG Key ID: 47D33FABE50962E0
5 changed files with 185 additions and 134 deletions

View File

@ -93,9 +93,7 @@ func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
hostNoPort = hostNoPort[:i] hostNoPort = hostNoPort[:i]
} else { } else {
switch u.Scheme { switch u.Scheme {
case "wss": case "wss", "https":
hostPort += ":443"
case "https":
hostPort += ":443" hostPort += ":443"
default: default:
hostPort += ":80" hostPort += ":80"
@ -111,7 +109,7 @@ var DefaultDialer = &Dialer{
} }
// nilDialer is dialer to use when receiver is nil. // nilDialer is dialer to use when receiver is nil.
var nilDialer Dialer = *DefaultDialer var nilDialer = *DefaultDialer
// Dial creates a new client connection. Use requestHeader to specify the // Dial creates a new client connection. Use requestHeader to specify the
// origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie). // origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie).

42
conn.go
View File

@ -110,39 +110,37 @@ type CloseError struct {
} }
func (e *CloseError) Error() string { func (e *CloseError) Error() string {
s := []byte("websocket: close ") s := "websocket: close " + strconv.Itoa(e.Code)
s = strconv.AppendInt(s, int64(e.Code), 10)
switch e.Code { switch e.Code {
case CloseNormalClosure: case CloseNormalClosure:
s = append(s, " (normal)"...) s += " (normal)"
case CloseGoingAway: case CloseGoingAway:
s = append(s, " (going away)"...) s += " (going away)"
case CloseProtocolError: case CloseProtocolError:
s = append(s, " (protocol error)"...) s += " (protocol error)"
case CloseUnsupportedData: case CloseUnsupportedData:
s = append(s, " (unsupported data)"...) s += " (unsupported data)"
case CloseNoStatusReceived: case CloseNoStatusReceived:
s = append(s, " (no status)"...) s += " (no status)"
case CloseAbnormalClosure: case CloseAbnormalClosure:
s = append(s, " (abnormal closure)"...) s += " (abnormal closure)"
case CloseInvalidFramePayloadData: case CloseInvalidFramePayloadData:
s = append(s, " (invalid payload data)"...) s += " (invalid payload data)"
case ClosePolicyViolation: case ClosePolicyViolation:
s = append(s, " (policy violation)"...) s += " (policy violation)"
case CloseMessageTooBig: case CloseMessageTooBig:
s = append(s, " (message too big)"...) s += " (message too big)"
case CloseMandatoryExtension: case CloseMandatoryExtension:
s = append(s, " (mandatory extension missing)"...) s += " (mandatory extension missing)"
case CloseInternalServerErr: case CloseInternalServerErr:
s = append(s, " (internal server error)"...) s += " (internal server error)"
case CloseTLSHandshake: case CloseTLSHandshake:
s = append(s, " (TLS handshake error)"...) s += " (TLS handshake error)"
} }
if e.Text != "" { if e.Text != "" {
s = append(s, ": "...) s += ": " + e.Text
s = append(s, e.Text...)
} }
return string(s) return s
} }
// IsCloseError returns boolean indicating whether the error is a *CloseError // IsCloseError returns boolean indicating whether the error is a *CloseError
@ -223,6 +221,9 @@ func isValidReceivedCloseCode(code int) bool {
return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999)
} }
// CloseHandler is a callback esed on closure.
type CloseHandler func(code int, text string) error
// The Conn type represents a WebSocket connection. // The Conn type represents a WebSocket connection.
type Conn struct { type Conn struct {
conn net.Conn conn net.Conn
@ -255,7 +256,7 @@ type Conn struct {
readMaskKey [4]byte readMaskKey [4]byte
handlePong func(string) error handlePong func(string) error
handlePing func(string) error handlePing func(string) error
handleClose func(int, string) error handleClose CloseHandler
readErrCount int readErrCount int
messageReader *messageReader // the current low-level reader messageReader *messageReader // the current low-level reader
@ -267,6 +268,7 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int)
return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil) return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil)
} }
// writeHook is an io.Writer that steals the buffer that it is called with.
type writeHook struct { type writeHook struct {
p []byte p []byte
} }
@ -1041,7 +1043,7 @@ func (c *Conn) SetReadLimit(limit int64) {
} }
// CloseHandler returns the current close handler // CloseHandler returns the current close handler
func (c *Conn) CloseHandler() func(code int, text string) error { func (c *Conn) CloseHandler() CloseHandler {
return c.handleClose return c.handleClose
} }
@ -1059,7 +1061,7 @@ func (c *Conn) CloseHandler() func(code int, text string) error {
// normal error handling. Applications should only set a close handler when the // normal error handling. Applications should only set a close handler when the
// application must perform some action before sending a close message back to // application must perform some action before sending a close message back to
// the peer. // the peer.
func (c *Conn) SetCloseHandler(h func(code int, text string) error) { func (c *Conn) SetCloseHandler(h CloseHandler) {
if h == nil { if h == nil {
h = func(code int, text string) error { h = func(code int, text string) error {
message := FormatCloseMessage(code, "") message := FormatCloseMessage(code, "")

View File

@ -7,7 +7,6 @@ package websocket
import ( import (
"bufio" "bufio"
"errors" "errors"
"net"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -159,17 +158,12 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
} }
} }
var (
netConn net.Conn
err error
)
h, ok := w.(http.Hijacker) h, ok := w.(http.Hijacker)
if !ok { if !ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker") return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
} }
var brw *bufio.ReadWriter var brw *bufio.ReadWriter
netConn, brw, err = h.Hijack() netConn, brw, err := h.Hijack()
if err != nil { if err != nil {
return u.returnError(w, r, http.StatusInternalServerError, err.Error()) return u.returnError(w, r, http.StatusInternalServerError, err.Error())
} }
@ -187,48 +181,55 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
c.newDecompressionReader = decompressNoContextTakeover c.newDecompressionReader = decompressNoContextTakeover
} }
p := c.writeBuf[:0] // workaround for haxe in newConnBRW
p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) brw.Writer.Reset(netConn)
p = append(p, computeAcceptKey(challengeKey)...)
p = append(p, "\r\n"...) // Clear deadlines set by HTTP server.
netConn.SetReadDeadline(time.Time{})
// start handshake timeout
if u.HandshakeTimeout > 0 {
netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout))
}
// write handshake
brw.WriteString("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ")
brw.WriteString(computeAcceptKey(challengeKey))
brw.WriteString("\r\n")
if c.subprotocol != "" { if c.subprotocol != "" {
p = append(p, "Sec-WebSocket-Protocol: "...) brw.WriteString("Sec-WebSocket-Protocol: ")
p = append(p, c.subprotocol...) brw.WriteString(c.subprotocol)
p = append(p, "\r\n"...) brw.WriteString("\r\n")
} }
if compress { if compress {
p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) brw.WriteString("Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n")
} }
for k, vs := range responseHeader { for k, vs := range responseHeader {
if k == "Sec-Websocket-Protocol" { if k == "Sec-Websocket-Protocol" {
continue continue
} }
for _, v := range vs { for _, v := range vs {
p = append(p, k...) brw.WriteString(k)
p = append(p, ": "...) brw.WriteString(": ")
for i := 0; i < len(v); i++ { for i := 0; i < len(v); i++ {
b := v[i] b := v[i]
if b <= 31 { if b <= 31 {
// prevent response splitting. // prevent response splitting.
b = ' ' b = ' '
} }
p = append(p, b) brw.WriteByte(b)
} }
p = append(p, "\r\n"...) brw.WriteString("\r\n")
} }
} }
p = append(p, "\r\n"...) brw.WriteString("\r\n")
// Clear deadlines set by HTTP server. // flush handshake
netConn.SetDeadline(time.Time{}) if err = brw.Writer.Flush(); err != nil {
if u.HandshakeTimeout > 0 {
netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout))
}
if _, err = netConn.Write(p); err != nil {
netConn.Close() netConn.Close()
return nil, err return nil, err
} }
// clear handshake write timeout
if u.HandshakeTimeout > 0 { if u.HandshakeTimeout > 0 {
netConn.SetWriteDeadline(time.Time{}) netConn.SetWriteDeadline(time.Time{})
} }

196
util.go
View File

@ -32,122 +32,152 @@ func generateChallengeKey() (string, error) {
} }
// Octet types from RFC 2616. // Octet types from RFC 2616.
var octetTypes [256]byte //
// OCTET = <any 8-bit sequence of data>
// CHAR = <any US-ASCII character (octets 0 - 127)>
// CTL = <any US-ASCII control character (octets 0 - 31) and DEL (127)>
// CR = <US-ASCII CR, carriage return (13)>
// LF = <US-ASCII LF, linefeed (10)>
// SP = <US-ASCII SP, space (32)>
// HT = <US-ASCII HT, horizontal-tab (9)>
// <"> = <US-ASCII double-quote mark (34)>
// CRLF = CR LF
// LWS = [CRLF] 1*( SP | HT )
// TEXT = <any OCTET except CTLs, but including LWS>
// separators = "(" | ")" | "<" | ">" | "@" | "," | ";" | ":" | "\" | <">
// | "/" | "[" | "]" | "?" | "=" | "{" | "}" | SP | HT
// token = 1*<any CHAR except CTLs or separators>
// qdtext = <any TEXT except <">>
const ( func skipSpace(s string) string {
isTokenOctet = 1 << iota for i := 0; i < len(s); i++ {
isSpaceOctet switch s[i] {
) case ' ', '\t', '\r', '\n':
default:
func init() { return s[i:]
// From RFC 2616
//
// OCTET = <any 8-bit sequence of data>
// CHAR = <any US-ASCII character (octets 0 - 127)>
// CTL = <any US-ASCII control character (octets 0 - 31) and DEL (127)>
// CR = <US-ASCII CR, carriage return (13)>
// LF = <US-ASCII LF, linefeed (10)>
// SP = <US-ASCII SP, space (32)>
// HT = <US-ASCII HT, horizontal-tab (9)>
// <"> = <US-ASCII double-quote mark (34)>
// CRLF = CR LF
// LWS = [CRLF] 1*( SP | HT )
// TEXT = <any OCTET except CTLs, but including LWS>
// separators = "(" | ")" | "<" | ">" | "@" | "," | ";" | ":" | "\" | <">
// | "/" | "[" | "]" | "?" | "=" | "{" | "}" | SP | HT
// token = 1*<any CHAR except CTLs or separators>
// qdtext = <any TEXT except <">>
for c := 0; c < 256; c++ {
var t byte
isCtl := c <= 31 || c == 127
isChar := 0 <= c && c <= 127
isSeparator := strings.IndexRune(" \t\"(),/:;<=>?@[]\\{}", rune(c)) >= 0
if strings.IndexRune(" \t\r\n", rune(c)) >= 0 {
t |= isSpaceOctet
}
if isChar && !isCtl && !isSeparator {
t |= isTokenOctet
}
octetTypes[c] = t
}
}
func skipSpace(s string) (rest string) {
i := 0
for ; i < len(s); i++ {
if octetTypes[s[i]]&isSpaceOctet == 0 {
break
} }
} }
return s[i:] return ""
} }
func nextToken(s string) (token, rest string) { func nextToken(s string) (token, rest string) {
i := 0 i := 0
loop:
for ; i < len(s); i++ { for ; i < len(s); i++ {
if octetTypes[s[i]]&isTokenOctet == 0 { c := s[i]
if c <= 31 || c >= 127 { // control characters & non-ascii are not token octets
break break
} }
switch c { //separators are not token octets
case ' ', '\t', '"', '(', ')', ',', '/', ':', ';', '<',
'=', '>', '?', '@', '[', ']', '\\', '{', '}':
break loop
}
} }
return s[:i], s[i:] return s[:i], s[i:]
} }
// nextTokenOrQuoted gets the next token, unescaping and unquoting quoted tokens
func nextTokenOrQuoted(s string) (value string, rest string) { func nextTokenOrQuoted(s string) (value string, rest string) {
// if it isnt quoted, then regular tokenization rules apply
if !strings.HasPrefix(s, "\"") { if !strings.HasPrefix(s, "\"") {
return nextToken(s) return nextToken(s)
} }
// trim off opening quote
s = s[1:] s = s[1:]
for i := 0; i < len(s); i++ {
// find closing quote while counting escapes
escapes := 0 // count escapes
escaped := false // whether the next char is escaped
i := 0
scan:
for ; i < len(s); i++ {
// skip escaped characters
if escaped {
escaped = false
continue
}
switch s[i] { switch s[i] {
case '"': case '"':
return s[:i], s[i+1:] // closing quote
break scan
case '\\': case '\\':
p := make([]byte, len(s)-1) // escape sequence
j := copy(p, s[:i]) escaped = true
escape := true escapes++
for i = i + 1; i < len(s); i++ {
b := s[i]
switch {
case escape:
escape = false
p[j] = b
j++
case b == '\\':
escape = true
case b == '"':
return string(p[:j]), s[i+1:]
default:
p[j] = b
j++
}
}
return "", ""
} }
} }
return "", ""
// handle unterminated quoted token
if i == len(s) {
return "", ""
}
// split out token
value, rest = s[:i], s[i+1:]
// handle token without escapes
if escapes == 0 {
return value, rest
}
// unescape token
buf := make([]byte, len(value)-escapes)
j := 0
escaped = false
for i := 0; i < len(value); i++ {
c := value[i]
// handle escape sequence
if c == '\\' && !escaped {
escaped = true
continue
}
escaped = false
// copy character
buf[j] = c
j++
}
return string(buf), rest
} }
// equalASCIIFold returns true if s is equal to t with ASCII case folding. // equalASCIIFold returns true if s is equal to t with ASCII case folding.
func equalASCIIFold(s, t string) bool { func equalASCIIFold(s, t string) bool {
for s != "" && t != "" { for s != "" && t != "" {
sr, size := utf8.DecodeRuneInString(s) // get first rune from both strings
s = s[size:] var sr, tr rune
tr, size := utf8.DecodeRuneInString(t) if s[0] < utf8.RuneSelf {
t = t[size:] sr, s = rune(s[0]), s[1:]
if sr == tr { } else {
continue r, size := utf8.DecodeRuneInString(s)
sr, s = r, s[size:]
} }
if 'A' <= sr && sr <= 'Z' { if t[0] < utf8.RuneSelf {
sr = sr + 'a' - 'A' tr, t = rune(t[0]), t[1:]
} else {
r, size := utf8.DecodeRuneInString(t)
tr, t = r, t[size:]
} }
if 'A' <= tr && tr <= 'Z' {
tr = tr + 'a' - 'A' // compare runes
} switch {
if sr != tr { case sr == tr:
case 'A' <= sr && sr <= 'Z':
if sr+'a'-'A' != tr {
return false
}
case 'A' <= tr && tr <= 'Z':
if tr+'a'-'A' != sr {
return false
}
default:
return false return false
} }
} }
return s == t return s == t
} }
@ -178,7 +208,7 @@ headers:
return false return false
} }
// parseExtensiosn parses WebSocket extensions from a header. // parseExtensions parses WebSocket extensions from a header.
func parseExtensions(header http.Header) []map[string]string { func parseExtensions(header http.Header) []map[string]string {
// From RFC 6455: // From RFC 6455:
// //

View File

@ -10,12 +10,32 @@ import (
"testing" "testing"
) )
var nextTokenTests = []struct {
input string
token, next string
}{
{"other,websocket,more", "other", ",websocket,more"},
{"websocket,more", "websocket", ",more"},
{"more", "more", ""},
}
func TestNextToken(t *testing.T) {
for _, tt := range nextTokenTests {
token, next := nextToken(tt.input)
if token != tt.token || next != tt.next {
t.Errorf("nextToken(%q) = %q, %q, want %q, %q", tt.input, token, next, tt.token, tt.next)
}
}
}
var equalASCIIFoldTests = []struct { var equalASCIIFoldTests = []struct {
t, s string t, s string
eq bool eq bool
}{ }{
{"WebSocket", "websocket", true}, {"WebSocket", "websocket", true},
{"websocket", "WebSocket", true}, {"websocket", "WebSocket", true},
{"websocket", "WebRocket", false},
{"WebRocket", "websocket", false},
{"Öyster", "öyster", false}, {"Öyster", "öyster", false},
} }