diff --git a/client.go b/client.go index 73eada1..24bd7ff 100644 --- a/client.go +++ b/client.go @@ -305,9 +305,15 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h }) } + // Close the network connection when returning an error. The variable + // netConn is set to nil before the success return at the end of the + // function. defer func() { if netConn != nil { - netConn.Close() + // It's safe to ignore the error from Close() because this code is + // only executed when returning a more important error to the + // application. + _ = netConn.Close() } }() @@ -398,8 +404,14 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h resp.Body = io.NopCloser(bytes.NewReader([]byte{})) conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") - netConn.SetDeadline(time.Time{}) - netConn = nil // to avoid close in defer. + if err := netConn.SetDeadline(time.Time{}); err != nil { + return nil, resp, err + } + + // Success! Set netConn to nil to stop the deferred function above from + // closing the network connection. + netConn = nil + return conn, resp, nil } diff --git a/client_server_test.go b/client_server_test.go index ec555b4..7de9e88 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -578,7 +578,7 @@ func TestRespOnBadHandshake(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(expectedStatus) - io.WriteString(w, expectedBody) + _, _ = io.WriteString(w, expectedBody) })) defer s.Close() @@ -828,7 +828,7 @@ func TestSocksProxyDial(t *testing.T) { } defer c1.Close() - c1.SetDeadline(time.Now().Add(30 * time.Second)) + _ = c1.SetDeadline(time.Now().Add(30 * time.Second)) buf := make([]byte, 32) if _, err := io.ReadFull(c1, buf[:3]); err != nil { @@ -867,10 +867,10 @@ func TestSocksProxyDial(t *testing.T) { defer c2.Close() done := make(chan struct{}) go func() { - io.Copy(c1, c2) + _, _ = io.Copy(c1, c2) close(done) }() - io.Copy(c2, c1) + _, _ = io.Copy(c2, c1) <-done }() diff --git a/compression.go b/compression.go index 813ffb1..fe1079e 100644 --- a/compression.go +++ b/compression.go @@ -33,7 +33,11 @@ func decompressNoContextTakeover(r io.Reader) io.ReadCloser { "\x01\x00\x00\xff\xff" fr, _ := flateReaderPool.Get().(io.ReadCloser) - fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) + mr := io.MultiReader(r, strings.NewReader(tail)) + if err := fr.(flate.Resetter).Reset(mr, nil); err != nil { + // Reset never fails, but handle error in case that changes. + fr = flate.NewReader(mr) + } return &flateReadWrapper{fr} } diff --git a/compression_test.go b/compression_test.go index 23591c4..00ae42f 100644 --- a/compression_test.go +++ b/compression_test.go @@ -22,7 +22,7 @@ func TestTruncWriter(t *testing.T) { if m > n { m = n } - w.Write(p[:m]) + _, _ = w.Write(p[:m]) p = p[m:] } if b.String() != data[:len(data)-len(w.p)] { @@ -46,7 +46,7 @@ func BenchmarkWriteNoCompression(b *testing.B) { messages := textMessages(100) b.ResetTimer() for i := 0; i < b.N; i++ { - c.WriteMessage(TextMessage, messages[i%len(messages)]) + _ = c.WriteMessage(TextMessage, messages[i%len(messages)]) } b.ReportAllocs() } @@ -59,7 +59,7 @@ func BenchmarkWriteWithCompression(b *testing.B) { c.newCompressionWriter = compressNoContextTakeover b.ResetTimer() for i := 0; i < b.N; i++ { - c.WriteMessage(TextMessage, messages[i%len(messages)]) + _ = c.WriteMessage(TextMessage, messages[i%len(messages)]) } b.ReportAllocs() } diff --git a/conn.go b/conn.go index 1bc4d3d..476616a 100644 --- a/conn.go +++ b/conn.go @@ -370,7 +370,9 @@ func (c *Conn) read(n int) ([]byte, error) { if err == io.EOF { err = errUnexpectedEOF } - c.br.Discard(len(p)) + // Discard is guaranteed to succeed because the number of bytes to discard + // is less than or equal to the number of bytes buffered. + _, _ = c.br.Discard(len(p)) return p, err } @@ -385,7 +387,9 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error return err } - c.conn.SetWriteDeadline(deadline) + if err := c.conn.SetWriteDeadline(deadline); err != nil { + return c.writeFatal(err) + } if len(buf1) == 0 { _, err = c.conn.Write(buf0) } else { @@ -395,7 +399,7 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error return c.writeFatal(err) } if frameType == CloseMessage { - c.writeFatal(ErrCloseSent) + _ = c.writeFatal(ErrCloseSent) } return nil } @@ -458,13 +462,14 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er return err } - c.conn.SetWriteDeadline(deadline) - _, err = c.conn.Write(buf) - if err != nil { + if err := c.conn.SetWriteDeadline(deadline); err != nil { + return c.writeFatal(err) + } + if _, err = c.conn.Write(buf); err != nil { return c.writeFatal(err) } if messageType == CloseMessage { - c.writeFatal(ErrCloseSent) + _ = c.writeFatal(ErrCloseSent) } return err } @@ -628,7 +633,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error { } if final { - w.endMessage(errWriteClosed) + _ = w.endMessage(errWriteClosed) return nil } @@ -815,7 +820,7 @@ func (c *Conn) advanceFrame() (int, error) { rsv2 := p[0]&rsv2Bit != 0 rsv3 := p[0]&rsv3Bit != 0 mask := p[1]&maskBit != 0 - c.setReadRemaining(int64(p[1] & 0x7f)) + _ = c.setReadRemaining(int64(p[1] & 0x7f)) // will not fail because argument is >= 0 c.readDecompress = false if rsv1 { @@ -920,7 +925,8 @@ func (c *Conn) advanceFrame() (int, error) { } if c.readLimit > 0 && c.readLength > c.readLimit { - c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) + // Make a best effort to send a close message describing the problem. + _ = c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) return noFrame, ErrReadLimit } @@ -932,7 +938,7 @@ func (c *Conn) advanceFrame() (int, error) { var payload []byte if c.readRemaining > 0 { payload, err = c.read(int(c.readRemaining)) - c.setReadRemaining(0) + _ = c.setReadRemaining(0) // will not fail because argument is >= 0 if err != nil { return noFrame, err } @@ -979,7 +985,8 @@ func (c *Conn) handleProtocolError(message string) error { if len(data) > maxControlFramePayloadSize { data = data[:maxControlFramePayloadSize] } - c.WriteControl(CloseMessage, data, time.Now().Add(writeWait)) + // Make a best effor to send a close message describing the problem. + _ = c.WriteControl(CloseMessage, data, time.Now().Add(writeWait)) return errors.New("websocket: " + message) } @@ -1052,7 +1059,7 @@ func (r *messageReader) Read(b []byte) (int, error) { } rem := c.readRemaining rem -= int64(n) - c.setReadRemaining(rem) + _ = c.setReadRemaining(rem) // rem is guaranteed to be >= 0 if c.readRemaining > 0 && c.readErr == io.EOF { c.readErr = errUnexpectedEOF } @@ -1134,7 +1141,8 @@ func (c *Conn) SetCloseHandler(h func(code int, text string) error) { if h == nil { h = func(code int, text string) error { message := FormatCloseMessage(code, "") - c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) + // Make a best effor to send the close message. + _ = c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) return nil } } diff --git a/conn_broadcast_test.go b/conn_broadcast_test.go index d8a6492..540be6b 100644 --- a/conn_broadcast_test.go +++ b/conn_broadcast_test.go @@ -69,9 +69,9 @@ func (b *broadcastBench) makeConns(numConns int) { select { case msg := <-c.msgCh: if msg.prepared != nil { - c.conn.WritePreparedMessage(msg.prepared) + _ = c.conn.WritePreparedMessage(msg.prepared) } else { - c.conn.WriteMessage(TextMessage, msg.payload) + _ = c.conn.WriteMessage(TextMessage, msg.payload) } val := atomic.AddInt32(&b.count, 1) if val%int32(numConns) == 0 { diff --git a/conn_test.go b/conn_test.go index e9f5441..3b244a9 100644 --- a/conn_test.go +++ b/conn_test.go @@ -157,7 +157,7 @@ func TestControl(t *testing.T) { wc := newTestConn(nil, &connBuf, isServer) rc := newTestConn(&connBuf, nil, !isServer) if isWriteControl { - wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)) + _ = wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)) } else { w, err := wc.NextWriter(PongMessage) if err != nil { @@ -174,7 +174,7 @@ func TestControl(t *testing.T) { } var actualMessage string rc.SetPongHandler(func(s string) error { actualMessage = s; return nil }) - rc.NextReader() + _, _, _ = rc.NextReader() if actualMessage != message { t.Errorf("%s: pong=%q, want %q", name, actualMessage, message) continue @@ -358,8 +358,8 @@ func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) { rc := newTestConn(&b1, &b2, true) w, _ := wc.NextWriter(BinaryMessage) - w.Write(make([]byte, bufSize+bufSize/2)) - wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second)) + _, _ = w.Write(make([]byte, bufSize+bufSize/2)) + _ = wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second)) w.Close() op, r, err := rc.NextReader() @@ -385,7 +385,7 @@ func TestEOFWithinFrame(t *testing.T) { rc := newTestConn(&b, nil, true) w, _ := wc.NextWriter(BinaryMessage) - w.Write(make([]byte, bufSize)) + _, _ = w.Write(make([]byte, bufSize)) w.Close() if n >= b.Len() { @@ -419,7 +419,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) { rc := newTestConn(&b1, &b2, true) w, _ := wc.NextWriter(BinaryMessage) - w.Write(make([]byte, bufSize+bufSize/2)) + _, _ = w.Write(make([]byte, bufSize+bufSize/2)) op, r, err := rc.NextReader() if op != BinaryMessage || err != nil { @@ -438,7 +438,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) { func TestWriteAfterMessageWriterClose(t *testing.T) { wc := newTestConn(nil, &bytes.Buffer{}, false) w, _ := wc.NextWriter(BinaryMessage) - io.WriteString(w, "hello") + _, _ = io.WriteString(w, "hello") if err := w.Close(); err != nil { t.Fatalf("unxpected error closing message writer, %v", err) } @@ -448,7 +448,7 @@ func TestWriteAfterMessageWriterClose(t *testing.T) { } w, _ = wc.NextWriter(BinaryMessage) - io.WriteString(w, "hello") + _, _ = io.WriteString(w, "hello") // close w by getting next writer _, err := wc.NextWriter(BinaryMessage) @@ -473,13 +473,13 @@ func TestReadLimit(t *testing.T) { // Send message at the limit with interleaved pong. w, _ := wc.NextWriter(BinaryMessage) - w.Write(message[:readLimit-1]) - wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)) - w.Write(message[:1]) + _, _ = w.Write(message[:readLimit-1]) + _ = wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)) + _, _ = w.Write(message[:1]) w.Close() // Send message larger than the limit. - wc.WriteMessage(BinaryMessage, message[:readLimit+1]) + _ = wc.WriteMessage(BinaryMessage, message[:readLimit+1]) op, _, err := rc.NextReader() if op != BinaryMessage || err != nil { @@ -592,7 +592,7 @@ func TestBufioReadBytes(t *testing.T) { rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil) w, _ := wc.NextWriter(BinaryMessage) - w.Write(m) + _, _ = w.Write(m) w.Close() op, r, err := rc.NextReader() @@ -666,7 +666,7 @@ func TestConcurrentWritePanic(t *testing.T) { w := blockingWriter{make(chan struct{}), make(chan struct{})} c := newTestConn(nil, w, false) go func() { - c.WriteMessage(TextMessage, []byte{}) + _ = c.WriteMessage(TextMessage, []byte{}) }() // wait for goroutine to block in write. @@ -679,7 +679,7 @@ func TestConcurrentWritePanic(t *testing.T) { } }() - c.WriteMessage(TextMessage, []byte{}) + _ = c.WriteMessage(TextMessage, []byte{}) t.Fatal("should not get here") } @@ -699,7 +699,7 @@ func TestFailedConnectionReadPanic(t *testing.T) { }() for i := 0; i < 20000; i++ { - c.ReadMessage() + _, _, _ = c.ReadMessage() } t.Fatal("should not get here") } diff --git a/join_test.go b/join_test.go index 961ac04..37bb30f 100644 --- a/join_test.go +++ b/join_test.go @@ -19,7 +19,7 @@ func TestJoinMessages(t *testing.T) { wc := newTestConn(nil, &connBuf, true) rc := newTestConn(&connBuf, nil, false) for _, m := range messages { - wc.WriteMessage(BinaryMessage, []byte(m)) + _ = wc.WriteMessage(BinaryMessage, []byte(m)) } var result bytes.Buffer diff --git a/prepared_test.go b/prepared_test.go index 536d58d..50d065e 100644 --- a/prepared_test.go +++ b/prepared_test.go @@ -45,7 +45,9 @@ func TestPreparedMessage(t *testing.T) { if tt.enableWriteCompression { c.newCompressionWriter = compressNoContextTakeover } - c.SetCompressionLevel(tt.compressionLevel) + if err := c.SetCompressionLevel(tt.compressionLevel); err != nil { + t.Fatal(err) + } // Seed random number generator for consistent frame mask. testRand.Seed(1234) diff --git a/server.go b/server.go index b76131d..02ea01f 100644 --- a/server.go +++ b/server.go @@ -178,6 +178,18 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade "websocket: hijack: "+err.Error()) } + // Close the network connection when returning an error. The variable + // netConn is set to nil before the success return at the end of the + // function. + defer func() { + if netConn != nil { + // It's safe to ignore the error from Close() because this code is + // only executed when returning a more important error to the + // application. + _ = netConn.Close() + } + }() + var br *bufio.Reader if u.ReadBufferSize == 0 && brw.Reader.Size() > 256 { // Use hijacked buffered reader as the connection reader. @@ -244,20 +256,30 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade } p = append(p, "\r\n"...) - // Clear deadlines set by HTTP server. - netConn.SetDeadline(time.Time{}) - if u.HandshakeTimeout > 0 { - netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)) + if err := netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)); err != nil { + return nil, err + } + } else { + // Clear deadlines set by HTTP server. + if err := netConn.SetDeadline(time.Time{}); err != nil { + return nil, err + } } + if _, err = netConn.Write(p); err != nil { - netConn.Close() return nil, err } if u.HandshakeTimeout > 0 { - netConn.SetWriteDeadline(time.Time{}) + if err := netConn.SetWriteDeadline(time.Time{}); err != nil { + return nil, err + } } + // Success! Set netConn to nil to stop the deferred function above from + // closing the network connection. + netConn = nil + return c, nil }