diff --git a/client.go b/client.go index 04fdafe..a81ddb1 100644 --- a/client.go +++ b/client.go @@ -328,9 +328,15 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h }) } + // Close the network connection when returning an error. The variable + // netConn is set to nil before the success return at the end of the + // function. defer func() { if netConn != nil { - netConn.Close() + // It's safe to ignore the error from Close() because this code is + // only executed when returning a more important error to the + // application. + _ = netConn.Close() } }() @@ -421,8 +427,14 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") - netConn.SetDeadline(time.Time{}) - netConn = nil // to avoid close in defer. + if err := netConn.SetDeadline(time.Time{}); err != nil { + return nil, resp, err + } + + // Success! Set netConn to nil to stop the deferred function above from + // closing the network connection. + netConn = nil + return conn, resp, nil } diff --git a/client_server_test.go b/client_server_test.go index a47df48..ae3c638 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -531,7 +531,7 @@ func TestRespOnBadHandshake(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(expectedStatus) - io.WriteString(w, expectedBody) + _, _ = io.WriteString(w, expectedBody) })) defer s.Close() @@ -781,7 +781,7 @@ func TestSocksProxyDial(t *testing.T) { } defer c1.Close() - c1.SetDeadline(time.Now().Add(30 * time.Second)) + _ = c1.SetDeadline(time.Now().Add(30 * time.Second)) buf := make([]byte, 32) if _, err := io.ReadFull(c1, buf[:3]); err != nil { @@ -820,10 +820,10 @@ func TestSocksProxyDial(t *testing.T) { defer c2.Close() done := make(chan struct{}) go func() { - io.Copy(c1, c2) + _, _ = io.Copy(c1, c2) close(done) }() - io.Copy(c2, c1) + _, _ = io.Copy(c2, c1) <-done }() diff --git a/compression.go b/compression.go index 813ffb1..fe1079e 100644 --- a/compression.go +++ b/compression.go @@ -33,7 +33,11 @@ func decompressNoContextTakeover(r io.Reader) io.ReadCloser { "\x01\x00\x00\xff\xff" fr, _ := flateReaderPool.Get().(io.ReadCloser) - fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) + mr := io.MultiReader(r, strings.NewReader(tail)) + if err := fr.(flate.Resetter).Reset(mr, nil); err != nil { + // Reset never fails, but handle error in case that changes. + fr = flate.NewReader(mr) + } return &flateReadWrapper{fr} } diff --git a/compression_test.go b/compression_test.go index 8a26b30..13b0860 100644 --- a/compression_test.go +++ b/compression_test.go @@ -23,7 +23,7 @@ func TestTruncWriter(t *testing.T) { if m > n { m = n } - w.Write(p[:m]) + _, _ = w.Write(p[:m]) p = p[m:] } if b.String() != data[:len(data)-len(w.p)] { @@ -47,7 +47,7 @@ func BenchmarkWriteNoCompression(b *testing.B) { messages := textMessages(100) b.ResetTimer() for i := 0; i < b.N; i++ { - c.WriteMessage(TextMessage, messages[i%len(messages)]) + _ = c.WriteMessage(TextMessage, messages[i%len(messages)]) } b.ReportAllocs() } @@ -60,7 +60,7 @@ func BenchmarkWriteWithCompression(b *testing.B) { c.newCompressionWriter = compressNoContextTakeover b.ResetTimer() for i := 0; i < b.N; i++ { - c.WriteMessage(TextMessage, messages[i%len(messages)]) + _ = c.WriteMessage(TextMessage, messages[i%len(messages)]) } b.ReportAllocs() } diff --git a/conn.go b/conn.go index 5161ef8..7f3e16c 100644 --- a/conn.go +++ b/conn.go @@ -372,7 +372,9 @@ func (c *Conn) read(n int) ([]byte, error) { if err == io.EOF { err = errUnexpectedEOF } - c.br.Discard(len(p)) + // Discard is guaranteed to succeed because the number of bytes to discard + // is less than or equal to the number of bytes buffered. + _, _ = c.br.Discard(len(p)) return p, err } @@ -387,7 +389,9 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error return err } - c.conn.SetWriteDeadline(deadline) + if err := c.conn.SetWriteDeadline(deadline); err != nil { + return c.writeFatal(err) + } if len(buf1) == 0 { _, err = c.conn.Write(buf0) } else { @@ -397,7 +401,7 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error return c.writeFatal(err) } if frameType == CloseMessage { - c.writeFatal(ErrCloseSent) + _ = c.writeFatal(ErrCloseSent) } return nil } @@ -460,13 +464,14 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er return err } - c.conn.SetWriteDeadline(deadline) - _, err = c.conn.Write(buf) - if err != nil { + if err := c.conn.SetWriteDeadline(deadline); err != nil { + return c.writeFatal(err) + } + if _, err = c.conn.Write(buf); err != nil { return c.writeFatal(err) } if messageType == CloseMessage { - c.writeFatal(ErrCloseSent) + _ = c.writeFatal(ErrCloseSent) } return err } @@ -630,7 +635,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error { } if final { - w.endMessage(errWriteClosed) + _ = w.endMessage(errWriteClosed) return nil } @@ -817,7 +822,7 @@ func (c *Conn) advanceFrame() (int, error) { rsv2 := p[0]&rsv2Bit != 0 rsv3 := p[0]&rsv3Bit != 0 mask := p[1]&maskBit != 0 - c.setReadRemaining(int64(p[1] & 0x7f)) + _ = c.setReadRemaining(int64(p[1] & 0x7f)) // will not fail because argument is >= 0 c.readDecompress = false if rsv1 { @@ -922,7 +927,8 @@ func (c *Conn) advanceFrame() (int, error) { } if c.readLimit > 0 && c.readLength > c.readLimit { - c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) + // Make a best effort to send a close message describing the problem. + _ = c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) return noFrame, ErrReadLimit } @@ -934,7 +940,7 @@ func (c *Conn) advanceFrame() (int, error) { var payload []byte if c.readRemaining > 0 { payload, err = c.read(int(c.readRemaining)) - c.setReadRemaining(0) + _ = c.setReadRemaining(0) // will not fail because argument is >= 0 if err != nil { return noFrame, err } @@ -981,7 +987,8 @@ func (c *Conn) handleProtocolError(message string) error { if len(data) > maxControlFramePayloadSize { data = data[:maxControlFramePayloadSize] } - c.WriteControl(CloseMessage, data, time.Now().Add(writeWait)) + // Make a best effor to send a close message describing the problem. + _ = c.WriteControl(CloseMessage, data, time.Now().Add(writeWait)) return errors.New("websocket: " + message) } @@ -1054,7 +1061,7 @@ func (r *messageReader) Read(b []byte) (int, error) { } rem := c.readRemaining rem -= int64(n) - c.setReadRemaining(rem) + _ = c.setReadRemaining(rem) // rem is guaranteed to be >= 0 if c.readRemaining > 0 && c.readErr == io.EOF { c.readErr = errUnexpectedEOF } @@ -1136,7 +1143,8 @@ func (c *Conn) SetCloseHandler(h func(code int, text string) error) { if h == nil { h = func(code int, text string) error { message := FormatCloseMessage(code, "") - c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) + // Make a best effor to send the close message. + _ = c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) return nil } } diff --git a/conn_broadcast_test.go b/conn_broadcast_test.go index 6e744fc..2cada79 100644 --- a/conn_broadcast_test.go +++ b/conn_broadcast_test.go @@ -70,9 +70,9 @@ func (b *broadcastBench) makeConns(numConns int) { select { case msg := <-c.msgCh: if msg.prepared != nil { - c.conn.WritePreparedMessage(msg.prepared) + _ = c.conn.WritePreparedMessage(msg.prepared) } else { - c.conn.WriteMessage(TextMessage, msg.payload) + _ = c.conn.WriteMessage(TextMessage, msg.payload) } val := atomic.AddInt32(&b.count, 1) if val%int32(numConns) == 0 { diff --git a/conn_test.go b/conn_test.go index 06e5184..eb5ab91 100644 --- a/conn_test.go +++ b/conn_test.go @@ -158,7 +158,7 @@ func TestControl(t *testing.T) { wc := newTestConn(nil, &connBuf, isServer) rc := newTestConn(&connBuf, nil, !isServer) if isWriteControl { - wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)) + _ = wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)) } else { w, err := wc.NextWriter(PongMessage) if err != nil { @@ -175,7 +175,7 @@ func TestControl(t *testing.T) { } var actualMessage string rc.SetPongHandler(func(s string) error { actualMessage = s; return nil }) - rc.NextReader() + _, _, _ = rc.NextReader() if actualMessage != message { t.Errorf("%s: pong=%q, want %q", name, actualMessage, message) continue @@ -359,8 +359,8 @@ func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) { rc := newTestConn(&b1, &b2, true) w, _ := wc.NextWriter(BinaryMessage) - w.Write(make([]byte, bufSize+bufSize/2)) - wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second)) + _, _ = w.Write(make([]byte, bufSize+bufSize/2)) + _ = wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second)) w.Close() op, r, err := rc.NextReader() @@ -386,7 +386,7 @@ func TestEOFWithinFrame(t *testing.T) { rc := newTestConn(&b, nil, true) w, _ := wc.NextWriter(BinaryMessage) - w.Write(make([]byte, bufSize)) + _, _ = w.Write(make([]byte, bufSize)) w.Close() if n >= b.Len() { @@ -420,7 +420,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) { rc := newTestConn(&b1, &b2, true) w, _ := wc.NextWriter(BinaryMessage) - w.Write(make([]byte, bufSize+bufSize/2)) + _, _ = w.Write(make([]byte, bufSize+bufSize/2)) op, r, err := rc.NextReader() if op != BinaryMessage || err != nil { @@ -439,7 +439,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) { func TestWriteAfterMessageWriterClose(t *testing.T) { wc := newTestConn(nil, &bytes.Buffer{}, false) w, _ := wc.NextWriter(BinaryMessage) - io.WriteString(w, "hello") + _, _ = io.WriteString(w, "hello") if err := w.Close(); err != nil { t.Fatalf("unxpected error closing message writer, %v", err) } @@ -449,7 +449,7 @@ func TestWriteAfterMessageWriterClose(t *testing.T) { } w, _ = wc.NextWriter(BinaryMessage) - io.WriteString(w, "hello") + _, _ = io.WriteString(w, "hello") // close w by getting next writer _, err := wc.NextWriter(BinaryMessage) @@ -474,13 +474,13 @@ func TestReadLimit(t *testing.T) { // Send message at the limit with interleaved pong. w, _ := wc.NextWriter(BinaryMessage) - w.Write(message[:readLimit-1]) - wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)) - w.Write(message[:1]) + _, _ = w.Write(message[:readLimit-1]) + _ = wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)) + _, _ = w.Write(message[:1]) w.Close() // Send message larger than the limit. - wc.WriteMessage(BinaryMessage, message[:readLimit+1]) + _ = wc.WriteMessage(BinaryMessage, message[:readLimit+1]) op, _, err := rc.NextReader() if op != BinaryMessage || err != nil { @@ -593,7 +593,7 @@ func TestBufioReadBytes(t *testing.T) { rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil) w, _ := wc.NextWriter(BinaryMessage) - w.Write(m) + _, _ = w.Write(m) w.Close() op, r, err := rc.NextReader() @@ -667,7 +667,7 @@ func TestConcurrentWritePanic(t *testing.T) { w := blockingWriter{make(chan struct{}), make(chan struct{})} c := newTestConn(nil, w, false) go func() { - c.WriteMessage(TextMessage, []byte{}) + _ = c.WriteMessage(TextMessage, []byte{}) }() // wait for goroutine to block in write. @@ -680,7 +680,7 @@ func TestConcurrentWritePanic(t *testing.T) { } }() - c.WriteMessage(TextMessage, []byte{}) + _ = c.WriteMessage(TextMessage, []byte{}) t.Fatal("should not get here") } @@ -700,7 +700,7 @@ func TestFailedConnectionReadPanic(t *testing.T) { }() for i := 0; i < 20000; i++ { - c.ReadMessage() + _, _, _ = c.ReadMessage() } t.Fatal("should not get here") } diff --git a/join_test.go b/join_test.go index 961ac04..37bb30f 100644 --- a/join_test.go +++ b/join_test.go @@ -19,7 +19,7 @@ func TestJoinMessages(t *testing.T) { wc := newTestConn(nil, &connBuf, true) rc := newTestConn(&connBuf, nil, false) for _, m := range messages { - wc.WriteMessage(BinaryMessage, []byte(m)) + _ = wc.WriteMessage(BinaryMessage, []byte(m)) } var result bytes.Buffer diff --git a/prepared_test.go b/prepared_test.go index 2297802..8842a72 100644 --- a/prepared_test.go +++ b/prepared_test.go @@ -40,7 +40,9 @@ func TestPreparedMessage(t *testing.T) { if tt.enableWriteCompression { c.newCompressionWriter = compressNoContextTakeover } - c.SetCompressionLevel(tt.compressionLevel) + if err := c.SetCompressionLevel(tt.compressionLevel); err != nil { + t.Fatal(err) + } // Seed random number generator for consistent frame mask. rand.Seed(1234) diff --git a/server.go b/server.go index bb33597..a76b8a1 100644 --- a/server.go +++ b/server.go @@ -182,8 +182,19 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade return u.returnError(w, r, http.StatusInternalServerError, err.Error()) } + // Close the network connection when returning an error. The variable + // netConn is set to nil before the success return at the end of the + // function. + defer func() { + if netConn != nil { + // It's safe to ignore the error from Close() because this code is + // only executed when returning a more important error to the + // application. + _ = netConn.Close() + } + }() + if brw.Reader.Buffered() > 0 { - netConn.Close() return nil, errors.New("websocket: client sent data before handshake is complete") } @@ -247,20 +258,30 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade } p = append(p, "\r\n"...) - // Clear deadlines set by HTTP server. - netConn.SetDeadline(time.Time{}) - if u.HandshakeTimeout > 0 { - netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)) + if err := netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)); err != nil { + return nil, err + } + } else { + // Clear deadlines set by HTTP server. + if err := netConn.SetDeadline(time.Time{}); err != nil { + return nil, err + } } + if _, err = netConn.Write(p); err != nil { - netConn.Close() return nil, err } if u.HandshakeTimeout > 0 { - netConn.SetWriteDeadline(time.Time{}) + if err := netConn.SetWriteDeadline(time.Time{}); err != nil { + return nil, err + } } + // Success! Set netConn to nil to stop the deferred function above from + // closing the network connection. + netConn = nil + return c, nil } @@ -356,7 +377,7 @@ func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte { // bufio.Writer's underlying writer. var wh writeHook bw.Reset(&wh) - bw.WriteByte(0) + _ = bw.WriteByte(0) bw.Flush() bw.Reset(originalWriter)