mirror of https://github.com/gorilla/websocket.git
Add write buffer pooling
Add WriteBufferPool to Dialer and Upgrader. This field specifies a pool to use for write operations on a connection. Use of the pool can reduce memory use when there is a modest write volume over a large number of connections. Use larger of hijacked buffer and buffer allocated for connection (if any) as buffer for building handshake response. This decreases possible allocations when building the handshake response. Modify bufio reuse test to call Upgrade instead of the internal newConnBRW. Move the test from conn_test.go to server_test.go because it's a serer test. Update newConn and newConnBRW: - Move the bufio "hacks" from newConnBRW to separate functions and call these functions directly from Upgrade. - Rename newConn to newTestConn and move to conn_test.go. Shorten argument list to common use case. - Rename newConnBRW to newConn. - Add pool code to newConn.
This commit is contained in:
parent
5fb94172f4
commit
b378caee5b
13
client.go
13
client.go
|
@ -69,6 +69,17 @@ type Dialer struct {
|
|||
// do not limit the size of the messages that can be sent or received.
|
||||
ReadBufferSize, WriteBufferSize int
|
||||
|
||||
// WriteBufferPool is a pool of buffers for write operations. If the value
|
||||
// is not set, then write buffers are allocated to the connection for the
|
||||
// lifetime of the connection.
|
||||
//
|
||||
// A pool is most useful when the application has a modest volume of writes
|
||||
// across a large number of connections.
|
||||
//
|
||||
// Applications should use a single pool for each unique value of
|
||||
// WriteBufferSize.
|
||||
WriteBufferPool BufferPool
|
||||
|
||||
// Subprotocols specifies the client's requested subprotocols.
|
||||
Subprotocols []string
|
||||
|
||||
|
@ -277,7 +288,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
|||
}
|
||||
}
|
||||
|
||||
conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize)
|
||||
conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil)
|
||||
|
||||
if err := req.Write(netConn); err != nil {
|
||||
return nil, nil, err
|
||||
|
|
|
@ -43,7 +43,7 @@ func textMessages(num int) [][]byte {
|
|||
|
||||
func BenchmarkWriteNoCompression(b *testing.B) {
|
||||
w := ioutil.Discard
|
||||
c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
|
||||
c := newTestConn(nil, w, false)
|
||||
messages := textMessages(100)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
|
@ -54,7 +54,7 @@ func BenchmarkWriteNoCompression(b *testing.B) {
|
|||
|
||||
func BenchmarkWriteWithCompression(b *testing.B) {
|
||||
w := ioutil.Discard
|
||||
c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
|
||||
c := newTestConn(nil, w, false)
|
||||
messages := textMessages(100)
|
||||
c.enableWriteCompression = true
|
||||
c.newCompressionWriter = compressNoContextTakeover
|
||||
|
@ -66,7 +66,7 @@ func BenchmarkWriteWithCompression(b *testing.B) {
|
|||
}
|
||||
|
||||
func TestValidCompressionLevel(t *testing.T) {
|
||||
c := newConn(fakeNetConn{}, false, 1024, 1024)
|
||||
c := newTestConn(nil, nil, false)
|
||||
for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} {
|
||||
if err := c.SetCompressionLevel(level); err == nil {
|
||||
t.Errorf("no error for level %d", level)
|
||||
|
|
89
conn.go
89
conn.go
|
@ -223,6 +223,20 @@ func isValidReceivedCloseCode(code int) bool {
|
|||
return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999)
|
||||
}
|
||||
|
||||
// BufferPool represents a pool of buffers. The *sync.Pool type satisfies this
|
||||
// interface. The type of the value stored in a pool is not specified.
|
||||
type BufferPool interface {
|
||||
// Get gets a value from the pool or returns nil if the pool is empty.
|
||||
Get() interface{}
|
||||
// Put adds a value to the pool.
|
||||
Put(interface{})
|
||||
}
|
||||
|
||||
// writePoolData is the type added to the write buffer pool. This wrapper is
|
||||
// used to prevent applications from peeking at and depending on the values
|
||||
// added to the pool.
|
||||
type writePoolData struct{ buf []byte }
|
||||
|
||||
// The Conn type represents a WebSocket connection.
|
||||
type Conn struct {
|
||||
conn net.Conn
|
||||
|
@ -232,6 +246,8 @@ type Conn struct {
|
|||
// Write fields
|
||||
mu chan bool // used as mutex to protect write to conn
|
||||
writeBuf []byte // frame is constructed in this buffer.
|
||||
writePool BufferPool
|
||||
writeBufSize int
|
||||
writeDeadline time.Time
|
||||
writer io.WriteCloser // the current writer returned to the application
|
||||
isWriting bool // for best-effort concurrent write detection
|
||||
|
@ -263,64 +279,29 @@ type Conn struct {
|
|||
newDecompressionReader func(io.Reader) io.ReadCloser
|
||||
}
|
||||
|
||||
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
|
||||
return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil)
|
||||
}
|
||||
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, writeBufferPool BufferPool, br *bufio.Reader, writeBuf []byte) *Conn {
|
||||
|
||||
type writeHook struct {
|
||||
p []byte
|
||||
}
|
||||
|
||||
func (wh *writeHook) Write(p []byte) (int, error) {
|
||||
wh.p = p
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, brw *bufio.ReadWriter) *Conn {
|
||||
mu := make(chan bool, 1)
|
||||
mu <- true
|
||||
|
||||
var br *bufio.Reader
|
||||
if readBufferSize == 0 && brw != nil && brw.Reader != nil {
|
||||
// Reuse the supplied bufio.Reader if the buffer has a useful size.
|
||||
// This code assumes that peek on a reader returns
|
||||
// bufio.Reader.buf[:0].
|
||||
brw.Reader.Reset(conn)
|
||||
if p, err := brw.Reader.Peek(0); err == nil && cap(p) >= 256 {
|
||||
br = brw.Reader
|
||||
}
|
||||
}
|
||||
if br == nil {
|
||||
if readBufferSize == 0 {
|
||||
readBufferSize = defaultReadBufferSize
|
||||
}
|
||||
if readBufferSize < maxControlFramePayloadSize {
|
||||
} else if readBufferSize < maxControlFramePayloadSize {
|
||||
// must be large enough for control frame
|
||||
readBufferSize = maxControlFramePayloadSize
|
||||
}
|
||||
br = bufio.NewReaderSize(conn, readBufferSize)
|
||||
}
|
||||
|
||||
var writeBuf []byte
|
||||
if writeBufferSize == 0 && brw != nil && brw.Writer != nil {
|
||||
// Use the bufio.Writer's buffer if the buffer has a useful size. This
|
||||
// code assumes that bufio.Writer.buf[:1] is passed to the
|
||||
// bufio.Writer's underlying writer.
|
||||
var wh writeHook
|
||||
brw.Writer.Reset(&wh)
|
||||
brw.Writer.WriteByte(0)
|
||||
brw.Flush()
|
||||
if cap(wh.p) >= maxFrameHeaderSize+256 {
|
||||
writeBuf = wh.p[:cap(wh.p)]
|
||||
}
|
||||
}
|
||||
|
||||
if writeBuf == nil {
|
||||
if writeBufferSize == 0 {
|
||||
if writeBufferSize <= 0 {
|
||||
writeBufferSize = defaultWriteBufferSize
|
||||
}
|
||||
writeBuf = make([]byte, writeBufferSize+maxFrameHeaderSize)
|
||||
writeBufferSize += maxFrameHeaderSize
|
||||
|
||||
if writeBuf == nil && writeBufferPool == nil {
|
||||
writeBuf = make([]byte, writeBufferSize)
|
||||
}
|
||||
|
||||
mu := make(chan bool, 1)
|
||||
mu <- true
|
||||
c := &Conn{
|
||||
isServer: isServer,
|
||||
br: br,
|
||||
|
@ -328,6 +309,8 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in
|
|||
mu: mu,
|
||||
readFinal: true,
|
||||
writeBuf: writeBuf,
|
||||
writePool: writeBufferPool,
|
||||
writeBufSize: writeBufferSize,
|
||||
enableWriteCompression: true,
|
||||
compressionLevel: defaultCompressionLevel,
|
||||
}
|
||||
|
@ -484,7 +467,19 @@ func (c *Conn) prepWrite(messageType int) error {
|
|||
c.writeErrMu.Lock()
|
||||
err := c.writeErr
|
||||
c.writeErrMu.Unlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.writeBuf == nil {
|
||||
wpd, ok := c.writePool.Get().(writePoolData)
|
||||
if ok {
|
||||
c.writeBuf = wpd.buf
|
||||
} else {
|
||||
c.writeBuf = make([]byte, c.writeBufSize)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NextWriter returns a writer for the next message to send. The writer's Close
|
||||
|
@ -610,6 +605,10 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
|
|||
|
||||
if final {
|
||||
c.writer = nil
|
||||
if c.writePool != nil {
|
||||
c.writePool.Put(writePoolData{buf: c.writeBuf})
|
||||
c.writeBuf = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -70,7 +70,7 @@ func (b *broadcastBench) makeConns(numConns int) {
|
|||
conns := make([]*broadcastConn, numConns)
|
||||
|
||||
for i := 0; i < numConns; i++ {
|
||||
c := newConn(fakeNetConn{Reader: nil, Writer: b.w}, true, 1024, 1024)
|
||||
c := newTestConn(nil, b.w, true)
|
||||
if b.compression {
|
||||
c.enableWriteCompression = true
|
||||
c.newCompressionWriter = compressNoContextTakeover
|
||||
|
|
187
conn_test.go
187
conn_test.go
|
@ -13,6 +13,7 @@ import (
|
|||
"io/ioutil"
|
||||
"net"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"testing/iotest"
|
||||
"time"
|
||||
|
@ -47,6 +48,12 @@ func (a fakeAddr) String() string {
|
|||
return "str"
|
||||
}
|
||||
|
||||
// newTestConn creates a connnection backed by a fake network connection using
|
||||
// default values for buffering.
|
||||
func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn {
|
||||
return newConn(fakeNetConn{Reader: r, Writer: w}, isServer, 1024, 1024, nil, nil, nil)
|
||||
}
|
||||
|
||||
func TestFraming(t *testing.T) {
|
||||
frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537}
|
||||
var readChunkers = []struct {
|
||||
|
@ -82,8 +89,8 @@ func TestFraming(t *testing.T) {
|
|||
for _, chunker := range readChunkers {
|
||||
|
||||
var connBuf bytes.Buffer
|
||||
wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
|
||||
rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024)
|
||||
wc := newTestConn(nil, &connBuf, isServer)
|
||||
rc := newTestConn(chunker.f(&connBuf), nil, !isServer)
|
||||
if compress {
|
||||
wc.newCompressionWriter = compressNoContextTakeover
|
||||
rc.newDecompressionReader = decompressNoContextTakeover
|
||||
|
@ -143,8 +150,8 @@ func TestControl(t *testing.T) {
|
|||
for _, isWriteControl := range []bool{true, false} {
|
||||
name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl)
|
||||
var connBuf bytes.Buffer
|
||||
wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
|
||||
rc := newConn(fakeNetConn{Reader: &connBuf, Writer: nil}, !isServer, 1024, 1024)
|
||||
wc := newTestConn(nil, &connBuf, isServer)
|
||||
rc := newTestConn(&connBuf, nil, !isServer)
|
||||
if isWriteControl {
|
||||
wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
|
||||
} else {
|
||||
|
@ -173,14 +180,124 @@ func TestControl(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// simpleBufferPool is an implementation of BufferPool for TestWriteBufferPool.
|
||||
type simpleBufferPool struct {
|
||||
v interface{}
|
||||
}
|
||||
|
||||
func (p *simpleBufferPool) Get() interface{} {
|
||||
v := p.v
|
||||
p.v = nil
|
||||
return v
|
||||
}
|
||||
|
||||
func (p *simpleBufferPool) Put(v interface{}) {
|
||||
p.v = v
|
||||
}
|
||||
|
||||
func TestWriteBufferPool(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
var pool simpleBufferPool
|
||||
wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil)
|
||||
rc := newTestConn(&buf, nil, false)
|
||||
|
||||
if wc.writeBuf != nil {
|
||||
t.Fatal("writeBuf not nil after create")
|
||||
}
|
||||
|
||||
// Part 1: test NextWriter/Write/Close
|
||||
|
||||
w, err := wc.NextWriter(TextMessage)
|
||||
if err != nil {
|
||||
t.Fatalf("wc.NextWriter() returned %v", err)
|
||||
}
|
||||
|
||||
if wc.writeBuf == nil {
|
||||
t.Fatal("writeBuf is nil after NextWriter")
|
||||
}
|
||||
|
||||
writeBufAddr := &wc.writeBuf[0]
|
||||
|
||||
const message = "Hello World!"
|
||||
|
||||
if _, err := io.WriteString(w, message); err != nil {
|
||||
t.Fatalf("io.WriteString(w, message) returned %v", err)
|
||||
}
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatalf("w.Close() returned %v", err)
|
||||
}
|
||||
|
||||
if wc.writeBuf != nil {
|
||||
t.Fatal("writeBuf not nil after w.Close()")
|
||||
}
|
||||
|
||||
if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
|
||||
t.Fatal("writeBuf not returned to pool")
|
||||
}
|
||||
|
||||
opCode, p, err := rc.ReadMessage()
|
||||
if opCode != TextMessage || err != nil {
|
||||
t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
|
||||
}
|
||||
|
||||
if s := string(p); s != message {
|
||||
t.Fatalf("message is %s, want %s", s, message)
|
||||
}
|
||||
|
||||
// Part 2: Test WriteMessage.
|
||||
|
||||
if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil {
|
||||
t.Fatalf("wc.WriteMessage() returned %v", err)
|
||||
}
|
||||
|
||||
if wc.writeBuf != nil {
|
||||
t.Fatal("writeBuf not nil after wc.WriteMessage()")
|
||||
}
|
||||
|
||||
if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
|
||||
t.Fatal("writeBuf not returned to pool after WriteMessage")
|
||||
}
|
||||
|
||||
opCode, p, err = rc.ReadMessage()
|
||||
if opCode != TextMessage || err != nil {
|
||||
t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
|
||||
}
|
||||
|
||||
if s := string(p); s != message {
|
||||
t.Fatalf("message is %s, want %s", s, message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteBufferPoolSync(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
var pool sync.Pool
|
||||
wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil)
|
||||
rc := newTestConn(&buf, nil, false)
|
||||
|
||||
const message = "Hello World!"
|
||||
for i := 0; i < 3; i++ {
|
||||
if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil {
|
||||
t.Fatalf("wc.WriteMessage() returned %v", err)
|
||||
}
|
||||
opCode, p, err := rc.ReadMessage()
|
||||
if opCode != TextMessage || err != nil {
|
||||
t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
|
||||
}
|
||||
if s := string(p); s != message {
|
||||
t.Fatalf("message is %s, want %s", s, message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
|
||||
const bufSize = 512
|
||||
|
||||
expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
|
||||
|
||||
var b1, b2 bytes.Buffer
|
||||
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
|
||||
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
|
||||
wc := newConn(&fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize, nil, nil, nil)
|
||||
rc := newTestConn(&b1, &b2, true)
|
||||
|
||||
w, _ := wc.NextWriter(BinaryMessage)
|
||||
w.Write(make([]byte, bufSize+bufSize/2))
|
||||
|
@ -206,8 +323,8 @@ func TestEOFWithinFrame(t *testing.T) {
|
|||
|
||||
for n := 0; ; n++ {
|
||||
var b bytes.Buffer
|
||||
wc := newConn(fakeNetConn{Reader: nil, Writer: &b}, false, 1024, 1024)
|
||||
rc := newConn(fakeNetConn{Reader: &b, Writer: nil}, true, 1024, 1024)
|
||||
wc := newTestConn(nil, &b, false)
|
||||
rc := newTestConn(&b, nil, true)
|
||||
|
||||
w, _ := wc.NextWriter(BinaryMessage)
|
||||
w.Write(make([]byte, bufSize))
|
||||
|
@ -240,8 +357,8 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
|
|||
const bufSize = 512
|
||||
|
||||
var b1, b2 bytes.Buffer
|
||||
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
|
||||
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
|
||||
wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, bufSize, nil, nil, nil)
|
||||
rc := newTestConn(&b1, &b2, true)
|
||||
|
||||
w, _ := wc.NextWriter(BinaryMessage)
|
||||
w.Write(make([]byte, bufSize+bufSize/2))
|
||||
|
@ -261,7 +378,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestWriteAfterMessageWriterClose(t *testing.T) {
|
||||
wc := newConn(fakeNetConn{Reader: nil, Writer: &bytes.Buffer{}}, false, 1024, 1024)
|
||||
wc := newTestConn(nil, &bytes.Buffer{}, false)
|
||||
w, _ := wc.NextWriter(BinaryMessage)
|
||||
io.WriteString(w, "hello")
|
||||
if err := w.Close(); err != nil {
|
||||
|
@ -292,8 +409,8 @@ func TestReadLimit(t *testing.T) {
|
|||
message := make([]byte, readLimit+1)
|
||||
|
||||
var b1, b2 bytes.Buffer
|
||||
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, readLimit-2)
|
||||
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
|
||||
wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil)
|
||||
rc := newTestConn(&b1, &b2, true)
|
||||
rc.SetReadLimit(readLimit)
|
||||
|
||||
// Send message at the limit with interleaved pong.
|
||||
|
@ -321,7 +438,7 @@ func TestReadLimit(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAddrs(t *testing.T) {
|
||||
c := newConn(&fakeNetConn{}, true, 1024, 1024)
|
||||
c := newTestConn(nil, nil, true)
|
||||
if c.LocalAddr() != localAddr {
|
||||
t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr)
|
||||
}
|
||||
|
@ -333,7 +450,7 @@ func TestAddrs(t *testing.T) {
|
|||
func TestUnderlyingConn(t *testing.T) {
|
||||
var b1, b2 bytes.Buffer
|
||||
fc := fakeNetConn{Reader: &b1, Writer: &b2}
|
||||
c := newConn(fc, true, 1024, 1024)
|
||||
c := newConn(fc, true, 1024, 1024, nil, nil, nil)
|
||||
ul := c.UnderlyingConn()
|
||||
if ul != fc {
|
||||
t.Fatalf("Underlying conn is not what it should be.")
|
||||
|
@ -347,8 +464,8 @@ func TestBufioReadBytes(t *testing.T) {
|
|||
m[len(m)-1] = '\n'
|
||||
|
||||
var b1, b2 bytes.Buffer
|
||||
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, len(m)+64, len(m)+64)
|
||||
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64)
|
||||
wc := newConn(fakeNetConn{Writer: &b1}, false, len(m)+64, len(m)+64, nil, nil, nil)
|
||||
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil)
|
||||
|
||||
w, _ := wc.NextWriter(BinaryMessage)
|
||||
w.Write(m)
|
||||
|
@ -423,7 +540,7 @@ func (w blockingWriter) Write(p []byte) (int, error) {
|
|||
|
||||
func TestConcurrentWritePanic(t *testing.T) {
|
||||
w := blockingWriter{make(chan struct{}), make(chan struct{})}
|
||||
c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
|
||||
c := newTestConn(nil, w, false)
|
||||
go func() {
|
||||
c.WriteMessage(TextMessage, []byte{})
|
||||
}()
|
||||
|
@ -449,7 +566,7 @@ func (r failingReader) Read(p []byte) (int, error) {
|
|||
}
|
||||
|
||||
func TestFailedConnectionReadPanic(t *testing.T) {
|
||||
c := newConn(fakeNetConn{Reader: failingReader{}, Writer: nil}, false, 1024, 1024)
|
||||
c := newTestConn(failingReader{}, nil, false)
|
||||
|
||||
defer func() {
|
||||
if v := recover(); v != nil {
|
||||
|
@ -462,35 +579,3 @@ func TestFailedConnectionReadPanic(t *testing.T) {
|
|||
}
|
||||
t.Fatal("should not get here")
|
||||
}
|
||||
|
||||
func TestBufioReuse(t *testing.T) {
|
||||
brw := bufio.NewReadWriter(bufio.NewReader(nil), bufio.NewWriter(nil))
|
||||
c := newConnBRW(nil, false, 0, 0, brw)
|
||||
|
||||
if c.br != brw.Reader {
|
||||
t.Error("connection did not reuse bufio.Reader")
|
||||
}
|
||||
|
||||
var wh writeHook
|
||||
brw.Writer.Reset(&wh)
|
||||
brw.WriteByte(0)
|
||||
brw.Flush()
|
||||
if &c.writeBuf[0] != &wh.p[0] {
|
||||
t.Error("connection did not reuse bufio.Writer")
|
||||
}
|
||||
|
||||
brw = bufio.NewReadWriter(bufio.NewReaderSize(nil, 0), bufio.NewWriterSize(nil, 0))
|
||||
c = newConnBRW(nil, false, 0, 0, brw)
|
||||
|
||||
if c.br == brw.Reader {
|
||||
t.Error("connection used bufio.Reader with small size")
|
||||
}
|
||||
|
||||
brw.Writer.Reset(&wh)
|
||||
brw.WriteByte(0)
|
||||
brw.Flush()
|
||||
if &c.writeBuf[0] != &wh.p[0] {
|
||||
t.Error("connection used bufio.Writer with small size")
|
||||
}
|
||||
|
||||
}
|
||||
|
|
17
json_test.go
17
json_test.go
|
@ -14,9 +14,8 @@ import (
|
|||
|
||||
func TestJSON(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
c := fakeNetConn{&buf, &buf}
|
||||
wc := newConn(c, true, 1024, 1024)
|
||||
rc := newConn(c, false, 1024, 1024)
|
||||
wc := newTestConn(nil, &buf, true)
|
||||
rc := newTestConn(&buf, nil, false)
|
||||
|
||||
var actual, expect struct {
|
||||
A int
|
||||
|
@ -39,10 +38,9 @@ func TestJSON(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPartialJSONRead(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
c := fakeNetConn{&buf, &buf}
|
||||
wc := newConn(c, true, 1024, 1024)
|
||||
rc := newConn(c, false, 1024, 1024)
|
||||
var buf0, buf1 bytes.Buffer
|
||||
wc := newTestConn(nil, &buf0, true)
|
||||
rc := newTestConn(&buf0, &buf1, false)
|
||||
|
||||
var v struct {
|
||||
A int
|
||||
|
@ -94,9 +92,8 @@ func TestPartialJSONRead(t *testing.T) {
|
|||
|
||||
func TestDeprecatedJSON(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
c := fakeNetConn{&buf, &buf}
|
||||
wc := newConn(c, true, 1024, 1024)
|
||||
rc := newConn(c, false, 1024, 1024)
|
||||
wc := newTestConn(nil, &buf, true)
|
||||
rc := newTestConn(&buf, nil, false)
|
||||
|
||||
var actual, expect struct {
|
||||
A int
|
||||
|
|
|
@ -36,7 +36,7 @@ func TestPreparedMessage(t *testing.T) {
|
|||
for _, tt := range preparedMessageTests {
|
||||
var data = []byte("this is a test")
|
||||
var buf bytes.Buffer
|
||||
c := newConn(fakeNetConn{Reader: nil, Writer: &buf}, tt.isServer, 1024, 1024)
|
||||
c := newTestConn(nil, &buf, tt.isServer)
|
||||
if tt.enableWriteCompression {
|
||||
c.newCompressionWriter = compressNoContextTakeover
|
||||
}
|
||||
|
|
73
server.go
73
server.go
|
@ -7,6 +7,7 @@ package websocket
|
|||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
@ -33,6 +34,17 @@ type Upgrader struct {
|
|||
// or received.
|
||||
ReadBufferSize, WriteBufferSize int
|
||||
|
||||
// WriteBufferPool is a pool of buffers for write operations. If the value
|
||||
// is not set, then write buffers are allocated to the connection for the
|
||||
// lifetime of the connection.
|
||||
//
|
||||
// A pool is most useful when the application has a modest volume of writes
|
||||
// across a large number of connections.
|
||||
//
|
||||
// Applications should use a single pool for each unique value of
|
||||
// WriteBufferSize.
|
||||
WriteBufferPool BufferPool
|
||||
|
||||
// Subprotocols specifies the server's supported protocols in order of
|
||||
// preference. If this field is not nil, then the Upgrade method negotiates a
|
||||
// subprotocol by selecting the first match in this list with a protocol
|
||||
|
@ -179,7 +191,21 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
|||
return nil, errors.New("websocket: client sent data before handshake is complete")
|
||||
}
|
||||
|
||||
c := newConnBRW(netConn, true, u.ReadBufferSize, u.WriteBufferSize, brw)
|
||||
var br *bufio.Reader
|
||||
if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 {
|
||||
// Reuse hijacked buffered reader as connection reader.
|
||||
br = brw.Reader
|
||||
}
|
||||
|
||||
buf := bufioWriterBuffer(netConn, brw.Writer)
|
||||
|
||||
var writeBuf []byte
|
||||
if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 {
|
||||
// Reuse hijacked write buffer as connection buffer.
|
||||
writeBuf = buf
|
||||
}
|
||||
|
||||
c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf)
|
||||
c.subprotocol = subprotocol
|
||||
|
||||
if compress {
|
||||
|
@ -187,7 +213,13 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
|||
c.newDecompressionReader = decompressNoContextTakeover
|
||||
}
|
||||
|
||||
p := c.writeBuf[:0]
|
||||
// Use larger of hijacked buffer and connection write buffer for header.
|
||||
p := buf
|
||||
if len(c.writeBuf) > len(p) {
|
||||
p = c.writeBuf
|
||||
}
|
||||
p = p[:0]
|
||||
|
||||
p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
|
||||
p = append(p, computeAcceptKey(challengeKey)...)
|
||||
p = append(p, "\r\n"...)
|
||||
|
@ -298,3 +330,40 @@ func IsWebSocketUpgrade(r *http.Request) bool {
|
|||
return tokenListContainsValue(r.Header, "Connection", "upgrade") &&
|
||||
tokenListContainsValue(r.Header, "Upgrade", "websocket")
|
||||
}
|
||||
|
||||
// bufioReader size returns the size of a bufio.Reader.
|
||||
func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int {
|
||||
// This code assumes that peek on a reset reader returns
|
||||
// bufio.Reader.buf[:0].
|
||||
// TODO: Use bufio.Reader.Size() after Go 1.10
|
||||
br.Reset(originalReader)
|
||||
if p, err := br.Peek(0); err == nil {
|
||||
return cap(p)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// writeHook is an io.Writer that records the last slice passed to it vio
|
||||
// io.Writer.Write.
|
||||
type writeHook struct {
|
||||
p []byte
|
||||
}
|
||||
|
||||
func (wh *writeHook) Write(p []byte) (int, error) {
|
||||
wh.p = p
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// bufioWriterBuffer grabs the buffer from a bufio.Writer.
|
||||
func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte {
|
||||
// This code assumes that bufio.Writer.buf[:1] is passed to the
|
||||
// bufio.Writer's underlying writer.
|
||||
var wh writeHook
|
||||
bw.Reset(&wh)
|
||||
bw.WriteByte(0)
|
||||
bw.Flush()
|
||||
|
||||
bw.Reset(originalWriter)
|
||||
|
||||
return wh.p[:cap(wh.p)]
|
||||
}
|
||||
|
|
|
@ -5,8 +5,12 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"net"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
|
@ -67,3 +71,49 @@ func TestCheckSameOrigin(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
type reuseTestResponseWriter struct {
|
||||
brw *bufio.ReadWriter
|
||||
http.ResponseWriter
|
||||
}
|
||||
|
||||
func (resp *reuseTestResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return fakeNetConn{strings.NewReader(""), &bytes.Buffer{}}, resp.brw, nil
|
||||
}
|
||||
|
||||
var bufioReuseTests = []struct {
|
||||
n int
|
||||
reuse bool
|
||||
}{
|
||||
{4096, true},
|
||||
{128, false},
|
||||
}
|
||||
|
||||
func TestBufioReuse(t *testing.T) {
|
||||
for i, tt := range bufioReuseTests {
|
||||
br := bufio.NewReaderSize(strings.NewReader(""), tt.n)
|
||||
bw := bufio.NewWriterSize(&bytes.Buffer{}, tt.n)
|
||||
resp := &reuseTestResponseWriter{
|
||||
brw: bufio.NewReadWriter(br, bw),
|
||||
}
|
||||
upgrader := Upgrader{}
|
||||
c, err := upgrader.Upgrade(resp, &http.Request{
|
||||
Method: "GET",
|
||||
Header: http.Header{
|
||||
"Upgrade": []string{"websocket"},
|
||||
"Connection": []string{"upgrade"},
|
||||
"Sec-Websocket-Key": []string{"dGhlIHNhbXBsZSBub25jZQ=="},
|
||||
"Sec-Websocket-Version": []string{"13"},
|
||||
}}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if reuse := c.br == br; reuse != tt.reuse {
|
||||
t.Errorf("%d: buffered reader reuse=%v, want %v", i, reuse, tt.reuse)
|
||||
}
|
||||
writeBuf := bufioWriterBuffer(c.UnderlyingConn(), bw)
|
||||
if reuse := &c.writeBuf[0] == &writeBuf[0]; reuse != tt.reuse {
|
||||
t.Errorf("%d: write buffer reuse=%v, want %v", i, reuse, tt.reuse)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue