mod: modification of review

This commit is contained in:
misu 2018-03-05 17:57:15 +09:00
parent 416a1d5b7b
commit af46d4fe1f
7 changed files with 93 additions and 193 deletions

View File

@ -73,9 +73,8 @@ type Dialer struct {
// Subprotocols specifies the client's requested subprotocols.
Subprotocols []string
// EnableCompression specifies if the client should attempt to negotiate
// per message compression (RFC 7692). Setting this value to true does not
// guarantee that compression will be supported.
// EnableCompression specify if the server should attempt to negotiate per
// message compression (RFC 7692).
EnableCompression bool
// Jar specifies the cookie jar.
@ -83,13 +82,10 @@ type Dialer struct {
// in responses.
Jar http.CookieJar
// CompressionLevel is passed to conn when the compression setting is true.
CompressionLevel int
// EnableContextTakeover specifies specifies if the client should attempt to negotiate
// per message compression with context-takeover (RFC 7692).
// but window bits is allowed only 15, because go's flate library support 15 bits only.
EnableContextTakeover bool
// AllowClientContextTakeover specifies whether the server will negotiate client context
// takeover for per message compression. Context takeover improves compression at the
// the cost of using more memory.
AllowClientContextTakeover bool
}
var errMalformedURL = errors.New("malformed ws or wss URL")
@ -205,7 +201,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
}
switch {
case d.EnableCompression && d.EnableContextTakeover:
case d.EnableCompression && d.AllowClientContextTakeover:
req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_max_window_bits=15; client_max_window_bits=15")
case d.EnableCompression:
req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover")
@ -286,13 +282,6 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize)
if d.EnableCompression {
if !isValidCompressionLevel(d.CompressionLevel) {
return nil, nil, errors.New("websocket: invalid compression level")
}
conn.compressionLevel = d.CompressionLevel
}
if err := req.Write(netConn); err != nil {
return nil, nil, err
}
@ -331,10 +320,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
switch {
case cmwb && smwb:
conn.contextTakeover = true
var wf contextTakeoverWriterFactory
wf.fw, _ = flate.NewWriter(&wf.tw, d.CompressionLevel)
conn.newCompressionWriter = wf.newCompressionWriter
var rf contextTakeoverReaderFactory

View File

@ -40,9 +40,14 @@ var cstDialer = Dialer{
HandshakeTimeout: 30 * time.Second,
}
type cstHandler struct{ *testing.T }
type cstHandlerConfig struct {
contextTakeover bool
}
type cstContextTakeoverHandler struct{ *testing.T }
type cstHandler struct {
*testing.T
cstHandlerConfig
}
type cstServer struct {
*httptest.Server
@ -55,25 +60,17 @@ const (
cstRequestURI = cstPath + "?" + cstRawQuery
)
func newServer(t *testing.T) *cstServer {
func newServer(t *testing.T, c cstHandlerConfig) *cstServer {
var s cstServer
s.Server = httptest.NewServer(cstHandler{t})
s.Server = httptest.NewServer(cstHandler{t, c})
s.Server.URL += cstRequestURI
s.URL = makeWsProto(s.Server.URL)
return &s
}
func newContextTakeoverServer(t *testing.T) *cstServer {
func newTLSServer(t *testing.T, c cstHandlerConfig) *cstServer {
var s cstServer
s.Server = httptest.NewServer(cstContextTakeoverHandler{t})
s.Server.URL += cstRequestURI
s.URL = makeWsProto(s.Server.URL)
return &s
}
func newTLSServer(t *testing.T) *cstServer {
var s cstServer
s.Server = httptest.NewTLSServer(cstHandler{t})
s.Server = httptest.NewTLSServer(cstHandler{t, c})
s.Server.URL += cstRequestURI
s.URL = makeWsProto(s.Server.URL)
return &s
@ -96,6 +93,9 @@ func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, "bad protocol", http.StatusBadRequest)
return
}
if t.contextTakeover {
cstUpgrader.AllowServerContextTakeover = true
}
ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}})
if err != nil {
t.Logf("Upgrade: %v", err)
@ -126,79 +126,27 @@ func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
t.Logf("Close: %v", err)
return
}
}
func (t cstContextTakeoverHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != cstPath {
t.Logf("path=%v, want %v", r.URL.Path, cstPath)
http.Error(w, "bad path", 400)
return
}
if r.URL.RawQuery != cstRawQuery {
t.Logf("query=%v, want %v", r.URL.RawQuery, cstRawQuery)
http.Error(w, "bad path", 400)
return
}
subprotos := Subprotocols(r)
if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) {
t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols)
http.Error(w, "bad protocol", 400)
return
}
cu := cstUpgrader
cu.CompressionLevel = defaultCompressionLevel
cu.EnableContextTakeover = true
ws, err := cu.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}})
if err != nil {
t.Logf("Upgrade: %v", err)
return
}
defer ws.Close()
if ws.Subprotocol() != "p1" {
t.Logf("Subprotocol() = %s, want p1", ws.Subprotocol())
ws.Close()
return
}
// first message
op, rd, err := ws.NextReader()
if err != nil {
t.Logf("NextReader: %v", err)
return
}
wr, err := ws.NextWriter(op)
if err != nil {
t.Logf("NextWriter: %v", err)
return
}
if _, err = io.Copy(wr, rd); err != nil {
t.Logf("NextWriter: %v", err)
return
}
if err := wr.Close(); err != nil {
t.Logf("Close: %v", err)
return
}
// second message
op, rd, err = ws.NextReader()
if err != nil {
t.Logf("NextReader: %v", err)
return
}
wr, err = ws.NextWriter(op)
if err != nil {
t.Logf("NextWriter: %v", err)
return
}
if _, err = io.Copy(wr, rd); err != nil {
t.Logf("NextWriter: %v", err)
return
}
if err := wr.Close(); err != nil {
t.Logf("Close: %v", err)
return
// for multipleSendRecv when context takeover.
if t.contextTakeover {
op, rd, err := ws.NextReader()
if err != nil {
t.Logf("NextReader: %v", err)
return
}
wr, err := ws.NextWriter(op)
if err != nil {
t.Logf("NextWriter: %v", err)
return
}
if _, err = io.Copy(wr, rd); err != nil {
t.Logf("NextWriter: %v", err)
return
}
if err := wr.Close(); err != nil {
t.Logf("Close: %v", err)
return
}
}
}
@ -227,47 +175,29 @@ func sendRecv(t *testing.T, ws *Conn) {
}
func multipleSendRecv(t *testing.T, ws *Conn) {
message := "Hello World!"
if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil {
t.Fatalf("SetWriteDeadline: %v", err)
}
if err := ws.WriteMessage(TextMessage, []byte(message)); err != nil {
t.Fatalf("WriteMessage: %v", err)
}
if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
t.Fatalf("SetReadDeadline: %v", err)
}
_, p, err := ws.ReadMessage()
if err != nil {
t.Fatalf("ReadMessage: %v", err)
}
if string(p) != message {
t.Fatalf("message=%s, want %s", p, message)
}
nextMessage := "Can you read message?"
if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil {
t.Fatalf("SetWriteDeadline: %v", err)
}
if err := ws.WriteMessage(TextMessage, []byte(nextMessage)); err != nil {
t.Fatalf("_WriteMessage: %v", err)
}
if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
t.Fatalf("_SetReadDeadline: %v", err)
}
_, p, err = ws.ReadMessage()
if err != nil {
t.Fatalf("_ReadMessage: %v", err)
}
if string(p) != nextMessage {
t.Fatalf("_message=%s, want %s", p, nextMessage)
for _, message := range []string{"Hello World", "Can you read message?"} {
if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil {
t.Fatalf("SetWriteDeadline: %v", err)
}
if err := ws.WriteMessage(TextMessage, []byte(message)); err != nil {
t.Fatalf("WriteMessage: %v", err)
}
if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
t.Fatalf("SetReadDeadline: %v", err)
}
_, p, err := ws.ReadMessage()
if err != nil {
t.Fatalf("ReadMessage: %v", err)
}
if string(p) != message {
t.Fatalf("message=%s, want %s", p, message)
}
}
}
func TestProxyDial(t *testing.T) {
s := newServer(t)
s := newServer(t, cstHandlerConfig{})
defer s.Close()
surl, _ := url.Parse(s.Server.URL)
@ -304,7 +234,7 @@ func TestProxyDial(t *testing.T) {
}
func TestProxyAuthorizationDial(t *testing.T) {
s := newServer(t)
s := newServer(t, cstHandlerConfig{})
defer s.Close()
surl, _ := url.Parse(s.Server.URL)
@ -344,7 +274,7 @@ func TestProxyAuthorizationDial(t *testing.T) {
}
func TestDial(t *testing.T) {
s := newServer(t)
s := newServer(t, cstHandlerConfig{})
defer s.Close()
ws, _, err := cstDialer.Dial(s.URL, nil)
@ -356,7 +286,7 @@ func TestDial(t *testing.T) {
}
func TestDialCookieJar(t *testing.T) {
s := newServer(t)
s := newServer(t, cstHandlerConfig{})
defer s.Close()
jar, _ := cookiejar.New(nil)
@ -404,7 +334,7 @@ func TestDialCookieJar(t *testing.T) {
}
func TestDialTLS(t *testing.T) {
s := newTLSServer(t)
s := newTLSServer(t, cstHandlerConfig{})
defer s.Close()
certs := x509.NewCertPool()
@ -430,7 +360,7 @@ func TestDialTLS(t *testing.T) {
func xTestDialTLSBadCert(t *testing.T) {
// This test is deactivated because of noisy logging from the net/http package.
s := newTLSServer(t)
s := newTLSServer(t, cstHandlerConfig{})
defer s.Close()
ws, _, err := cstDialer.Dial(s.URL, nil)
@ -441,7 +371,7 @@ func xTestDialTLSBadCert(t *testing.T) {
}
func TestDialTLSNoVerify(t *testing.T) {
s := newTLSServer(t)
s := newTLSServer(t, cstHandlerConfig{})
defer s.Close()
d := cstDialer
@ -455,7 +385,7 @@ func TestDialTLSNoVerify(t *testing.T) {
}
func TestDialTimeout(t *testing.T) {
s := newServer(t)
s := newServer(t, cstHandlerConfig{})
defer s.Close()
d := cstDialer
@ -468,7 +398,7 @@ func TestDialTimeout(t *testing.T) {
}
func TestDialBadScheme(t *testing.T) {
s := newServer(t)
s := newServer(t, cstHandlerConfig{})
defer s.Close()
ws, _, err := cstDialer.Dial(s.Server.URL, nil)
@ -479,7 +409,7 @@ func TestDialBadScheme(t *testing.T) {
}
func TestDialBadOrigin(t *testing.T) {
s := newServer(t)
s := newServer(t, cstHandlerConfig{})
defer s.Close()
ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}})
@ -496,7 +426,7 @@ func TestDialBadOrigin(t *testing.T) {
}
func TestDialBadHeader(t *testing.T) {
s := newServer(t)
s := newServer(t, cstHandlerConfig{})
defer s.Close()
for _, k := range []string{"Upgrade",
@ -543,7 +473,7 @@ func TestBadMethod(t *testing.T) {
}
func TestHandshake(t *testing.T) {
s := newServer(t)
s := newServer(t, cstHandlerConfig{})
defer s.Close()
ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {s.URL}})
@ -605,7 +535,7 @@ func TestRespOnBadHandshake(t *testing.T) {
// TestHostHeader confirms that the host header provided in the call to Dial is
// sent to the server.
func TestHostHeader(t *testing.T) {
s := newServer(t)
s := newServer(t, cstHandlerConfig{})
defer s.Close()
specifiedHost := make(chan string, 1)
@ -632,7 +562,7 @@ func TestHostHeader(t *testing.T) {
}
func TestDialCompression(t *testing.T) {
s := newServer(t)
s := newServer(t, cstHandlerConfig{})
defer s.Close()
dialer := cstDialer
@ -646,13 +576,12 @@ func TestDialCompression(t *testing.T) {
}
func TestDialCompressionOfContextTakeover(t *testing.T) {
s := newContextTakeoverServer(t)
s := newServer(t, cstHandlerConfig{true})
defer s.Close()
dialer := cstDialer
dialer.EnableCompression = true
dialer.EnableContextTakeover = true
dialer.CompressionLevel = 2
dialer.AllowClientContextTakeover = true
ws, _, err := dialer.Dial(s.URL, nil)
if err != nil {
t.Fatalf("Dial: %v", err)
@ -663,7 +592,7 @@ func TestDialCompressionOfContextTakeover(t *testing.T) {
}
func TestSocksProxyDial(t *testing.T) {
s := newServer(t)
s := newServer(t, cstHandlerConfig{})
defer s.Close()
proxyListener, err := net.Listen("tcp", "127.0.0.1:0")

View File

@ -172,10 +172,17 @@ type (
}
)
func (f *contextTakeoverWriterFactory) newCompressionWriter(w io.WriteCloser, level int) io.WriteCloser {
f.tw.w = w
f.tw.n = 0
return &flateTakeoverWriteWrapper{f}
func (wf *contextTakeoverWriterFactory) newCompressionWriter(w io.WriteCloser, level int) io.WriteCloser {
// Set writer on first write.
// In order to guarantee the consistency of compression with the client,
// do not reassign later.
if wf.fw == nil {
wf.fw, _ = flate.NewWriter(&wf.tw, level)
}
wf.tw.w = w
wf.tw.n = 0
return &flateTakeoverWriteWrapper{wf}
}
func (w *flateTakeoverWriteWrapper) Write(p []byte) (int, error) {

View File

@ -2,7 +2,6 @@ package websocket
import (
"bytes"
"compress/flate"
"fmt"
"io"
"io/ioutil"
@ -71,9 +70,7 @@ func BenchmarkWriteWithCompressionOfContextTakeover(b *testing.B) {
c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
messages := textMessages(100)
c.enableWriteCompression = true
c.contextTakeover = true
var f contextTakeoverWriterFactory
f.fw, _ = flate.NewWriter(&f.tw, 2) // level is specified in Dialer, Upgrader
c.newCompressionWriter = f.newCompressionWriter
b.ResetTimer()
for i := 0; i < b.N; i++ {

View File

@ -263,8 +263,6 @@ type Conn struct {
readDecompress bool // whether last read frame had RSV1 set
newDecompressionReader func(io.Reader) io.ReadCloser // arges may flateReadWrapper struct
contextTakeover bool
}
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
@ -1145,8 +1143,6 @@ func (c *Conn) EnableWriteCompression(enable bool) {
// binary messages. This function is a noop if compression was not negotiated
// with the peer. See the compress/flate package for a description of
// compression levels.
// If you use context-takeover, do not specify a compression level from this method.
// Please set it to Dialer or Upgrader in advance.
func (c *Conn) SetCompressionLevel(level int) error {
if !isValidCompressionLevel(level) {
return errors.New("websocket: invalid compression level")

View File

@ -109,14 +109,12 @@ func TestFraming(t *testing.T) {
var wf contextTakeoverWriterFactory
wf.fw, _ = flate.NewWriter(&wf.tw, defaultCompressionLevel)
wc.newCompressionWriter = wf.newCompressionWriter
wc.contextTakeover = true
var rf contextTakeoverReaderFactory
fr := flate.NewReader(nil)
rf.fr = fr
rc.newDecompressionReader = rf.newDeCompressionReader
rc.contextTakeover = true
case compressCondition.compress:
wc.newCompressionWriter = compressNoContextTakeover
rc.newDecompressionReader = decompressNoContextTakeover

View File

@ -54,17 +54,13 @@ type Upgrader struct {
CheckOrigin func(r *http.Request) bool
// EnableCompression specify if the server should attempt to negotiate per
// message compression (RFC 7692). Setting this value to true does not
// guarantee that compression will be supported.
// message compression (RFC 7692).
EnableCompression bool
// CompressionLevel is passed to conn when the compression setting is true.
CompressionLevel int
// EnableContextTakeover specifies specifies if the client should attempt to negotiate
// per message compression with context-takeover (RFC 7692).
// but window bits is allowed only 15, because go's flate library support 15 bits only.
EnableContextTakeover bool
// AllowServerContextTakeover specifies whether the server will negotiate server context
// takeover for per message compression. Context takeover improves compression at the
// cost of using more memory.
AllowServerContextTakeover bool
}
func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) {
@ -196,18 +192,9 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
c.subprotocol = subprotocol
if compress {
if !isValidCompressionLevel(u.CompressionLevel) {
return nil, errors.New("websocket: invalid compression level")
}
c.compressionLevel = u.CompressionLevel
switch {
case contextTakeover && u.EnableContextTakeover:
c.contextTakeover = contextTakeover
case contextTakeover && u.AllowServerContextTakeover:
var wf contextTakeoverWriterFactory
wf.fw, _ = flate.NewWriter(&wf.tw, u.CompressionLevel)
c.newCompressionWriter = wf.newCompressionWriter
var rf contextTakeoverReaderFactory
@ -231,7 +218,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
}
if compress {
switch {
case contextTakeover && u.EnableContextTakeover:
case contextTakeover && u.AllowServerContextTakeover:
p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_max_window_bits=15; client_max_window_bits=15\r\n"...)
default:
p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)