Add support for fasthttp

This commit is contained in:
Ryan Leavengood 2015-12-15 11:37:40 -05:00
parent 844dd6d40e
commit 76d5f02c36
4 changed files with 176 additions and 29 deletions

116
fasthttp.go Normal file
View File

@ -0,0 +1,116 @@
// +build go1.4
package websocket
import (
"bytes"
"net"
"github.com/valyala/fasthttp"
)
func checkSameOriginFastHTTP(ctx *fasthttp.RequestCtx) bool {
return checkSameOriginFromHeaderAndHost(string(ctx.Request.Header.Peek(originHeader)), string(ctx.Host()))
}
// FastHTTPUpgrader is used to upgrade a fasthttp request into a websocket
// connection. A Handler function must be provided to receive that connection.
type FastHTTPUpgrader struct {
// Handler receives a websocket connection after the handshake has been
// completed. This must be provided.
Handler func(*Conn)
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer
// size is zero, then a default value of 4096 is used. The I/O buffer sizes
// do not limit the size of the messages that can be sent or received.
ReadBufferSize, WriteBufferSize int
// Subprotocols specifies the server's supported protocols in order of
// preference. If this field is set, then the Upgrade method negotiates a
// subprotocol by selecting the first match in this list with a protocol
// requested by the client.
Subprotocols []string
// CheckOrigin returns true if the request Origin header is acceptable. If
// CheckOrigin is nil, the host in the Origin header must not be set or
// must match the host of the request.
CheckOrigin func(ctx *fasthttp.RequestCtx) bool
}
var websocketVersionByte = []byte(websocketVersion)
// UpgradeHandler handles a request for a websocket connection and does all the
// checks necessary to ensure the request is valid. If a CheckOrigin function
// was provided, it will be called, otherwise the Origin will be checked against
// the request host value. If a subprotocol has not already been set, the best
// choice will be made from the values provided to the upgrader and from the
// client.
//
// Once the request has been verified and the response sent, the connection will
// be hijacked and the provided Handler will be called.
func (f *FastHTTPUpgrader) UpgradeHandler(ctx *fasthttp.RequestCtx) {
if f.Handler == nil {
panic("FastHTTPUpgrader does not have a Handler set")
}
if !ctx.IsGet() {
ctx.Error("websocket: method not GET", fasthttp.StatusMethodNotAllowed)
return
}
if !bytes.Equal(ctx.Request.Header.Peek("Sec-Websocket-Version"), websocketVersionByte) {
ctx.Error("websocket: version != 13", fasthttp.StatusBadRequest)
return
}
if !ctx.Request.Header.ConnectionUpgrade() {
ctx.Error("websocket: could not find connection header with token 'upgrade'", fasthttp.StatusBadRequest)
return
}
if !tokenListContainsValue(string(ctx.Request.Header.Peek("Upgrade")), "websocket") {
ctx.Error("websocket: could not find upgrade header with token 'websocket'", fasthttp.StatusBadRequest)
return
}
checkOrigin := f.CheckOrigin
if checkOrigin == nil {
checkOrigin = checkSameOriginFastHTTP
}
if !checkOrigin(ctx) {
ctx.Error("websocket: origin not allowed", fasthttp.StatusForbidden)
return
}
challengeKey := ctx.Request.Header.Peek("Sec-Websocket-Key")
if len(challengeKey) == 0 {
ctx.Error("websocket: key missing or blank", fasthttp.StatusBadRequest)
return
}
ctx.SetStatusCode(fasthttp.StatusSwitchingProtocols)
ctx.Response.Header.Set("Upgrade", "websocket")
ctx.Response.Header.Set("Connection", "Upgrade")
ctx.Response.Header.Set("Sec-WebSocket-Accept", computeAcceptKeyByte(challengeKey))
// The subprotocol may have already been set in the response
subprotocol := string(ctx.Response.Header.Peek(protocolHeader))
if subprotocol == "" {
// Find the best protocol, if any
clientProtocols := subprotocolsFromHeader(string(ctx.Request.Header.Peek(protocolHeader)))
if len(clientProtocols) != 0 {
subprotocol = matchSubprotocol(clientProtocols, f.Subprotocols)
if subprotocol != "" {
ctx.Response.Header.Set(protocolHeader, subprotocol)
}
}
}
ctx.Hijack(func(conn net.Conn) {
c := newConn(conn, true, f.ReadBufferSize, f.WriteBufferSize)
if subprotocol != "" {
c.subprotocol = subprotocol
}
f.Handler(c)
})
}

View File

