diff --git a/client.go b/client.go index 3b5cac4..c25d24f 100644 --- a/client.go +++ b/client.go @@ -215,16 +215,6 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re } } - readBufferSize := d.ReadBufferSize - if readBufferSize == 0 { - readBufferSize = 4096 - } - - writeBufferSize := d.WriteBufferSize - if writeBufferSize == 0 { - writeBufferSize = 4096 - } - if len(d.Subprotocols) > 0 { h := http.Header{} for k, v := range requestHeader { @@ -234,7 +224,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re requestHeader = h } - conn, resp, err := NewClient(netConn, u, requestHeader, readBufferSize, writeBufferSize) + conn, resp, err := NewClient(netConn, u, requestHeader, d.ReadBufferSize, d.WriteBufferSize) if err != nil { return nil, resp, err } diff --git a/conn.go b/conn.go index dc7d111..005b6d8 100644 --- a/conn.go +++ b/conn.go @@ -16,6 +16,20 @@ import ( "time" ) +const ( + maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask + maxControlFramePayloadSize = 125 + finalBit = 1 << 7 + maskBit = 1 << 7 + writeWait = time.Second + + defaultReadBufferSize = 4096 + defaultWriteBufferSize = 4096 + + continuationFrame = 0 + noFrame = -1 +) + // Close codes defined in RFC 6455, section 11.7. const ( CloseNormalClosure = 1000 @@ -55,20 +69,13 @@ const ( PongMessage = 10 ) -var ( - continuationFrame = 0 - noFrame = -1 -) +// ErrCloseSent is returned when the application writes a message to the +// connection after sending a close message. +var ErrCloseSent = errors.New("websocket: close sent") -var ( - // ErrCloseSent is returned when the application writes a message to the - // connection after sending a close message. - ErrCloseSent = errors.New("websocket: close sent") - - // ErrReadLimit is returned when reading a message that is larger than the - // read limit set for the connection. - ErrReadLimit = errors.New("websocket: read limit exceeded") -) +// ErrReadLimit is returned when reading a message that is larger than the +// read limit set for the connection. +var ErrReadLimit = errors.New("websocket: read limit exceeded") // netError satisfies the net Error interface. type netError struct { @@ -99,14 +106,6 @@ var ( errInvalidControlFrame = errors.New("websocket: invalid control frame") ) -const ( - maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask - maxControlFramePayloadSize = 125 - finalBit = 1 << 7 - maskBit = 1 << 7 - writeWait = time.Second -) - func hideTempErr(err error) error { if e, ok := err.(net.Error); ok && e.Temporary() { err = struct{ error }{err} @@ -167,17 +166,24 @@ type Conn struct { handlePing func(string) error } -func newConn(conn net.Conn, isServer bool, readBufSize, writeBufSize int) *Conn { +func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { mu := make(chan bool, 1) mu <- true + if readBufferSize == 0 { + readBufferSize = defaultReadBufferSize + } + if writeBufferSize == 0 { + writeBufferSize = defaultWriteBufferSize + } + c := &Conn{ isServer: isServer, - br: bufio.NewReaderSize(conn, readBufSize), + br: bufio.NewReaderSize(conn, readBufferSize), conn: conn, mu: mu, readFinal: true, - writeBuf: make([]byte, writeBufSize+maxFrameHeaderSize), + writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize), writeFrameType: noFrame, writePos: maxFrameHeaderSize, } diff --git a/server.go b/server.go index c24c410..349e5b9 100644 --- a/server.go +++ b/server.go @@ -21,11 +21,6 @@ type HandshakeError struct { func (e HandshakeError) Error() string { return e.message } -const ( - defaultReadBufferSize = 4096 - defaultWriteBufferSize = 4096 -) - // Upgrader specifies parameters for upgrading an HTTP connection to a // WebSocket connection. type Upgrader struct { @@ -147,15 +142,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade return nil, errors.New("websocket: client sent data before handshake is complete") } - readBufSize := u.ReadBufferSize - if readBufSize == 0 { - readBufSize = defaultReadBufferSize - } - writeBufSize := u.WriteBufferSize - if writeBufSize == 0 { - writeBufSize = defaultWriteBufferSize - } - c := newConn(netConn, true, readBufSize, writeBufSize) + c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize) c.subprotocol = subprotocol p := c.writeBuf[:0]