merge: master

This commit is contained in:
misu 2019-02-02 13:48:39 +09:00
commit a8d45c1b96
9 changed files with 371 additions and 154 deletions

View File

@ -27,7 +27,7 @@ package API is stable.
### Protocol Compliance
The Gorilla WebSocket package passes the server tests in the [Autobahn Test
Suite](http://autobahn.ws/testsuite) using the application in the [examples/autobahn
Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn).
### Gorilla WebSocket compared with other packages

View File

@ -71,7 +71,7 @@ type Dialer struct {
// HandshakeTimeout specifies the duration for the handshake to complete.
HandshakeTimeout time.Duration
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer
// size is zero, then a useful default size is used. The I/O buffer sizes
// do not limit the size of the messages that can be sent or received.
ReadBufferSize, WriteBufferSize int
@ -144,7 +144,7 @@ var nilDialer = *DefaultDialer
// Use the response.Header to get the selected subprotocol
// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
//
// The context will be used in the request and in the Dialer
// The context will be used in the request and in the Dialer.
//
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
// non-nil *http.Response so that callers can handle redirects, authentication,

View File

@ -11,8 +11,10 @@ import (
"crypto/x509"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/http"
"net/http/cookiejar"
@ -51,15 +53,10 @@ type cstHandler struct {
cstHandlerConfig
}
var cstDialerWithoutHandshakeTimeout = Dialer{
Subprotocols: []string{"p1", "p2"},
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
type cstServer struct {
*httptest.Server
URL string
t *testing.T
}
const (
@ -341,10 +338,7 @@ func TestDialCookieJar(t *testing.T) {
sendRecv(t, ws)
}
func TestDialTLS(t *testing.T) {
s := newTLSServer(t, cstHandlerConfig{})
defer s.Close()
func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool {
certs := x509.NewCertPool()
for _, c := range s.TLS.Certificates {
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
@ -355,35 +349,15 @@ func TestDialTLS(t *testing.T) {
certs.AddCert(root)
}
}
d := cstDialer
d.TLSClientConfig = &tls.Config{RootCAs: certs}
ws, _, err := d.Dial(s.URL, nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer ws.Close()
sendRecv(t, ws)
return certs
}
func xTestDialTLSBadCert(t *testing.T) {
// This test is deactivated because of noisy logging from the net/http package.
s := newTLSServer(t, cstHandlerConfig{})
defer s.Close()
ws, _, err := cstDialer.Dial(s.URL, nil)
if err == nil {
ws.Close()
t.Fatalf("Dial: nil")
}
}
func TestDialTLSNoVerify(t *testing.T) {
func TestDialTLS(t *testing.T) {
s := newTLSServer(t, cstHandlerConfig{})
defer s.Close()
d := cstDialer
d.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
ws, _, err := d.Dial(s.URL, nil)
if err != nil {
t.Fatalf("Dial: %v", err)
@ -468,7 +442,8 @@ func TestHandshakeTimeoutInContext(t *testing.T) {
s := newServer(t, cstHandlerConfig{})
defer s.Close()
d := cstDialerWithoutHandshakeTimeout
d := cstDialer
d.HandshakeTimeout = 0
d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) {
netDialer := &net.Dialer{}
c, err := netDialer.DialContext(ctx, n, a)
@ -619,33 +594,195 @@ func TestRespOnBadHandshake(t *testing.T) {
}
}
// TestHostHeader confirms that the host header provided in the call to Dial is
// sent to the server.
func TestHostHeader(t *testing.T) {
s := newServer(t, cstHandlerConfig{})
defer s.Close()
type testLogWriter struct {
t *testing.T
}
specifiedHost := make(chan string, 1)
origHandler := s.Server.Config.Handler
func (w testLogWriter) Write(p []byte) (int, error) {
w.t.Logf("%s", p)
return len(p), nil
}
// Capture the request Host header.
s.Server.Config.Handler = http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
specifiedHost <- r.Host
origHandler.ServeHTTP(w, r)
})
// TestHost tests handling of host names and confirms that it matches net/http.
func TestHost(t *testing.T) {
ws, _, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}})
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer ws.Close()
upgrader := Upgrader{}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if IsWebSocketUpgrade(r) {
c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
if err != nil {
t.Fatal(err)
}
c.Close()
} else {
w.Header().Set("X-Test-Host", r.Host)
}
})
if gotHost := <-specifiedHost; gotHost != "testhost" {
t.Fatalf("gotHost = %q, want \"testhost\"", gotHost)
server := httptest.NewServer(handler)
defer server.Close()
tlsServer := httptest.NewTLSServer(handler)
defer tlsServer.Close()
addrs := map[*httptest.Server]string{server: server.Listener.Addr().String(), tlsServer: tlsServer.Listener.Addr().String()}
wsProtos := map[*httptest.Server]string{server: "ws://", tlsServer: "wss://"}
httpProtos := map[*httptest.Server]string{server: "http://", tlsServer: "https://"}
// Avoid log noise from net/http server by logging to testing.T
server.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
tlsServer.Config.ErrorLog = server.Config.ErrorLog
cas := rootCAs(t, tlsServer)
tests := []struct {
fail bool // true if dial / get should fail
server *httptest.Server // server to use
url string // host for request URI
header string // optional request host header
tls string // optiona host for tls ServerName
wantAddr string // expected host for dial
wantHeader string // expected request header on server
insecureSkipVerify bool
}{
{
server: server,
url: addrs[server],
wantAddr: addrs[server],
wantHeader: addrs[server],
},
{
server: tlsServer,
url: addrs[tlsServer],
wantAddr: addrs[tlsServer],
wantHeader: addrs[tlsServer],
},
{
server: server,
url: addrs[server],
header: "badhost.com",
wantAddr: addrs[server],
wantHeader: "badhost.com",
},
{
server: tlsServer,
url: addrs[tlsServer],
header: "badhost.com",
wantAddr: addrs[tlsServer],
wantHeader: "badhost.com",
},
{
server: server,
url: "example.com",
header: "badhost.com",
wantAddr: "example.com:80",
wantHeader: "badhost.com",
},
{
server: tlsServer,
url: "example.com",
header: "badhost.com",
wantAddr: "example.com:443",
wantHeader: "badhost.com",
},
{
server: server,
url: "badhost.com",
header: "example.com",
wantAddr: "badhost.com:80",
wantHeader: "example.com",
},
{
fail: true,
server: tlsServer,
url: "badhost.com",
header: "example.com",
wantAddr: "badhost.com:443",
},
{
server: tlsServer,
url: "badhost.com",
insecureSkipVerify: true,
wantAddr: "badhost.com:443",
wantHeader: "badhost.com",
},
{
server: tlsServer,
url: "badhost.com",
tls: "example.com",
wantAddr: "badhost.com:443",
wantHeader: "badhost.com",
},
}
sendRecv(t, ws)
for i, tt := range tests {
tls := &tls.Config{
RootCAs: cas,
ServerName: tt.tls,
InsecureSkipVerify: tt.insecureSkipVerify,
}
var gotAddr string
dialer := Dialer{
NetDial: func(network, addr string) (net.Conn, error) {
gotAddr = addr
return net.Dial(network, addrs[tt.server])
},
TLSClientConfig: tls,
}
// Test websocket dial
h := http.Header{}
if tt.header != "" {
h.Set("Host", tt.header)
}
c, resp, err := dialer.Dial(wsProtos[tt.server]+tt.url+"/", h)
if err == nil {
c.Close()
}
check := func(protos map[*httptest.Server]string) {
name := fmt.Sprintf("%d: %s%s/ header[Host]=%q, tls.ServerName=%q", i+1, protos[tt.server], tt.url, tt.header, tt.tls)
if gotAddr != tt.wantAddr {
t.Errorf("%s: got addr %s, want %s", name, gotAddr, tt.wantAddr)
}
switch {
case tt.fail && err == nil:
t.Errorf("%s: unexpected success", name)
case !tt.fail && err != nil:
t.Errorf("%s: unexpected error %v", name, err)
case !tt.fail && err == nil:
if gotHost := resp.Header.Get("X-Test-Host"); gotHost != tt.wantHeader {
t.Errorf("%s: got host %s, want %s", name, gotHost, tt.wantHeader)
}
}
}
check(wsProtos)
// Confirm that net/http has same result
transport := &http.Transport{
Dial: dialer.NetDial,
TLSClientConfig: dialer.TLSClientConfig,
}
req, _ := http.NewRequest("GET", httpProtos[tt.server]+tt.url+"/", nil)
if tt.header != "" {
req.Host = tt.header
}
client := &http.Client{Transport: transport}
resp, err = client.Do(req)
if err == nil {
resp.Body.Close()
}
transport.CloseIdleConnections()
check(httpProtos)
}
}
func TestDialCompression(t *testing.T) {
@ -785,19 +922,8 @@ func TestTracingDialWithContext(t *testing.T) {
s := newTLSServer(t, cstHandlerConfig{})
defer s.Close()
certs := x509.NewCertPool()
for _, c := range s.TLS.Certificates {
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
if err != nil {
t.Fatalf("error parsing server's root cert: %v", err)
}
for _, root := range roots {
certs.AddCert(root)
}
}
d := cstDialer
d.TLSClientConfig = &tls.Config{RootCAs: certs}
d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
ws, _, err := d.DialContext(ctx, s.URL, nil)
if err != nil {
@ -835,19 +961,8 @@ func TestEmptyTracingDialWithContext(t *testing.T) {
s := newTLSServer(t, cstHandlerConfig{})
defer s.Close()
certs := x509.NewCertPool()
for _, c := range s.TLS.Certificates {
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
if err != nil {
t.Fatalf("error parsing server's root cert: %v", err)
}
for _, root := range roots {
certs.AddCert(root)
}
}
d := cstDialer
d.TLSClientConfig = &tls.Config{RootCAs: certs}
d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
ws, _, err := d.DialContext(ctx, s.URL, nil)
if err != nil {

49
conn.go
View File

@ -453,7 +453,8 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
return err
}
func (c *Conn) prepWrite(messageType int) error {
// beginMessage prepares a connection and message writer for a new message.
func (c *Conn) beginMessage(mw *messageWriter, messageType int) error {
// Close previous writer if not already closed by the application. It's
// probably better to return an error in this situation, but we cannot
// change this without breaking existing applications.
@ -473,6 +474,10 @@ func (c *Conn) prepWrite(messageType int) error {
return err
}
mw.c = c
mw.frameType = messageType
mw.pos = maxFrameHeaderSize
if c.writeBuf == nil {
wpd, ok := c.writePool.Get().(writePoolData)
if ok {
@ -493,16 +498,11 @@ func (c *Conn) prepWrite(messageType int) error {
// All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and
// PongMessage) are supported.
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
if err := c.prepWrite(messageType); err != nil {
var mw messageWriter
if err := c.beginMessage(&mw, messageType); err != nil {
return nil, err
}
mw := &messageWriter{
c: c,
frameType: messageType,
pos: maxFrameHeaderSize,
}
c.writer = mw
c.writer = &mw
if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
w := c.newCompressionWriter(c.writer, c.compressionLevel)
mw.compress = true
@ -519,10 +519,16 @@ type messageWriter struct {
err error
}
func (w *messageWriter) fatal(err error) error {
func (w *messageWriter) endMessage(err error) error {
if w.err != nil {
w.err = err
w.c.writer = nil
return err
}
c := w.c
w.err = err
c.writer = nil
if c.writePool != nil {
c.writePool.Put(writePoolData{buf: c.writeBuf})
c.writeBuf = nil
}
return err
}
@ -536,7 +542,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
// Check for invalid control frames.
if isControl(w.frameType) &&
(!final || length > maxControlFramePayloadSize) {
return w.fatal(errInvalidControlFrame)
return w.endMessage(errInvalidControlFrame)
}
b0 := byte(w.frameType)
@ -581,7 +587,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos])
if len(extra) > 0 {
return c.writeFatal(errors.New("websocket: internal error, extra used in client mode"))
return w.endMessage(c.writeFatal(errors.New("websocket: internal error, extra used in client mode")))
}
}
@ -602,15 +608,11 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
c.isWriting = false
if err != nil {
return w.fatal(err)
return w.endMessage(err)
}
if final {
c.writer = nil
if c.writePool != nil {
c.writePool.Put(writePoolData{buf: c.writeBuf})
c.writeBuf = nil
}
w.endMessage(errWriteClosed)
return nil
}
@ -711,7 +713,6 @@ func (w *messageWriter) Close() error {
if err := w.flushFrame(true, nil); err != nil {
return err
}
w.err = errWriteClosed
return nil
}
@ -744,10 +745,10 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error {
if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) {
// Fast path with no allocations and single frame.
if err := c.prepWrite(messageType); err != nil {
var mw messageWriter
if err := c.beginMessage(&mw, messageType); err != nil {
return err
}
mw := messageWriter{c: c, frameType: messageType, pos: maxFrameHeaderSize}
n := copy(c.writeBuf[mw.pos:], data)
mw.pos += n
data = data[n:]
@ -1043,7 +1044,7 @@ func (c *Conn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
// SetReadLimit sets the maximum size for a message read from the peer. If a
// SetReadLimit sets the maximum size in bytes for a message read from the peer. If a
// message exceeds the limit, the connection sends a close message to the peer
// and returns ErrReadLimit to the application.
func (c *Conn) SetReadLimit(limit int64) {

View File

@ -227,11 +227,16 @@ func (p *simpleBufferPool) Put(v interface{}) {
}
func TestWriteBufferPool(t *testing.T) {
const message = "Now is the time for all good people to come to the aid of the party."
var buf bytes.Buffer
var pool simpleBufferPool
wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil)
rc := newTestConn(&buf, nil, false)
// Specify writeBufferSize smaller than message size to ensure that pooling
// works with fragmented messages.
wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, len(message)-1, &pool, nil, nil)
if wc.writeBuf != nil {
t.Fatal("writeBuf not nil after create")
}
@ -249,8 +254,6 @@ func TestWriteBufferPool(t *testing.T) {
writeBufAddr := &wc.writeBuf[0]
const message = "Hello World!"
if _, err := io.WriteString(w, message); err != nil {
t.Fatalf("io.WriteString(w, message) returned %v", err)
}
@ -300,6 +303,7 @@ func TestWriteBufferPool(t *testing.T) {
}
}
// TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool.
func TestWriteBufferPoolSync(t *testing.T) {
var buf bytes.Buffer
var pool sync.Pool
@ -321,6 +325,56 @@ func TestWriteBufferPoolSync(t *testing.T) {
}
}
// errorWriter is an io.Writer than returns an error on all writes.
type errorWriter struct{}
func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("Error!") }
// TestWriteBufferPoolError ensures that buffer is returned to pool after error
// on write.
func TestWriteBufferPoolError(t *testing.T) {
// Part 1: Test NextWriter/Write/Close
var pool simpleBufferPool
wc := newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil)
w, err := wc.NextWriter(TextMessage)
if err != nil {
t.Fatalf("wc.NextWriter() returned %v", err)
}
if wc.writeBuf == nil {
t.Fatal("writeBuf is nil after NextWriter")
}
writeBufAddr := &wc.writeBuf[0]
if _, err := io.WriteString(w, "Hello"); err != nil {
t.Fatalf("io.WriteString(w, message) returned %v", err)
}
if err := w.Close(); err == nil {
t.Fatalf("w.Close() did not return error")
}
if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
t.Fatal("writeBuf not returned to pool")
}
// Part 2: Test WriteMessage
wc = newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil)
if err := wc.WriteMessage(TextMessage, []byte("Hello")); err == nil {
t.Fatalf("wc.WriteMessage did not return error")
}
if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
t.Fatal("writeBuf not returned to pool")
}
}
func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
const bufSize = 512

View File

@ -1,6 +1,6 @@
# Test Server
This package contains a server for the [Autobahn WebSockets Test Suite](http://autobahn.ws/testsuite).
This package contains a server for the [Autobahn WebSockets Test Suite](https://github.com/crossbario/autobahn-testsuite).
To test the server, run

View File

@ -28,7 +28,7 @@ type Upgrader struct {
// HandshakeTimeout specifies the duration for the handshake to complete.
HandshakeTimeout time.Duration
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer
// size is zero, then buffers allocated by the HTTP server are used. The
// I/O buffer sizes do not limit the size of the messages that can be sent
// or received.

132
util.go
View File

@ -31,68 +31,113 @@ func generateChallengeKey() (string, error) {
return base64.StdEncoding.EncodeToString(p), nil
}
// Octet types from RFC 2616.
var octetTypes [256]byte
const (
isTokenOctet = 1 << iota
isSpaceOctet
)
func init() {
// From RFC 2616
//
// OCTET = <any 8-bit sequence of data>
// CHAR = <any US-ASCII character (octets 0 - 127)>
// CTL = <any US-ASCII control character (octets 0 - 31) and DEL (127)>
// CR = <US-ASCII CR, carriage return (13)>
// LF = <US-ASCII LF, linefeed (10)>
// SP = <US-ASCII SP, space (32)>
// HT = <US-ASCII HT, horizontal-tab (9)>
// <"> = <US-ASCII double-quote mark (34)>
// CRLF = CR LF
// LWS = [CRLF] 1*( SP | HT )
// TEXT = <any OCTET except CTLs, but including LWS>
// separators = "(" | ")" | "<" | ">" | "@" | "," | ";" | ":" | "\" | <">
// | "/" | "[" | "]" | "?" | "=" | "{" | "}" | SP | HT
// token = 1*<any CHAR except CTLs or separators>
// qdtext = <any TEXT except <">>
for c := 0; c < 256; c++ {
var t byte
isCtl := c <= 31 || c == 127
isChar := 0 <= c && c <= 127
isSeparator := strings.IndexRune(" \t\"(),/:;<=>?@[]\\{}", rune(c)) >= 0
if strings.IndexRune(" \t\r\n", rune(c)) >= 0 {
t |= isSpaceOctet
}
if isChar && !isCtl && !isSeparator {
t |= isTokenOctet
}
octetTypes[c] = t
}
// Token octets per RFC 2616.
var isTokenOctet = [256]bool{
'!': true,
'#': true,
'$': true,
'%': true,
'&': true,
'\'': true,
'*': true,
'+': true,
'-': true,
'.': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'W': true,
'V': true,
'X': true,
'Y': true,
'Z': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'|': true,
'~': true,
}
// skipSpace returns a slice of the string s with all leading RFC 2616 linear
// whitespace removed.
func skipSpace(s string) (rest string) {
i := 0
for ; i < len(s); i++ {
if octetTypes[s[i]]&isSpaceOctet == 0 {
if b := s[i]; b != ' ' && b != '\t' {
break
}
}
return s[i:]
}
// nextToken returns the leading RFC 2616 token of s and the string following
// the token.
func nextToken(s string) (token, rest string) {
i := 0
for ; i < len(s); i++ {
if octetTypes[s[i]]&isTokenOctet == 0 {
if !isTokenOctet[s[i]] {
break
}
}
return s[:i], s[i:]
}
// nextTokenOrQuoted returns the leading token or quoted string per RFC 2616
// and the string following the token or quoted string.
func nextTokenOrQuoted(s string) (value string, rest string) {
if !strings.HasPrefix(s, "\"") {
return nextToken(s)
@ -128,7 +173,8 @@ func nextTokenOrQuoted(s string) (value string, rest string) {
return "", ""
}
// equalASCIIFold returns true if s is equal to t with ASCII case folding.
// equalASCIIFold returns true if s is equal to t with ASCII case folding as
// defined in RFC 4790.
func equalASCIIFold(s, t string) bool {
for s != "" && t != "" {
sr, size := utf8.DecodeRuneInString(s)

View File

@ -17,6 +17,7 @@ var equalASCIIFoldTests = []struct {
{"WebSocket", "websocket", true},
{"websocket", "WebSocket", true},
{"Öyster", "öyster", false},
{"WebSocket", "WetSocket", false},
}
func TestEqualASCIIFold(t *testing.T) {