mirror of https://github.com/gorilla/websocket.git
fix-issue-480
This commit is contained in:
parent
b65e62901f
commit
33da5b06cf
32
server.go
32
server.go
|
@ -7,6 +7,7 @@ package websocket
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -44,6 +45,7 @@ type Upgrader struct {
|
||||||
// WriteBufferSize.
|
// WriteBufferSize.
|
||||||
WriteBufferPool BufferPool
|
WriteBufferPool BufferPool
|
||||||
|
|
||||||
|
// Subprotocols have lower priority than NegotiateSuprotocol.
|
||||||
// Subprotocols specifies the server's supported protocols in order of
|
// Subprotocols specifies the server's supported protocols in order of
|
||||||
// preference. If this field is not nil, then the Upgrade method negotiates a
|
// preference. If this field is not nil, then the Upgrade method negotiates a
|
||||||
// subprotocol by selecting the first match in this list with a protocol
|
// subprotocol by selecting the first match in this list with a protocol
|
||||||
|
@ -70,6 +72,14 @@ type Upgrader struct {
|
||||||
// guarantee that compression will be supported. Currently only "no context
|
// guarantee that compression will be supported. Currently only "no context
|
||||||
// takeover" modes are supported.
|
// takeover" modes are supported.
|
||||||
EnableCompression bool
|
EnableCompression bool
|
||||||
|
|
||||||
|
// NegotiateSubprotocol has higher priority than Subprotocols.
|
||||||
|
// NegotiateSubprotocol returns the negotiated subprotocol for the handshake
|
||||||
|
// request. If the returned string is "", then the the Sec-Websocket-Protocol header
|
||||||
|
// is not included in the handshake response. If the function returns an error, then
|
||||||
|
// Upgrade responds to the client with http.StatusBadRequest.
|
||||||
|
// If this function is not nil, then the Upgrader.Subportocols field is ignored.
|
||||||
|
NegotiateSubprotocol func(r *http.Request) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) {
|
func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) {
|
||||||
|
@ -96,7 +106,7 @@ func checkSameOrigin(r *http.Request) bool {
|
||||||
return equalASCIIFold(u.Host, r.Host)
|
return equalASCIIFold(u.Host, r.Host)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
|
func (u *Upgrader) selectSubprotocol(r *http.Request) string {
|
||||||
if u.Subprotocols != nil {
|
if u.Subprotocols != nil {
|
||||||
clientProtocols := Subprotocols(r)
|
clientProtocols := Subprotocols(r)
|
||||||
for _, serverProtocol := range u.Subprotocols {
|
for _, serverProtocol := range u.Subprotocols {
|
||||||
|
@ -106,8 +116,6 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if responseHeader != nil {
|
|
||||||
return responseHeader.Get("Sec-Websocket-Protocol")
|
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
@ -115,11 +123,14 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header
|
||||||
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
|
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
|
||||||
//
|
//
|
||||||
// The responseHeader is included in the response to the client's upgrade
|
// The responseHeader is included in the response to the client's upgrade
|
||||||
// request. Use the responseHeader to specify cookies (Set-Cookie) and the
|
// request. Use the responseHeader to specify cookies (Set-Cookie).
|
||||||
// application negotiated subprotocol (Sec-WebSocket-Protocol).
|
|
||||||
//
|
//
|
||||||
// If the upgrade fails, then Upgrade replies to the client with an HTTP error
|
// If the upgrade fails, then Upgrade replies to the client with an HTTP error
|
||||||
// response.
|
// response.
|
||||||
|
//
|
||||||
|
// The responseHeader does not support negotiated subprotocol(Sec-Websocket-Protocol)
|
||||||
|
// IF necessary,please use Upgrader.NegotiateSubprotocol and Upgrader.Subprotocols
|
||||||
|
// Use the method to view the Upgrader struct.
|
||||||
func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
|
func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
|
||||||
const badHandshake = "websocket: the client is not using the websocket protocol: "
|
const badHandshake = "websocket: the client is not using the websocket protocol: "
|
||||||
|
|
||||||
|
@ -156,7 +167,16 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
||||||
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header is missing or blank")
|
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header is missing or blank")
|
||||||
}
|
}
|
||||||
|
|
||||||
subprotocol := u.selectSubprotocol(r, responseHeader)
|
subprotocol := ""
|
||||||
|
if u.NegotiateSubprotocol != nil {
|
||||||
|
str, err := u.NegotiateSubprotocol(r)
|
||||||
|
if err != nil {
|
||||||
|
return u.returnError(w, r, http.StatusBadRequest, fmt.Sprintf("websocket:handshake negotiation protocol error:%s", err))
|
||||||
|
}
|
||||||
|
subprotocol = str
|
||||||
|
} else {
|
||||||
|
subprotocol = u.selectSubprotocol(r)
|
||||||
|
}
|
||||||
|
|
||||||
// Negotiate PMCE
|
// Negotiate PMCE
|
||||||
var compress bool
|
var compress bool
|
||||||
|
|
|
@ -7,8 +7,10 @@ package websocket
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -117,3 +119,74 @@ func TestBufioReuse(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var negotiateSubprotocolTests = []struct {
|
||||||
|
*Upgrader
|
||||||
|
match bool
|
||||||
|
shouldErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
&Upgrader{
|
||||||
|
NegotiateSubprotocol: func(r *http.Request) (s string, err error) { return "json", nil },
|
||||||
|
}, true, false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
&Upgrader{
|
||||||
|
Subprotocols: []string{"json"},
|
||||||
|
}, true, false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
&Upgrader{
|
||||||
|
Subprotocols: []string{"not-match"},
|
||||||
|
}, false, false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
&Upgrader{
|
||||||
|
NegotiateSubprotocol: func(r *http.Request) (s string, err error) { return "", errors.New("not-match") },
|
||||||
|
}, false, true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNegotiateSubprotocol(t *testing.T) {
|
||||||
|
for i := range negotiateSubprotocolTests {
|
||||||
|
upgrade := negotiateSubprotocolTests[i].Upgrader
|
||||||
|
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
upgrade.Upgrade(w, r, nil)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req, err := http.NewRequest("GET", s.URL, strings.NewReader(""))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRequest retuened error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Connection", "upgrade")
|
||||||
|
req.Header.Set("Upgrade", "websocket")
|
||||||
|
req.Header.Set("Sec-Websocket-Version", "13")
|
||||||
|
req.Header.Set("Sec-Websocket-Protocol", "json")
|
||||||
|
req.Header.Set("Sec-Websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do returned error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if negotiateSubprotocolTests[i].shouldErr && resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("The expecred status code is %d,actual status code is %d", http.StatusBadRequest, resp.StatusCode)
|
||||||
|
} else {
|
||||||
|
if negotiateSubprotocolTests[i].match {
|
||||||
|
protocol := resp.Header.Get("Sec-Websocket-Protocol")
|
||||||
|
if protocol != "json" {
|
||||||
|
t.Errorf("Negotiation protocol failed,request protocol is json,reponese protocol is %s", protocol)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if _, ok := resp.Header["Sec-Websocket-Protocol"]; ok {
|
||||||
|
t.Errorf("Negotiation protocol failed,Sec-Websocket-Protocol field should be empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.Close()
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue