From b378caee5b2a673abe3897c40fde529bff7b986e Mon Sep 17 00:00:00 2001 From: Steven Scott Date: Fri, 17 Aug 2018 19:50:34 -0700 Subject: [PATCH] Add write buffer pooling Add WriteBufferPool to Dialer and Upgrader. This field specifies a pool to use for write operations on a connection. Use of the pool can reduce memory use when there is a modest write volume over a large number of connections. Use larger of hijacked buffer and buffer allocated for connection (if any) as buffer for building handshake response. This decreases possible allocations when building the handshake response. Modify bufio reuse test to call Upgrade instead of the internal newConnBRW. Move the test from conn_test.go to server_test.go because it's a serer test. Update newConn and newConnBRW: - Move the bufio "hacks" from newConnBRW to separate functions and call these functions directly from Upgrade. - Rename newConn to newTestConn and move to conn_test.go. Shorten argument list to common use case. - Rename newConnBRW to newConn. - Add pool code to newConn. --- client.go | 13 ++- compression_test.go | 6 +- conn.go | 95 +++++++++++---------- conn_broadcast_test.go | 2 +- conn_test.go | 187 ++++++++++++++++++++++++++++++----------- json_test.go | 17 ++-- prepared_test.go | 2 +- server.go | 73 +++++++++++++++- server_test.go | 50 +++++++++++ 9 files changed, 328 insertions(+), 117 deletions(-) diff --git a/client.go b/client.go index 41f8ed5..7a24309 100644 --- a/client.go +++ b/client.go @@ -69,6 +69,17 @@ type Dialer struct { // do not limit the size of the messages that can be sent or received. ReadBufferSize, WriteBufferSize int + // WriteBufferPool is a pool of buffers for write operations. If the value + // is not set, then write buffers are allocated to the connection for the + // lifetime of the connection. + // + // A pool is most useful when the application has a modest volume of writes + // across a large number of connections. + // + // Applications should use a single pool for each unique value of + // WriteBufferSize. + WriteBufferPool BufferPool + // Subprotocols specifies the client's requested subprotocols. Subprotocols []string @@ -277,7 +288,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re } } - conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize) + conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil) if err := req.Write(netConn); err != nil { return nil, nil, err diff --git a/compression_test.go b/compression_test.go index 659cf42..8a26b30 100644 --- a/compression_test.go +++ b/compression_test.go @@ -43,7 +43,7 @@ func textMessages(num int) [][]byte { func BenchmarkWriteNoCompression(b *testing.B) { w := ioutil.Discard - c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) + c := newTestConn(nil, w, false) messages := textMessages(100) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -54,7 +54,7 @@ func BenchmarkWriteNoCompression(b *testing.B) { func BenchmarkWriteWithCompression(b *testing.B) { w := ioutil.Discard - c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) + c := newTestConn(nil, w, false) messages := textMessages(100) c.enableWriteCompression = true c.newCompressionWriter = compressNoContextTakeover @@ -66,7 +66,7 @@ func BenchmarkWriteWithCompression(b *testing.B) { } func TestValidCompressionLevel(t *testing.T) { - c := newConn(fakeNetConn{}, false, 1024, 1024) + c := newTestConn(nil, nil, false) for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} { if err := c.SetCompressionLevel(level); err == nil { t.Errorf("no error for level %d", level) diff --git a/conn.go b/conn.go index 78f972b..d2a21c1 100644 --- a/conn.go +++ b/conn.go @@ -223,6 +223,20 @@ func isValidReceivedCloseCode(code int) bool { return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) } +// BufferPool represents a pool of buffers. The *sync.Pool type satisfies this +// interface. The type of the value stored in a pool is not specified. +type BufferPool interface { + // Get gets a value from the pool or returns nil if the pool is empty. + Get() interface{} + // Put adds a value to the pool. + Put(interface{}) +} + +// writePoolData is the type added to the write buffer pool. This wrapper is +// used to prevent applications from peeking at and depending on the values +// added to the pool. +type writePoolData struct{ buf []byte } + // The Conn type represents a WebSocket connection. type Conn struct { conn net.Conn @@ -232,6 +246,8 @@ type Conn struct { // Write fields mu chan bool // used as mutex to protect write to conn writeBuf []byte // frame is constructed in this buffer. + writePool BufferPool + writeBufSize int writeDeadline time.Time writer io.WriteCloser // the current writer returned to the application isWriting bool // for best-effort concurrent write detection @@ -263,64 +279,29 @@ type Conn struct { newDecompressionReader func(io.Reader) io.ReadCloser } -func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { - return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil) -} +func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, writeBufferPool BufferPool, br *bufio.Reader, writeBuf []byte) *Conn { -type writeHook struct { - p []byte -} - -func (wh *writeHook) Write(p []byte) (int, error) { - wh.p = p - return len(p), nil -} - -func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, brw *bufio.ReadWriter) *Conn { - mu := make(chan bool, 1) - mu <- true - - var br *bufio.Reader - if readBufferSize == 0 && brw != nil && brw.Reader != nil { - // Reuse the supplied bufio.Reader if the buffer has a useful size. - // This code assumes that peek on a reader returns - // bufio.Reader.buf[:0]. - brw.Reader.Reset(conn) - if p, err := brw.Reader.Peek(0); err == nil && cap(p) >= 256 { - br = brw.Reader - } - } if br == nil { if readBufferSize == 0 { readBufferSize = defaultReadBufferSize - } - if readBufferSize < maxControlFramePayloadSize { + } else if readBufferSize < maxControlFramePayloadSize { + // must be large enough for control frame readBufferSize = maxControlFramePayloadSize } br = bufio.NewReaderSize(conn, readBufferSize) } - var writeBuf []byte - if writeBufferSize == 0 && brw != nil && brw.Writer != nil { - // Use the bufio.Writer's buffer if the buffer has a useful size. This - // code assumes that bufio.Writer.buf[:1] is passed to the - // bufio.Writer's underlying writer. - var wh writeHook - brw.Writer.Reset(&wh) - brw.Writer.WriteByte(0) - brw.Flush() - if cap(wh.p) >= maxFrameHeaderSize+256 { - writeBuf = wh.p[:cap(wh.p)] - } - } - - if writeBuf == nil { - if writeBufferSize == 0 { - writeBufferSize = defaultWriteBufferSize - } - writeBuf = make([]byte, writeBufferSize+maxFrameHeaderSize) + if writeBufferSize <= 0 { + writeBufferSize = defaultWriteBufferSize + } + writeBufferSize += maxFrameHeaderSize + + if writeBuf == nil && writeBufferPool == nil { + writeBuf = make([]byte, writeBufferSize) } + mu := make(chan bool, 1) + mu <- true c := &Conn{ isServer: isServer, br: br, @@ -328,6 +309,8 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in mu: mu, readFinal: true, writeBuf: writeBuf, + writePool: writeBufferPool, + writeBufSize: writeBufferSize, enableWriteCompression: true, compressionLevel: defaultCompressionLevel, } @@ -484,7 +467,19 @@ func (c *Conn) prepWrite(messageType int) error { c.writeErrMu.Lock() err := c.writeErr c.writeErrMu.Unlock() - return err + if err != nil { + return err + } + + if c.writeBuf == nil { + wpd, ok := c.writePool.Get().(writePoolData) + if ok { + c.writeBuf = wpd.buf + } else { + c.writeBuf = make([]byte, c.writeBufSize) + } + } + return nil } // NextWriter returns a writer for the next message to send. The writer's Close @@ -610,6 +605,10 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error { if final { c.writer = nil + if c.writePool != nil { + c.writePool.Put(writePoolData{buf: c.writeBuf}) + c.writeBuf = nil + } return nil } diff --git a/conn_broadcast_test.go b/conn_broadcast_test.go index 45038e4..2e71c07 100644 --- a/conn_broadcast_test.go +++ b/conn_broadcast_test.go @@ -70,7 +70,7 @@ func (b *broadcastBench) makeConns(numConns int) { conns := make([]*broadcastConn, numConns) for i := 0; i < numConns; i++ { - c := newConn(fakeNetConn{Reader: nil, Writer: b.w}, true, 1024, 1024) + c := newTestConn(nil, b.w, true) if b.compression { c.enableWriteCompression = true c.newCompressionWriter = compressNoContextTakeover diff --git a/conn_test.go b/conn_test.go index 5fda7b5..5ca0673 100644 --- a/conn_test.go +++ b/conn_test.go @@ -13,6 +13,7 @@ import ( "io/ioutil" "net" "reflect" + "sync" "testing" "testing/iotest" "time" @@ -47,6 +48,12 @@ func (a fakeAddr) String() string { return "str" } +// newTestConn creates a connnection backed by a fake network connection using +// default values for buffering. +func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn { + return newConn(fakeNetConn{Reader: r, Writer: w}, isServer, 1024, 1024, nil, nil, nil) +} + func TestFraming(t *testing.T) { frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537} var readChunkers = []struct { @@ -82,8 +89,8 @@ func TestFraming(t *testing.T) { for _, chunker := range readChunkers { var connBuf bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024) - rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024) + wc := newTestConn(nil, &connBuf, isServer) + rc := newTestConn(chunker.f(&connBuf), nil, !isServer) if compress { wc.newCompressionWriter = compressNoContextTakeover rc.newDecompressionReader = decompressNoContextTakeover @@ -143,8 +150,8 @@ func TestControl(t *testing.T) { for _, isWriteControl := range []bool{true, false} { name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl) var connBuf bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024) - rc := newConn(fakeNetConn{Reader: &connBuf, Writer: nil}, !isServer, 1024, 1024) + wc := newTestConn(nil, &connBuf, isServer) + rc := newTestConn(&connBuf, nil, !isServer) if isWriteControl { wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)) } else { @@ -173,14 +180,124 @@ func TestControl(t *testing.T) { } } +// simpleBufferPool is an implementation of BufferPool for TestWriteBufferPool. +type simpleBufferPool struct { + v interface{} +} + +func (p *simpleBufferPool) Get() interface{} { + v := p.v + p.v = nil + return v +} + +func (p *simpleBufferPool) Put(v interface{}) { + p.v = v +} + +func TestWriteBufferPool(t *testing.T) { + var buf bytes.Buffer + var pool simpleBufferPool + wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil) + rc := newTestConn(&buf, nil, false) + + if wc.writeBuf != nil { + t.Fatal("writeBuf not nil after create") + } + + // Part 1: test NextWriter/Write/Close + + w, err := wc.NextWriter(TextMessage) + if err != nil { + t.Fatalf("wc.NextWriter() returned %v", err) + } + + if wc.writeBuf == nil { + t.Fatal("writeBuf is nil after NextWriter") + } + + writeBufAddr := &wc.writeBuf[0] + + const message = "Hello World!" + + if _, err := io.WriteString(w, message); err != nil { + t.Fatalf("io.WriteString(w, message) returned %v", err) + } + + if err := w.Close(); err != nil { + t.Fatalf("w.Close() returned %v", err) + } + + if wc.writeBuf != nil { + t.Fatal("writeBuf not nil after w.Close()") + } + + if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { + t.Fatal("writeBuf not returned to pool") + } + + opCode, p, err := rc.ReadMessage() + if opCode != TextMessage || err != nil { + t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) + } + + if s := string(p); s != message { + t.Fatalf("message is %s, want %s", s, message) + } + + // Part 2: Test WriteMessage. + + if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil { + t.Fatalf("wc.WriteMessage() returned %v", err) + } + + if wc.writeBuf != nil { + t.Fatal("writeBuf not nil after wc.WriteMessage()") + } + + if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { + t.Fatal("writeBuf not returned to pool after WriteMessage") + } + + opCode, p, err = rc.ReadMessage() + if opCode != TextMessage || err != nil { + t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) + } + + if s := string(p); s != message { + t.Fatalf("message is %s, want %s", s, message) + } +} + +func TestWriteBufferPoolSync(t *testing.T) { + var buf bytes.Buffer + var pool sync.Pool + wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil) + rc := newTestConn(&buf, nil, false) + + const message = "Hello World!" + for i := 0; i < 3; i++ { + if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil { + t.Fatalf("wc.WriteMessage() returned %v", err) + } + opCode, p, err := rc.ReadMessage() + if opCode != TextMessage || err != nil { + t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) + } + if s := string(p); s != message { + t.Fatalf("message is %s, want %s", s, message) + } + } +} + func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) { const bufSize = 512 expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"} var b1, b2 bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize) - rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024) + wc := newConn(&fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize, nil, nil, nil) + rc := newTestConn(&b1, &b2, true) w, _ := wc.NextWriter(BinaryMessage) w.Write(make([]byte, bufSize+bufSize/2)) @@ -206,8 +323,8 @@ func TestEOFWithinFrame(t *testing.T) { for n := 0; ; n++ { var b bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &b}, false, 1024, 1024) - rc := newConn(fakeNetConn{Reader: &b, Writer: nil}, true, 1024, 1024) + wc := newTestConn(nil, &b, false) + rc := newTestConn(&b, nil, true) w, _ := wc.NextWriter(BinaryMessage) w.Write(make([]byte, bufSize)) @@ -240,8 +357,8 @@ func TestEOFBeforeFinalFrame(t *testing.T) { const bufSize = 512 var b1, b2 bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize) - rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024) + wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, bufSize, nil, nil, nil) + rc := newTestConn(&b1, &b2, true) w, _ := wc.NextWriter(BinaryMessage) w.Write(make([]byte, bufSize+bufSize/2)) @@ -261,7 +378,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) { } func TestWriteAfterMessageWriterClose(t *testing.T) { - wc := newConn(fakeNetConn{Reader: nil, Writer: &bytes.Buffer{}}, false, 1024, 1024) + wc := newTestConn(nil, &bytes.Buffer{}, false) w, _ := wc.NextWriter(BinaryMessage) io.WriteString(w, "hello") if err := w.Close(); err != nil { @@ -292,8 +409,8 @@ func TestReadLimit(t *testing.T) { message := make([]byte, readLimit+1) var b1, b2 bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, readLimit-2) - rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024) + wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil) + rc := newTestConn(&b1, &b2, true) rc.SetReadLimit(readLimit) // Send message at the limit with interleaved pong. @@ -321,7 +438,7 @@ func TestReadLimit(t *testing.T) { } func TestAddrs(t *testing.T) { - c := newConn(&fakeNetConn{}, true, 1024, 1024) + c := newTestConn(nil, nil, true) if c.LocalAddr() != localAddr { t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr) } @@ -333,7 +450,7 @@ func TestAddrs(t *testing.T) { func TestUnderlyingConn(t *testing.T) { var b1, b2 bytes.Buffer fc := fakeNetConn{Reader: &b1, Writer: &b2} - c := newConn(fc, true, 1024, 1024) + c := newConn(fc, true, 1024, 1024, nil, nil, nil) ul := c.UnderlyingConn() if ul != fc { t.Fatalf("Underlying conn is not what it should be.") @@ -347,8 +464,8 @@ func TestBufioReadBytes(t *testing.T) { m[len(m)-1] = '\n' var b1, b2 bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, len(m)+64, len(m)+64) - rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64) + wc := newConn(fakeNetConn{Writer: &b1}, false, len(m)+64, len(m)+64, nil, nil, nil) + rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil) w, _ := wc.NextWriter(BinaryMessage) w.Write(m) @@ -423,7 +540,7 @@ func (w blockingWriter) Write(p []byte) (int, error) { func TestConcurrentWritePanic(t *testing.T) { w := blockingWriter{make(chan struct{}), make(chan struct{})} - c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) + c := newTestConn(nil, w, false) go func() { c.WriteMessage(TextMessage, []byte{}) }() @@ -449,7 +566,7 @@ func (r failingReader) Read(p []byte) (int, error) { } func TestFailedConnectionReadPanic(t *testing.T) { - c := newConn(fakeNetConn{Reader: failingReader{}, Writer: nil}, false, 1024, 1024) + c := newTestConn(failingReader{}, nil, false) defer func() { if v := recover(); v != nil { @@ -462,35 +579,3 @@ func TestFailedConnectionReadPanic(t *testing.T) { } t.Fatal("should not get here") } - -func TestBufioReuse(t *testing.T) { - brw := bufio.NewReadWriter(bufio.NewReader(nil), bufio.NewWriter(nil)) - c := newConnBRW(nil, false, 0, 0, brw) - - if c.br != brw.Reader { - t.Error("connection did not reuse bufio.Reader") - } - - var wh writeHook - brw.Writer.Reset(&wh) - brw.WriteByte(0) - brw.Flush() - if &c.writeBuf[0] != &wh.p[0] { - t.Error("connection did not reuse bufio.Writer") - } - - brw = bufio.NewReadWriter(bufio.NewReaderSize(nil, 0), bufio.NewWriterSize(nil, 0)) - c = newConnBRW(nil, false, 0, 0, brw) - - if c.br == brw.Reader { - t.Error("connection used bufio.Reader with small size") - } - - brw.Writer.Reset(&wh) - brw.WriteByte(0) - brw.Flush() - if &c.writeBuf[0] != &wh.p[0] { - t.Error("connection used bufio.Writer with small size") - } - -} diff --git a/json_test.go b/json_test.go index 61100e4..e4c4bdf 100644 --- a/json_test.go +++ b/json_test.go @@ -14,9 +14,8 @@ import ( func TestJSON(t *testing.T) { var buf bytes.Buffer - c := fakeNetConn{&buf, &buf} - wc := newConn(c, true, 1024, 1024) - rc := newConn(c, false, 1024, 1024) + wc := newTestConn(nil, &buf, true) + rc := newTestConn(&buf, nil, false) var actual, expect struct { A int @@ -39,10 +38,9 @@ func TestJSON(t *testing.T) { } func TestPartialJSONRead(t *testing.T) { - var buf bytes.Buffer - c := fakeNetConn{&buf, &buf} - wc := newConn(c, true, 1024, 1024) - rc := newConn(c, false, 1024, 1024) + var buf0, buf1 bytes.Buffer + wc := newTestConn(nil, &buf0, true) + rc := newTestConn(&buf0, &buf1, false) var v struct { A int @@ -94,9 +92,8 @@ func TestPartialJSONRead(t *testing.T) { func TestDeprecatedJSON(t *testing.T) { var buf bytes.Buffer - c := fakeNetConn{&buf, &buf} - wc := newConn(c, true, 1024, 1024) - rc := newConn(c, false, 1024, 1024) + wc := newTestConn(nil, &buf, true) + rc := newTestConn(&buf, nil, false) var actual, expect struct { A int diff --git a/prepared_test.go b/prepared_test.go index cf98c6c..2297802 100644 --- a/prepared_test.go +++ b/prepared_test.go @@ -36,7 +36,7 @@ func TestPreparedMessage(t *testing.T) { for _, tt := range preparedMessageTests { var data = []byte("this is a test") var buf bytes.Buffer - c := newConn(fakeNetConn{Reader: nil, Writer: &buf}, tt.isServer, 1024, 1024) + c := newTestConn(nil, &buf, tt.isServer) if tt.enableWriteCompression { c.newCompressionWriter = compressNoContextTakeover } diff --git a/server.go b/server.go index 4834c38..3bd627b 100644 --- a/server.go +++ b/server.go @@ -7,6 +7,7 @@ package websocket import ( "bufio" "errors" + "io" "net" "net/http" "net/url" @@ -33,6 +34,17 @@ type Upgrader struct { // or received. ReadBufferSize, WriteBufferSize int + // WriteBufferPool is a pool of buffers for write operations. If the value + // is not set, then write buffers are allocated to the connection for the + // lifetime of the connection. + // + // A pool is most useful when the application has a modest volume of writes + // across a large number of connections. + // + // Applications should use a single pool for each unique value of + // WriteBufferSize. + WriteBufferPool BufferPool + // Subprotocols specifies the server's supported protocols in order of // preference. If this field is not nil, then the Upgrade method negotiates a // subprotocol by selecting the first match in this list with a protocol @@ -179,7 +191,21 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade return nil, errors.New("websocket: client sent data before handshake is complete") } - c := newConnBRW(netConn, true, u.ReadBufferSize, u.WriteBufferSize, brw) + var br *bufio.Reader + if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 { + // Reuse hijacked buffered reader as connection reader. + br = brw.Reader + } + + buf := bufioWriterBuffer(netConn, brw.Writer) + + var writeBuf []byte + if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 { + // Reuse hijacked write buffer as connection buffer. + writeBuf = buf + } + + c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf) c.subprotocol = subprotocol if compress { @@ -187,7 +213,13 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade c.newDecompressionReader = decompressNoContextTakeover } - p := c.writeBuf[:0] + // Use larger of hijacked buffer and connection write buffer for header. + p := buf + if len(c.writeBuf) > len(p) { + p = c.writeBuf + } + p = p[:0] + p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) p = append(p, computeAcceptKey(challengeKey)...) p = append(p, "\r\n"...) @@ -298,3 +330,40 @@ func IsWebSocketUpgrade(r *http.Request) bool { return tokenListContainsValue(r.Header, "Connection", "upgrade") && tokenListContainsValue(r.Header, "Upgrade", "websocket") } + +// bufioReader size returns the size of a bufio.Reader. +func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int { + // This code assumes that peek on a reset reader returns + // bufio.Reader.buf[:0]. + // TODO: Use bufio.Reader.Size() after Go 1.10 + br.Reset(originalReader) + if p, err := br.Peek(0); err == nil { + return cap(p) + } + return 0 +} + +// writeHook is an io.Writer that records the last slice passed to it vio +// io.Writer.Write. +type writeHook struct { + p []byte +} + +func (wh *writeHook) Write(p []byte) (int, error) { + wh.p = p + return len(p), nil +} + +// bufioWriterBuffer grabs the buffer from a bufio.Writer. +func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte { + // This code assumes that bufio.Writer.buf[:1] is passed to the + // bufio.Writer's underlying writer. + var wh writeHook + bw.Reset(&wh) + bw.WriteByte(0) + bw.Flush() + + bw.Reset(originalWriter) + + return wh.p[:cap(wh.p)] +} diff --git a/server_test.go b/server_test.go index c43dbb2..0d64075 100644 --- a/server_test.go +++ b/server_test.go @@ -5,8 +5,12 @@ package websocket import ( + "bufio" + "bytes" + "net" "net/http" "reflect" + "strings" "testing" ) @@ -67,3 +71,49 @@ func TestCheckSameOrigin(t *testing.T) { } } } + +type reuseTestResponseWriter struct { + brw *bufio.ReadWriter + http.ResponseWriter +} + +func (resp *reuseTestResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return fakeNetConn{strings.NewReader(""), &bytes.Buffer{}}, resp.brw, nil +} + +var bufioReuseTests = []struct { + n int + reuse bool +}{ + {4096, true}, + {128, false}, +} + +func TestBufioReuse(t *testing.T) { + for i, tt := range bufioReuseTests { + br := bufio.NewReaderSize(strings.NewReader(""), tt.n) + bw := bufio.NewWriterSize(&bytes.Buffer{}, tt.n) + resp := &reuseTestResponseWriter{ + brw: bufio.NewReadWriter(br, bw), + } + upgrader := Upgrader{} + c, err := upgrader.Upgrade(resp, &http.Request{ + Method: "GET", + Header: http.Header{ + "Upgrade": []string{"websocket"}, + "Connection": []string{"upgrade"}, + "Sec-Websocket-Key": []string{"dGhlIHNhbXBsZSBub25jZQ=="}, + "Sec-Websocket-Version": []string{"13"}, + }}, nil) + if err != nil { + t.Fatal(err) + } + if reuse := c.br == br; reuse != tt.reuse { + t.Errorf("%d: buffered reader reuse=%v, want %v", i, reuse, tt.reuse) + } + writeBuf := bufioWriterBuffer(c.UnderlyingConn(), bw) + if reuse := &c.writeBuf[0] == &writeBuf[0]; reuse != tt.reuse { + t.Errorf("%d: write buffer reuse=%v, want %v", i, reuse, tt.reuse) + } + } +}