Add write buffer pooling

Add WriteBufferPool to Dialer and Upgrader. This field specifies a pool
to use for write operations on a connection.  Use of the pool can reduce
memory use when there is a modest write volume over a large number of
connections.

Use larger of hijacked buffer and buffer allocated for connection (if
any) as buffer for building handshake response. This decreases possible
allocations when building the handshake response.

Modify bufio reuse test to call Upgrade instead of the internal
newConnBRW. Move the test from conn_test.go to server_test.go because
it's a serer test.

Update newConn and newConnBRW:

- Move the bufio "hacks" from newConnBRW to separate functions and call
these functions directly from Upgrade.
- Rename newConn to newTestConn and move to conn_test.go. Shorten
argument list to common use case.
- Rename newConnBRW to newConn.
- Add pool code to newConn.
This commit is contained in:
Steven Scott 2018-08-17 19:50:34 -07:00 committed by Gary Burd
parent 5fb94172f4
commit b378caee5b
9 changed files with 328 additions and 117 deletions

View File

@ -69,6 +69,17 @@ type Dialer struct {
// do not limit the size of the messages that can be sent or received.
ReadBufferSize, WriteBufferSize int
// WriteBufferPool is a pool of buffers for write operations. If the value
// is not set, then write buffers are allocated to the connection for the
// lifetime of the connection.
//
// A pool is most useful when the application has a modest volume of writes
// across a large number of connections.
//
// Applications should use a single pool for each unique value of
// WriteBufferSize.
WriteBufferPool BufferPool
// Subprotocols specifies the client's requested subprotocols.
Subprotocols []string
@ -277,7 +288,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
}
}
conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize)
conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil)
if err := req.Write(netConn); err != nil {
return nil, nil, err

View File

@ -43,7 +43,7 @@ func textMessages(num int) [][]byte {
func BenchmarkWriteNoCompression(b *testing.B) {
w := ioutil.Discard
c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
c := newTestConn(nil, w, false)
messages := textMessages(100)
b.ResetTimer()
for i := 0; i < b.N; i++ {
@ -54,7 +54,7 @@ func BenchmarkWriteNoCompression(b *testing.B) {
func BenchmarkWriteWithCompression(b *testing.B) {
w := ioutil.Discard
c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
c := newTestConn(nil, w, false)
messages := textMessages(100)
c.enableWriteCompression = true
c.newCompressionWriter = compressNoContextTakeover
@ -66,7 +66,7 @@ func BenchmarkWriteWithCompression(b *testing.B) {
}
func TestValidCompressionLevel(t *testing.T) {
c := newConn(fakeNetConn{}, false, 1024, 1024)
c := newTestConn(nil, nil, false)
for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} {
if err := c.SetCompressionLevel(level); err == nil {
t.Errorf("no error for level %d", level)

89
conn.go
View File

@ -223,6 +223,20 @@ func isValidReceivedCloseCode(code int) bool {
return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999)
}
// BufferPool represents a pool of buffers. The *sync.Pool type satisfies this
// interface. The type of the value stored in a pool is not specified.
type BufferPool interface {
// Get gets a value from the pool or returns nil if the pool is empty.
Get() interface{}
// Put adds a value to the pool.
Put(interface{})
}
// writePoolData is the type added to the write buffer pool. This wrapper is
// used to prevent applications from peeking at and depending on the values
// added to the pool.
type writePoolData struct{ buf []byte }
// The Conn type represents a WebSocket connection.
type Conn struct {
conn net.Conn
@ -232,6 +246,8 @@ type Conn struct {
// Write fields
mu chan bool // used as mutex to protect write to conn
writeBuf []byte // frame is constructed in this buffer.
writePool BufferPool
writeBufSize int
writeDeadline time.Time
writer io.WriteCloser // the current writer returned to the application
isWriting bool // for best-effort concurrent write detection
@ -263,64 +279,29 @@ type Conn struct {
newDecompressionReader func(io.Reader) io.ReadCloser
}
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil)
}
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, writeBufferPool BufferPool, br *bufio.Reader, writeBuf []byte) *Conn {
type writeHook struct {
p []byte
}
func (wh *writeHook) Write(p []byte) (int, error) {
wh.p = p
return len(p), nil
}
func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, brw *bufio.ReadWriter) *Conn {
mu := make(chan bool, 1)
mu <- true
var br *bufio.Reader
if readBufferSize == 0 && brw != nil && brw.Reader != nil {
// Reuse the supplied bufio.Reader if the buffer has a useful size.
// This code assumes that peek on a reader returns
// bufio.Reader.buf[:0].
brw.Reader.Reset(conn)
if p, err := brw.Reader.Peek(0); err == nil && cap(p) >= 256 {
br = brw.Reader
}
}
if br == nil {
if readBufferSize == 0 {
readBufferSize = defaultReadBufferSize
}
if readBufferSize < maxControlFramePayloadSize {
} else if readBufferSize < maxControlFramePayloadSize {
// must be large enough for control frame
readBufferSize = maxControlFramePayloadSize
}
br = bufio.NewReaderSize(conn, readBufferSize)
}
var writeBuf []byte
if writeBufferSize == 0 && brw != nil && brw.Writer != nil {
// Use the bufio.Writer's buffer if the buffer has a useful size. This
// code assumes that bufio.Writer.buf[:1] is passed to the
// bufio.Writer's underlying writer.
var wh writeHook
brw.Writer.Reset(&wh)
brw.Writer.WriteByte(0)
brw.Flush()
if cap(wh.p) >= maxFrameHeaderSize+256 {
writeBuf = wh.p[:cap(wh.p)]
}
}
if writeBuf == nil {
if writeBufferSize == 0 {
if writeBufferSize <= 0 {
writeBufferSize = defaultWriteBufferSize
}
writeBuf = make([]byte, writeBufferSize+maxFrameHeaderSize)
writeBufferSize += maxFrameHeaderSize
if writeBuf == nil && writeBufferPool == nil {
writeBuf = make([]byte, writeBufferSize)
}
mu := make(chan bool, 1)
mu <- true
c := &Conn{
isServer: isServer,
br: br,
@ -328,6 +309,8 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in
mu: mu,
readFinal: true,
writeBuf: writeBuf,
writePool: writeBufferPool,
writeBufSize: writeBufferSize,
enableWriteCompression: true,
compressionLevel: defaultCompressionLevel,
}
@ -484,7 +467,19 @@ func (c *Conn) prepWrite(messageType int) error {
c.writeErrMu.Lock()
err := c.writeErr
c.writeErrMu.Unlock()
if err != nil {
return err
}
if c.writeBuf == nil {
wpd, ok := c.writePool.Get().(writePoolData)
if ok {
c.writeBuf = wpd.buf
} else {
c.writeBuf = make([]byte, c.writeBufSize)
}
}
return nil
}
// NextWriter returns a writer for the next message to send. The writer's Close
@ -610,6 +605,10 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
if final {
c.writer = nil
if c.writePool != nil {
c.writePool.Put(writePoolData{buf: c.writeBuf})
c.writeBuf = nil
}
return nil
}

View File

@ -70,7 +70,7 @@ func (b *broadcastBench) makeConns(numConns int) {
conns := make([]*broadcastConn, numConns)
for i := 0; i < numConns; i++ {
c := newConn(fakeNetConn{Reader: nil, Writer: b.w}, true, 1024, 1024)
c := newTestConn(nil, b.w, true)
if b.compression {
c.enableWriteCompression = true
c.newCompressionWriter = compressNoContextTakeover

View File

@ -13,6 +13,7 @@ import (
"io/ioutil"
"net"
"reflect"
"sync"
"testing"
"testing/iotest"
"time"
@ -47,6 +48,12 @@ func (a fakeAddr) String() string {
return "str"
}
// newTestConn creates a connnection backed by a fake network connection using
// default values for buffering.
func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn {
return newConn(fakeNetConn{Reader: r, Writer: w}, isServer, 1024, 1024, nil, nil, nil)
}
func TestFraming(t *testing.T) {
frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537}
var readChunkers = []struct {
@ -82,8 +89,8 @@ func TestFraming(t *testing.T) {
for _, chunker := range readChunkers {
var connBuf bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024)
wc := newTestConn(nil, &connBuf, isServer)
rc := newTestConn(chunker.f(&connBuf), nil, !isServer)
if compress {
wc.newCompressionWriter = compressNoContextTakeover
rc.newDecompressionReader = decompressNoContextTakeover
@ -143,8 +150,8 @@ func TestControl(t *testing.T) {
for _, isWriteControl := range []bool{true, false} {
name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl)
var connBuf bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
rc := newConn(fakeNetConn{Reader: &connBuf, Writer: nil}, !isServer, 1024, 1024)
wc := newTestConn(nil, &connBuf, isServer)
rc := newTestConn(&connBuf, nil, !isServer)
if isWriteControl {
wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
} else {
@ -173,14 +180,124 @@ func TestControl(t *testing.T) {
}
}
// simpleBufferPool is an implementation of BufferPool for TestWriteBufferPool.
type simpleBufferPool struct {
v interface{}
}
func (p *simpleBufferPool) Get() interface{} {
v := p.v
p.v = nil
return v
}
func (p *simpleBufferPool) Put(v interface{}) {
p.v = v
}
func TestWriteBufferPool(t *testing.T) {
var buf bytes.Buffer
var pool simpleBufferPool
wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil)
rc := newTestConn(&buf, nil, false)
if wc.writeBuf != nil {
t.Fatal("writeBuf not nil after create")
}
// Part 1: test NextWriter/Write/Close
w, err := wc.NextWriter(TextMessage)
if err != nil {
t.Fatalf("wc.NextWriter() returned %v", err)
}
if wc.writeBuf == nil {
t.Fatal("writeBuf is nil after NextWriter")
}
writeBufAddr := &wc.writeBuf[0]
const message = "Hello World!"
if _, err := io.WriteString(w, message); err != nil {
t.Fatalf("io.WriteString(w, message) returned %v", err)
}
if err := w.Close(); err != nil {
t.Fatalf("w.Close() returned %v", err)
}
if wc.writeBuf != nil {
t.Fatal("writeBuf not nil after w.Close()")
}
if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
t.Fatal("writeBuf not returned to pool")
}
opCode, p, err := rc.ReadMessage()
if opCode != TextMessage || err != nil {
t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
}
if s := string(p); s != message {
t.Fatalf("message is %s, want %s", s, message)
}
// Part 2: Test WriteMessage.
if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil {
t.Fatalf("wc.WriteMessage() returned %v", err)
}
if wc.writeBuf != nil {
t.Fatal("writeBuf not nil after wc.WriteMessage()")
}
if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
t.Fatal("writeBuf not returned to pool after WriteMessage")
}
opCode, p, err = rc.ReadMessage()
if opCode != TextMessage || err != nil {
t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
}
if s := string(p); s != message {
t.Fatalf("message is %s, want %s", s, message)
}
}
func TestWriteBufferPoolSync(t *testing.T) {
var buf bytes.Buffer
var pool sync.Pool
wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil)
rc := newTestConn(&buf, nil, false)
const message = "Hello World!"
for i := 0; i < 3; i++ {
if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil {
t.Fatalf("wc.WriteMessage() returned %v", err)
}
opCode, p, err := rc.ReadMessage()
if opCode != TextMessage || err != nil {
t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
}
if s := string(p); s != message {
t.Fatalf("message is %s, want %s", s, message)
}
}
}
func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
const bufSize = 512
expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
var b1, b2 bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
wc := newConn(&fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize, nil, nil, nil)
rc := newTestConn(&b1, &b2, true)
w, _ := wc.NextWriter(BinaryMessage)
w.Write(make([]byte, bufSize+bufSize/2))
@ -206,8 +323,8 @@ func TestEOFWithinFrame(t *testing.T) {
for n := 0; ; n++ {
var b bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &b}, false, 1024, 1024)
rc := newConn(fakeNetConn{Reader: &b, Writer: nil}, true, 1024, 1024)
wc := newTestConn(nil, &b, false)
rc := newTestConn(&b, nil, true)
w, _ := wc.NextWriter(BinaryMessage)
w.Write(make([]byte, bufSize))
@ -240,8 +357,8 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
const bufSize = 512
var b1, b2 bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, bufSize, nil, nil, nil)
rc := newTestConn(&b1, &b2, true)
w, _ := wc.NextWriter(BinaryMessage)
w.Write(make([]byte, bufSize+bufSize/2))
@ -261,7 +378,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
}
func TestWriteAfterMessageWriterClose(t *testing.T) {
wc := newConn(fakeNetConn{Reader: nil, Writer: &bytes.Buffer{}}, false, 1024, 1024)
wc := newTestConn(nil, &bytes.Buffer{}, false)
w, _ := wc.NextWriter(BinaryMessage)
io.WriteString(w, "hello")
if err := w.Close(); err != nil {
@ -292,8 +409,8 @@ func TestReadLimit(t *testing.T) {
message := make([]byte, readLimit+1)
var b1, b2 bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, readLimit-2)
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil)
rc := newTestConn(&b1, &b2, true)
rc.SetReadLimit(readLimit)
// Send message at the limit with interleaved pong.
@ -321,7 +438,7 @@ func TestReadLimit(t *testing.T) {
}
func TestAddrs(t *testing.T) {
c := newConn(&fakeNetConn{}, true, 1024, 1024)
c := newTestConn(nil, nil, true)
if c.LocalAddr() != localAddr {
t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr)
}
@ -333,7 +450,7 @@ func TestAddrs(t *testing.T) {
func TestUnderlyingConn(t *testing.T) {
var b1, b2 bytes.Buffer
fc := fakeNetConn{Reader: &b1, Writer: &b2}
c := newConn(fc, true, 1024, 1024)
c := newConn(fc, true, 1024, 1024, nil, nil, nil)
ul := c.UnderlyingConn()
if ul != fc {
t.Fatalf("Underlying conn is not what it should be.")
@ -347,8 +464,8 @@ func TestBufioReadBytes(t *testing.T) {
m[len(m)-1] = '\n'
var b1, b2 bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, len(m)+64, len(m)+64)
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64)
wc := newConn(fakeNetConn{Writer: &b1}, false, len(m)+64, len(m)+64, nil, nil, nil)
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil)
w, _ := wc.NextWriter(BinaryMessage)
w.Write(m)
@ -423,7 +540,7 @@ func (w blockingWriter) Write(p []byte) (int, error) {
func TestConcurrentWritePanic(t *testing.T) {
w := blockingWriter{make(chan struct{}), make(chan struct{})}
c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
c := newTestConn(nil, w, false)
go func() {
c.WriteMessage(TextMessage, []byte{})
}()
@ -449,7 +566,7 @@ func (r failingReader) Read(p []byte) (int, error) {
}
func TestFailedConnectionReadPanic(t *testing.T) {
c := newConn(fakeNetConn{Reader: failingReader{}, Writer: nil}, false, 1024, 1024)
c := newTestConn(failingReader{}, nil, false)
defer func() {
if v := recover(); v != nil {
@ -462,35 +579,3 @@ func TestFailedConnectionReadPanic(t *testing.T) {
}
t.Fatal("should not get here")
}
func TestBufioReuse(t *testing.T) {
brw := bufio.NewReadWriter(bufio.NewReader(nil), bufio.NewWriter(nil))
c := newConnBRW(nil, false, 0, 0, brw)
if c.br != brw.Reader {
t.Error("connection did not reuse bufio.Reader")
}
var wh writeHook
brw.Writer.Reset(&wh)
brw.WriteByte(0)
brw.Flush()
if &c.writeBuf[0] != &wh.p[0] {
t.Error("connection did not reuse bufio.Writer")
}
brw = bufio.NewReadWriter(bufio.NewReaderSize(nil, 0), bufio.NewWriterSize(nil, 0))
c = newConnBRW(nil, false, 0, 0, brw)
if c.br == brw.Reader {
t.Error("connection used bufio.Reader with small size")
}
brw.Writer.Reset(&wh)
brw.WriteByte(0)
brw.Flush()
if &c.writeBuf[0] != &wh.p[0] {
t.Error("connection used bufio.Writer with small size")
}
}

View File

@ -14,9 +14,8 @@ import (
func TestJSON(t *testing.T) {
var buf bytes.Buffer
c := fakeNetConn{&buf, &buf}
wc := newConn(c, true, 1024, 1024)
rc := newConn(c, false, 1024, 1024)
wc := newTestConn(nil, &buf, true)
rc := newTestConn(&buf, nil, false)
var actual, expect struct {
A int
@ -39,10 +38,9 @@ func TestJSON(t *testing.T) {
}
func TestPartialJSONRead(t *testing.T) {
var buf bytes.Buffer
c := fakeNetConn{&buf, &buf}
wc := newConn(c, true, 1024, 1024)
rc := newConn(c, false, 1024, 1024)
var buf0, buf1 bytes.Buffer
wc := newTestConn(nil, &buf0, true)
rc := newTestConn(&buf0, &buf1, false)
var v struct {
A int
@ -94,9 +92,8 @@ func TestPartialJSONRead(t *testing.T) {
func TestDeprecatedJSON(t *testing.T) {
var buf bytes.Buffer
c := fakeNetConn{&buf, &buf}
wc := newConn(c, true, 1024, 1024)
rc := newConn(c, false, 1024, 1024)
wc := newTestConn(nil, &buf, true)
rc := newTestConn(&buf, nil, false)
var actual, expect struct {
A int

View File

@ -36,7 +36,7 @@ func TestPreparedMessage(t *testing.T) {
for _, tt := range preparedMessageTests {
var data = []byte("this is a test")
var buf bytes.Buffer
c := newConn(fakeNetConn{Reader: nil, Writer: &buf}, tt.isServer, 1024, 1024)
c := newTestConn(nil, &buf, tt.isServer)
if tt.enableWriteCompression {
c.newCompressionWriter = compressNoContextTakeover
}

View File

@ -7,6 +7,7 @@ package websocket
import (
"bufio"
"errors"
"io"
"net"
"net/http"
"net/url"
@ -33,6 +34,17 @@ type Upgrader struct {
// or received.
ReadBufferSize, WriteBufferSize int
// WriteBufferPool is a pool of buffers for write operations. If the value
// is not set, then write buffers are allocated to the connection for the
// lifetime of the connection.
//
// A pool is most useful when the application has a modest volume of writes
// across a large number of connections.
//
// Applications should use a single pool for each unique value of
// WriteBufferSize.
WriteBufferPool BufferPool
// Subprotocols specifies the server's supported protocols in order of
// preference. If this field is not nil, then the Upgrade method negotiates a
// subprotocol by selecting the first match in this list with a protocol
@ -179,7 +191,21 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
return nil, errors.New("websocket: client sent data before handshake is complete")
}
c := newConnBRW(netConn, true, u.ReadBufferSize, u.WriteBufferSize, brw)
var br *bufio.Reader
if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 {
// Reuse hijacked buffered reader as connection reader.
br = brw.Reader
}
buf := bufioWriterBuffer(netConn, brw.Writer)
var writeBuf []byte
if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 {
// Reuse hijacked write buffer as connection buffer.
writeBuf = buf
}
c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf)
c.subprotocol = subprotocol
if compress {
@ -187,7 +213,13 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
c.newDecompressionReader = decompressNoContextTakeover
}
p := c.writeBuf[:0]
// Use larger of hijacked buffer and connection write buffer for header.
p := buf
if len(c.writeBuf) > len(p) {
p = c.writeBuf
}
p = p[:0]
p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
p = append(p, computeAcceptKey(challengeKey)...)
p = append(p, "\r\n"...)
@ -298,3 +330,40 @@ func IsWebSocketUpgrade(r *http.Request) bool {
return tokenListContainsValue(r.Header, "Connection", "upgrade") &&
tokenListContainsValue(r.Header, "Upgrade", "websocket")
}
// bufioReader size returns the size of a bufio.Reader.
func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int {
// This code assumes that peek on a reset reader returns
// bufio.Reader.buf[:0].
// TODO: Use bufio.Reader.Size() after Go 1.10
br.Reset(originalReader)
if p, err := br.Peek(0); err == nil {
return cap(p)
}
return 0
}
// writeHook is an io.Writer that records the last slice passed to it vio
// io.Writer.Write.
type writeHook struct {
p []byte
}
func (wh *writeHook) Write(p []byte) (int, error) {
wh.p = p
return len(p), nil
}
// bufioWriterBuffer grabs the buffer from a bufio.Writer.
func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte {
// This code assumes that bufio.Writer.buf[:1] is passed to the
// bufio.Writer's underlying writer.
var wh writeHook
bw.Reset(&wh)
bw.WriteByte(0)
bw.Flush()
bw.Reset(originalWriter)
return wh.p[:cap(wh.p)]
}

View File

@ -5,8 +5,12 @@
package websocket
import (
"bufio"
"bytes"
"net"
"net/http"
"reflect"
"strings"
"testing"
)
@ -67,3 +71,49 @@ func TestCheckSameOrigin(t *testing.T) {
}
}
}
type reuseTestResponseWriter struct {
brw *bufio.ReadWriter
http.ResponseWriter
}
func (resp *reuseTestResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return fakeNetConn{strings.NewReader(""), &bytes.Buffer{}}, resp.brw, nil
}
var bufioReuseTests = []struct {
n int
reuse bool
}{
{4096, true},
{128, false},
}
func TestBufioReuse(t *testing.T) {
for i, tt := range bufioReuseTests {
br := bufio.NewReaderSize(strings.NewReader(""), tt.n)
bw := bufio.NewWriterSize(&bytes.Buffer{}, tt.n)
resp := &reuseTestResponseWriter{
brw: bufio.NewReadWriter(br, bw),
}
upgrader := Upgrader{}
c, err := upgrader.Upgrade(resp, &http.Request{
Method: "GET",
Header: http.Header{
"Upgrade": []string{"websocket"},
"Connection": []string{"upgrade"},
"Sec-Websocket-Key": []string{"dGhlIHNhbXBsZSBub25jZQ=="},
"Sec-Websocket-Version": []string{"13"},
}}, nil)
if err != nil {
t.Fatal(err)
}
if reuse := c.br == br; reuse != tt.reuse {
t.Errorf("%d: buffered reader reuse=%v, want %v", i, reuse, tt.reuse)
}
writeBuf := bufioWriterBuffer(c.UnderlyingConn(), bw)
if reuse := &c.writeBuf[0] == &writeBuf[0]; reuse != tt.reuse {
t.Errorf("%d: write buffer reuse=%v, want %v", i, reuse, tt.reuse)
}
}
}