Add convenience method to check if websockets required (#779)

* Add convenience method to check if websockets required

* Add tests

* Fix up tests for develop branch
This commit is contained in:
David Irvine 2017-01-02 03:05:30 -05:00 committed by Bo-Yi Wu
parent ff17a8dd75
commit ebe3580daf
2 changed files with 32 additions and 0 deletions

View File

@ -383,6 +383,16 @@ func (c *Context) ContentType() string {
return filterFlags(c.requestHeader("Content-Type")) return filterFlags(c.requestHeader("Content-Type"))
} }
// IsWebsocket returns true if the request headers indicate that a websocket
// handshake is being initiated by the client.
func (c *Context) IsWebsocket() bool {
if strings.Contains(strings.ToLower(c.requestHeader("Connection")), "upgrade") &&
strings.ToLower(c.requestHeader("Upgrade")) == "websocket" {
return true
}
return false
}
func (c *Context) requestHeader(key string) string { func (c *Context) requestHeader(key string) string {
if values, _ := c.Request.Header[key]; len(values) > 0 { if values, _ := c.Request.Header[key]; len(values) > 0 {
return values[0] return values[0]

View File

@ -814,3 +814,25 @@ func TestContextGolangContext(t *testing.T) {
assert.Equal(t, c.Value("foo"), "bar") assert.Equal(t, c.Value("foo"), "bar")
assert.Nil(t, c.Value(1)) assert.Nil(t, c.Value(1))
} }
func TestWebsocketsRequired(t *testing.T) {
// Example request from spec: https://tools.ietf.org/html/rfc6455#section-1.2
c, _ := CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest("GET", "/chat", nil)
c.Request.Header.Set("Host", "server.example.com")
c.Request.Header.Set("Upgrade", "websocket")
c.Request.Header.Set("Connection", "Upgrade")
c.Request.Header.Set("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
c.Request.Header.Set("Origin", "http://example.com")
c.Request.Header.Set("Sec-WebSocket-Protocol", "chat, superchat")
c.Request.Header.Set("Sec-WebSocket-Version", "13")
assert.True(t, c.IsWebsocket())
// Normal request, no websocket required.
c, _ = CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest("GET", "/chat", nil)
c.Request.Header.Set("Host", "server.example.com")
assert.False(t, c.IsWebsocket())
}