diff --git a/server.go b/server.go index 85616c7..8d7137d 100644 --- a/server.go +++ b/server.go @@ -251,3 +251,10 @@ func Subprotocols(r *http.Request) []string { } return protocols } + +// IsWebSocketUpgrade returns true if the client requested upgrade to the +// WebSocket protocol. +func IsWebSocketUpgrade(r *http.Request) bool { + return tokenListContainsValue(r.Header, "Connection", "upgrade") && + tokenListContainsValue(r.Header, "Upgrade", "websocket") +} diff --git a/server_test.go b/server_test.go index ead0776..0a28141 100644 --- a/server_test.go +++ b/server_test.go @@ -31,3 +31,21 @@ func TestSubprotocols(t *testing.T) { } } } + +var isWebSocketUpgradeTests = []struct { + ok bool + h http.Header +}{ + {false, http.Header{"Upgrade": {"websocket"}}}, + {false, http.Header{"Connection": {"upgrade"}}}, + {true, http.Header{"Connection": {"upgRade"}, "Upgrade": {"WebSocket"}}}, +} + +func TestIsWebSocketUpgrade(t *testing.T) { + for _, tt := range isWebSocketUpgradeTests { + ok := IsWebSocketUpgrade(&http.Request{Header: tt.h}) + if tt.ok != ok { + t.Errorf("IsWebSocketUpgrade(%v) returned %v, want %v", tt.h, ok, tt.ok) + } + } +}