mirror of https://github.com/gorilla/websocket.git
Add write buffer pooling
Add WriteBufferPool to Dialer and Upgrader. This field specifies a pool to use for write operations on a connection. Use of the pool can reduce memory use when there is a modest write volume over a large number of connections. Use larger of hijacked buffer and buffer allocated for connection (if any) as buffer for building handshake response. This decreases possible allocations when building the handshake response. Modify bufio reuse test to call Upgrade instead of the internal newConnBRW. Move the test from conn_test.go to server_test.go because it's a serer test. Update newConn and newConnBRW: - Move the bufio "hacks" from newConnBRW to separate functions and call these functions directly from Upgrade. - Rename newConn to newTestConn and move to conn_test.go. Shorten argument list to common use case. - Rename newConnBRW to newConn. - Add pool code to newConn.
This commit is contained in:
parent
5fb94172f4
commit
b378caee5b
13
client.go
13
client.go
|
@ -69,6 +69,17 @@ type Dialer struct {
|
||||||
// do not limit the size of the messages that can be sent or received.
|
// do not limit the size of the messages that can be sent or received.
|
||||||
ReadBufferSize, WriteBufferSize int
|
ReadBufferSize, WriteBufferSize int
|
||||||
|
|
||||||
|
// WriteBufferPool is a pool of buffers for write operations. If the value
|
||||||
|
// is not set, then write buffers are allocated to the connection for the
|
||||||
|
// lifetime of the connection.
|
||||||
|
//
|
||||||
|
// A pool is most useful when the application has a modest volume of writes
|
||||||
|
// across a large number of connections.
|
||||||
|
//
|
||||||
|
// Applications should use a single pool for each unique value of
|
||||||
|
// WriteBufferSize.
|
||||||
|
WriteBufferPool BufferPool
|
||||||
|
|
||||||
// Subprotocols specifies the client's requested subprotocols.
|
// Subprotocols specifies the client's requested subprotocols.
|
||||||
Subprotocols []string
|
Subprotocols []string
|
||||||
|
|
||||||
|
@ -277,7 +288,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize)
|
conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil)
|
||||||
|
|
||||||
if err := req.Write(netConn); err != nil {
|
if err := req.Write(netConn); err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
|
|
|
@ -43,7 +43,7 @@ func textMessages(num int) [][]byte {
|
||||||
|
|
||||||
func BenchmarkWriteNoCompression(b *testing.B) {
|
func BenchmarkWriteNoCompression(b *testing.B) {
|
||||||
w := ioutil.Discard
|
w := ioutil.Discard
|
||||||
c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
|
c := newTestConn(nil, w, false)
|
||||||
messages := textMessages(100)
|
messages := textMessages(100)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
|
@ -54,7 +54,7 @@ func BenchmarkWriteNoCompression(b *testing.B) {
|
||||||
|
|
||||||
func BenchmarkWriteWithCompression(b *testing.B) {
|
func BenchmarkWriteWithCompression(b *testing.B) {
|
||||||
w := ioutil.Discard
|
w := ioutil.Discard
|
||||||
c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
|
c := newTestConn(nil, w, false)
|
||||||
messages := textMessages(100)
|
messages := textMessages(100)
|
||||||
c.enableWriteCompression = true
|
c.enableWriteCompression = true
|
||||||
c.newCompressionWriter = compressNoContextTakeover
|
c.newCompressionWriter = compressNoContextTakeover
|
||||||
|
@ -66,7 +66,7 @@ func BenchmarkWriteWithCompression(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidCompressionLevel(t *testing.T) {
|
func TestValidCompressionLevel(t *testing.T) {
|
||||||
c := newConn(fakeNetConn{}, false, 1024, 1024)
|
c := newTestConn(nil, nil, false)
|
||||||
for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} {
|
for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} {
|
||||||
if err := c.SetCompressionLevel(level); err == nil {
|
if err := c.SetCompressionLevel(level); err == nil {
|
||||||
t.Errorf("no error for level %d", level)
|
t.Errorf("no error for level %d", level)
|
||||||
|
|
89
conn.go
89
conn.go
|
@ -223,6 +223,20 @@ func isValidReceivedCloseCode(code int) bool {
|
||||||
return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999)
|
return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BufferPool represents a pool of buffers. The *sync.Pool type satisfies this
|
||||||
|
// interface. The type of the value stored in a pool is not specified.
|
||||||
|
type BufferPool interface {
|
||||||
|
// Get gets a value from the pool or returns nil if the pool is empty.
|
||||||
|
Get() interface{}
|
||||||
|
// Put adds a value to the pool.
|
||||||
|
Put(interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// writePoolData is the type added to the write buffer pool. This wrapper is
|
||||||
|
// used to prevent applications from peeking at and depending on the values
|
||||||
|
// added to the pool.
|
||||||
|
type writePoolData struct{ buf []byte }
|
||||||
|
|
||||||
// The Conn type represents a WebSocket connection.
|
// The Conn type represents a WebSocket connection.
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
|
@ -232,6 +246,8 @@ type Conn struct {
|
||||||
// Write fields
|
// Write fields
|
||||||
mu chan bool // used as mutex to protect write to conn
|
mu chan bool // used as mutex to protect write to conn
|
||||||
writeBuf []byte // frame is constructed in this buffer.
|
writeBuf []byte // frame is constructed in this buffer.
|
||||||
|
writePool BufferPool
|
||||||
|
writeBufSize int
|
||||||
writeDeadline time.Time
|
writeDeadline time.Time
|
||||||
writer io.WriteCloser // the current writer returned to the application
|
writer io.WriteCloser // the current writer returned to the application
|
||||||
isWriting bool // for best-effort concurrent write detection
|
isWriting bool // for best-effort concurrent write detection
|
||||||
|
@ -263,64 +279,29 @@ type Conn struct {
|
||||||
newDecompressionReader func(io.Reader) io.ReadCloser
|
newDecompressionReader func(io.Reader) io.ReadCloser
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
|
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, writeBufferPool BufferPool, br *bufio.Reader, writeBuf []byte) *Conn {
|
||||||
return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
type writeHook struct {
|
|
||||||
p []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (wh *writeHook) Write(p []byte) (int, error) {
|
|
||||||
wh.p = p
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, brw *bufio.ReadWriter) *Conn {
|
|
||||||
mu := make(chan bool, 1)
|
|
||||||
mu <- true
|
|
||||||
|
|
||||||
var br *bufio.Reader
|
|
||||||
if readBufferSize == 0 && brw != nil && brw.Reader != nil {
|
|
||||||
// Reuse the supplied bufio.Reader if the buffer has a useful size.
|
|
||||||
// This code assumes that peek on a reader returns
|
|
||||||
// bufio.Reader.buf[:0].
|
|
||||||
brw.Reader.Reset(conn)
|
|
||||||
if p, err := brw.Reader.Peek(0); err == nil && cap(p) >= 256 {
|
|
||||||
br = brw.Reader
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if br == nil {
|
if br == nil {
|
||||||
if readBufferSize == 0 {
|
if readBufferSize == 0 {
|
||||||
readBufferSize = defaultReadBufferSize
|
readBufferSize = defaultReadBufferSize
|
||||||
}
|
} else if readBufferSize < maxControlFramePayloadSize {
|
||||||
if readBufferSize < maxControlFramePayloadSize {
|
// must be large enough for control frame
|
||||||
readBufferSize = maxControlFramePayloadSize
|
readBufferSize = maxControlFramePayloadSize
|
||||||
}
|
}
|
||||||
br = bufio.NewReaderSize(conn, readBufferSize)
|
br = bufio.NewReaderSize(conn, readBufferSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
var writeBuf []byte
|
if writeBufferSize <= 0 {
|
||||||
if writeBufferSize == 0 && brw != nil && brw.Writer != nil {
|
|
||||||
// Use the bufio.Writer's buffer if the buffer has a useful size. This
|
|
||||||
// code assumes that bufio.Writer.buf[:1] is passed to the
|
|
||||||
// bufio.Writer's underlying writer.
|
|
||||||
var wh writeHook
|
|
||||||
brw.Writer.Reset(&wh)
|
|
||||||
brw.Writer.WriteByte(0)
|
|
||||||
brw.Flush()
|
|
||||||
if cap(wh.p) >= maxFrameHeaderSize+256 {
|
|
||||||
writeBuf = wh.p[:cap(wh.p)]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if writeBuf == nil {
|
|
||||||
if writeBufferSize == 0 {
|
|
||||||
writeBufferSize = defaultWriteBufferSize
|
writeBufferSize = defaultWriteBufferSize
|
||||||
}
|
}
|
||||||
writeBuf = make([]byte, writeBufferSize+maxFrameHeaderSize)
|
writeBufferSize += maxFrameHeaderSize
|
||||||
|
|
||||||
|
if writeBuf == nil && writeBufferPool == nil {
|
||||||
|
writeBuf = make([]byte, writeBufferSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mu := make(chan bool, 1)
|
||||||
|
mu <- true
|
||||||
c := &Conn{
|
c := &Conn{
|
||||||
isServer: isServer,
|
isServer: isServer,
|
||||||
br: br,
|
br: br,
|
||||||
|
@ -328,6 +309,8 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in
|
||||||
mu: mu,
|
mu: mu,
|
||||||
readFinal: true,
|
readFinal: true,
|
||||||
writeBuf: writeBuf,
|
writeBuf: writeBuf,
|
||||||
|
writePool: writeBufferPool,
|
||||||
|
writeBufSize: writeBufferSize,
|
||||||
enableWriteCompression: true,
|
enableWriteCompression: true,
|
||||||
compressionLevel: defaultCompressionLevel,
|
compressionLevel: defaultCompressionLevel,
|
||||||
}
|
}
|
||||||
|
@ -484,7 +467,19 @@ func (c *Conn) prepWrite(messageType int) error {
|
||||||
c.writeErrMu.Lock()
|
c.writeErrMu.Lock()
|
||||||
err := c.writeErr
|
err := c.writeErr
|
||||||
c.writeErrMu.Unlock()
|
c.writeErrMu.Unlock()
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.writeBuf == nil {
|
||||||
|
wpd, ok := c.writePool.Get().(writePoolData)
|
||||||
|
if ok {
|
||||||
|
c.writeBuf = wpd.buf
|
||||||
|
} else {
|
||||||
|
c.writeBuf = make([]byte, c.writeBufSize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NextWriter returns a writer for the next message to send. The writer's Close
|
// NextWriter returns a writer for the next message to send. The writer's Close
|
||||||
|
@ -610,6 +605,10 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
|
||||||
|
|
||||||
if final {
|
if final {
|
||||||
c.writer = nil
|
c.writer = nil
|
||||||
|
if c.writePool != nil {
|
||||||
|
c.writePool.Put(writePoolData{buf: c.writeBuf})
|
||||||
|
c.writeBuf = nil
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -70,7 +70,7 @@ func (b *broadcastBench) makeConns(numConns int) {
|
||||||
conns := make([]*broadcastConn, numConns)
|
conns := make([]*broadcastConn, numConns)
|
||||||
|
|
||||||
for i := 0; i < numConns; i++ {
|
for i := 0; i < numConns; i++ {
|
||||||
c := newConn(fakeNetConn{Reader: nil, Writer: b.w}, true, 1024, 1024)
|
c := newTestConn(nil, b.w, true)
|
||||||
if b.compression {
|
if b.compression {
|
||||||
c.enableWriteCompression = true
|
c.enableWriteCompression = true
|
||||||
c.newCompressionWriter = compressNoContextTakeover
|
c.newCompressionWriter = compressNoContextTakeover
|
||||||
|
|
187
conn_test.go
187
conn_test.go
|
@ -13,6 +13,7 @@ import (
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"testing/iotest"
|
"testing/iotest"
|
||||||
"time"
|
"time"
|
||||||
|
@ -47,6 +48,12 @@ func (a fakeAddr) String() string {
|
||||||
return "str"
|
return "str"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// newTestConn creates a connnection backed by a fake network connection using
|
||||||
|
// default values for buffering.
|
||||||
|
func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn {
|
||||||
|
return newConn(fakeNetConn{Reader: r, Writer: w}, isServer, 1024, 1024, nil, nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
func TestFraming(t *testing.T) {
|
func TestFraming(t *testing.T) {
|
||||||
frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537}
|
frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537}
|
||||||
var readChunkers = []struct {
|
var readChunkers = []struct {
|
||||||
|
@ -82,8 +89,8 @@ func TestFraming(t *testing.T) {
|
||||||
for _, chunker := range readChunkers {
|
for _, chunker := range readChunkers {
|
||||||
|
|
||||||
var connBuf bytes.Buffer
|
var connBuf bytes.Buffer
|
||||||
wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
|
wc := newTestConn(nil, &connBuf, isServer)
|
||||||
rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024)
|
rc := newTestConn(chunker.f(&connBuf), nil, !isServer)
|
||||||
if compress {
|
if compress {
|
||||||
wc.newCompressionWriter = compressNoContextTakeover
|
wc.newCompressionWriter = compressNoContextTakeover
|
||||||
rc.newDecompressionReader = decompressNoContextTakeover
|
rc.newDecompressionReader = decompressNoContextTakeover
|
||||||
|
@ -143,8 +150,8 @@ func TestControl(t *testing.T) {
|
||||||
for _, isWriteControl := range []bool{true, false} {
|
for _, isWriteControl := range []bool{true, false} {
|
||||||
name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl)
|
name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl)
|
||||||
var connBuf bytes.Buffer
|
var connBuf bytes.Buffer
|
||||||
wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
|
wc := newTestConn(nil, &connBuf, isServer)
|
||||||
rc := newConn(fakeNetConn{Reader: &connBuf, Writer: nil}, !isServer, 1024, 1024)
|
rc := newTestConn(&connBuf, nil, !isServer)
|
||||||
if isWriteControl {
|
if isWriteControl {
|
||||||
wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
|
wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
|
||||||
} else {
|
} else {
|
||||||
|
@ -173,14 +180,124 @@ func TestControl(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// simpleBufferPool is an implementation of BufferPool for TestWriteBufferPool.
|
||||||
|
type simpleBufferPool struct {
|
||||||
|
v interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *simpleBufferPool) Get() interface{} {
|
||||||
|
v := p.v
|
||||||
|
p.v = nil
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *simpleBufferPool) Put(v interface{}) {
|
||||||
|
p.v = v
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteBufferPool(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
var pool simpleBufferPool
|
||||||
|
wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil)
|
||||||
|
rc := newTestConn(&buf, nil, false)
|
||||||
|
|
||||||
|
if wc.writeBuf != nil {
|
||||||
|
t.Fatal("writeBuf not nil after create")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Part 1: test NextWriter/Write/Close
|
||||||
|
|
||||||
|
w, err := wc.NextWriter(TextMessage)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("wc.NextWriter() returned %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if wc.writeBuf == nil {
|
||||||
|
t.Fatal("writeBuf is nil after NextWriter")
|
||||||
|
}
|
||||||
|
|
||||||
|
writeBufAddr := &wc.writeBuf[0]
|
||||||
|
|
||||||
|
const message = "Hello World!"
|
||||||
|
|
||||||
|
if _, err := io.WriteString(w, message); err != nil {
|
||||||
|
t.Fatalf("io.WriteString(w, message) returned %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.Close(); err != nil {
|
||||||
|
t.Fatalf("w.Close() returned %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if wc.writeBuf != nil {
|
||||||
|
t.Fatal("writeBuf not nil after w.Close()")
|
||||||
|
}
|
||||||
|
|
||||||
|
if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
|
||||||
|
t.Fatal("writeBuf not returned to pool")
|
||||||
|
}
|
||||||
|
|
||||||
|
opCode, p, err := rc.ReadMessage()
|
||||||
|
if opCode != TextMessage || err != nil {
|
||||||
|
t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s := string(p); s != message {
|
||||||
|
t.Fatalf("message is %s, want %s", s, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Part 2: Test WriteMessage.
|
||||||
|
|
||||||
|
if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil {
|
||||||
|
t.Fatalf("wc.WriteMessage() returned %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if wc.writeBuf != nil {
|
||||||
|
t.Fatal("writeBuf not nil after wc.WriteMessage()")
|
||||||
|
}
|
||||||
|
|
||||||
|
if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
|
||||||
|
t.Fatal("writeBuf not returned to pool after WriteMessage")
|
||||||
|
}
|
||||||
|
|
||||||
|
opCode, p, err = rc.ReadMessage()
|
||||||
|
if opCode != TextMessage || err != nil {
|
||||||
|
t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s := string(p); s != message {
|
||||||
|
t.Fatalf("message is %s, want %s", s, message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteBufferPoolSync(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
var pool sync.Pool
|
||||||
|
wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil)
|
||||||
|
rc := newTestConn(&buf, nil, false)
|
||||||
|
|
||||||
|
const message = "Hello World!"
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil {
|
||||||
|
t.Fatalf("wc.WriteMessage() returned %v", err)
|
||||||
|
}
|
||||||
|
opCode, p, err := rc.ReadMessage()
|
||||||
|
if opCode != TextMessage || err != nil {
|
||||||
|
t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
|
||||||
|
}
|
||||||
|
if s := string(p); s != message {
|
||||||
|
t.Fatalf("message is %s, want %s", s, message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
|
func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
|
||||||
const bufSize = 512
|
const bufSize = 512
|
||||||
|
|
||||||
expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
|
expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
|
||||||
|
|
||||||
var b1, b2 bytes.Buffer
|
var b1, b2 bytes.Buffer
|
||||||
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
|
wc := newConn(&fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize, nil, nil, nil)
|
||||||
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
|
rc := newTestConn(&b1, &b2, true)
|
||||||
|
|
||||||
w, _ := wc.NextWriter(BinaryMessage)
|
w, _ := wc.NextWriter(BinaryMessage)
|
||||||
w.Write(make([]byte, bufSize+bufSize/2))
|
w.Write(make([]byte, bufSize+bufSize/2))
|
||||||
|
@ -206,8 +323,8 @@ func TestEOFWithinFrame(t *testing.T) {
|
||||||
|
|
||||||
for n := 0; ; n++ {
|
for n := 0; ; n++ {
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
wc := newConn(fakeNetConn{Reader: nil, Writer: &b}, false, 1024, 1024)
|
wc := newTestConn(nil, &b, false)
|
||||||
rc := newConn(fakeNetConn{Reader: &b, Writer: nil}, true, 1024, 1024)
|
rc := newTestConn(&b, nil, true)
|
||||||
|
|
||||||
w, _ := wc.NextWriter(BinaryMessage)
|
w, _ := wc.NextWriter(BinaryMessage)
|
||||||
w.Write(make([]byte, bufSize))
|
w.Write(make([]byte, bufSize))
|
||||||
|
@ -240,8 +357,8 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
|
||||||
const bufSize = 512
|
const bufSize = 512
|
||||||
|
|
||||||
var b1, b2 bytes.Buffer
|
var b1, b2 bytes.Buffer
|
||||||
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
|
wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, bufSize, nil, nil, nil)
|
||||||
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
|
rc := newTestConn(&b1, &b2, true)
|
||||||
|
|
||||||
w, _ := wc.NextWriter(BinaryMessage)
|
w, _ := wc.NextWriter(BinaryMessage)
|
||||||
w.Write(make([]byte, bufSize+bufSize/2))
|
w.Write(make([]byte, bufSize+bufSize/2))
|
||||||
|
@ -261,7 +378,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWriteAfterMessageWriterClose(t *testing.T) {
|
func TestWriteAfterMessageWriterClose(t *testing.T) {
|
||||||
wc := newConn(fakeNetConn{Reader: nil, Writer: &bytes.Buffer{}}, false, 1024, 1024)
|
wc := newTestConn(nil, &bytes.Buffer{}, false)
|
||||||
w, _ := wc.NextWriter(BinaryMessage)
|
w, _ := wc.NextWriter(BinaryMessage)
|
||||||
io.WriteString(w, "hello")
|
io.WriteString(w, "hello")
|
||||||
if err := w.Close(); err != nil {
|
if err := w.Close(); err != nil {
|
||||||
|
@ -292,8 +409,8 @@ func TestReadLimit(t *testing.T) {
|
||||||
message := make([]byte, readLimit+1)
|
message := make([]byte, readLimit+1)
|
||||||
|
|
||||||
var b1, b2 bytes.Buffer
|
var b1, b2 bytes.Buffer
|
||||||
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, readLimit-2)
|
wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil)
|
||||||
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
|
rc := newTestConn(&b1, &b2, true)
|
||||||
rc.SetReadLimit(readLimit)
|
rc.SetReadLimit(readLimit)
|
||||||
|
|
||||||
// Send message at the limit with interleaved pong.
|
// Send message at the limit with interleaved pong.
|
||||||
|
@ -321,7 +438,7 @@ func TestReadLimit(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAddrs(t *testing.T) {
|
func TestAddrs(t *testing.T) {
|
||||||
c := newConn(&fakeNetConn{}, true, 1024, 1024)
|
c := newTestConn(nil, nil, true)
|
||||||
if c.LocalAddr() != localAddr {
|
if c.LocalAddr() != localAddr {
|
||||||
t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr)
|
t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr)
|
||||||
}
|
}
|
||||||
|
@ -333,7 +450,7 @@ func TestAddrs(t *testing.T) {
|
||||||
func TestUnderlyingConn(t *testing.T) {
|
func TestUnderlyingConn(t *testing.T) {
|
||||||
var b1, b2 bytes.Buffer
|
var b1, b2 bytes.Buffer
|
||||||
fc := fakeNetConn{Reader: &b1, Writer: &b2}
|
fc := fakeNetConn{Reader: &b1, Writer: &b2}
|
||||||
c := newConn(fc, true, 1024, 1024)
|
c := newConn(fc, true, 1024, 1024, nil, nil, nil)
|
||||||
ul := c.UnderlyingConn()
|
ul := c.UnderlyingConn()
|
||||||
if ul != fc {
|
if ul != fc {
|
||||||
t.Fatalf("Underlying conn is not what it should be.")
|
t.Fatalf("Underlying conn is not what it should be.")
|
||||||
|
@ -347,8 +464,8 @@ func TestBufioReadBytes(t *testing.T) {
|
||||||
m[len(m)-1] = '\n'
|
m[len(m)-1] = '\n'
|
||||||
|
|
||||||
var b1, b2 bytes.Buffer
|
var b1, b2 bytes.Buffer
|
||||||
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, len(m)+64, len(m)+64)
|
wc := newConn(fakeNetConn{Writer: &b1}, false, len(m)+64, len(m)+64, nil, nil, nil)
|
||||||
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64)
|
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil)
|
||||||
|
|
||||||
w, _ := wc.NextWriter(BinaryMessage)
|
w, _ := wc.NextWriter(BinaryMessage)
|
||||||
w.Write(m)
|
w.Write(m)
|
||||||
|
@ -423,7 +540,7 @@ func (w blockingWriter) Write(p []byte) (int, error) {
|
||||||
|
|
||||||
func TestConcurrentWritePanic(t *testing.T) {
|
func TestConcurrentWritePanic(t *testing.T) {
|
||||||
w := blockingWriter{make(chan struct{}), make(chan struct{})}
|
w := blockingWriter{make(chan struct{}), make(chan struct{})}
|
||||||
c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
|
c := newTestConn(nil, w, false)
|
||||||
go func() {
|
go func() {
|
||||||
c.WriteMessage(TextMessage, []byte{})
|
c.WriteMessage(TextMessage, []byte{})
|
||||||
}()
|
}()
|
||||||
|
@ -449,7 +566,7 @@ func (r failingReader) Read(p []byte) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFailedConnectionReadPanic(t *testing.T) {
|
func TestFailedConnectionReadPanic(t *testing.T) {
|
||||||
c := newConn(fakeNetConn{Reader: failingReader{}, Writer: nil}, false, 1024, 1024)
|
c := newTestConn(failingReader{}, nil, false)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if v := recover(); v != nil {
|
if v := recover(); v != nil {
|
||||||
|
@ -462,35 +579,3 @@ func TestFailedConnectionReadPanic(t *testing.T) {
|
||||||
}
|
}
|
||||||
t.Fatal("should not get here")
|
t.Fatal("should not get here")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBufioReuse(t *testing.T) {
|
|
||||||
brw := bufio.NewReadWriter(bufio.NewReader(nil), bufio.NewWriter(nil))
|
|
||||||
c := newConnBRW(nil, false, 0, 0, brw)
|
|
||||||
|
|
||||||
if c.br != brw.Reader {
|
|
||||||
t.Error("connection did not reuse bufio.Reader")
|
|
||||||
}
|
|
||||||
|
|
||||||
var wh writeHook
|
|
||||||
brw.Writer.Reset(&wh)
|
|
||||||
brw.WriteByte(0)
|
|
||||||
brw.Flush()
|
|
||||||
if &c.writeBuf[0] != &wh.p[0] {
|
|
||||||
t.Error("connection did not reuse bufio.Writer")
|
|
||||||
}
|
|
||||||
|
|
||||||
brw = bufio.NewReadWriter(bufio.NewReaderSize(nil, 0), bufio.NewWriterSize(nil, 0))
|
|
||||||
c = newConnBRW(nil, false, 0, 0, brw)
|
|
||||||
|
|
||||||
if c.br == brw.Reader {
|
|
||||||
t.Error("connection used bufio.Reader with small size")
|
|
||||||
}
|
|
||||||
|
|
||||||
brw.Writer.Reset(&wh)
|
|
||||||
brw.WriteByte(0)
|
|
||||||
brw.Flush()
|
|
||||||
if &c.writeBuf[0] != &wh.p[0] {
|
|
||||||
t.Error("connection used bufio.Writer with small size")
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
17
json_test.go
17
json_test.go
|
@ -14,9 +14,8 @@ import (
|
||||||
|
|
||||||
func TestJSON(t *testing.T) {
|
func TestJSON(t *testing.T) {
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
c := fakeNetConn{&buf, &buf}
|
wc := newTestConn(nil, &buf, true)
|
||||||
wc := newConn(c, true, 1024, 1024)
|
rc := newTestConn(&buf, nil, false)
|
||||||
rc := newConn(c, false, 1024, 1024)
|
|
||||||
|
|
||||||
var actual, expect struct {
|
var actual, expect struct {
|
||||||
A int
|
A int
|
||||||
|
@ -39,10 +38,9 @@ func TestJSON(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPartialJSONRead(t *testing.T) {
|
func TestPartialJSONRead(t *testing.T) {
|
||||||
var buf bytes.Buffer
|
var buf0, buf1 bytes.Buffer
|
||||||
c := fakeNetConn{&buf, &buf}
|
wc := newTestConn(nil, &buf0, true)
|
||||||
wc := newConn(c, true, 1024, 1024)
|
rc := newTestConn(&buf0, &buf1, false)
|
||||||
rc := newConn(c, false, 1024, 1024)
|
|
||||||
|
|
||||||
var v struct {
|
var v struct {
|
||||||
A int
|
A int
|
||||||
|
@ -94,9 +92,8 @@ func TestPartialJSONRead(t *testing.T) {
|
||||||
|
|
||||||
func TestDeprecatedJSON(t *testing.T) {
|
func TestDeprecatedJSON(t *testing.T) {
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
c := fakeNetConn{&buf, &buf}
|
wc := newTestConn(nil, &buf, true)
|
||||||
wc := newConn(c, true, 1024, 1024)
|
rc := newTestConn(&buf, nil, false)
|
||||||
rc := newConn(c, false, 1024, 1024)
|
|
||||||
|
|
||||||
var actual, expect struct {
|
var actual, expect struct {
|
||||||
A int
|
A int
|
||||||
|
|
|
@ -36,7 +36,7 @@ func TestPreparedMessage(t *testing.T) {
|
||||||
for _, tt := range preparedMessageTests {
|
for _, tt := range preparedMessageTests {
|
||||||
var data = []byte("this is a test")
|
var data = []byte("this is a test")
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
c := newConn(fakeNetConn{Reader: nil, Writer: &buf}, tt.isServer, 1024, 1024)
|
c := newTestConn(nil, &buf, tt.isServer)
|
||||||
if tt.enableWriteCompression {
|
if tt.enableWriteCompression {
|
||||||
c.newCompressionWriter = compressNoContextTakeover
|
c.newCompressionWriter = compressNoContextTakeover
|
||||||
}
|
}
|
||||||
|
|
73
server.go
73
server.go
|
@ -7,6 +7,7 @@ package websocket
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"errors"
|
"errors"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -33,6 +34,17 @@ type Upgrader struct {
|
||||||
// or received.
|
// or received.
|
||||||
ReadBufferSize, WriteBufferSize int
|
ReadBufferSize, WriteBufferSize int
|
||||||
|
|
||||||
|
// WriteBufferPool is a pool of buffers for write operations. If the value
|
||||||
|
// is not set, then write buffers are allocated to the connection for the
|
||||||
|
// lifetime of the connection.
|
||||||
|
//
|
||||||
|
// A pool is most useful when the application has a modest volume of writes
|
||||||
|
// across a large number of connections.
|
||||||
|
//
|
||||||
|
// Applications should use a single pool for each unique value of
|
||||||
|
// WriteBufferSize.
|
||||||
|
WriteBufferPool BufferPool
|
||||||
|
|
||||||
// Subprotocols specifies the server's supported protocols in order of
|
// Subprotocols specifies the server's supported protocols in order of
|
||||||
// preference. If this field is not nil, then the Upgrade method negotiates a
|
// preference. If this field is not nil, then the Upgrade method negotiates a
|
||||||
// subprotocol by selecting the first match in this list with a protocol
|
// subprotocol by selecting the first match in this list with a protocol
|
||||||
|
@ -179,7 +191,21 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
||||||
return nil, errors.New("websocket: client sent data before handshake is complete")
|
return nil, errors.New("websocket: client sent data before handshake is complete")
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newConnBRW(netConn, true, u.ReadBufferSize, u.WriteBufferSize, brw)
|
var br *bufio.Reader
|
||||||
|
if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 {
|
||||||
|
// Reuse hijacked buffered reader as connection reader.
|
||||||
|
br = brw.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := bufioWriterBuffer(netConn, brw.Writer)
|
||||||
|
|
||||||
|
var writeBuf []byte
|
||||||
|
if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 {
|
||||||
|
// Reuse hijacked write buffer as connection buffer.
|
||||||
|
writeBuf = buf
|
||||||
|
}
|
||||||
|
|
||||||
|
c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf)
|
||||||
c.subprotocol = subprotocol
|
c.subprotocol = subprotocol
|
||||||
|
|
||||||
if compress {
|
if compress {
|
||||||
|
@ -187,7 +213,13 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
||||||
c.newDecompressionReader = decompressNoContextTakeover
|
c.newDecompressionReader = decompressNoContextTakeover
|
||||||
}
|
}
|
||||||
|
|
||||||
p := c.writeBuf[:0]
|
// Use larger of hijacked buffer and connection write buffer for header.
|
||||||
|
p := buf
|
||||||
|
if len(c.writeBuf) > len(p) {
|
||||||
|
p = c.writeBuf
|
||||||
|
}
|
||||||
|
p = p[:0]
|
||||||
|
|
||||||
p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
|
p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
|
||||||
p = append(p, computeAcceptKey(challengeKey)...)
|
p = append(p, computeAcceptKey(challengeKey)...)
|
||||||
p = append(p, "\r\n"...)
|
p = append(p, "\r\n"...)
|
||||||
|
@ -298,3 +330,40 @@ func IsWebSocketUpgrade(r *http.Request) bool {
|
||||||
return tokenListContainsValue(r.Header, "Connection", "upgrade") &&
|
return tokenListContainsValue(r.Header, "Connection", "upgrade") &&
|
||||||
tokenListContainsValue(r.Header, "Upgrade", "websocket")
|
tokenListContainsValue(r.Header, "Upgrade", "websocket")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// bufioReader size returns the size of a bufio.Reader.
|
||||||
|
func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int {
|
||||||
|
// This code assumes that peek on a reset reader returns
|
||||||
|
// bufio.Reader.buf[:0].
|
||||||
|
// TODO: Use bufio.Reader.Size() after Go 1.10
|
||||||
|
br.Reset(originalReader)
|
||||||
|
if p, err := br.Peek(0); err == nil {
|
||||||
|
return cap(p)
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeHook is an io.Writer that records the last slice passed to it vio
|
||||||
|
// io.Writer.Write.
|
||||||
|
type writeHook struct {
|
||||||
|
p []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wh *writeHook) Write(p []byte) (int, error) {
|
||||||
|
wh.p = p
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// bufioWriterBuffer grabs the buffer from a bufio.Writer.
|
||||||
|
func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte {
|
||||||
|
// This code assumes that bufio.Writer.buf[:1] is passed to the
|
||||||
|
// bufio.Writer's underlying writer.
|
||||||
|
var wh writeHook
|
||||||
|
bw.Reset(&wh)
|
||||||
|
bw.WriteByte(0)
|
||||||
|
bw.Flush()
|
||||||
|
|
||||||
|
bw.Reset(originalWriter)
|
||||||
|
|
||||||
|
return wh.p[:cap(wh.p)]
|
||||||
|
}
|
||||||
|
|
|
@ -5,8 +5,12 @@
|
||||||
package websocket
|
package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -67,3 +71,49 @@ func TestCheckSameOrigin(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type reuseTestResponseWriter struct {
|
||||||
|
brw *bufio.ReadWriter
|
||||||
|
http.ResponseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (resp *reuseTestResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
return fakeNetConn{strings.NewReader(""), &bytes.Buffer{}}, resp.brw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var bufioReuseTests = []struct {
|
||||||
|
n int
|
||||||
|
reuse bool
|
||||||
|
}{
|
||||||
|
{4096, true},
|
||||||
|
{128, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBufioReuse(t *testing.T) {
|
||||||
|
for i, tt := range bufioReuseTests {
|
||||||
|
br := bufio.NewReaderSize(strings.NewReader(""), tt.n)
|
||||||
|
bw := bufio.NewWriterSize(&bytes.Buffer{}, tt.n)
|
||||||
|
resp := &reuseTestResponseWriter{
|
||||||
|
brw: bufio.NewReadWriter(br, bw),
|
||||||
|
}
|
||||||
|
upgrader := Upgrader{}
|
||||||
|
c, err := upgrader.Upgrade(resp, &http.Request{
|
||||||
|
Method: "GET",
|
||||||
|
Header: http.Header{
|
||||||
|
"Upgrade": []string{"websocket"},
|
||||||
|
"Connection": []string{"upgrade"},
|
||||||
|
"Sec-Websocket-Key": []string{"dGhlIHNhbXBsZSBub25jZQ=="},
|
||||||
|
"Sec-Websocket-Version": []string{"13"},
|
||||||
|
}}, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if reuse := c.br == br; reuse != tt.reuse {
|
||||||
|
t.Errorf("%d: buffered reader reuse=%v, want %v", i, reuse, tt.reuse)
|
||||||
|
}
|
||||||
|
writeBuf := bufioWriterBuffer(c.UnderlyingConn(), bw)
|
||||||
|
if reuse := &c.writeBuf[0] == &writeBuf[0]; reuse != tt.reuse {
|
||||||
|
t.Errorf("%d: write buffer reuse=%v, want %v", i, reuse, tt.reuse)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue