diff --git a/.travis.yml b/.travis.yml index 8687342..66435ac 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,17 @@ language: go +sudo: false -go: - - 1.1 - - 1.2 - - tip +matrix: + include: + - go: 1.4 + - go: 1.5 + - go: 1.6 + - go: tip + allow_failures: + - go: tip + +script: + - go get -t -v ./... + - diff -u <(echo -n) <(gofmt -d .) + - go vet $(go list ./... | grep -v /vendor/) + - go test -v -race ./... diff --git a/client.go b/client.go index 8242c66..bd87bb5 100644 --- a/client.go +++ b/client.go @@ -8,6 +8,7 @@ import ( "bufio" "bytes" "crypto/tls" + "encoding/base64" "errors" "io" "io/ioutil" @@ -98,13 +99,20 @@ func parseURL(s string) (*url.URL, error) { return nil, errMalformedURL } - u.Host = s - u.Opaque = "/" - if i := strings.Index(s, "/"); i >= 0 { - u.Host = s[:i] - u.Opaque = s[i:] + if i := strings.Index(s, "?"); i >= 0 { + u.RawQuery = s[i+1:] + s = s[:i] } + if i := strings.Index(s, "/"); i >= 0 { + u.Opaque = s[i:] + s = s[:i] + } else { + u.Opaque = "/" + } + + u.Host = s + if strings.Contains(u.Host, "@") { // Don't bother parsing user information because user information is // not allowed in websocket URIs. @@ -266,11 +274,19 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re } if proxyURL != nil { + connectHeader := make(http.Header) + if user := proxyURL.User; user != nil { + proxyUser := user.Username() + if proxyPassword, passwordSet := user.Password(); passwordSet { + credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword)) + connectHeader.Set("Proxy-Authorization", "Basic "+credential) + } + } connectReq := &http.Request{ Method: "CONNECT", URL: &url.URL{Opaque: hostPort}, Host: hostPort, - Header: make(http.Header), + Header: connectHeader, } connectReq.Write(netConn) @@ -290,12 +306,8 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re } if u.Scheme == "https" { - cfg := d.TLSClientConfig - if cfg == nil { - cfg = &tls.Config{ServerName: hostNoPort} - } else if cfg.ServerName == "" { - shallowCopy := *cfg - cfg = &shallowCopy + cfg := cloneTLSConfig(d.TLSClientConfig) + if cfg.ServerName == "" { cfg.ServerName = hostNoPort } tlsConn := tls.Client(netConn, cfg) @@ -344,3 +356,32 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re netConn = nil // to avoid close in defer. return conn, resp, nil } + +// cloneTLSConfig clones all public fields except the fields +// SessionTicketsDisabled and SessionTicketKey. This avoids copying the +// sync.Mutex in the sync.Once and makes it safe to call cloneTLSConfig on a +// config in active use. +func cloneTLSConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return &tls.Config{} + } + return &tls.Config{ + Rand: cfg.Rand, + Time: cfg.Time, + Certificates: cfg.Certificates, + NameToCertificate: cfg.NameToCertificate, + GetCertificate: cfg.GetCertificate, + RootCAs: cfg.RootCAs, + NextProtos: cfg.NextProtos, + ServerName: cfg.ServerName, + ClientAuth: cfg.ClientAuth, + ClientCAs: cfg.ClientCAs, + InsecureSkipVerify: cfg.InsecureSkipVerify, + CipherSuites: cfg.CipherSuites, + PreferServerCipherSuites: cfg.PreferServerCipherSuites, + ClientSessionCache: cfg.ClientSessionCache, + MinVersion: cfg.MinVersion, + MaxVersion: cfg.MaxVersion, + CurvePreferences: cfg.CurvePreferences, + } +} diff --git a/client_server_test.go b/client_server_test.go index 05a7888..3f7345d 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -7,6 +7,7 @@ package websocket import ( "crypto/tls" "crypto/x509" + "encoding/base64" "io" "io/ioutil" "net" @@ -41,9 +42,16 @@ type cstServer struct { URL string } +const ( + cstPath = "/a/b" + cstRawQuery = "x=y" + cstRequestURI = cstPath + "?" + cstRawQuery +) + func newServer(t *testing.T) *cstServer { var s cstServer s.Server = httptest.NewServer(cstHandler{t}) + s.Server.URL += cstRequestURI s.URL = makeWsProto(s.Server.URL) return &s } @@ -51,11 +59,22 @@ func newServer(t *testing.T) *cstServer { func newTLSServer(t *testing.T) *cstServer { var s cstServer s.Server = httptest.NewTLSServer(cstHandler{t}) + s.Server.URL += cstRequestURI s.URL = makeWsProto(s.Server.URL) return &s } func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != cstPath { + t.Logf("path=%v, want %v", r.URL.Path, cstPath) + http.Error(w, "bad path", 400) + return + } + if r.URL.RawQuery != cstRawQuery { + t.Logf("query=%v, want %v", r.URL.RawQuery, cstRawQuery) + http.Error(w, "bad path", 400) + return + } subprotos := Subprotocols(r) if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) { t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols) @@ -157,6 +176,46 @@ func TestProxyDial(t *testing.T) { cstDialer.Proxy = http.ProxyFromEnvironment } +func TestProxyAuthorizationDial(t *testing.T) { + s := newServer(t) + defer s.Close() + + surl, _ := url.Parse(s.URL) + surl.User = url.UserPassword("username", "password") + cstDialer.Proxy = http.ProxyURL(surl) + + connect := false + origHandler := s.Server.Config.Handler + + // Capture the request Host header. + s.Server.Config.Handler = http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + proxyAuth := r.Header.Get("Proxy-Authorization") + expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password")) + if r.Method == "CONNECT" && proxyAuth == expectedProxyAuth { + connect = true + w.WriteHeader(200) + return + } + + if !connect { + t.Log("connect with proxy authorization not recieved") + http.Error(w, "connect with proxy authorization not recieved", 405) + return + } + origHandler.ServeHTTP(w, r) + }) + + ws, _, err := cstDialer.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + sendRecv(t, ws) + + cstDialer.Proxy = http.ProxyFromEnvironment +} + func TestDial(t *testing.T) { s := newServer(t) defer s.Close() @@ -188,7 +247,7 @@ func TestDialTLS(t *testing.T) { d := cstDialer d.NetDial = func(network, addr string) (net.Conn, error) { return net.Dial(network, u.Host) } d.TLSClientConfig = &tls.Config{RootCAs: certs} - ws, _, err := d.Dial("wss://example.com/", nil) + ws, _, err := d.Dial("wss://example.com"+cstRequestURI, nil) if err != nil { t.Fatalf("Dial: %v", err) } diff --git a/client_test.go b/client_test.go index 07a9cb4..7d2b084 100644 --- a/client_test.go +++ b/client_test.go @@ -11,16 +11,19 @@ import ( ) var parseURLTests = []struct { - s string - u *url.URL + s string + u *url.URL + rui string }{ - {"ws://example.com/", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}}, - {"ws://example.com", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}}, - {"ws://example.com:7777/", &url.URL{Scheme: "ws", Host: "example.com:7777", Opaque: "/"}}, - {"wss://example.com/", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/"}}, - {"wss://example.com/a/b", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b"}}, - {"ss://example.com/a/b", nil}, - {"ws://webmaster@example.com/", nil}, + {"ws://example.com/", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}, "/"}, + {"ws://example.com", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}, "/"}, + {"ws://example.com:7777/", &url.URL{Scheme: "ws", Host: "example.com:7777", Opaque: "/"}, "/"}, + {"wss://example.com/", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/"}, "/"}, + {"wss://example.com/a/b", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b"}, "/a/b"}, + {"ss://example.com/a/b", nil, ""}, + {"ws://webmaster@example.com/", nil, ""}, + {"wss://example.com/a/b?x=y", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b", RawQuery: "x=y"}, "/a/b?x=y"}, + {"wss://example.com?x=y", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/", RawQuery: "x=y"}, "/?x=y"}, } func TestParseURL(t *testing.T) { @@ -30,14 +33,19 @@ func TestParseURL(t *testing.T) { t.Errorf("parseURL(%q) returned error %v", tt.s, err) continue } - if tt.u == nil && err == nil { - t.Errorf("parseURL(%q) did not return error", tt.s) + if tt.u == nil { + if err == nil { + t.Errorf("parseURL(%q) did not return error", tt.s) + } continue } if !reflect.DeepEqual(u, tt.u) { - t.Errorf("parseURL(%q) returned %v, want %v", tt.s, u, tt.u) + t.Errorf("parseURL(%q) = %v, want %v", tt.s, u, tt.u) continue } + if u.RequestURI() != tt.rui { + t.Errorf("parseURL(%q).RequestURI() = %v, want %v", tt.s, u.RequestURI(), tt.rui) + } } } diff --git a/conn.go b/conn.go index a7c5345..eb7b223 100644 --- a/conn.go +++ b/conn.go @@ -102,7 +102,66 @@ type CloseError struct { } func (e *CloseError) Error() string { - return "websocket: close " + strconv.Itoa(e.Code) + " " + e.Text + s := []byte("websocket: close ") + s = strconv.AppendInt(s, int64(e.Code), 10) + switch e.Code { + case CloseNormalClosure: + s = append(s, " (normal)"...) + case CloseGoingAway: + s = append(s, " (going away)"...) + case CloseProtocolError: + s = append(s, " (protocol error)"...) + case CloseUnsupportedData: + s = append(s, " (unsupported data)"...) + case CloseNoStatusReceived: + s = append(s, " (no status)"...) + case CloseAbnormalClosure: + s = append(s, " (abnormal closure)"...) + case CloseInvalidFramePayloadData: + s = append(s, " (invalid payload data)"...) + case ClosePolicyViolation: + s = append(s, " (policy violation)"...) + case CloseMessageTooBig: + s = append(s, " (message too big)"...) + case CloseMandatoryExtension: + s = append(s, " (mandatory extension missing)"...) + case CloseInternalServerErr: + s = append(s, " (internal server error)"...) + case CloseTLSHandshake: + s = append(s, " (TLS handshake error)"...) + } + if e.Text != "" { + s = append(s, ": "...) + s = append(s, e.Text...) + } + return string(s) +} + +// IsCloseError returns boolean indicating whether the error is a *CloseError +// with one of the specified codes. +func IsCloseError(err error, codes ...int) bool { + if e, ok := err.(*CloseError); ok { + for _, code := range codes { + if e.Code == code { + return true + } + } + } + return false +} + +// IsUnexpectedCloseError returns boolean indicating whether the error is a +// *CloseError with a code not in the list of expected codes. +func IsUnexpectedCloseError(err error, expectedCodes ...int) bool { + if e, ok := err.(*CloseError); ok { + for _, code := range expectedCodes { + if e.Code == code { + return false + } + } + return true + } + return false } var ( @@ -162,6 +221,7 @@ type Conn struct { writeFrameType int // type of the current frame. writeSeq int // incremented to invalidate message writers. writeDeadline time.Time + isWriting bool // for best-effort concurrent write detection // Read fields readMessageCompressed bool @@ -177,6 +237,7 @@ type Conn struct { readMaskKey [4]byte handlePong func(string) error handlePing func(string) error + readErrCount int } func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { @@ -323,9 +384,6 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er // // There can be at most one open writer on a connection. NextWriter closes the // previous writer if the application has not already done so. -// -// The NextWriter method and the writers returned from the method cannot be -// accessed by more than one goroutine at a time. func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { if c.writeErr != nil { return nil, c.writeErr @@ -414,9 +472,22 @@ func (c *Conn) flushFrame(final bool, extra []byte) error { } } - // Write the buffers to the connection. + // Write the buffers to the connection with best-effort detection of + // concurrent writes. See the concurrency section in the package + // documentation for more info. + + if c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = true + c.writeErr = c.write(c.writeFrameType, c.writeDeadline, c.writeBuf[framePos:c.writePos], extra) + if !c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = false + // Setup for next frame. c.writePos = maxFrameHeaderSize c.writeFrameType = continuationFrame @@ -725,13 +796,15 @@ func (c *Conn) advanceFrame() (int, error) { return noFrame, err } case CloseMessage: - c.WriteControl(CloseMessage, []byte{}, time.Now().Add(writeWait)) + echoMessage := []byte{} closeCode := CloseNoStatusReceived closeText := "" if len(payload) >= 2 { + echoMessage = payload[:2] closeCode = int(binary.BigEndian.Uint16(payload)) closeText = string(payload[2:]) } + c.WriteControl(CloseMessage, echoMessage, time.Now().Add(writeWait)) return noFrame, &CloseError{Code: closeCode, Text: closeText} } @@ -749,8 +822,10 @@ func (c *Conn) handleProtocolError(message string) error { // There can be at most one open reader on a connection. NextReader discards // the previous message if the application has not already consumed it. // -// The NextReader method and the readers returned from the method cannot be -// accessed by more than one goroutine at a time. +// Applications must break out of the application's read loop when this method +// returns a non-nil error value. Errors returned from this method are +// permanent. Once this method returns a non-nil error, all subsequent calls to +// this method return the same error. func (c *Conn) NextReader() (int, io.Reader, error) { c.readSeq++ @@ -772,6 +847,15 @@ func (c *Conn) NextReader() (int, io.Reader, error) { return frameType, r, nil } } + + // Applications that do handle the error returned from this method spin in + // tight loop on connection failure. To help application developers detect + // this error, panic on repeated reads to the failed connection. + c.readErrCount++ + if c.readErrCount >= 1000 { + panic("repeated read on failed websocket connection") + } + return noFrame, nil, c.readErr } @@ -798,6 +882,9 @@ func (r messageReader) Read(b []byte) (int, error) { r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b[:n]) } r.c.readRemaining -= int64(n) + if r.c.readRemaining > 0 && r.c.readErr == io.EOF { + r.c.readErr = errUnexpectedEOF + } return n, r.c.readErr } diff --git a/conn_test.go b/conn_test.go index 02f2d4b..0243c11 100644 --- a/conn_test.go +++ b/conn_test.go @@ -7,6 +7,7 @@ package websocket import ( "bufio" "bytes" + "errors" "fmt" "io" "io/ioutil" @@ -173,6 +174,41 @@ func TestCloseBeforeFinalFrame(t *testing.T) { } } +func TestEOFWithinFrame(t *testing.T) { + const bufSize = 64 + + for n := 0; ; n++ { + var b bytes.Buffer + wc := newConn(fakeNetConn{Reader: nil, Writer: &b}, false, 1024, 1024) + rc := newConn(fakeNetConn{Reader: &b, Writer: nil}, true, 1024, 1024) + + w, _ := wc.NextWriter(BinaryMessage) + w.Write(make([]byte, bufSize)) + w.Close() + + if n >= b.Len() { + break + } + b.Truncate(n) + + op, r, err := rc.NextReader() + if err == errUnexpectedEOF { + continue + } + if op != BinaryMessage || err != nil { + t.Fatalf("%d: NextReader() returned %d, %v", n, op, err) + } + _, err = io.Copy(ioutil.Discard, r) + if err != errUnexpectedEOF { + t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF) + } + _, _, err = rc.NextReader() + if err != errUnexpectedEOF { + t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF) + } + } +} + func TestEOFBeforeFinalFrame(t *testing.T) { const bufSize = 512 @@ -270,3 +306,97 @@ func TestBufioReadBytes(t *testing.T) { t.Fatalf("read returnd %d bytes, want %d bytes", len(p), len(m)) } } + +var closeErrorTests = []struct { + err error + codes []int + ok bool +}{ + {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true}, + {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false}, + {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true}, + {errors.New("hello"), []int{CloseNormalClosure}, false}, +} + +func TestCloseError(t *testing.T) { + for _, tt := range closeErrorTests { + ok := IsCloseError(tt.err, tt.codes...) + if ok != tt.ok { + t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok) + } + } +} + +var unexpectedCloseErrorTests = []struct { + err error + codes []int + ok bool +}{ + {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false}, + {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true}, + {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false}, + {errors.New("hello"), []int{CloseNormalClosure}, false}, +} + +func TestUnexpectedCloseErrors(t *testing.T) { + for _, tt := range unexpectedCloseErrorTests { + ok := IsUnexpectedCloseError(tt.err, tt.codes...) + if ok != tt.ok { + t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok) + } + } +} + +type blockingWriter struct { + c1, c2 chan struct{} +} + +func (w blockingWriter) Write(p []byte) (int, error) { + // Allow main to continue + close(w.c1) + // Wait for panic in main + <-w.c2 + return len(p), nil +} + +func TestConcurrentWritePanic(t *testing.T) { + w := blockingWriter{make(chan struct{}), make(chan struct{})} + c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) + go func() { + c.WriteMessage(TextMessage, []byte{}) + }() + + // wait for goroutine to block in write. + <-w.c1 + + defer func() { + close(w.c2) + if v := recover(); v != nil { + return + } + }() + + c.WriteMessage(TextMessage, []byte{}) + t.Fatal("should not get here") +} + +type failingReader struct{} + +func (r failingReader) Read(p []byte) (int, error) { + return 0, io.EOF +} + +func TestFailedConnectionReadPanic(t *testing.T) { + c := newConn(fakeNetConn{Reader: failingReader{}, Writer: nil}, false, 1024, 1024) + + defer func() { + if v := recover(); v != nil { + return + } + }() + + for i := 0; i < 20000; i++ { + c.ReadMessage() + } + t.Fatal("should not get here") +} diff --git a/doc.go b/doc.go index 7228627..c901a7a 100644 --- a/doc.go +++ b/doc.go @@ -46,8 +46,7 @@ // method to get an io.WriteCloser, write the message to the writer and close // the writer when done. To receive a message, call the connection NextReader // method to get an io.Reader and read until io.EOF is returned. This snippet -// snippet shows how to echo messages using the NextWriter and NextReader -// methods: +// shows how to echo messages using the NextWriter and NextReader methods: // // for { // messageType, r, err := conn.NextReader() @@ -86,14 +85,32 @@ // and pong. Call the connection WriteControl, WriteMessage or NextWriter // methods to send a control message to the peer. // -// Connections handle received ping and pong messages by invoking a callback -// function set with SetPingHandler and SetPongHandler methods. These callback -// functions can be invoked from the ReadMessage method, the NextReader method -// or from a call to the data message reader returned from NextReader. +// Connections handle received close messages by sending a close message to the +// peer and returning a *CloseError from the the NextReader, ReadMessage or the +// message Read method. // -// Connections handle received close messages by returning an error from the -// ReadMessage method, the NextReader method or from a call to the data message -// reader returned from NextReader. +// Connections handle received ping and pong messages by invoking callback +// functions set with SetPingHandler and SetPongHandler methods. The callback +// functions are called from the NextReader, ReadMessage and the message Read +// methods. +// +// The default ping handler sends a pong to the peer. The application's reading +// goroutine can block for a short time while the handler writes the pong data +// to the connection. +// +// The application must read the connection to process ping, pong and close +// messages sent from the peer. If the application is not otherwise interested +// in messages from the peer, then the application should start a goroutine to +// read and discard messages from the peer. A simple example is: +// +// func readLoop(c *websocket.Conn) { +// for { +// if _, _, err := c.NextReader(); err != nil { +// c.Close() +// break +// } +// } +// } // // Concurrency // @@ -108,22 +125,6 @@ // The Close and WriteControl methods can be called concurrently with all other // methods. // -// Read is Required -// -// The application must read the connection to process ping and close messages -// sent from the peer. If the application is not otherwise interested in -// messages from the peer, then the application should start a goroutine to read -// and discard messages from the peer. A simple example is: -// -// func readLoop(c *websocket.Conn) { -// for { -// if _, _, err := c.NextReader(); err != nil { -// c.Close() -// break -// } -// } -// } -// // Origin Considerations // // Web browsers allow Javascript applications to open a WebSocket connection to @@ -141,9 +142,9 @@ // An application can allow connections from any origin by specifying a // function that always returns true: // -// var upgrader = websocket.Upgrader{ +// var upgrader = websocket.Upgrader{ // CheckOrigin: func(r *http.Request) bool { return true }, -// } +// } // // The deprecated Upgrade function does not enforce an origin policy. It's the // application's responsibility to check the Origin header before calling diff --git a/example_test.go b/example_test.go new file mode 100644 index 0000000..96449ea --- /dev/null +++ b/example_test.go @@ -0,0 +1,46 @@ +// Copyright 2015 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket_test + +import ( + "log" + "net/http" + "testing" + + "github.com/gorilla/websocket" +) + +var ( + c *websocket.Conn + req *http.Request +) + +// The websocket.IsUnexpectedCloseError function is useful for identifying +// application and protocol errors. +// +// This server application works with a client application running in the +// browser. The client application does not explicitly close the websocket. The +// only expected close message from the client has the code +// websocket.CloseGoingAway. All other other close messages are likely the +// result of an application or protocol error and are logged to aid debugging. +func ExampleIsUnexpectedCloseError() { + + for { + messageType, p, err := c.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) { + log.Printf("error: %v, user-agent: %v", err, req.Header.Get("User-Agent")) + } + return + } + processMesage(messageType, p) + } +} + +func processMesage(mt int, p []byte) {} + +// TestX prevents godoc from showing this entire file in the example. Remove +// this function when a second example is added. +func TestX(t *testing.T) {} diff --git a/examples/chat/conn.go b/examples/chat/conn.go index 22816f0..40fd38c 100644 --- a/examples/chat/conn.go +++ b/examples/chat/conn.go @@ -51,6 +51,9 @@ func (c *connection) readPump() { for { _, message, err := c.ws.ReadMessage() if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) { + log.Printf("error: %v", err) + } break } h.broadcast <- message diff --git a/examples/echo/client.go b/examples/echo/client.go index 45a0231..6578094 100644 --- a/examples/echo/client.go +++ b/examples/echo/client.go @@ -10,6 +10,8 @@ import ( "flag" "log" "net/url" + "os" + "os/signal" "time" "github.com/gorilla/websocket" @@ -21,6 +23,9 @@ func main() { flag.Parse() log.SetFlags(0) + interrupt := make(chan os.Signal, 1) + signal.Notify(interrupt, os.Interrupt) + u := url.URL{Scheme: "ws", Host: *addr, Path: "/echo"} log.Printf("connecting to %s", u.String()) @@ -30,13 +35,16 @@ func main() { } defer c.Close() + done := make(chan struct{}) + go func() { defer c.Close() + defer close(done) for { _, message, err := c.ReadMessage() if err != nil { log.Println("read:", err) - break + return } log.Printf("recv: %s", message) } @@ -45,11 +53,29 @@ func main() { ticker := time.NewTicker(time.Second) defer ticker.Stop() - for t := range ticker.C { - err := c.WriteMessage(websocket.TextMessage, []byte(t.String())) - if err != nil { - log.Println("write:", err) - break + for { + select { + case t := <-ticker.C: + err := c.WriteMessage(websocket.TextMessage, []byte(t.String())) + if err != nil { + log.Println("write:", err) + return + } + case <-interrupt: + log.Println("interrupt") + // To cleanly close a connection, a client should send a close + // frame and wait for the server to close the connection. + err := c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + if err != nil { + log.Println("write close:", err) + return + } + select { + case <-done: + case <-time.After(time.Second): + } + c.Close() + return } } } diff --git a/examples/filewatch/main.go b/examples/filewatch/main.go index a2c7b85..2ac2b32 100644 --- a/examples/filewatch/main.go +++ b/examples/filewatch/main.go @@ -120,7 +120,7 @@ func serveWs(w http.ResponseWriter, r *http.Request) { } var lastMod time.Time - if n, err := strconv.ParseInt(r.FormValue("lastMod"), 16, 64); err != nil { + if n, err := strconv.ParseInt(r.FormValue("lastMod"), 16, 64); err == nil { lastMod = time.Unix(0, n) } diff --git a/server.go b/server.go index 8167e13..befa0c9 100644 --- a/server.go +++ b/server.go @@ -124,6 +124,9 @@ func (u *Upgrader) selectCompressionExtension(r *http.Request) (string, bool, er // The responseHeader is included in the response to the client's upgrade // request. Use the responseHeader to specify cookies (Set-Cookie) and the // application negotiated subprotocol (Sec-Websocket-Protocol). +// +// If the upgrade fails, then Upgrade replies to the client with an HTTP error +// response. func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) { if r.Method != "GET" { return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: method not GET") @@ -291,3 +294,10 @@ func Subprotocols(r *http.Request) []string { } return protocols } + +// IsWebSocketUpgrade returns true if the client requested upgrade to the +// WebSocket protocol. +func IsWebSocketUpgrade(r *http.Request) bool { + return tokenListContainsValue(r.Header, "Connection", "upgrade") && + tokenListContainsValue(r.Header, "Upgrade", "websocket") +} diff --git a/server_test.go b/server_test.go index ead0776..0a28141 100644 --- a/server_test.go +++ b/server_test.go @@ -31,3 +31,21 @@ func TestSubprotocols(t *testing.T) { } } } + +var isWebSocketUpgradeTests = []struct { + ok bool + h http.Header +}{ + {false, http.Header{"Upgrade": {"websocket"}}}, + {false, http.Header{"Connection": {"upgrade"}}}, + {true, http.Header{"Connection": {"upgRade"}, "Upgrade": {"WebSocket"}}}, +} + +func TestIsWebSocketUpgrade(t *testing.T) { + for _, tt := range isWebSocketUpgradeTests { + ok := IsWebSocketUpgrade(&http.Request{Header: tt.h}) + if tt.ok != ok { + t.Errorf("IsWebSocketUpgrade(%v) returned %v, want %v", tt.h, ok, tt.ok) + } + } +}