mirror of https://github.com/gorilla/websocket.git
Merge remote-tracking branch 'gorilla/master'
This commit is contained in:
commit
7de010c67d
19
.travis.yml
19
.travis.yml
|
@ -1,6 +1,17 @@
|
|||
language: go
|
||||
sudo: false
|
||||
|
||||
go:
|
||||
- 1.1
|
||||
- 1.2
|
||||
- tip
|
||||
matrix:
|
||||
include:
|
||||
- go: 1.4
|
||||
- go: 1.5
|
||||
- go: 1.6
|
||||
- go: tip
|
||||
allow_failures:
|
||||
- go: tip
|
||||
|
||||
script:
|
||||
- go get -t -v ./...
|
||||
- diff -u <(echo -n) <(gofmt -d .)
|
||||
- go vet $(go list ./... | grep -v /vendor/)
|
||||
- go test -v -race ./...
|
||||
|
|
65
client.go
65
client.go
|
@ -8,6 +8,7 @@ import (
|
|||
"bufio"
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
|
@ -98,13 +99,20 @@ func parseURL(s string) (*url.URL, error) {
|
|||
return nil, errMalformedURL
|
||||
}
|
||||
|
||||
u.Host = s
|
||||
u.Opaque = "/"
|
||||
if i := strings.Index(s, "/"); i >= 0 {
|
||||
u.Host = s[:i]
|
||||
u.Opaque = s[i:]
|
||||
if i := strings.Index(s, "?"); i >= 0 {
|
||||
u.RawQuery = s[i+1:]
|
||||
s = s[:i]
|
||||
}
|
||||
|
||||
if i := strings.Index(s, "/"); i >= 0 {
|
||||
u.Opaque = s[i:]
|
||||
s = s[:i]
|
||||
} else {
|
||||
u.Opaque = "/"
|
||||
}
|
||||
|
||||
u.Host = s
|
||||
|
||||
if strings.Contains(u.Host, "@") {
|
||||
// Don't bother parsing user information because user information is
|
||||
// not allowed in websocket URIs.
|
||||
|
@ -266,11 +274,19 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
|||
}
|
||||
|
||||
if proxyURL != nil {
|
||||
connectHeader := make(http.Header)
|
||||
if user := proxyURL.User; user != nil {
|
||||
proxyUser := user.Username()
|
||||
if proxyPassword, passwordSet := user.Password(); passwordSet {
|
||||
credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword))
|
||||
connectHeader.Set("Proxy-Authorization", "Basic "+credential)
|
||||
}
|
||||
}
|
||||
connectReq := &http.Request{
|
||||
Method: "CONNECT",
|
||||
URL: &url.URL{Opaque: hostPort},
|
||||
Host: hostPort,
|
||||
Header: make(http.Header),
|
||||
Header: connectHeader,
|
||||
}
|
||||
|
||||
connectReq.Write(netConn)
|
||||
|
@ -290,12 +306,8 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
|||
}
|
||||
|
||||
if u.Scheme == "https" {
|
||||
cfg := d.TLSClientConfig
|
||||
if cfg == nil {
|
||||
cfg = &tls.Config{ServerName: hostNoPort}
|
||||
} else if cfg.ServerName == "" {
|
||||
shallowCopy := *cfg
|
||||
cfg = &shallowCopy
|
||||
cfg := cloneTLSConfig(d.TLSClientConfig)
|
||||
if cfg.ServerName == "" {
|
||||
cfg.ServerName = hostNoPort
|
||||
}
|
||||
tlsConn := tls.Client(netConn, cfg)
|
||||
|
@ -344,3 +356,32 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
|||
netConn = nil // to avoid close in defer.
|
||||
return conn, resp, nil
|
||||
}
|
||||
|
||||
// cloneTLSConfig clones all public fields except the fields
|
||||
// SessionTicketsDisabled and SessionTicketKey. This avoids copying the
|
||||
// sync.Mutex in the sync.Once and makes it safe to call cloneTLSConfig on a
|
||||
// config in active use.
|
||||
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
|
||||
if cfg == nil {
|
||||
return &tls.Config{}
|
||||
}
|
||||
return &tls.Config{
|
||||
Rand: cfg.Rand,
|
||||
Time: cfg.Time,
|
||||
Certificates: cfg.Certificates,
|
||||
NameToCertificate: cfg.NameToCertificate,
|
||||
GetCertificate: cfg.GetCertificate,
|
||||
RootCAs: cfg.RootCAs,
|
||||
NextProtos: cfg.NextProtos,
|
||||
ServerName: cfg.ServerName,
|
||||
ClientAuth: cfg.ClientAuth,
|
||||
ClientCAs: cfg.ClientCAs,
|
||||
InsecureSkipVerify: cfg.InsecureSkipVerify,
|
||||
CipherSuites: cfg.CipherSuites,
|
||||
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
|
||||
ClientSessionCache: cfg.ClientSessionCache,
|
||||
MinVersion: cfg.MinVersion,
|
||||
MaxVersion: cfg.MaxVersion,
|
||||
CurvePreferences: cfg.CurvePreferences,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ package websocket
|
|||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
|
@ -41,9 +42,16 @@ type cstServer struct {
|
|||
URL string
|
||||
}
|
||||
|
||||
const (
|
||||
cstPath = "/a/b"
|
||||
cstRawQuery = "x=y"
|
||||
cstRequestURI = cstPath + "?" + cstRawQuery
|
||||
)
|
||||
|
||||
func newServer(t *testing.T) *cstServer {
|
||||
var s cstServer
|
||||
s.Server = httptest.NewServer(cstHandler{t})
|
||||
s.Server.URL += cstRequestURI
|
||||
s.URL = makeWsProto(s.Server.URL)
|
||||
return &s
|
||||
}
|
||||
|
@ -51,11 +59,22 @@ func newServer(t *testing.T) *cstServer {
|
|||
func newTLSServer(t *testing.T) *cstServer {
|
||||
var s cstServer
|
||||
s.Server = httptest.NewTLSServer(cstHandler{t})
|
||||
s.Server.URL += cstRequestURI
|
||||
s.URL = makeWsProto(s.Server.URL)
|
||||
return &s
|
||||
}
|
||||
|
||||
func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != cstPath {
|
||||
t.Logf("path=%v, want %v", r.URL.Path, cstPath)
|
||||
http.Error(w, "bad path", 400)
|
||||
return
|
||||
}
|
||||
if r.URL.RawQuery != cstRawQuery {
|
||||
t.Logf("query=%v, want %v", r.URL.RawQuery, cstRawQuery)
|
||||
http.Error(w, "bad path", 400)
|
||||
return
|
||||
}
|
||||
subprotos := Subprotocols(r)
|
||||
if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) {
|
||||
t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols)
|
||||
|
@ -157,6 +176,46 @@ func TestProxyDial(t *testing.T) {
|
|||
cstDialer.Proxy = http.ProxyFromEnvironment
|
||||
}
|
||||
|
||||
func TestProxyAuthorizationDial(t *testing.T) {
|
||||
s := newServer(t)
|
||||
defer s.Close()
|
||||
|
||||
surl, _ := url.Parse(s.URL)
|
||||
surl.User = url.UserPassword("username", "password")
|
||||
cstDialer.Proxy = http.ProxyURL(surl)
|
||||
|
||||
connect := false
|
||||
origHandler := s.Server.Config.Handler
|
||||
|
||||
// Capture the request Host header.
|
||||
s.Server.Config.Handler = http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
proxyAuth := r.Header.Get("Proxy-Authorization")
|
||||
expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password"))
|
||||
if r.Method == "CONNECT" && proxyAuth == expectedProxyAuth {
|
||||
connect = true
|
||||
w.WriteHeader(200)
|
||||
return
|
||||
}
|
||||
|
||||
if !connect {
|
||||
t.Log("connect with proxy authorization not recieved")
|
||||
http.Error(w, "connect with proxy authorization not recieved", 405)
|
||||
return
|
||||
}
|
||||
origHandler.ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
ws, _, err := cstDialer.Dial(s.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
defer ws.Close()
|
||||
sendRecv(t, ws)
|
||||
|
||||
cstDialer.Proxy = http.ProxyFromEnvironment
|
||||
}
|
||||
|
||||
func TestDial(t *testing.T) {
|
||||
s := newServer(t)
|
||||
defer s.Close()
|
||||
|
@ -188,7 +247,7 @@ func TestDialTLS(t *testing.T) {
|
|||
d := cstDialer
|
||||
d.NetDial = func(network, addr string) (net.Conn, error) { return net.Dial(network, u.Host) }
|
||||
d.TLSClientConfig = &tls.Config{RootCAs: certs}
|
||||
ws, _, err := d.Dial("wss://example.com/", nil)
|
||||
ws, _, err := d.Dial("wss://example.com"+cstRequestURI, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
|
|
|
@ -11,16 +11,19 @@ import (
|
|||
)
|
||||
|
||||
var parseURLTests = []struct {
|
||||
s string
|
||||
u *url.URL
|
||||
s string
|
||||
u *url.URL
|
||||
rui string
|
||||
}{
|
||||
{"ws://example.com/", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}},
|
||||
{"ws://example.com", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}},
|
||||
{"ws://example.com:7777/", &url.URL{Scheme: "ws", Host: "example.com:7777", Opaque: "/"}},
|
||||
{"wss://example.com/", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/"}},
|
||||
{"wss://example.com/a/b", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b"}},
|
||||
{"ss://example.com/a/b", nil},
|
||||
{"ws://webmaster@example.com/", nil},
|
||||
{"ws://example.com/", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}, "/"},
|
||||
{"ws://example.com", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}, "/"},
|
||||
{"ws://example.com:7777/", &url.URL{Scheme: "ws", Host: "example.com:7777", Opaque: "/"}, "/"},
|
||||
{"wss://example.com/", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/"}, "/"},
|
||||
{"wss://example.com/a/b", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b"}, "/a/b"},
|
||||
{"ss://example.com/a/b", nil, ""},
|
||||
{"ws://webmaster@example.com/", nil, ""},
|
||||
{"wss://example.com/a/b?x=y", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b", RawQuery: "x=y"}, "/a/b?x=y"},
|
||||
{"wss://example.com?x=y", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/", RawQuery: "x=y"}, "/?x=y"},
|
||||
}
|
||||
|
||||
func TestParseURL(t *testing.T) {
|
||||
|
@ -30,14 +33,19 @@ func TestParseURL(t *testing.T) {
|
|||
t.Errorf("parseURL(%q) returned error %v", tt.s, err)
|
||||
continue
|
||||
}
|
||||
if tt.u == nil && err == nil {
|
||||
t.Errorf("parseURL(%q) did not return error", tt.s)
|
||||
if tt.u == nil {
|
||||
if err == nil {
|
||||
t.Errorf("parseURL(%q) did not return error", tt.s)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(u, tt.u) {
|
||||
t.Errorf("parseURL(%q) returned %v, want %v", tt.s, u, tt.u)
|
||||
t.Errorf("parseURL(%q) = %v, want %v", tt.s, u, tt.u)
|
||||
continue
|
||||
}
|
||||
if u.RequestURI() != tt.rui {
|
||||
t.Errorf("parseURL(%q).RequestURI() = %v, want %v", tt.s, u.RequestURI(), tt.rui)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
103
conn.go
103
conn.go
|
@ -102,7 +102,66 @@ type CloseError struct {
|
|||
}
|
||||
|
||||
func (e *CloseError) Error() string {
|
||||
return "websocket: close " + strconv.Itoa(e.Code) + " " + e.Text
|
||||
s := []byte("websocket: close ")
|
||||
s = strconv.AppendInt(s, int64(e.Code), 10)
|
||||
switch e.Code {
|
||||
case CloseNormalClosure:
|
||||
s = append(s, " (normal)"...)
|
||||
case CloseGoingAway:
|
||||
s = append(s, " (going away)"...)
|
||||
case CloseProtocolError:
|
||||
s = append(s, " (protocol error)"...)
|
||||
case CloseUnsupportedData:
|
||||
s = append(s, " (unsupported data)"...)
|
||||
case CloseNoStatusReceived:
|
||||
s = append(s, " (no status)"...)
|
||||
case CloseAbnormalClosure:
|
||||
s = append(s, " (abnormal closure)"...)
|
||||
case CloseInvalidFramePayloadData:
|
||||
s = append(s, " (invalid payload data)"...)
|
||||
case ClosePolicyViolation:
|
||||
s = append(s, " (policy violation)"...)
|
||||
case CloseMessageTooBig:
|
||||
s = append(s, " (message too big)"...)
|
||||
case CloseMandatoryExtension:
|
||||
s = append(s, " (mandatory extension missing)"...)
|
||||
case CloseInternalServerErr:
|
||||
s = append(s, " (internal server error)"...)
|
||||
case CloseTLSHandshake:
|
||||
s = append(s, " (TLS handshake error)"...)
|
||||
}
|
||||
if e.Text != "" {
|
||||
s = append(s, ": "...)
|
||||
s = append(s, e.Text...)
|
||||
}
|
||||
return string(s)
|
||||
}
|
||||
|
||||
// IsCloseError returns boolean indicating whether the error is a *CloseError
|
||||
// with one of the specified codes.
|
||||
func IsCloseError(err error, codes ...int) bool {
|
||||
if e, ok := err.(*CloseError); ok {
|
||||
for _, code := range codes {
|
||||
if e.Code == code {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsUnexpectedCloseError returns boolean indicating whether the error is a
|
||||
// *CloseError with a code not in the list of expected codes.
|
||||
func IsUnexpectedCloseError(err error, expectedCodes ...int) bool {
|
||||
if e, ok := err.(*CloseError); ok {
|
||||
for _, code := range expectedCodes {
|
||||
if e.Code == code {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var (
|
||||
|
@ -162,6 +221,7 @@ type Conn struct {
|
|||
writeFrameType int // type of the current frame.
|
||||
writeSeq int // incremented to invalidate message writers.
|
||||
writeDeadline time.Time
|
||||
isWriting bool // for best-effort concurrent write detection
|
||||
|
||||
// Read fields
|
||||
readMessageCompressed bool
|
||||
|
@ -177,6 +237,7 @@ type Conn struct {
|
|||
readMaskKey [4]byte
|
||||
handlePong func(string) error
|
||||
handlePing func(string) error
|
||||
readErrCount int
|
||||
}
|
||||
|
||||
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
|
||||
|
@ -323,9 +384,6 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
|
|||
//
|
||||
// There can be at most one open writer on a connection. NextWriter closes the
|
||||
// previous writer if the application has not already done so.
|
||||
//
|
||||
// The NextWriter method and the writers returned from the method cannot be
|
||||
// accessed by more than one goroutine at a time.
|
||||
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
|
||||
if c.writeErr != nil {
|
||||
return nil, c.writeErr
|
||||
|
@ -414,9 +472,22 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
|
|||
}
|
||||
}
|
||||
|
||||
// Write the buffers to the connection.
|
||||
// Write the buffers to the connection with best-effort detection of
|
||||
// concurrent writes. See the concurrency section in the package
|
||||
// documentation for more info.
|
||||
|
||||
if c.isWriting {
|
||||
panic("concurrent write to websocket connection")
|
||||
}
|
||||
c.isWriting = true
|
||||
|
||||
c.writeErr = c.write(c.writeFrameType, c.writeDeadline, c.writeBuf[framePos:c.writePos], extra)
|
||||
|
||||
if !c.isWriting {
|
||||
panic("concurrent write to websocket connection")
|
||||
}
|
||||
c.isWriting = false
|
||||
|
||||
// Setup for next frame.
|
||||
c.writePos = maxFrameHeaderSize
|
||||
c.writeFrameType = continuationFrame
|
||||
|
@ -725,13 +796,15 @@ func (c *Conn) advanceFrame() (int, error) {
|
|||
return noFrame, err
|
||||
}
|
||||
case CloseMessage:
|
||||
c.WriteControl(CloseMessage, []byte{}, time.Now().Add(writeWait))
|
||||
echoMessage := []byte{}
|
||||
closeCode := CloseNoStatusReceived
|
||||
closeText := ""
|
||||
if len(payload) >= 2 {
|
||||
echoMessage = payload[:2]
|
||||
closeCode = int(binary.BigEndian.Uint16(payload))
|
||||
closeText = string(payload[2:])
|
||||
}
|
||||
c.WriteControl(CloseMessage, echoMessage, time.Now().Add(writeWait))
|
||||
return noFrame, &CloseError{Code: closeCode, Text: closeText}
|
||||
}
|
||||
|
||||
|
@ -749,8 +822,10 @@ func (c *Conn) handleProtocolError(message string) error {
|
|||
// There can be at most one open reader on a connection. NextReader discards
|
||||
// the previous message if the application has not already consumed it.
|
||||
//
|
||||
// The NextReader method and the readers returned from the method cannot be
|
||||
// accessed by more than one goroutine at a time.
|
||||
// Applications must break out of the application's read loop when this method
|
||||
// returns a non-nil error value. Errors returned from this method are
|
||||
// permanent. Once this method returns a non-nil error, all subsequent calls to
|
||||
// this method return the same error.
|
||||
func (c *Conn) NextReader() (int, io.Reader, error) {
|
||||
|
||||
c.readSeq++
|
||||
|
@ -772,6 +847,15 @@ func (c *Conn) NextReader() (int, io.Reader, error) {
|
|||
return frameType, r, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Applications that do handle the error returned from this method spin in
|
||||
// tight loop on connection failure. To help application developers detect
|
||||
// this error, panic on repeated reads to the failed connection.
|
||||
c.readErrCount++
|
||||
if c.readErrCount >= 1000 {
|
||||
panic("repeated read on failed websocket connection")
|
||||
}
|
||||
|
||||
return noFrame, nil, c.readErr
|
||||
}
|
||||
|
||||
|
@ -798,6 +882,9 @@ func (r messageReader) Read(b []byte) (int, error) {
|
|||
r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b[:n])
|
||||
}
|
||||
r.c.readRemaining -= int64(n)
|
||||
if r.c.readRemaining > 0 && r.c.readErr == io.EOF {
|
||||
r.c.readErr = errUnexpectedEOF
|
||||
}
|
||||
return n, r.c.readErr
|
||||
}
|
||||
|
||||
|
|
130
conn_test.go
130
conn_test.go
|
@ -7,6 +7,7 @@ package websocket
|
|||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
|
@ -173,6 +174,41 @@ func TestCloseBeforeFinalFrame(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestEOFWithinFrame(t *testing.T) {
|
||||
const bufSize = 64
|
||||
|
||||
for n := 0; ; n++ {
|
||||
var b bytes.Buffer
|
||||
wc := newConn(fakeNetConn{Reader: nil, Writer: &b}, false, 1024, 1024)
|
||||
rc := newConn(fakeNetConn{Reader: &b, Writer: nil}, true, 1024, 1024)
|
||||
|
||||
w, _ := wc.NextWriter(BinaryMessage)
|
||||
w.Write(make([]byte, bufSize))
|
||||
w.Close()
|
||||
|
||||
if n >= b.Len() {
|
||||
break
|
||||
}
|
||||
b.Truncate(n)
|
||||
|
||||
op, r, err := rc.NextReader()
|
||||
if err == errUnexpectedEOF {
|
||||
continue
|
||||
}
|
||||
if op != BinaryMessage || err != nil {
|
||||
t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
|
||||
}
|
||||
_, err = io.Copy(ioutil.Discard, r)
|
||||
if err != errUnexpectedEOF {
|
||||
t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
|
||||
}
|
||||
_, _, err = rc.NextReader()
|
||||
if err != errUnexpectedEOF {
|
||||
t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEOFBeforeFinalFrame(t *testing.T) {
|
||||
const bufSize = 512
|
||||
|
||||
|
@ -270,3 +306,97 @@ func TestBufioReadBytes(t *testing.T) {
|
|||
t.Fatalf("read returnd %d bytes, want %d bytes", len(p), len(m))
|
||||
}
|
||||
}
|
||||
|
||||
var closeErrorTests = []struct {
|
||||
err error
|
||||
codes []int
|
||||
ok bool
|
||||
}{
|
||||
{&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true},
|
||||
{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false},
|
||||
{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true},
|
||||
{errors.New("hello"), []int{CloseNormalClosure}, false},
|
||||
}
|
||||
|
||||
func TestCloseError(t *testing.T) {
|
||||
for _, tt := range closeErrorTests {
|
||||
ok := IsCloseError(tt.err, tt.codes...)
|
||||
if ok != tt.ok {
|
||||
t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var unexpectedCloseErrorTests = []struct {
|
||||
err error
|
||||
codes []int
|
||||
ok bool
|
||||
}{
|
||||
{&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false},
|
||||
{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true},
|
||||
{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false},
|
||||
{errors.New("hello"), []int{CloseNormalClosure}, false},
|
||||
}
|
||||
|
||||
func TestUnexpectedCloseErrors(t *testing.T) {
|
||||
for _, tt := range unexpectedCloseErrorTests {
|
||||
ok := IsUnexpectedCloseError(tt.err, tt.codes...)
|
||||
if ok != tt.ok {
|
||||
t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type blockingWriter struct {
|
||||
c1, c2 chan struct{}
|
||||
}
|
||||
|
||||
func (w blockingWriter) Write(p []byte) (int, error) {
|
||||
// Allow main to continue
|
||||
close(w.c1)
|
||||
// Wait for panic in main
|
||||
<-w.c2
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func TestConcurrentWritePanic(t *testing.T) {
|
||||
w := blockingWriter{make(chan struct{}), make(chan struct{})}
|
||||
c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
|
||||
go func() {
|
||||
c.WriteMessage(TextMessage, []byte{})
|
||||
}()
|
||||
|
||||
// wait for goroutine to block in write.
|
||||
<-w.c1
|
||||
|
||||
defer func() {
|
||||
close(w.c2)
|
||||
if v := recover(); v != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
c.WriteMessage(TextMessage, []byte{})
|
||||
t.Fatal("should not get here")
|
||||
}
|
||||
|
||||
type failingReader struct{}
|
||||
|
||||
func (r failingReader) Read(p []byte) (int, error) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func TestFailedConnectionReadPanic(t *testing.T) {
|
||||
c := newConn(fakeNetConn{Reader: failingReader{}, Writer: nil}, false, 1024, 1024)
|
||||
|
||||
defer func() {
|
||||
if v := recover(); v != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
for i := 0; i < 20000; i++ {
|
||||
c.ReadMessage()
|
||||
}
|
||||
t.Fatal("should not get here")
|
||||
}
|
||||
|
|
55
doc.go
55
doc.go
|
@ -46,8 +46,7 @@
|
|||
// method to get an io.WriteCloser, write the message to the writer and close
|
||||
// the writer when done. To receive a message, call the connection NextReader
|
||||
// method to get an io.Reader and read until io.EOF is returned. This snippet
|
||||
// snippet shows how to echo messages using the NextWriter and NextReader
|
||||
// methods:
|
||||
// shows how to echo messages using the NextWriter and NextReader methods:
|
||||
//
|
||||
// for {
|
||||
// messageType, r, err := conn.NextReader()
|
||||
|
@ -86,14 +85,32 @@
|
|||
// and pong. Call the connection WriteControl, WriteMessage or NextWriter
|
||||
// methods to send a control message to the peer.
|
||||
//
|
||||
// Connections handle received ping and pong messages by invoking a callback
|
||||
// function set with SetPingHandler and SetPongHandler methods. These callback
|
||||
// functions can be invoked from the ReadMessage method, the NextReader method
|
||||
// or from a call to the data message reader returned from NextReader.
|
||||
// Connections handle received close messages by sending a close message to the
|
||||
// peer and returning a *CloseError from the the NextReader, ReadMessage or the
|
||||
// message Read method.
|
||||
//
|
||||
// Connections handle received close messages by returning an error from the
|
||||
// ReadMessage method, the NextReader method or from a call to the data message
|
||||
// reader returned from NextReader.
|
||||
// Connections handle received ping and pong messages by invoking callback
|
||||
// functions set with SetPingHandler and SetPongHandler methods. The callback
|
||||
// functions are called from the NextReader, ReadMessage and the message Read
|
||||
// methods.
|
||||
//
|
||||
// The default ping handler sends a pong to the peer. The application's reading
|
||||
// goroutine can block for a short time while the handler writes the pong data
|
||||
// to the connection.
|
||||
//
|
||||
// The application must read the connection to process ping, pong and close
|
||||
// messages sent from the peer. If the application is not otherwise interested
|
||||
// in messages from the peer, then the application should start a goroutine to
|
||||
// read and discard messages from the peer. A simple example is:
|
||||
//
|
||||
// func readLoop(c *websocket.Conn) {
|
||||
// for {
|
||||
// if _, _, err := c.NextReader(); err != nil {
|
||||
// c.Close()
|
||||
// break
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// Concurrency
|
||||
//
|
||||
|
@ -108,22 +125,6 @@
|
|||
// The Close and WriteControl methods can be called concurrently with all other
|
||||
// methods.
|
||||
//
|
||||
// Read is Required
|
||||
//
|
||||
// The application must read the connection to process ping and close messages
|
||||
// sent from the peer. If the application is not otherwise interested in
|
||||
// messages from the peer, then the application should start a goroutine to read
|
||||
// and discard messages from the peer. A simple example is:
|
||||
//
|
||||
// func readLoop(c *websocket.Conn) {
|
||||
// for {
|
||||
// if _, _, err := c.NextReader(); err != nil {
|
||||
// c.Close()
|
||||
// break
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// Origin Considerations
|
||||
//
|
||||
// Web browsers allow Javascript applications to open a WebSocket connection to
|
||||
|
@ -141,9 +142,9 @@
|
|||
// An application can allow connections from any origin by specifying a
|
||||
// function that always returns true:
|
||||
//
|
||||
// var upgrader = websocket.Upgrader{
|
||||
// var upgrader = websocket.Upgrader{
|
||||
// CheckOrigin: func(r *http.Request) bool { return true },
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// The deprecated Upgrade function does not enforce an origin policy. It's the
|
||||
// application's responsibility to check the Origin header before calling
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
// Copyright 2015 The Gorilla WebSocket Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package websocket_test
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var (
|
||||
c *websocket.Conn
|
||||
req *http.Request
|
||||
)
|
||||
|
||||
// The websocket.IsUnexpectedCloseError function is useful for identifying
|
||||
// application and protocol errors.
|
||||
//
|
||||
// This server application works with a client application running in the
|
||||
// browser. The client application does not explicitly close the websocket. The
|
||||
// only expected close message from the client has the code
|
||||
// websocket.CloseGoingAway. All other other close messages are likely the
|
||||
// result of an application or protocol error and are logged to aid debugging.
|
||||
func ExampleIsUnexpectedCloseError() {
|
||||
|
||||
for {
|
||||
messageType, p, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
|
||||
log.Printf("error: %v, user-agent: %v", err, req.Header.Get("User-Agent"))
|
||||
}
|
||||
return
|
||||
}
|
||||
processMesage(messageType, p)
|
||||
}
|
||||
}
|
||||
|
||||
func processMesage(mt int, p []byte) {}
|
||||
|
||||
// TestX prevents godoc from showing this entire file in the example. Remove
|
||||
// this function when a second example is added.
|
||||
func TestX(t *testing.T) {}
|
|
@ -51,6 +51,9 @@ func (c *connection) readPump() {
|
|||
for {
|
||||
_, message, err := c.ws.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
|
||||
log.Printf("error: %v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
h.broadcast <- message
|
||||
|
|
|
@ -10,6 +10,8 @@ import (
|
|||
"flag"
|
||||
"log"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
@ -21,6 +23,9 @@ func main() {
|
|||
flag.Parse()
|
||||
log.SetFlags(0)
|
||||
|
||||
interrupt := make(chan os.Signal, 1)
|
||||
signal.Notify(interrupt, os.Interrupt)
|
||||
|
||||
u := url.URL{Scheme: "ws", Host: *addr, Path: "/echo"}
|
||||
log.Printf("connecting to %s", u.String())
|
||||
|
||||
|
@ -30,13 +35,16 @@ func main() {
|
|||
}
|
||||
defer c.Close()
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer c.Close()
|
||||
defer close(done)
|
||||
for {
|
||||
_, message, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
log.Println("read:", err)
|
||||
break
|
||||
return
|
||||
}
|
||||
log.Printf("recv: %s", message)
|
||||
}
|
||||
|
@ -45,11 +53,29 @@ func main() {
|
|||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for t := range ticker.C {
|
||||
err := c.WriteMessage(websocket.TextMessage, []byte(t.String()))
|
||||
if err != nil {
|
||||
log.Println("write:", err)
|
||||
break
|
||||
for {
|
||||
select {
|
||||
case t := <-ticker.C:
|
||||
err := c.WriteMessage(websocket.TextMessage, []byte(t.String()))
|
||||
if err != nil {
|
||||
log.Println("write:", err)
|
||||
return
|
||||
}
|
||||
case <-interrupt:
|
||||
log.Println("interrupt")
|
||||
// To cleanly close a connection, a client should send a close
|
||||
// frame and wait for the server to close the connection.
|
||||
err := c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||
if err != nil {
|
||||
log.Println("write close:", err)
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -120,7 +120,7 @@ func serveWs(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
var lastMod time.Time
|
||||
if n, err := strconv.ParseInt(r.FormValue("lastMod"), 16, 64); err != nil {
|
||||
if n, err := strconv.ParseInt(r.FormValue("lastMod"), 16, 64); err == nil {
|
||||
lastMod = time.Unix(0, n)
|
||||
}
|
||||
|
||||
|
|
10
server.go
10
server.go
|
@ -124,6 +124,9 @@ func (u *Upgrader) selectCompressionExtension(r *http.Request) (string, bool, er
|
|||
// The responseHeader is included in the response to the client's upgrade
|
||||
// request. Use the responseHeader to specify cookies (Set-Cookie) and the
|
||||
// application negotiated subprotocol (Sec-Websocket-Protocol).
|
||||
//
|
||||
// If the upgrade fails, then Upgrade replies to the client with an HTTP error
|
||||
// response.
|
||||
func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
|
||||
if r.Method != "GET" {
|
||||
return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: method not GET")
|
||||
|
@ -291,3 +294,10 @@ func Subprotocols(r *http.Request) []string {
|
|||
}
|
||||
return protocols
|
||||
}
|
||||
|
||||
// IsWebSocketUpgrade returns true if the client requested upgrade to the
|
||||
// WebSocket protocol.
|
||||
func IsWebSocketUpgrade(r *http.Request) bool {
|
||||
return tokenListContainsValue(r.Header, "Connection", "upgrade") &&
|
||||
tokenListContainsValue(r.Header, "Upgrade", "websocket")
|
||||
}
|
||||
|
|
|
@ -31,3 +31,21 @@ func TestSubprotocols(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
var isWebSocketUpgradeTests = []struct {
|
||||
ok bool
|
||||
h http.Header
|
||||
}{
|
||||
{false, http.Header{"Upgrade": {"websocket"}}},
|
||||
{false, http.Header{"Connection": {"upgrade"}}},
|
||||
{true, http.Header{"Connection": {"upgRade"}, "Upgrade": {"WebSocket"}}},
|
||||
}
|
||||
|
||||
func TestIsWebSocketUpgrade(t *testing.T) {
|
||||
for _, tt := range isWebSocketUpgradeTests {
|
||||
ok := IsWebSocketUpgrade(&http.Request{Header: tt.h})
|
||||
if tt.ok != ok {
|
||||
t.Errorf("IsWebSocketUpgrade(%v) returned %v, want %v", tt.h, ok, tt.ok)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue