mirror of https://github.com/gorilla/websocket.git
Merge remote-tracking branch 'gorilla/master'
This commit is contained in:
commit
7de010c67d
19
.travis.yml
19
.travis.yml
|
@ -1,6 +1,17 @@
|
||||||
language: go
|
language: go
|
||||||
|
sudo: false
|
||||||
|
|
||||||
go:
|
matrix:
|
||||||
- 1.1
|
include:
|
||||||
- 1.2
|
- go: 1.4
|
||||||
- tip
|
- go: 1.5
|
||||||
|
- go: 1.6
|
||||||
|
- go: tip
|
||||||
|
allow_failures:
|
||||||
|
- go: tip
|
||||||
|
|
||||||
|
script:
|
||||||
|
- go get -t -v ./...
|
||||||
|
- diff -u <(echo -n) <(gofmt -d .)
|
||||||
|
- go vet $(go list ./... | grep -v /vendor/)
|
||||||
|
- go test -v -race ./...
|
||||||
|
|
65
client.go
65
client.go
|
@ -8,6 +8,7 @@ import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
@ -98,13 +99,20 @@ func parseURL(s string) (*url.URL, error) {
|
||||||
return nil, errMalformedURL
|
return nil, errMalformedURL
|
||||||
}
|
}
|
||||||
|
|
||||||
u.Host = s
|
if i := strings.Index(s, "?"); i >= 0 {
|
||||||
u.Opaque = "/"
|
u.RawQuery = s[i+1:]
|
||||||
if i := strings.Index(s, "/"); i >= 0 {
|
s = s[:i]
|
||||||
u.Host = s[:i]
|
|
||||||
u.Opaque = s[i:]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if i := strings.Index(s, "/"); i >= 0 {
|
||||||
|
u.Opaque = s[i:]
|
||||||
|
s = s[:i]
|
||||||
|
} else {
|
||||||
|
u.Opaque = "/"
|
||||||
|
}
|
||||||
|
|
||||||
|
u.Host = s
|
||||||
|
|
||||||
if strings.Contains(u.Host, "@") {
|
if strings.Contains(u.Host, "@") {
|
||||||
// Don't bother parsing user information because user information is
|
// Don't bother parsing user information because user information is
|
||||||
// not allowed in websocket URIs.
|
// not allowed in websocket URIs.
|
||||||
|
@ -266,11 +274,19 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
||||||
}
|
}
|
||||||
|
|
||||||
if proxyURL != nil {
|
if proxyURL != nil {
|
||||||
|
connectHeader := make(http.Header)
|
||||||
|
if user := proxyURL.User; user != nil {
|
||||||
|
proxyUser := user.Username()
|
||||||
|
if proxyPassword, passwordSet := user.Password(); passwordSet {
|
||||||
|
credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword))
|
||||||
|
connectHeader.Set("Proxy-Authorization", "Basic "+credential)
|
||||||
|
}
|
||||||
|
}
|
||||||
connectReq := &http.Request{
|
connectReq := &http.Request{
|
||||||
Method: "CONNECT",
|
Method: "CONNECT",
|
||||||
URL: &url.URL{Opaque: hostPort},
|
URL: &url.URL{Opaque: hostPort},
|
||||||
Host: hostPort,
|
Host: hostPort,
|
||||||
Header: make(http.Header),
|
Header: connectHeader,
|
||||||
}
|
}
|
||||||
|
|
||||||
connectReq.Write(netConn)
|
connectReq.Write(netConn)
|
||||||
|
@ -290,12 +306,8 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
||||||
}
|
}
|
||||||
|
|
||||||
if u.Scheme == "https" {
|
if u.Scheme == "https" {
|
||||||
cfg := d.TLSClientConfig
|
cfg := cloneTLSConfig(d.TLSClientConfig)
|
||||||
if cfg == nil {
|
if cfg.ServerName == "" {
|
||||||
cfg = &tls.Config{ServerName: hostNoPort}
|
|
||||||
} else if cfg.ServerName == "" {
|
|
||||||
shallowCopy := *cfg
|
|
||||||
cfg = &shallowCopy
|
|
||||||
cfg.ServerName = hostNoPort
|
cfg.ServerName = hostNoPort
|
||||||
}
|
}
|
||||||
tlsConn := tls.Client(netConn, cfg)
|
tlsConn := tls.Client(netConn, cfg)
|
||||||
|
@ -344,3 +356,32 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
||||||
netConn = nil // to avoid close in defer.
|
netConn = nil // to avoid close in defer.
|
||||||
return conn, resp, nil
|
return conn, resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// cloneTLSConfig clones all public fields except the fields
|
||||||
|
// SessionTicketsDisabled and SessionTicketKey. This avoids copying the
|
||||||
|
// sync.Mutex in the sync.Once and makes it safe to call cloneTLSConfig on a
|
||||||
|
// config in active use.
|
||||||
|
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
|
||||||
|
if cfg == nil {
|
||||||
|
return &tls.Config{}
|
||||||
|
}
|
||||||
|
return &tls.Config{
|
||||||
|
Rand: cfg.Rand,
|
||||||
|
Time: cfg.Time,
|
||||||
|
Certificates: cfg.Certificates,
|
||||||
|
NameToCertificate: cfg.NameToCertificate,
|
||||||
|
GetCertificate: cfg.GetCertificate,
|
||||||
|
RootCAs: cfg.RootCAs,
|
||||||
|
NextProtos: cfg.NextProtos,
|
||||||
|
ServerName: cfg.ServerName,
|
||||||
|
ClientAuth: cfg.ClientAuth,
|
||||||
|
ClientCAs: cfg.ClientCAs,
|
||||||
|
InsecureSkipVerify: cfg.InsecureSkipVerify,
|
||||||
|
CipherSuites: cfg.CipherSuites,
|
||||||
|
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
|
||||||
|
ClientSessionCache: cfg.ClientSessionCache,
|
||||||
|
MinVersion: cfg.MinVersion,
|
||||||
|
MaxVersion: cfg.MaxVersion,
|
||||||
|
CurvePreferences: cfg.CurvePreferences,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@ package websocket
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"encoding/base64"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
|
@ -41,9 +42,16 @@ type cstServer struct {
|
||||||
URL string
|
URL string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
cstPath = "/a/b"
|
||||||
|
cstRawQuery = "x=y"
|
||||||
|
cstRequestURI = cstPath + "?" + cstRawQuery
|
||||||
|
)
|
||||||
|
|
||||||
func newServer(t *testing.T) *cstServer {
|
func newServer(t *testing.T) *cstServer {
|
||||||
var s cstServer
|
var s cstServer
|
||||||
s.Server = httptest.NewServer(cstHandler{t})
|
s.Server = httptest.NewServer(cstHandler{t})
|
||||||
|
s.Server.URL += cstRequestURI
|
||||||
s.URL = makeWsProto(s.Server.URL)
|
s.URL = makeWsProto(s.Server.URL)
|
||||||
return &s
|
return &s
|
||||||
}
|
}
|
||||||
|
@ -51,11 +59,22 @@ func newServer(t *testing.T) *cstServer {
|
||||||
func newTLSServer(t *testing.T) *cstServer {
|
func newTLSServer(t *testing.T) *cstServer {
|
||||||
var s cstServer
|
var s cstServer
|
||||||
s.Server = httptest.NewTLSServer(cstHandler{t})
|
s.Server = httptest.NewTLSServer(cstHandler{t})
|
||||||
|
s.Server.URL += cstRequestURI
|
||||||
s.URL = makeWsProto(s.Server.URL)
|
s.URL = makeWsProto(s.Server.URL)
|
||||||
return &s
|
return &s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != cstPath {
|
||||||
|
t.Logf("path=%v, want %v", r.URL.Path, cstPath)
|
||||||
|
http.Error(w, "bad path", 400)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.URL.RawQuery != cstRawQuery {
|
||||||
|
t.Logf("query=%v, want %v", r.URL.RawQuery, cstRawQuery)
|
||||||
|
http.Error(w, "bad path", 400)
|
||||||
|
return
|
||||||
|
}
|
||||||
subprotos := Subprotocols(r)
|
subprotos := Subprotocols(r)
|
||||||
if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) {
|
if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) {
|
||||||
t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols)
|
t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols)
|
||||||
|
@ -157,6 +176,46 @@ func TestProxyDial(t *testing.T) {
|
||||||
cstDialer.Proxy = http.ProxyFromEnvironment
|
cstDialer.Proxy = http.ProxyFromEnvironment
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxyAuthorizationDial(t *testing.T) {
|
||||||
|
s := newServer(t)
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
surl, _ := url.Parse(s.URL)
|
||||||
|
surl.User = url.UserPassword("username", "password")
|
||||||
|
cstDialer.Proxy = http.ProxyURL(surl)
|
||||||
|
|
||||||
|
connect := false
|
||||||
|
origHandler := s.Server.Config.Handler
|
||||||
|
|
||||||
|
// Capture the request Host header.
|
||||||
|
s.Server.Config.Handler = http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
proxyAuth := r.Header.Get("Proxy-Authorization")
|
||||||
|
expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password"))
|
||||||
|
if r.Method == "CONNECT" && proxyAuth == expectedProxyAuth {
|
||||||
|
connect = true
|
||||||
|
w.WriteHeader(200)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !connect {
|
||||||
|
t.Log("connect with proxy authorization not recieved")
|
||||||
|
http.Error(w, "connect with proxy authorization not recieved", 405)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
origHandler.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
|
||||||
|
ws, _, err := cstDialer.Dial(s.URL, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Dial: %v", err)
|
||||||
|
}
|
||||||
|
defer ws.Close()
|
||||||
|
sendRecv(t, ws)
|
||||||
|
|
||||||
|
cstDialer.Proxy = http.ProxyFromEnvironment
|
||||||
|
}
|
||||||
|
|
||||||
func TestDial(t *testing.T) {
|
func TestDial(t *testing.T) {
|
||||||
s := newServer(t)
|
s := newServer(t)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
@ -188,7 +247,7 @@ func TestDialTLS(t *testing.T) {
|
||||||
d := cstDialer
|
d := cstDialer
|
||||||
d.NetDial = func(network, addr string) (net.Conn, error) { return net.Dial(network, u.Host) }
|
d.NetDial = func(network, addr string) (net.Conn, error) { return net.Dial(network, u.Host) }
|
||||||
d.TLSClientConfig = &tls.Config{RootCAs: certs}
|
d.TLSClientConfig = &tls.Config{RootCAs: certs}
|
||||||
ws, _, err := d.Dial("wss://example.com/", nil)
|
ws, _, err := d.Dial("wss://example.com"+cstRequestURI, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Dial: %v", err)
|
t.Fatalf("Dial: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,14 +13,17 @@ import (
|
||||||
var parseURLTests = []struct {
|
var parseURLTests = []struct {
|
||||||
s string
|
s string
|
||||||
u *url.URL
|
u *url.URL
|
||||||
|
rui string
|
||||||
}{
|
}{
|
||||||
{"ws://example.com/", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}},
|
{"ws://example.com/", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}, "/"},
|
||||||
{"ws://example.com", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}},
|
{"ws://example.com", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}, "/"},
|
||||||
{"ws://example.com:7777/", &url.URL{Scheme: "ws", Host: "example.com:7777", Opaque: "/"}},
|
{"ws://example.com:7777/", &url.URL{Scheme: "ws", Host: "example.com:7777", Opaque: "/"}, "/"},
|
||||||
{"wss://example.com/", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/"}},
|
{"wss://example.com/", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/"}, "/"},
|
||||||
{"wss://example.com/a/b", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b"}},
|
{"wss://example.com/a/b", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b"}, "/a/b"},
|
||||||
{"ss://example.com/a/b", nil},
|
{"ss://example.com/a/b", nil, ""},
|
||||||
{"ws://webmaster@example.com/", nil},
|
{"ws://webmaster@example.com/", nil, ""},
|
||||||
|
{"wss://example.com/a/b?x=y", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b", RawQuery: "x=y"}, "/a/b?x=y"},
|
||||||
|
{"wss://example.com?x=y", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/", RawQuery: "x=y"}, "/?x=y"},
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseURL(t *testing.T) {
|
func TestParseURL(t *testing.T) {
|
||||||
|
@ -30,14 +33,19 @@ func TestParseURL(t *testing.T) {
|
||||||
t.Errorf("parseURL(%q) returned error %v", tt.s, err)
|
t.Errorf("parseURL(%q) returned error %v", tt.s, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if tt.u == nil && err == nil {
|
if tt.u == nil {
|
||||||
|
if err == nil {
|
||||||
t.Errorf("parseURL(%q) did not return error", tt.s)
|
t.Errorf("parseURL(%q) did not return error", tt.s)
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(u, tt.u) {
|
if !reflect.DeepEqual(u, tt.u) {
|
||||||
t.Errorf("parseURL(%q) returned %v, want %v", tt.s, u, tt.u)
|
t.Errorf("parseURL(%q) = %v, want %v", tt.s, u, tt.u)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if u.RequestURI() != tt.rui {
|
||||||
|
t.Errorf("parseURL(%q).RequestURI() = %v, want %v", tt.s, u.RequestURI(), tt.rui)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
103
conn.go
103
conn.go
|
@ -102,7 +102,66 @@ type CloseError struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *CloseError) Error() string {
|
func (e *CloseError) Error() string {
|
||||||
return "websocket: close " + strconv.Itoa(e.Code) + " " + e.Text
|
s := []byte("websocket: close ")
|
||||||
|
s = strconv.AppendInt(s, int64(e.Code), 10)
|
||||||
|
switch e.Code {
|
||||||
|
case CloseNormalClosure:
|
||||||
|
s = append(s, " (normal)"...)
|
||||||
|
case CloseGoingAway:
|
||||||
|
s = append(s, " (going away)"...)
|
||||||
|
case CloseProtocolError:
|
||||||
|
s = append(s, " (protocol error)"...)
|
||||||
|
case CloseUnsupportedData:
|
||||||
|
s = append(s, " (unsupported data)"...)
|
||||||
|
case CloseNoStatusReceived:
|
||||||
|
s = append(s, " (no status)"...)
|
||||||
|
case CloseAbnormalClosure:
|
||||||
|
s = append(s, " (abnormal closure)"...)
|
||||||
|
case CloseInvalidFramePayloadData:
|
||||||
|
s = append(s, " (invalid payload data)"...)
|
||||||
|
case ClosePolicyViolation:
|
||||||
|
s = append(s, " (policy violation)"...)
|
||||||
|
case CloseMessageTooBig:
|
||||||
|
s = append(s, " (message too big)"...)
|
||||||
|
case CloseMandatoryExtension:
|
||||||
|
s = append(s, " (mandatory extension missing)"...)
|
||||||
|
case CloseInternalServerErr:
|
||||||
|
s = append(s, " (internal server error)"...)
|
||||||
|
case CloseTLSHandshake:
|
||||||
|
s = append(s, " (TLS handshake error)"...)
|
||||||
|
}
|
||||||
|
if e.Text != "" {
|
||||||
|
s = append(s, ": "...)
|
||||||
|
s = append(s, e.Text...)
|
||||||
|
}
|
||||||
|
return string(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsCloseError returns boolean indicating whether the error is a *CloseError
|
||||||
|
// with one of the specified codes.
|
||||||
|
func IsCloseError(err error, codes ...int) bool {
|
||||||
|
if e, ok := err.(*CloseError); ok {
|
||||||
|
for _, code := range codes {
|
||||||
|
if e.Code == code {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsUnexpectedCloseError returns boolean indicating whether the error is a
|
||||||
|
// *CloseError with a code not in the list of expected codes.
|
||||||
|
func IsUnexpectedCloseError(err error, expectedCodes ...int) bool {
|
||||||
|
if e, ok := err.(*CloseError); ok {
|
||||||
|
for _, code := range expectedCodes {
|
||||||
|
if e.Code == code {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -162,6 +221,7 @@ type Conn struct {
|
||||||
writeFrameType int // type of the current frame.
|
writeFrameType int // type of the current frame.
|
||||||
writeSeq int // incremented to invalidate message writers.
|
writeSeq int // incremented to invalidate message writers.
|
||||||
writeDeadline time.Time
|
writeDeadline time.Time
|
||||||
|
isWriting bool // for best-effort concurrent write detection
|
||||||
|
|
||||||
// Read fields
|
// Read fields
|
||||||
readMessageCompressed bool
|
readMessageCompressed bool
|
||||||
|
@ -177,6 +237,7 @@ type Conn struct {
|
||||||
readMaskKey [4]byte
|
readMaskKey [4]byte
|
||||||
handlePong func(string) error
|
handlePong func(string) error
|
||||||
handlePing func(string) error
|
handlePing func(string) error
|
||||||
|
readErrCount int
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
|
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
|
||||||
|
@ -323,9 +384,6 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
|
||||||
//
|
//
|
||||||
// There can be at most one open writer on a connection. NextWriter closes the
|
// There can be at most one open writer on a connection. NextWriter closes the
|
||||||
// previous writer if the application has not already done so.
|
// previous writer if the application has not already done so.
|
||||||
//
|
|
||||||
// The NextWriter method and the writers returned from the method cannot be
|
|
||||||
// accessed by more than one goroutine at a time.
|
|
||||||
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
|
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
|
||||||
if c.writeErr != nil {
|
if c.writeErr != nil {
|
||||||
return nil, c.writeErr
|
return nil, c.writeErr
|
||||||
|
@ -414,9 +472,22 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write the buffers to the connection.
|
// Write the buffers to the connection with best-effort detection of
|
||||||
|
// concurrent writes. See the concurrency section in the package
|
||||||
|
// documentation for more info.
|
||||||
|
|
||||||
|
if c.isWriting {
|
||||||
|
panic("concurrent write to websocket connection")
|
||||||
|
}
|
||||||
|
c.isWriting = true
|
||||||
|
|
||||||
c.writeErr = c.write(c.writeFrameType, c.writeDeadline, c.writeBuf[framePos:c.writePos], extra)
|
c.writeErr = c.write(c.writeFrameType, c.writeDeadline, c.writeBuf[framePos:c.writePos], extra)
|
||||||
|
|
||||||
|
if !c.isWriting {
|
||||||
|
panic("concurrent write to websocket connection")
|
||||||
|
}
|
||||||
|
c.isWriting = false
|
||||||
|
|
||||||
// Setup for next frame.
|
// Setup for next frame.
|
||||||
c.writePos = maxFrameHeaderSize
|
c.writePos = maxFrameHeaderSize
|
||||||
c.writeFrameType = continuationFrame
|
c.writeFrameType = continuationFrame
|
||||||
|
@ -725,13 +796,15 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
return noFrame, err
|
return noFrame, err
|
||||||
}
|
}
|
||||||
case CloseMessage:
|
case CloseMessage:
|
||||||
c.WriteControl(CloseMessage, []byte{}, time.Now().Add(writeWait))
|
echoMessage := []byte{}
|
||||||
closeCode := CloseNoStatusReceived
|
closeCode := CloseNoStatusReceived
|
||||||
closeText := ""
|
closeText := ""
|
||||||
if len(payload) >= 2 {
|
if len(payload) >= 2 {
|
||||||
|
echoMessage = payload[:2]
|
||||||
closeCode = int(binary.BigEndian.Uint16(payload))
|
closeCode = int(binary.BigEndian.Uint16(payload))
|
||||||
closeText = string(payload[2:])
|
closeText = string(payload[2:])
|
||||||
}
|
}
|
||||||
|
c.WriteControl(CloseMessage, echoMessage, time.Now().Add(writeWait))
|
||||||
return noFrame, &CloseError{Code: closeCode, Text: closeText}
|
return noFrame, &CloseError{Code: closeCode, Text: closeText}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -749,8 +822,10 @@ func (c *Conn) handleProtocolError(message string) error {
|
||||||
// There can be at most one open reader on a connection. NextReader discards
|
// There can be at most one open reader on a connection. NextReader discards
|
||||||
// the previous message if the application has not already consumed it.
|
// the previous message if the application has not already consumed it.
|
||||||
//
|
//
|
||||||
// The NextReader method and the readers returned from the method cannot be
|
// Applications must break out of the application's read loop when this method
|
||||||
// accessed by more than one goroutine at a time.
|
// returns a non-nil error value. Errors returned from this method are
|
||||||
|
// permanent. Once this method returns a non-nil error, all subsequent calls to
|
||||||
|
// this method return the same error.
|
||||||
func (c *Conn) NextReader() (int, io.Reader, error) {
|
func (c *Conn) NextReader() (int, io.Reader, error) {
|
||||||
|
|
||||||
c.readSeq++
|
c.readSeq++
|
||||||
|
@ -772,6 +847,15 @@ func (c *Conn) NextReader() (int, io.Reader, error) {
|
||||||
return frameType, r, nil
|
return frameType, r, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Applications that do handle the error returned from this method spin in
|
||||||
|
// tight loop on connection failure. To help application developers detect
|
||||||
|
// this error, panic on repeated reads to the failed connection.
|
||||||
|
c.readErrCount++
|
||||||
|
if c.readErrCount >= 1000 {
|
||||||
|
panic("repeated read on failed websocket connection")
|
||||||
|
}
|
||||||
|
|
||||||
return noFrame, nil, c.readErr
|
return noFrame, nil, c.readErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -798,6 +882,9 @@ func (r messageReader) Read(b []byte) (int, error) {
|
||||||
r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b[:n])
|
r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b[:n])
|
||||||
}
|
}
|
||||||
r.c.readRemaining -= int64(n)
|
r.c.readRemaining -= int64(n)
|
||||||
|
if r.c.readRemaining > 0 && r.c.readErr == io.EOF {
|
||||||
|
r.c.readErr = errUnexpectedEOF
|
||||||
|
}
|
||||||
return n, r.c.readErr
|
return n, r.c.readErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
130
conn_test.go
130
conn_test.go
|
@ -7,6 +7,7 @@ package websocket
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
@ -173,6 +174,41 @@ func TestCloseBeforeFinalFrame(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEOFWithinFrame(t *testing.T) {
|
||||||
|
const bufSize = 64
|
||||||
|
|
||||||
|
for n := 0; ; n++ {
|
||||||
|
var b bytes.Buffer
|
||||||
|
wc := newConn(fakeNetConn{Reader: nil, Writer: &b}, false, 1024, 1024)
|
||||||
|
rc := newConn(fakeNetConn{Reader: &b, Writer: nil}, true, 1024, 1024)
|
||||||
|
|
||||||
|
w, _ := wc.NextWriter(BinaryMessage)
|
||||||
|
w.Write(make([]byte, bufSize))
|
||||||
|
w.Close()
|
||||||
|
|
||||||
|
if n >= b.Len() {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
b.Truncate(n)
|
||||||
|
|
||||||
|
op, r, err := rc.NextReader()
|
||||||
|
if err == errUnexpectedEOF {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if op != BinaryMessage || err != nil {
|
||||||
|
t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
|
||||||
|
}
|
||||||
|
_, err = io.Copy(ioutil.Discard, r)
|
||||||
|
if err != errUnexpectedEOF {
|
||||||
|
t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
|
||||||
|
}
|
||||||
|
_, _, err = rc.NextReader()
|
||||||
|
if err != errUnexpectedEOF {
|
||||||
|
t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestEOFBeforeFinalFrame(t *testing.T) {
|
func TestEOFBeforeFinalFrame(t *testing.T) {
|
||||||
const bufSize = 512
|
const bufSize = 512
|
||||||
|
|
||||||
|
@ -270,3 +306,97 @@ func TestBufioReadBytes(t *testing.T) {
|
||||||
t.Fatalf("read returnd %d bytes, want %d bytes", len(p), len(m))
|
t.Fatalf("read returnd %d bytes, want %d bytes", len(p), len(m))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var closeErrorTests = []struct {
|
||||||
|
err error
|
||||||
|
codes []int
|
||||||
|
ok bool
|
||||||
|
}{
|
||||||
|
{&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true},
|
||||||
|
{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false},
|
||||||
|
{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true},
|
||||||
|
{errors.New("hello"), []int{CloseNormalClosure}, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloseError(t *testing.T) {
|
||||||
|
for _, tt := range closeErrorTests {
|
||||||
|
ok := IsCloseError(tt.err, tt.codes...)
|
||||||
|
if ok != tt.ok {
|
||||||
|
t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var unexpectedCloseErrorTests = []struct {
|
||||||
|
err error
|
||||||
|
codes []int
|
||||||
|
ok bool
|
||||||
|
}{
|
||||||
|
{&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false},
|
||||||
|
{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true},
|
||||||
|
{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false},
|
||||||
|
{errors.New("hello"), []int{CloseNormalClosure}, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnexpectedCloseErrors(t *testing.T) {
|
||||||
|
for _, tt := range unexpectedCloseErrorTests {
|
||||||
|
ok := IsUnexpectedCloseError(tt.err, tt.codes...)
|
||||||
|
if ok != tt.ok {
|
||||||
|
t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type blockingWriter struct {
|
||||||
|
c1, c2 chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w blockingWriter) Write(p []byte) (int, error) {
|
||||||
|
// Allow main to continue
|
||||||
|
close(w.c1)
|
||||||
|
// Wait for panic in main
|
||||||
|
<-w.c2
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrentWritePanic(t *testing.T) {
|
||||||
|
w := blockingWriter{make(chan struct{}), make(chan struct{})}
|
||||||
|
c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
|
||||||
|
go func() {
|
||||||
|
c.WriteMessage(TextMessage, []byte{})
|
||||||
|
}()
|
||||||
|
|
||||||
|
// wait for goroutine to block in write.
|
||||||
|
<-w.c1
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
close(w.c2)
|
||||||
|
if v := recover(); v != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
c.WriteMessage(TextMessage, []byte{})
|
||||||
|
t.Fatal("should not get here")
|
||||||
|
}
|
||||||
|
|
||||||
|
type failingReader struct{}
|
||||||
|
|
||||||
|
func (r failingReader) Read(p []byte) (int, error) {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFailedConnectionReadPanic(t *testing.T) {
|
||||||
|
c := newConn(fakeNetConn{Reader: failingReader{}, Writer: nil}, false, 1024, 1024)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if v := recover(); v != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for i := 0; i < 20000; i++ {
|
||||||
|
c.ReadMessage()
|
||||||
|
}
|
||||||
|
t.Fatal("should not get here")
|
||||||
|
}
|
||||||
|
|
51
doc.go
51
doc.go
|
@ -46,8 +46,7 @@
|
||||||
// method to get an io.WriteCloser, write the message to the writer and close
|
// method to get an io.WriteCloser, write the message to the writer and close
|
||||||
// the writer when done. To receive a message, call the connection NextReader
|
// the writer when done. To receive a message, call the connection NextReader
|
||||||
// method to get an io.Reader and read until io.EOF is returned. This snippet
|
// method to get an io.Reader and read until io.EOF is returned. This snippet
|
||||||
// snippet shows how to echo messages using the NextWriter and NextReader
|
// shows how to echo messages using the NextWriter and NextReader methods:
|
||||||
// methods:
|
|
||||||
//
|
//
|
||||||
// for {
|
// for {
|
||||||
// messageType, r, err := conn.NextReader()
|
// messageType, r, err := conn.NextReader()
|
||||||
|
@ -86,14 +85,32 @@
|
||||||
// and pong. Call the connection WriteControl, WriteMessage or NextWriter
|
// and pong. Call the connection WriteControl, WriteMessage or NextWriter
|
||||||
// methods to send a control message to the peer.
|
// methods to send a control message to the peer.
|
||||||
//
|
//
|
||||||
// Connections handle received ping and pong messages by invoking a callback
|
// Connections handle received close messages by sending a close message to the
|
||||||
// function set with SetPingHandler and SetPongHandler methods. These callback
|
// peer and returning a *CloseError from the the NextReader, ReadMessage or the
|
||||||
// functions can be invoked from the ReadMessage method, the NextReader method
|
// message Read method.
|
||||||
// or from a call to the data message reader returned from NextReader.
|
|
||||||
//
|
//
|
||||||
// Connections handle received close messages by returning an error from the
|
// Connections handle received ping and pong messages by invoking callback
|
||||||
// ReadMessage method, the NextReader method or from a call to the data message
|
// functions set with SetPingHandler and SetPongHandler methods. The callback
|
||||||
// reader returned from NextReader.
|
// functions are called from the NextReader, ReadMessage and the message Read
|
||||||
|
// methods.
|
||||||
|
//
|
||||||
|
// The default ping handler sends a pong to the peer. The application's reading
|
||||||
|
// goroutine can block for a short time while the handler writes the pong data
|
||||||
|
// to the connection.
|
||||||
|
//
|
||||||
|
// The application must read the connection to process ping, pong and close
|
||||||
|
// messages sent from the peer. If the application is not otherwise interested
|
||||||
|
// in messages from the peer, then the application should start a goroutine to
|
||||||
|
// read and discard messages from the peer. A simple example is:
|
||||||
|
//
|
||||||
|
// func readLoop(c *websocket.Conn) {
|
||||||
|
// for {
|
||||||
|
// if _, _, err := c.NextReader(); err != nil {
|
||||||
|
// c.Close()
|
||||||
|
// break
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
//
|
//
|
||||||
// Concurrency
|
// Concurrency
|
||||||
//
|
//
|
||||||
|
@ -108,22 +125,6 @@
|
||||||
// The Close and WriteControl methods can be called concurrently with all other
|
// The Close and WriteControl methods can be called concurrently with all other
|
||||||
// methods.
|
// methods.
|
||||||
//
|
//
|
||||||
// Read is Required
|
|
||||||
//
|
|
||||||
// The application must read the connection to process ping and close messages
|
|
||||||
// sent from the peer. If the application is not otherwise interested in
|
|
||||||
// messages from the peer, then the application should start a goroutine to read
|
|
||||||
// and discard messages from the peer. A simple example is:
|
|
||||||
//
|
|
||||||
// func readLoop(c *websocket.Conn) {
|
|
||||||
// for {
|
|
||||||
// if _, _, err := c.NextReader(); err != nil {
|
|
||||||
// c.Close()
|
|
||||||
// break
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// Origin Considerations
|
// Origin Considerations
|
||||||
//
|
//
|
||||||
// Web browsers allow Javascript applications to open a WebSocket connection to
|
// Web browsers allow Javascript applications to open a WebSocket connection to
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
// Copyright 2015 The Gorilla WebSocket Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package websocket_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
c *websocket.Conn
|
||||||
|
req *http.Request
|
||||||
|
)
|
||||||
|
|
||||||
|
// The websocket.IsUnexpectedCloseError function is useful for identifying
|
||||||
|
// application and protocol errors.
|
||||||
|
//
|
||||||
|
// This server application works with a client application running in the
|
||||||
|
// browser. The client application does not explicitly close the websocket. The
|
||||||
|
// only expected close message from the client has the code
|
||||||
|
// websocket.CloseGoingAway. All other other close messages are likely the
|
||||||
|
// result of an application or protocol error and are logged to aid debugging.
|
||||||
|
func ExampleIsUnexpectedCloseError() {
|
||||||
|
|
||||||
|
for {
|
||||||
|
messageType, p, err := c.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
|
||||||
|
log.Printf("error: %v, user-agent: %v", err, req.Header.Get("User-Agent"))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
processMesage(messageType, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func processMesage(mt int, p []byte) {}
|
||||||
|
|
||||||
|
// TestX prevents godoc from showing this entire file in the example. Remove
|
||||||
|
// this function when a second example is added.
|
||||||
|
func TestX(t *testing.T) {}
|
|
@ -51,6 +51,9 @@ func (c *connection) readPump() {
|
||||||
for {
|
for {
|
||||||
_, message, err := c.ws.ReadMessage()
|
_, message, err := c.ws.ReadMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
|
||||||
|
log.Printf("error: %v", err)
|
||||||
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
h.broadcast <- message
|
h.broadcast <- message
|
||||||
|
|
|
@ -10,6 +10,8 @@ import (
|
||||||
"flag"
|
"flag"
|
||||||
"log"
|
"log"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
|
@ -21,6 +23,9 @@ func main() {
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
log.SetFlags(0)
|
log.SetFlags(0)
|
||||||
|
|
||||||
|
interrupt := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(interrupt, os.Interrupt)
|
||||||
|
|
||||||
u := url.URL{Scheme: "ws", Host: *addr, Path: "/echo"}
|
u := url.URL{Scheme: "ws", Host: *addr, Path: "/echo"}
|
||||||
log.Printf("connecting to %s", u.String())
|
log.Printf("connecting to %s", u.String())
|
||||||
|
|
||||||
|
@ -30,13 +35,16 @@ func main() {
|
||||||
}
|
}
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
|
defer close(done)
|
||||||
for {
|
for {
|
||||||
_, message, err := c.ReadMessage()
|
_, message, err := c.ReadMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("read:", err)
|
log.Println("read:", err)
|
||||||
break
|
return
|
||||||
}
|
}
|
||||||
log.Printf("recv: %s", message)
|
log.Printf("recv: %s", message)
|
||||||
}
|
}
|
||||||
|
@ -45,11 +53,29 @@ func main() {
|
||||||
ticker := time.NewTicker(time.Second)
|
ticker := time.NewTicker(time.Second)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for t := range ticker.C {
|
for {
|
||||||
|
select {
|
||||||
|
case t := <-ticker.C:
|
||||||
err := c.WriteMessage(websocket.TextMessage, []byte(t.String()))
|
err := c.WriteMessage(websocket.TextMessage, []byte(t.String()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("write:", err)
|
log.Println("write:", err)
|
||||||
break
|
return
|
||||||
|
}
|
||||||
|
case <-interrupt:
|
||||||
|
log.Println("interrupt")
|
||||||
|
// To cleanly close a connection, a client should send a close
|
||||||
|
// frame and wait for the server to close the connection.
|
||||||
|
err := c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||||
|
if err != nil {
|
||||||
|
log.Println("write close:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
}
|
||||||
|
c.Close()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -120,7 +120,7 @@ func serveWs(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var lastMod time.Time
|
var lastMod time.Time
|
||||||
if n, err := strconv.ParseInt(r.FormValue("lastMod"), 16, 64); err != nil {
|
if n, err := strconv.ParseInt(r.FormValue("lastMod"), 16, 64); err == nil {
|
||||||
lastMod = time.Unix(0, n)
|
lastMod = time.Unix(0, n)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
10
server.go
10
server.go
|
@ -124,6 +124,9 @@ func (u *Upgrader) selectCompressionExtension(r *http.Request) (string, bool, er
|
||||||
// The responseHeader is included in the response to the client's upgrade
|
// The responseHeader is included in the response to the client's upgrade
|
||||||
// request. Use the responseHeader to specify cookies (Set-Cookie) and the
|
// request. Use the responseHeader to specify cookies (Set-Cookie) and the
|
||||||
// application negotiated subprotocol (Sec-Websocket-Protocol).
|
// application negotiated subprotocol (Sec-Websocket-Protocol).
|
||||||
|
//
|
||||||
|
// If the upgrade fails, then Upgrade replies to the client with an HTTP error
|
||||||
|
// response.
|
||||||
func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
|
func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
|
||||||
if r.Method != "GET" {
|
if r.Method != "GET" {
|
||||||
return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: method not GET")
|
return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: method not GET")
|
||||||
|
@ -291,3 +294,10 @@ func Subprotocols(r *http.Request) []string {
|
||||||
}
|
}
|
||||||
return protocols
|
return protocols
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsWebSocketUpgrade returns true if the client requested upgrade to the
|
||||||
|
// WebSocket protocol.
|
||||||
|
func IsWebSocketUpgrade(r *http.Request) bool {
|
||||||
|
return tokenListContainsValue(r.Header, "Connection", "upgrade") &&
|
||||||
|
tokenListContainsValue(r.Header, "Upgrade", "websocket")
|
||||||
|
}
|
||||||
|
|
|
@ -31,3 +31,21 @@ func TestSubprotocols(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var isWebSocketUpgradeTests = []struct {
|
||||||
|
ok bool
|
||||||
|
h http.Header
|
||||||
|
}{
|
||||||
|
{false, http.Header{"Upgrade": {"websocket"}}},
|
||||||
|
{false, http.Header{"Connection": {"upgrade"}}},
|
||||||
|
{true, http.Header{"Connection": {"upgRade"}, "Upgrade": {"WebSocket"}}},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsWebSocketUpgrade(t *testing.T) {
|
||||||
|
for _, tt := range isWebSocketUpgradeTests {
|
||||||
|
ok := IsWebSocketUpgrade(&http.Request{Header: tt.h})
|
||||||
|
if tt.ok != ok {
|
||||||
|
t.Errorf("IsWebSocketUpgrade(%v) returned %v, want %v", tt.h, ok, tt.ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue