mirror of https://github.com/gorilla/websocket.git
Add support for fasthttp
This commit is contained in:
parent
844dd6d40e
commit
76d5f02c36
|
@ -0,0 +1,116 @@
|
|||
// +build go1.4
|
||||
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func checkSameOriginFastHTTP(ctx *fasthttp.RequestCtx) bool {
|
||||
return checkSameOriginFromHeaderAndHost(string(ctx.Request.Header.Peek(originHeader)), string(ctx.Host()))
|
||||
}
|
||||
|
||||
// FastHTTPUpgrader is used to upgrade a fasthttp request into a websocket
|
||||
// connection. A Handler function must be provided to receive that connection.
|
||||
type FastHTTPUpgrader struct {
|
||||
// Handler receives a websocket connection after the handshake has been
|
||||
// completed. This must be provided.
|
||||
Handler func(*Conn)
|
||||
|
||||
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer
|
||||
// size is zero, then a default value of 4096 is used. The I/O buffer sizes
|
||||
// do not limit the size of the messages that can be sent or received.
|
||||
ReadBufferSize, WriteBufferSize int
|
||||
|
||||
// Subprotocols specifies the server's supported protocols in order of
|
||||
// preference. If this field is set, then the Upgrade method negotiates a
|
||||
// subprotocol by selecting the first match in this list with a protocol
|
||||
// requested by the client.
|
||||
Subprotocols []string
|
||||
|
||||
// CheckOrigin returns true if the request Origin header is acceptable. If
|
||||
// CheckOrigin is nil, the host in the Origin header must not be set or
|
||||
// must match the host of the request.
|
||||
CheckOrigin func(ctx *fasthttp.RequestCtx) bool
|
||||
}
|
||||
|
||||
var websocketVersionByte = []byte(websocketVersion)
|
||||
|
||||
// UpgradeHandler handles a request for a websocket connection and does all the
|
||||
// checks necessary to ensure the request is valid. If a CheckOrigin function
|
||||
// was provided, it will be called, otherwise the Origin will be checked against
|
||||
// the request host value. If a subprotocol has not already been set, the best
|
||||
// choice will be made from the values provided to the upgrader and from the
|
||||
// client.
|
||||
//
|
||||
// Once the request has been verified and the response sent, the connection will
|
||||
// be hijacked and the provided Handler will be called.
|
||||
func (f *FastHTTPUpgrader) UpgradeHandler(ctx *fasthttp.RequestCtx) {
|
||||
if f.Handler == nil {
|
||||
panic("FastHTTPUpgrader does not have a Handler set")
|
||||
}
|
||||
|
||||
if !ctx.IsGet() {
|
||||
ctx.Error("websocket: method not GET", fasthttp.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
if !bytes.Equal(ctx.Request.Header.Peek("Sec-Websocket-Version"), websocketVersionByte) {
|
||||
ctx.Error("websocket: version != 13", fasthttp.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if !ctx.Request.Header.ConnectionUpgrade() {
|
||||
ctx.Error("websocket: could not find connection header with token 'upgrade'", fasthttp.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if !tokenListContainsValue(string(ctx.Request.Header.Peek("Upgrade")), "websocket") {
|
||||
ctx.Error("websocket: could not find upgrade header with token 'websocket'", fasthttp.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
checkOrigin := f.CheckOrigin
|
||||
if checkOrigin == nil {
|
||||
checkOrigin = checkSameOriginFastHTTP
|
||||
}
|
||||
if !checkOrigin(ctx) {
|
||||
ctx.Error("websocket: origin not allowed", fasthttp.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
challengeKey := ctx.Request.Header.Peek("Sec-Websocket-Key")
|
||||
if len(challengeKey) == 0 {
|
||||
ctx.Error("websocket: key missing or blank", fasthttp.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.SetStatusCode(fasthttp.StatusSwitchingProtocols)
|
||||
ctx.Response.Header.Set("Upgrade", "websocket")
|
||||
ctx.Response.Header.Set("Connection", "Upgrade")
|
||||
ctx.Response.Header.Set("Sec-WebSocket-Accept", computeAcceptKeyByte(challengeKey))
|
||||
|
||||
// The subprotocol may have already been set in the response
|
||||
subprotocol := string(ctx.Response.Header.Peek(protocolHeader))
|
||||
if subprotocol == "" {
|
||||
// Find the best protocol, if any
|
||||
clientProtocols := subprotocolsFromHeader(string(ctx.Request.Header.Peek(protocolHeader)))
|
||||
if len(clientProtocols) != 0 {
|
||||
subprotocol = matchSubprotocol(clientProtocols, f.Subprotocols)
|
||||
if subprotocol != "" {
|
||||
ctx.Response.Header.Set(protocolHeader, subprotocol)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Hijack(func(conn net.Conn) {
|
||||
c := newConn(conn, true, f.ReadBufferSize, f.WriteBufferSize)
|
||||
if subprotocol != "" {
|
||||
c.subprotocol = subprotocol
|
||||
}
|
||||
f.Handler(c)
|
||||
})
|
||||
}
|
56
server.go
56
server.go
|
@ -14,6 +14,12 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
originHeader = "Origin"
|
||||
protocolHeader = "Sec-Websocket-Protocol"
|
||||
websocketVersion = "13"
|
||||
)
|
||||
|
||||
// HandshakeError describes an error with the handshake from the peer.
|
||||
type HandshakeError struct {
|
||||
message string
|
||||
|
@ -60,30 +66,42 @@ func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status in
|
|||
|
||||
// checkSameOrigin returns true if the origin is not set or is equal to the request host.
|
||||
func checkSameOrigin(r *http.Request) bool {
|
||||
origin := r.Header["Origin"]
|
||||
origin := r.Header[originHeader]
|
||||
if len(origin) == 0 {
|
||||
return true
|
||||
}
|
||||
u, err := url.Parse(origin[0])
|
||||
return checkSameOriginFromHeaderAndHost(origin[0], r.Host)
|
||||
}
|
||||
|
||||
func checkSameOriginFromHeaderAndHost(origin, reqHost string) bool {
|
||||
if len(origin) == 0 {
|
||||
return true
|
||||
}
|
||||
u, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return u.Host == r.Host
|
||||
return u.Host == reqHost
|
||||
}
|
||||
|
||||
func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
|
||||
if u.Subprotocols != nil {
|
||||
clientProtocols := Subprotocols(r)
|
||||
for _, serverProtocol := range u.Subprotocols {
|
||||
for _, clientProtocol := range clientProtocols {
|
||||
if clientProtocol == serverProtocol {
|
||||
return clientProtocol
|
||||
}
|
||||
return matchSubprotocol(Subprotocols(r), u.Subprotocols)
|
||||
} else if responseHeader != nil {
|
||||
return responseHeader.Get(protocolHeader)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func matchSubprotocol(clientProtocols, serverProtocols []string) string {
|
||||
for _, serverProtocol := range serverProtocols {
|
||||
for _, clientProtocol := range clientProtocols {
|
||||
if clientProtocol == serverProtocol {
|
||||
return clientProtocol
|
||||
}
|
||||
}
|
||||
} else if responseHeader != nil {
|
||||
return responseHeader.Get("Sec-Websocket-Protocol")
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
|
@ -96,15 +114,15 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
|||
if r.Method != "GET" {
|
||||
return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: method not GET")
|
||||
}
|
||||
if values := r.Header["Sec-Websocket-Version"]; len(values) == 0 || values[0] != "13" {
|
||||
return u.returnError(w, r, http.StatusBadRequest, "websocket: version != 13")
|
||||
if values := r.Header["Sec-Websocket-Version"]; len(values) == 0 || values[0] != websocketVersion {
|
||||
return u.returnError(w, r, http.StatusBadRequest, "websocket: version !="+websocketVersion)
|
||||
}
|
||||
|
||||
if !tokenListContainsValue(r.Header, "Connection", "upgrade") {
|
||||
if !headerListContainsValue(r.Header, "Connection", "upgrade") {
|
||||
return u.returnError(w, r, http.StatusBadRequest, "websocket: could not find connection header with token 'upgrade'")
|
||||
}
|
||||
|
||||
if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
|
||||
if !headerListContainsValue(r.Header, "Upgrade", "websocket") {
|
||||
return u.returnError(w, r, http.StatusBadRequest, "websocket: could not find upgrade header with token 'websocket'")
|
||||
}
|
||||
|
||||
|
@ -158,7 +176,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
|||
p = append(p, "\r\n"...)
|
||||
}
|
||||
for k, vs := range responseHeader {
|
||||
if k == "Sec-Websocket-Protocol" {
|
||||
if k == protocolHeader {
|
||||
continue
|
||||
}
|
||||
for _, v := range vs {
|
||||
|
@ -238,7 +256,11 @@ func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header,
|
|||
// Subprotocols returns the subprotocols requested by the client in the
|
||||
// Sec-Websocket-Protocol header.
|
||||
func Subprotocols(r *http.Request) []string {
|
||||
h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol"))
|
||||
return subprotocolsFromHeader(r.Header.Get(protocolHeader))
|
||||
}
|
||||
|
||||
func subprotocolsFromHeader(header string) []string {
|
||||
h := strings.TrimSpace(header)
|
||||
if h == "" {
|
||||
return nil
|
||||
}
|
||||
|
|
23
util.go
23
util.go
|
@ -13,14 +13,19 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
// tokenListContainsValue returns true if the 1#token header with the given
|
||||
// headerListContainsValue returns true if the 1#token header with the given
|
||||
// name contains token.
|
||||
func tokenListContainsValue(header http.Header, name string, value string) bool {
|
||||
func headerListContainsValue(header http.Header, name string, value string) bool {
|
||||
for _, v := range header[name] {
|
||||
for _, s := range strings.Split(v, ",") {
|
||||
if strings.EqualFold(value, strings.TrimSpace(s)) {
|
||||
return true
|
||||
}
|
||||
return tokenListContainsValue(v, value)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func tokenListContainsValue(list string, value string) bool {
|
||||
for _, s := range strings.Split(list, ",") {
|
||||
if strings.EqualFold(value, strings.TrimSpace(s)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
|
@ -29,8 +34,12 @@ func tokenListContainsValue(header http.Header, name string, value string) bool
|
|||
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
||||
|
||||
func computeAcceptKey(challengeKey string) string {
|
||||
return computeAcceptKeyByte([]byte(challengeKey))
|
||||
}
|
||||
|
||||
func computeAcceptKeyByte(challengeKey []byte) string {
|
||||
h := sha1.New()
|
||||
h.Write([]byte(challengeKey))
|
||||
h.Write(challengeKey)
|
||||
h.Write(keyGUID)
|
||||
return base64.StdEncoding.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
|
10
util_test.go
10
util_test.go
|
@ -9,7 +9,7 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
var tokenListContainsValueTests = []struct {
|
||||
var headerListContainsValueTests = []struct {
|
||||
value string
|
||||
ok bool
|
||||
}{
|
||||
|
@ -23,12 +23,12 @@ var tokenListContainsValueTests = []struct {
|
|||
{"other, websocket, more", true},
|
||||
}
|
||||
|
||||
func TestTokenListContainsValue(t *testing.T) {
|
||||
for _, tt := range tokenListContainsValueTests {
|
||||
func TestHeaderListContainsValue(t *testing.T) {
|
||||
for _, tt := range headerListContainsValueTests {
|
||||
h := http.Header{"Upgrade": {tt.value}}
|
||||
ok := tokenListContainsValue(h, "Upgrade", "websocket")
|
||||
ok := headerListContainsValue(h, "Upgrade", "websocket")
|
||||
if ok != tt.ok {
|
||||
t.Errorf("tokenListContainsValue(h, n, %q) = %v, want %v", tt.value, ok, tt.ok)
|
||||
t.Errorf("headerListContainsValue(h, n, %q) = %v, want %v", tt.value, ok, tt.ok)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue