Use bufio.Reader returned from hijack in upgrade

Use the bufio.Reader returned from hijack if the reader's buffer size is
equal to the buffer size specified in Upgrader.ReadBufferSize.
This commit is contained in:
Gary Burd 2017-03-01 09:36:54 -08:00
parent 3f3e394da2
commit 286b5c9371
3 changed files with 38 additions and 7 deletions

21
conn.go
View File

@ -265,6 +265,10 @@ type Conn struct {
} }
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil)
}
func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, brw *bufio.ReadWriter) *Conn {
mu := make(chan bool, 1) mu := make(chan bool, 1)
mu <- true mu <- true
@ -274,13 +278,28 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int)
if readBufferSize < maxControlFramePayloadSize { if readBufferSize < maxControlFramePayloadSize {
readBufferSize = maxControlFramePayloadSize readBufferSize = maxControlFramePayloadSize
} }
// Reuse the supplied brw.Reader if brw.Reader's buf is the requested size.
var br *bufio.Reader
if brw != nil && brw.Reader != nil {
// This code assumes that peek on a reset reader returns
// bufio.Reader.buf[:0].
brw.Reader.Reset(conn)
if p, err := brw.Reader.Peek(0); err == nil && cap(p) == readBufferSize {
br = brw.Reader
}
}
if br == nil {
br = bufio.NewReaderSize(conn, readBufferSize)
}
if writeBufferSize == 0 { if writeBufferSize == 0 {
writeBufferSize = defaultWriteBufferSize writeBufferSize = defaultWriteBufferSize
} }
c := &Conn{ c := &Conn{
isServer: isServer, isServer: isServer,
br: bufio.NewReaderSize(conn, readBufferSize), br: br,
conn: conn, conn: conn,
mu: mu, mu: mu,
readFinal: true, readFinal: true,

View File

@ -463,3 +463,17 @@ func TestFailedConnectionReadPanic(t *testing.T) {
} }
t.Fatal("should not get here") t.Fatal("should not get here")
} }
func TestBufioReaderReuse(t *testing.T) {
brw := bufio.NewReadWriter(bufio.NewReader(nil), nil)
c := newConnBRW(nil, false, 0, 0, brw)
if c.br != brw.Reader {
t.Error("connection did not reuse bufio.Reader")
}
brw = bufio.NewReadWriter(bufio.NewReaderSize(nil, 1234), nil) // size must not equal bufio.defaultBufSize
c = newConnBRW(nil, false, 0, 0, brw)
if c.br == brw.Reader {
t.Error("connection reuse bufio.Reader with wrong size")
}
}

View File

@ -152,7 +152,6 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
var ( var (
netConn net.Conn netConn net.Conn
br *bufio.Reader
err error err error
) )
@ -160,19 +159,18 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
if !ok { if !ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker") return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
} }
var rw *bufio.ReadWriter var brw *bufio.ReadWriter
netConn, rw, err = h.Hijack() netConn, brw, err = h.Hijack()
if err != nil { if err != nil {
return u.returnError(w, r, http.StatusInternalServerError, err.Error()) return u.returnError(w, r, http.StatusInternalServerError, err.Error())
} }
br = rw.Reader
if br.Buffered() > 0 { if brw.Reader.Buffered() > 0 {
netConn.Close() netConn.Close()
return nil, errors.New("websocket: client sent data before handshake is complete") return nil, errors.New("websocket: client sent data before handshake is complete")
} }
c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize) c := newConnBRW(netConn, true, u.ReadBufferSize, u.WriteBufferSize, brw)
c.subprotocol = subprotocol c.subprotocol = subprotocol
if compress { if compress {