@ -14,6 +14,12 @@ import (
"time" "time"
) )
const (
originHeader = "Origin"
protocolHeader = "Sec-Websocket-Protocol"
websocketVersion = "13"
)
// HandshakeError describes an error with the handshake from the peer. // HandshakeError describes an error with the handshake from the peer.
type HandshakeError struct { type HandshakeError struct {
message string message string
@ -60,30 +66,42 @@ func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status in
// checkSameOrigin returns true if the origin is not set or is equal to the request host. // checkSameOrigin returns true if the origin is not set or is equal to the request host.
func checkSameOrigin(r *http.Request) bool { func checkSameOrigin(r *http.Request) bool {
origin := r.Header["Origin"] origin := r.Header[originHeader]
if len(origin) == 0 { if len(origin) == 0 {
return true return true
} }
u, err := url.Parse(origin[0]) return checkSameOriginFromHeaderAndHost(origin[0], r.Host)
}
func checkSameOriginFromHeaderAndHost(origin, reqHost string) bool {
if len(origin) == 0 {
return true
}
u, err := url.Parse(origin)
if err != nil { if err != nil {
return false return false
} }
return u.Host == r.Host return u.Host == reqHost
} }
func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string { func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
if u.Subprotocols != nil { if u.Subprotocols != nil {
clientProtocols := Subprotocols(r) return matchSubprotocol(Subprotocols(r), u.Subprotocols)
for _, serverProtocol := range u.Subprotocols { } else if responseHeader != nil {
for _, clientProtocol := range clientProtocols { return responseHeader.Get(protocolHeader)
if clientProtocol == serverProtocol { }
return clientProtocol return ""
} }
func matchSubprotocol(clientProtocols, serverProtocols []string) string {
for _, serverProtocol := range serverProtocols {
for _, clientProtocol := range clientProtocols {
if clientProtocol == serverProtocol {
return clientProtocol
} }
} }
} else if responseHeader != nil {
return responseHeader.Get("Sec-Websocket-Protocol")
} }
return "" return ""
} }
@ -96,15 +114,15 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
if r.Method != "GET" { if r.Method != "GET" {
return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: method not GET") return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: method not GET")
} }
if values := r.Header["Sec-Websocket-Version"]; len(values) == 0 || values[0] != "13" { if values := r.Header["Sec-Websocket-Version"]; len(values) == 0 || values[0] != websocketVersion {
return u.returnError(w, r, http.StatusBadRequest, "websocket: version != 13") return u.returnError(w, r, http.StatusBadRequest, "websocket: version !="+websocketVersion)
} }
if !tokenListContainsValue(r.Header, "Connection", "upgrade") { if !headerListContainsValue(r.Header, "Connection", "upgrade") {
return u.returnError(w, r, http.StatusBadRequest, "websocket: could not find connection header with token 'upgrade'") return u.returnError(w, r, http.StatusBadRequest, "websocket: could not find connection header with token 'upgrade'")
} }
if !tokenListContainsValue(r.Header, "Upgrade", "websocket") { if !headerListContainsValue(r.Header, "Upgrade", "websocket") {
return u.returnError(w, r, http.StatusBadRequest, "websocket: could not find upgrade header with token 'websocket'") return u.returnError(w, r, http.StatusBadRequest, "websocket: could not find upgrade header with token 'websocket'")
} }
@ -158,7 +176,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
p = append(p, "\r\n"...) p = append(p, "\r\n"...)
} }
for k, vs := range responseHeader { for k, vs := range responseHeader {
if k == "Sec-Websocket-Protocol" { if k == protocolHeader {
continue continue
} }
for _, v := range vs { for _, v := range vs {
@ -238,7 +256,11 @@ func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header,
// Subprotocols returns the subprotocols requested by the client in the // Subprotocols returns the subprotocols requested by the client in the
// Sec-Websocket-Protocol header. // Sec-Websocket-Protocol header.
func Subprotocols(r *http.Request) []string { func Subprotocols(r *http.Request) []string {
h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol")) return subprotocolsFromHeader(r.Header.Get(protocolHeader))
}
func subprotocolsFromHeader(header string) []string {
h := strings.TrimSpace(header)
if h == "" { if h == "" {
return nil return nil
} }

23
util.go
View File

@ -13,14 +13,19 @@ import (
"strings" "strings"
) )
// tokenListContainsValue returns true if the 1#token header with the given // headerListContainsValue returns true if the 1#token header with the given
// name contains token. // name contains token.
func tokenListContainsValue(header http.Header, name string, value string) bool { func headerListContainsValue(header http.Header, name string, value string) bool {
for _, v := range header[name] { for _, v := range header[name] {
for _, s := range strings.Split(v, ",") { return tokenListContainsValue(v, value)
if strings.EqualFold(value, strings.TrimSpace(s)) { }
return true return false
} }
func tokenListContainsValue(list string, value string) bool {
for _, s := range strings.Split(list, ",") {
if strings.EqualFold(value, strings.TrimSpace(s)) {
return true
} }
} }
return false return false
@ -29,8 +34,12 @@ func tokenListContainsValue(header http.Header, name string, value string) bool
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func computeAcceptKey(challengeKey string) string { func computeAcceptKey(challengeKey string) string {
return computeAcceptKeyByte([]byte(challengeKey))
}
func computeAcceptKeyByte(challengeKey []byte) string {
h := sha1.New() h := sha1.New()
h.Write([]byte(challengeKey)) h.Write(challengeKey)
h.Write(keyGUID) h.Write(keyGUID)
return base64.StdEncoding.EncodeToString(h.Sum(nil)) return base64.StdEncoding.EncodeToString(h.Sum(nil))
} }

View File

@ -9,7 +9,7 @@ import (
"testing" "testing"
) )
var tokenListContainsValueTests = []struct { var headerListContainsValueTests = []struct {
value string value string
ok bool ok bool
}{ }{
@ -23,12 +23,12 @@ var tokenListContainsValueTests = []struct {
{"other, websocket, more", true}, {"other, websocket, more", true},
} }
func TestTokenListContainsValue(t *testing.T) { func TestHeaderListContainsValue(t *testing.T) {
for _, tt := range tokenListContainsValueTests { for _, tt := range headerListContainsValueTests {
h := http.Header{"Upgrade": {tt.value}} h := http.Header{"Upgrade": {tt.value}}
ok := tokenListContainsValue(h, "Upgrade", "websocket") ok := headerListContainsValue(h, "Upgrade", "websocket")
if ok != tt.ok { if ok != tt.ok {
t.Errorf("tokenListContainsValue(h, n, %q) = %v, want %v", tt.value, ok, tt.ok) t.Errorf("headerListContainsValue(h, n, %q) = %v, want %v", tt.value, ok, tt.ok)
} }
} }
} }