mirror of https://github.com/gorilla/websocket.git
commit
f5e80e4017
36
client.go
36
client.go
|
@ -5,8 +5,11 @@
|
||||||
package websocket
|
package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -127,6 +130,11 @@ func parseURL(s string) (*url.URL, error) {
|
||||||
u.Opaque = s[i:]
|
u.Opaque = s[i:]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if strings.Contains(u.Host, "@") {
|
||||||
|
// WebSocket URIs do not contain user information.
|
||||||
|
return nil, errMalformedURL
|
||||||
|
}
|
||||||
|
|
||||||
return &u, nil
|
return &u, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -155,7 +163,8 @@ var DefaultDialer *Dialer
|
||||||
//
|
//
|
||||||
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
|
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
|
||||||
// non-nil *http.Response so that callers can handle redirects, authentication,
|
// non-nil *http.Response so that callers can handle redirects, authentication,
|
||||||
// etc.
|
// etcetera. The response body may not contain the entire response and does not
|
||||||
|
// need to be closed by the application.
|
||||||
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
|
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
|
||||||
u, err := parseURL(urlStr)
|
u, err := parseURL(urlStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -224,8 +233,33 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
||||||
requestHeader = h
|
requestHeader = h
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(requestHeader["Host"]) > 0 {
|
||||||
|
// This can be used to supply a Host: header which is different from
|
||||||
|
// the dial address.
|
||||||
|
u.Host = requestHeader.Get("Host")
|
||||||
|
|
||||||
|
// Drop "Host" header
|
||||||
|
h := http.Header{}
|
||||||
|
for k, v := range requestHeader {
|
||||||
|
if k == "Host" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
h[k] = v
|
||||||
|
}
|
||||||
|
requestHeader = h
|
||||||
|
}
|
||||||
|
|
||||||
conn, resp, err := NewClient(netConn, u, requestHeader, d.ReadBufferSize, d.WriteBufferSize)
|
conn, resp, err := NewClient(netConn, u, requestHeader, d.ReadBufferSize, d.WriteBufferSize)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err == ErrBadHandshake {
|
||||||
|
// Before closing the network connection on return from this
|
||||||
|
// function, slurp up some of the response to aid application
|
||||||
|
// debugging.
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
n, _ := io.ReadFull(resp.Body, buf)
|
||||||
|
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
|
||||||
|
}
|
||||||
return nil, resp, err
|
return nil, resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -8,11 +8,13 @@ import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"io"
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -34,22 +36,22 @@ var cstDialer = Dialer{
|
||||||
|
|
||||||
type cstHandler struct{ *testing.T }
|
type cstHandler struct{ *testing.T }
|
||||||
|
|
||||||
type Server struct {
|
type cstServer struct {
|
||||||
*httptest.Server
|
*httptest.Server
|
||||||
URL string
|
URL string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newServer(t *testing.T) *Server {
|
func newServer(t *testing.T) *cstServer {
|
||||||
var s Server
|
var s cstServer
|
||||||
s.Server = httptest.NewServer(cstHandler{t})
|
s.Server = httptest.NewServer(cstHandler{t})
|
||||||
s.URL = "ws" + s.Server.URL[len("http"):]
|
s.URL = makeWsProto(s.Server.URL)
|
||||||
return &s
|
return &s
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTLSServer(t *testing.T) *Server {
|
func newTLSServer(t *testing.T) *cstServer {
|
||||||
var s Server
|
var s cstServer
|
||||||
s.Server = httptest.NewTLSServer(cstHandler{t})
|
s.Server = httptest.NewTLSServer(cstHandler{t})
|
||||||
s.URL = "ws" + s.Server.URL[len("http"):]
|
s.URL = makeWsProto(s.Server.URL)
|
||||||
return &s
|
return &s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -97,6 +99,10 @@ func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func makeWsProto(s string) string {
|
||||||
|
return "ws" + strings.TrimPrefix(s, "http")
|
||||||
|
}
|
||||||
|
|
||||||
func sendRecv(t *testing.T, ws *Conn) {
|
func sendRecv(t *testing.T, ws *Conn) {
|
||||||
const message = "Hello World!"
|
const message = "Hello World!"
|
||||||
if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil {
|
if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil {
|
||||||
|
@ -157,6 +163,7 @@ func TestDialTLS(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func xTestDialTLSBadCert(t *testing.T) {
|
func xTestDialTLSBadCert(t *testing.T) {
|
||||||
|
// This test is deactivated because of noisy logging from the net/http package.
|
||||||
s := newTLSServer(t)
|
s := newTLSServer(t)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
|
@ -247,3 +254,70 @@ func TestHandshake(t *testing.T) {
|
||||||
}
|
}
|
||||||
sendRecv(t, ws)
|
sendRecv(t, ws)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRespOnBadHandshake(t *testing.T) {
|
||||||
|
const expectedStatus = http.StatusGone
|
||||||
|
const expectedBody = "This is the response body."
|
||||||
|
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(expectedStatus)
|
||||||
|
io.WriteString(w, expectedBody)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil)
|
||||||
|
if err == nil {
|
||||||
|
ws.Close()
|
||||||
|
t.Fatalf("Dial: nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp == nil {
|
||||||
|
t.Fatalf("resp=nil, err=%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != expectedStatus {
|
||||||
|
t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
|
||||||
|
}
|
||||||
|
|
||||||
|
p, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadFull(resp.Body) returned error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(p) != expectedBody {
|
||||||
|
t.Errorf("resp.Body=%s, want %s", p, expectedBody)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the Host header is specified in `Dial()`, the server must receive it as
|
||||||
|
// the `Host:` header.
|
||||||
|
func TestHostHeader(t *testing.T) {
|
||||||
|
s := newServer(t)
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
specifiedHost := make(chan string, 1)
|
||||||
|
origHandler := s.Server.Config.Handler
|
||||||
|
|
||||||
|
// Capture the request Host header.
|
||||||
|
s.Server.Config.Handler = http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
specifiedHost <- r.Host
|
||||||
|
origHandler.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
|
||||||
|
ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Dial: %v", err)
|
||||||
|
}
|
||||||
|
defer ws.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||||
|
t.Fatalf("resp.StatusCode = %v, want http.StatusSwitchingProtocols", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotHost := <-specifiedHost; gotHost != "testhost" {
|
||||||
|
t.Fatalf("gotHost = %q, want \"testhost\"", gotHost)
|
||||||
|
}
|
||||||
|
|
||||||
|
sendRecv(t, ws)
|
||||||
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@ var parseURLTests = []struct {
|
||||||
{"wss://example.com/", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/"}},
|
{"wss://example.com/", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/"}},
|
||||||
{"wss://example.com/a/b", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b"}},
|
{"wss://example.com/a/b", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b"}},
|
||||||
{"ss://example.com/a/b", nil},
|
{"ss://example.com/a/b", nil},
|
||||||
|
{"ws://webmaster@example.com/", nil},
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseURL(t *testing.T) {
|
func TestParseURL(t *testing.T) {
|
||||||
|
|
2
conn.go
2
conn.go
|
@ -801,7 +801,7 @@ func (c *Conn) SetPingHandler(h func(string) error) {
|
||||||
c.handlePing = h
|
c.handlePing = h
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetPongHandler sets then handler for pong messages received from the peer.
|
// SetPongHandler sets the handler for pong messages received from the peer.
|
||||||
// The default pong handler does nothing.
|
// The default pong handler does nothing.
|
||||||
func (c *Conn) SetPongHandler(h func(string) error) {
|
func (c *Conn) SetPongHandler(h func(string) error) {
|
||||||
if h == nil {
|
if h == nil {
|
||||||
|
|
2
doc.go
2
doc.go
|
@ -24,7 +24,7 @@
|
||||||
// ... Use conn to send and receive messages.
|
// ... Use conn to send and receive messages.
|
||||||
// }
|
// }
|
||||||
//
|
//
|
||||||
// Call the connection WriteMessage and ReadMessages methods to send and
|
// Call the connection's WriteMessage and ReadMessage methods to send and
|
||||||
// receive messages as a slice of bytes. This snippet of code shows how to echo
|
// receive messages as a slice of bytes. This snippet of code shows how to echo
|
||||||
// messages using these methods:
|
// messages using these methods:
|
||||||
//
|
//
|
||||||
|
|
10
json.go
10
json.go
|
@ -6,6 +6,7 @@ package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"io"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WriteJSON is deprecated, use c.WriteJSON instead.
|
// WriteJSON is deprecated, use c.WriteJSON instead.
|
||||||
|
@ -45,5 +46,12 @@ func (c *Conn) ReadJSON(v interface{}) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return json.NewDecoder(r).Decode(v)
|
err = json.NewDecoder(r).Decode(v)
|
||||||
|
if err == io.EOF {
|
||||||
|
// Decode returns io.EOF when the message is empty or all whitespace.
|
||||||
|
// Convert to io.ErrUnexpectedEOF so that application can distinguish
|
||||||
|
// between an error reading the JSON value and the connection closing.
|
||||||
|
err = io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
56
json_test.go
56
json_test.go
|
@ -6,6 +6,8 @@ package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
@ -36,6 +38,60 @@ func TestJSON(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPartialJsonRead(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
c := fakeNetConn{&buf, &buf}
|
||||||
|
wc := newConn(c, true, 1024, 1024)
|
||||||
|
rc := newConn(c, false, 1024, 1024)
|
||||||
|
|
||||||
|
var v struct {
|
||||||
|
A int
|
||||||
|
B string
|
||||||
|
}
|
||||||
|
v.A = 1
|
||||||
|
v.B = "hello"
|
||||||
|
|
||||||
|
messageCount := 0
|
||||||
|
|
||||||
|
// Partial JSON values.
|
||||||
|
|
||||||
|
data, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
for i := len(data) - 1; i >= 0; i-- {
|
||||||
|
if err := wc.WriteMessage(TextMessage, data[:i]); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
messageCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Whitespace.
|
||||||
|
|
||||||
|
if err := wc.WriteMessage(TextMessage, []byte(" ")); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
messageCount++
|
||||||
|
|
||||||
|
// Close.
|
||||||
|
|
||||||
|
if err := wc.WriteMessage(CloseMessage, FormatCloseMessage(CloseNormalClosure, "")); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < messageCount; i++ {
|
||||||
|
err := rc.ReadJSON(&v)
|
||||||
|
if err != io.ErrUnexpectedEOF {
|
||||||
|
t.Error("read", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = rc.ReadJSON(&v)
|
||||||
|
if err != io.EOF {
|
||||||
|
t.Error("final", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDeprecatedJSON(t *testing.T) {
|
func TestDeprecatedJSON(t *testing.T) {
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
c := fakeNetConn{&buf, &buf}
|
c := fakeNetConn{&buf, &buf}
|
||||||
|
|
Loading…
Reference in New Issue