mirror of https://github.com/gorilla/websocket.git
Add support for fasthttp
This commit is contained in:
parent
844dd6d40e
commit
76d5f02c36
|
@ -0,0 +1,116 @@
|
||||||
|
// +build go1.4
|
||||||
|
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func checkSameOriginFastHTTP(ctx *fasthttp.RequestCtx) bool {
|
||||||
|
return checkSameOriginFromHeaderAndHost(string(ctx.Request.Header.Peek(originHeader)), string(ctx.Host()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FastHTTPUpgrader is used to upgrade a fasthttp request into a websocket
|
||||||
|
// connection. A Handler function must be provided to receive that connection.
|
||||||
|
type FastHTTPUpgrader struct {
|
||||||
|
// Handler receives a websocket connection after the handshake has been
|
||||||
|
// completed. This must be provided.
|
||||||
|
Handler func(*Conn)
|
||||||
|
|
||||||
|
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer
|
||||||
|
// size is zero, then a default value of 4096 is used. The I/O buffer sizes
|
||||||
|
// do not limit the size of the messages that can be sent or received.
|
||||||
|
ReadBufferSize, WriteBufferSize int
|
||||||
|
|
||||||
|
// Subprotocols specifies the server's supported protocols in order of
|
||||||
|
// preference. If this field is set, then the Upgrade method negotiates a
|
||||||
|
// subprotocol by selecting the first match in this list with a protocol
|
||||||
|
// requested by the client.
|
||||||
|
Subprotocols []string
|
||||||
|
|
||||||
|
// CheckOrigin returns true if the request Origin header is acceptable. If
|
||||||
|
// CheckOrigin is nil, the host in the Origin header must not be set or
|
||||||
|
// must match the host of the request.
|
||||||
|
CheckOrigin func(ctx *fasthttp.RequestCtx) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var websocketVersionByte = []byte(websocketVersion)
|
||||||
|
|
||||||
|
// UpgradeHandler handles a request for a websocket connection and does all the
|
||||||
|
// checks necessary to ensure the request is valid. If a CheckOrigin function
|
||||||
|
// was provided, it will be called, otherwise the Origin will be checked against
|
||||||
|
// the request host value. If a subprotocol has not already been set, the best
|
||||||
|
// choice will be made from the values provided to the upgrader and from the
|
||||||
|
// client.
|
||||||
|
//
|
||||||
|
// Once the request has been verified and the response sent, the connection will
|
||||||
|
// be hijacked and the provided Handler will be called.
|
||||||
|
func (f *FastHTTPUpgrader) UpgradeHandler(ctx *fasthttp.RequestCtx) {
|
||||||
|
if f.Handler == nil {
|
||||||
|
panic("FastHTTPUpgrader does not have a Handler set")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ctx.IsGet() {
|
||||||
|
ctx.Error("websocket: method not GET", fasthttp.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(ctx.Request.Header.Peek("Sec-Websocket-Version"), websocketVersionByte) {
|
||||||
|
ctx.Error("websocket: version != 13", fasthttp.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ctx.Request.Header.ConnectionUpgrade() {
|
||||||
|
ctx.Error("websocket: could not find connection header with token 'upgrade'", fasthttp.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tokenListContainsValue(string(ctx.Request.Header.Peek("Upgrade")), "websocket") {
|
||||||
|
ctx.Error("websocket: could not find upgrade header with token 'websocket'", fasthttp.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
checkOrigin := f.CheckOrigin
|
||||||
|
if checkOrigin == nil {
|
||||||
|
checkOrigin = checkSameOriginFastHTTP
|
||||||
|
}
|
||||||
|
if !checkOrigin(ctx) {
|
||||||
|
ctx.Error("websocket: origin not allowed", fasthttp.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
challengeKey := ctx.Request.Header.Peek("Sec-Websocket-Key")
|
||||||
|
if len(challengeKey) == 0 {
|
||||||
|
ctx.Error("websocket: key missing or blank", fasthttp.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusSwitchingProtocols)
|
||||||
|
ctx.Response.Header.Set("Upgrade", "websocket")
|
||||||
|
ctx.Response.Header.Set("Connection", "Upgrade")
|
||||||
|
ctx.Response.Header.Set("Sec-WebSocket-Accept", computeAcceptKeyByte(challengeKey))
|
||||||
|
|
||||||
|
// The subprotocol may have already been set in the response
|
||||||
|
subprotocol := string(ctx.Response.Header.Peek(protocolHeader))
|
||||||
|
if subprotocol == "" {
|
||||||
|
// Find the best protocol, if any
|
||||||
|
clientProtocols := subprotocolsFromHeader(string(ctx.Request.Header.Peek(protocolHeader)))
|
||||||
|
if len(clientProtocols) != 0 {
|
||||||
|
subprotocol = matchSubprotocol(clientProtocols, f.Subprotocols)
|
||||||
|
if subprotocol != "" {
|
||||||
|
ctx.Response.Header.Set(protocolHeader, subprotocol)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Hijack(func(conn net.Conn) {
|
||||||
|
c := newConn(conn, true, f.ReadBufferSize, f.WriteBufferSize)
|
||||||
|
if subprotocol != "" {
|
||||||
|
c.subprotocol = subprotocol
|
||||||
|
}
|
||||||
|
f.Handler(c)
|
||||||
|
})
|
||||||
|
}
|
56
server.go
56
server.go
|
@ -14,6 +14,12 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
originHeader = "Origin"
|
||||||
|
protocolHeader = "Sec-Websocket-Protocol"
|
||||||
|
websocketVersion = "13"
|
||||||
|
)
|
||||||
|
|
||||||
// HandshakeError describes an error with the handshake from the peer.
|
// HandshakeError describes an error with the handshake from the peer.
|
||||||
type HandshakeError struct {
|
type HandshakeError struct {
|
||||||
message string
|
message string
|
||||||
|
@ -60,30 +66,42 @@ func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status in
|
||||||
|
|
||||||
// checkSameOrigin returns true if the origin is not set or is equal to the request host.
|
// checkSameOrigin returns true if the origin is not set or is equal to the request host.
|
||||||
func checkSameOrigin(r *http.Request) bool {
|
func checkSameOrigin(r *http.Request) bool {
|
||||||
origin := r.Header["Origin"]
|
origin := r.Header[originHeader]
|
||||||
if len(origin) == 0 {
|
if len(origin) == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
u, err := url.Parse(origin[0])
|
return checkSameOriginFromHeaderAndHost(origin[0], r.Host)
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkSameOriginFromHeaderAndHost(origin, reqHost string) bool {
|
||||||
|
if len(origin) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
u, err := url.Parse(origin)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return u.Host == r.Host
|
return u.Host == reqHost
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
|
func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
|
||||||
if u.Subprotocols != nil {
|
if u.Subprotocols != nil {
|
||||||
clientProtocols := Subprotocols(r)
|
return matchSubprotocol(Subprotocols(r), u.Subprotocols)
|
||||||
for _, serverProtocol := range u.Subprotocols {
|
} else if responseHeader != nil {
|
||||||
for _, clientProtocol := range clientProtocols {
|
return responseHeader.Get(protocolHeader)
|
||||||
if clientProtocol == serverProtocol {
|
}
|
||||||
return clientProtocol
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func matchSubprotocol(clientProtocols, serverProtocols []string) string {
|
||||||
|
for _, serverProtocol := range serverProtocols {
|
||||||
|
for _, clientProtocol := range clientProtocols {
|
||||||
|
if clientProtocol == serverProtocol {
|
||||||
|
return clientProtocol
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if responseHeader != nil {
|
|
||||||
return responseHeader.Get("Sec-Websocket-Protocol")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -96,15 +114,15 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
||||||
if r.Method != "GET" {
|
if r.Method != "GET" {
|
||||||
return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: method not GET")
|
return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: method not GET")
|
||||||
}
|
}
|
||||||
if values := r.Header["Sec-Websocket-Version"]; len(values) == 0 || values[0] != "13" {
|
if values := r.Header["Sec-Websocket-Version"]; len(values) == 0 || values[0] != websocketVersion {
|
||||||
return u.returnError(w, r, http.StatusBadRequest, "websocket: version != 13")
|
return u.returnError(w, r, http.StatusBadRequest, "websocket: version !="+websocketVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !tokenListContainsValue(r.Header, "Connection", "upgrade") {
|
if !headerListContainsValue(r.Header, "Connection", "upgrade") {
|
||||||
return u.returnError(w, r, http.StatusBadRequest, "websocket: could not find connection header with token 'upgrade'")
|
return u.returnError(w, r, http.StatusBadRequest, "websocket: could not find connection header with token 'upgrade'")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
|
if !headerListContainsValue(r.Header, "Upgrade", "websocket") {
|
||||||
return u.returnError(w, r, http.StatusBadRequest, "websocket: could not find upgrade header with token 'websocket'")
|
return u.returnError(w, r, http.StatusBadRequest, "websocket: could not find upgrade header with token 'websocket'")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -158,7 +176,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
||||||
p = append(p, "\r\n"...)
|
p = append(p, "\r\n"...)
|
||||||
}
|
}
|
||||||
for k, vs := range responseHeader {
|
for k, vs := range responseHeader {
|
||||||
if k == "Sec-Websocket-Protocol" {
|
if k == protocolHeader {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
for _, v := range vs {
|
for _, v := range vs {
|
||||||
|
@ -238,7 +256,11 @@ func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header,
|
||||||
// Subprotocols returns the subprotocols requested by the client in the
|
// Subprotocols returns the subprotocols requested by the client in the
|
||||||
// Sec-Websocket-Protocol header.
|
// Sec-Websocket-Protocol header.
|
||||||
func Subprotocols(r *http.Request) []string {
|
func Subprotocols(r *http.Request) []string {
|
||||||
h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol"))
|
return subprotocolsFromHeader(r.Header.Get(protocolHeader))
|
||||||
|
}
|
||||||
|
|
||||||
|
func subprotocolsFromHeader(header string) []string {
|
||||||
|
h := strings.TrimSpace(header)
|
||||||
if h == "" {
|
if h == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
23
util.go
23
util.go
|
@ -13,14 +13,19 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// tokenListContainsValue returns true if the 1#token header with the given
|
// headerListContainsValue returns true if the 1#token header with the given
|
||||||
// name contains token.
|
// name contains token.
|
||||||
func tokenListContainsValue(header http.Header, name string, value string) bool {
|
func headerListContainsValue(header http.Header, name string, value string) bool {
|
||||||
for _, v := range header[name] {
|
for _, v := range header[name] {
|
||||||
for _, s := range strings.Split(v, ",") {
|
return tokenListContainsValue(v, value)
|
||||||
if strings.EqualFold(value, strings.TrimSpace(s)) {
|
}
|
||||||
return true
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func tokenListContainsValue(list string, value string) bool {
|
||||||
|
for _, s := range strings.Split(list, ",") {
|
||||||
|
if strings.EqualFold(value, strings.TrimSpace(s)) {
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
|
@ -29,8 +34,12 @@ func tokenListContainsValue(header http.Header, name string, value string) bool
|
||||||
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
||||||
|
|
||||||
func computeAcceptKey(challengeKey string) string {
|
func computeAcceptKey(challengeKey string) string {
|
||||||
|
return computeAcceptKeyByte([]byte(challengeKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
func computeAcceptKeyByte(challengeKey []byte) string {
|
||||||
h := sha1.New()
|
h := sha1.New()
|
||||||
h.Write([]byte(challengeKey))
|
h.Write(challengeKey)
|
||||||
h.Write(keyGUID)
|
h.Write(keyGUID)
|
||||||
return base64.StdEncoding.EncodeToString(h.Sum(nil))
|
return base64.StdEncoding.EncodeToString(h.Sum(nil))
|
||||||
}
|
}
|
||||||
|
|
10
util_test.go
10
util_test.go
|
@ -9,7 +9,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
var tokenListContainsValueTests = []struct {
|
var headerListContainsValueTests = []struct {
|
||||||
value string
|
value string
|
||||||
ok bool
|
ok bool
|
||||||
}{
|
}{
|
||||||
|
@ -23,12 +23,12 @@ var tokenListContainsValueTests = []struct {
|
||||||
{"other, websocket, more", true},
|
{"other, websocket, more", true},
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTokenListContainsValue(t *testing.T) {
|
func TestHeaderListContainsValue(t *testing.T) {
|
||||||
for _, tt := range tokenListContainsValueTests {
|
for _, tt := range headerListContainsValueTests {
|
||||||
h := http.Header{"Upgrade": {tt.value}}
|
h := http.Header{"Upgrade": {tt.value}}
|
||||||
ok := tokenListContainsValue(h, "Upgrade", "websocket")
|
ok := headerListContainsValue(h, "Upgrade", "websocket")
|
||||||
if ok != tt.ok {
|
if ok != tt.ok {
|
||||||
t.Errorf("tokenListContainsValue(h, n, %q) = %v, want %v", tt.value, ok, tt.ok)
|
t.Errorf("headerListContainsValue(h, n, %q) = %v, want %v", tt.value, ok, tt.ok)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue