Merge branch 'feature/add_extension_header' into feature/impl_decompressContextTakeover

This commit is contained in:
misu 2018-01-21 15:28:20 +09:00
commit a2b487748b
1 changed files with 26 additions and 8 deletions

View File

@ -142,14 +142,21 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
subprotocol := u.selectSubprotocol(r, responseHeader) subprotocol := u.selectSubprotocol(r, responseHeader)
// Negotiate PMCE // Negotiate PMCE
var compress bool var (
compress bool
contextTakeover bool
)
if u.EnableCompression { if u.EnableCompression {
for _, ext := range parseExtensions(r.Header) { for _, ext := range parseExtensions(r.Header) {
if ext[""] != "permessage-deflate" { // map[string]string{"":"permessage-deflate", "client_max_window_bits":""}
continue // detect context-takeover from client_max_window_bits
if ext[""] == "permessage-deflate" {
compress = true
}
if _, ok := ext["client_max_window_bits"]; ok {
contextTakeover = true
} }
compress = true
break
} }
} }
@ -177,8 +184,14 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
c.subprotocol = subprotocol c.subprotocol = subprotocol
if compress { if compress {
c.newCompressionWriter = compressNoContextTakeover switch {
c.newDecompressionReader = decompressNoContextTakeover case contextTakeover:
c.newCompressionWriter = compressContextTakeover
c.newDecompressionReader = decompressContextTakeover
default:
c.newCompressionWriter = compressNoContextTakeover
c.newDecompressionReader = decompressNoContextTakeover
}
} }
p := c.writeBuf[:0] p := c.writeBuf[:0]
@ -191,7 +204,12 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
p = append(p, "\r\n"...) p = append(p, "\r\n"...)
} }
if compress { if compress {
p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) switch {
case contextTakeover:
p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_max_window_bits=15; client_max_window_bits=15\r\n"...)
default:
p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
}
} }
for k, vs := range responseHeader { for k, vs := range responseHeader {
if k == "Sec-Websocket-Protocol" { if k == "Sec-Websocket-Protocol" {