This commit is contained in:
Canelo Hill 2024-06-20 03:35:45 +00:00 committed by GitHub
commit ec9a86625d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 100 additions and 53 deletions

View File

@ -305,9 +305,15 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
}) })
} }
// Close the network connection when returning an error. The variable
// netConn is set to nil before the success return at the end of the
// function.
defer func() { defer func() {
if netConn != nil { if netConn != nil {
netConn.Close() // It's safe to ignore the error from Close() because this code is
// only executed when returning a more important error to the
// application.
_ = netConn.Close()
} }
}() }()
@ -398,8 +404,14 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
resp.Body = io.NopCloser(bytes.NewReader([]byte{})) resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
netConn.SetDeadline(time.Time{}) if err := netConn.SetDeadline(time.Time{}); err != nil {
netConn = nil // to avoid close in defer. return nil, resp, err
}
// Success! Set netConn to nil to stop the deferred function above from
// closing the network connection.
netConn = nil
return conn, resp, nil return conn, resp, nil
} }

View File

@ -546,7 +546,7 @@ func TestRespOnBadHandshake(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(expectedStatus) w.WriteHeader(expectedStatus)
io.WriteString(w, expectedBody) _, _ = io.WriteString(w, expectedBody)
})) }))
defer s.Close() defer s.Close()
@ -796,7 +796,7 @@ func TestSocksProxyDial(t *testing.T) {
} }
defer c1.Close() defer c1.Close()
c1.SetDeadline(time.Now().Add(30 * time.Second)) _ = c1.SetDeadline(time.Now().Add(30 * time.Second))
buf := make([]byte, 32) buf := make([]byte, 32)
if _, err := io.ReadFull(c1, buf[:3]); err != nil { if _, err := io.ReadFull(c1, buf[:3]); err != nil {
@ -835,10 +835,10 @@ func TestSocksProxyDial(t *testing.T) {
defer c2.Close() defer c2.Close()
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
io.Copy(c1, c2) _, _ = io.Copy(c1, c2)
close(done) close(done)
}() }()
io.Copy(c2, c1) _, _ = io.Copy(c2, c1)
<-done <-done
}() }()

View File

@ -33,7 +33,11 @@ func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
"\x01\x00\x00\xff\xff" "\x01\x00\x00\xff\xff"
fr, _ := flateReaderPool.Get().(io.ReadCloser) fr, _ := flateReaderPool.Get().(io.ReadCloser)
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) mr := io.MultiReader(r, strings.NewReader(tail))
if err := fr.(flate.Resetter).Reset(mr, nil); err != nil {
// Reset never fails, but handle error in case that changes.
fr = flate.NewReader(mr)
}
return &flateReadWrapper{fr} return &flateReadWrapper{fr}
} }

View File

@ -22,7 +22,7 @@ func TestTruncWriter(t *testing.T) {
if m > n { if m > n {
m = n m = n
} }
w.Write(p[:m]) _, _ = w.Write(p[:m])
p = p[m:] p = p[m:]
} }
if b.String() != data[:len(data)-len(w.p)] { if b.String() != data[:len(data)-len(w.p)] {
@ -46,7 +46,7 @@ func BenchmarkWriteNoCompression(b *testing.B) {
messages := textMessages(100) messages := textMessages(100)
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
c.WriteMessage(TextMessage, messages[i%len(messages)]) _ = c.WriteMessage(TextMessage, messages[i%len(messages)])
} }
b.ReportAllocs() b.ReportAllocs()
} }
@ -59,7 +59,7 @@ func BenchmarkWriteWithCompression(b *testing.B) {
c.newCompressionWriter = compressNoContextTakeover c.newCompressionWriter = compressNoContextTakeover
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
c.WriteMessage(TextMessage, messages[i%len(messages)]) _ = c.WriteMessage(TextMessage, messages[i%len(messages)])
} }
b.ReportAllocs() b.ReportAllocs()
} }

36
conn.go
View File

@ -371,7 +371,9 @@ func (c *Conn) read(n int) ([]byte, error) {
if err == io.EOF { if err == io.EOF {
err = errUnexpectedEOF err = errUnexpectedEOF
} }
c.br.Discard(len(p)) // Discard is guaranteed to succeed because the number of bytes to discard
// is less than or equal to the number of bytes buffered.
_, _ = c.br.Discard(len(p))
return p, err return p, err
} }
@ -386,7 +388,9 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error
return err return err
} }
c.conn.SetWriteDeadline(deadline) if err := c.conn.SetWriteDeadline(deadline); err != nil {
return c.writeFatal(err)
}
if len(buf1) == 0 { if len(buf1) == 0 {
_, err = c.conn.Write(buf0) _, err = c.conn.Write(buf0)
} else { } else {
@ -396,7 +400,7 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error
return c.writeFatal(err) return c.writeFatal(err)
} }
if frameType == CloseMessage { if frameType == CloseMessage {
c.writeFatal(ErrCloseSent) _ = c.writeFatal(ErrCloseSent)
} }
return nil return nil
} }
@ -459,13 +463,14 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
return err return err
} }
c.conn.SetWriteDeadline(deadline) if err := c.conn.SetWriteDeadline(deadline); err != nil {
_, err = c.conn.Write(buf) return c.writeFatal(err)
if err != nil { }
if _, err = c.conn.Write(buf); err != nil {
return c.writeFatal(err) return c.writeFatal(err)
} }
if messageType == CloseMessage { if messageType == CloseMessage {
c.writeFatal(ErrCloseSent) _ = c.writeFatal(ErrCloseSent)
} }
return err return err
} }
@ -629,7 +634,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
} }
if final { if final {
w.endMessage(errWriteClosed) _ = w.endMessage(errWriteClosed)
return nil return nil
} }
@ -816,7 +821,7 @@ func (c *Conn) advanceFrame() (int, error) {
rsv2 := p[0]&rsv2Bit != 0 rsv2 := p[0]&rsv2Bit != 0
rsv3 := p[0]&rsv3Bit != 0 rsv3 := p[0]&rsv3Bit != 0
mask := p[1]&maskBit != 0 mask := p[1]&maskBit != 0
c.setReadRemaining(int64(p[1] & 0x7f)) _ = c.setReadRemaining(int64(p[1] & 0x7f)) // will not fail because argument is >= 0
c.readDecompress = false c.readDecompress = false
if rsv1 { if rsv1 {
@ -921,7 +926,8 @@ func (c *Conn) advanceFrame() (int, error) {
} }
if c.readLimit > 0 && c.readLength > c.readLimit { if c.readLimit > 0 && c.readLength > c.readLimit {
c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) // Make a best effort to send a close message describing the problem.
_ = c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
return noFrame, ErrReadLimit return noFrame, ErrReadLimit
} }
@ -933,7 +939,7 @@ func (c *Conn) advanceFrame() (int, error) {
var payload []byte var payload []byte
if c.readRemaining > 0 { if c.readRemaining > 0 {
payload, err = c.read(int(c.readRemaining)) payload, err = c.read(int(c.readRemaining))
c.setReadRemaining(0) _ = c.setReadRemaining(0) // will not fail because argument is >= 0
if err != nil { if err != nil {
return noFrame, err return noFrame, err
} }
@ -980,7 +986,8 @@ func (c *Conn) handleProtocolError(message string) error {
if len(data) > maxControlFramePayloadSize { if len(data) > maxControlFramePayloadSize {
data = data[:maxControlFramePayloadSize] data = data[:maxControlFramePayloadSize]
} }
c.WriteControl(CloseMessage, data, time.Now().Add(writeWait)) // Make a best effor to send a close message describing the problem.
_ = c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
return errors.New("websocket: " + message) return errors.New("websocket: " + message)
} }
@ -1053,7 +1060,7 @@ func (r *messageReader) Read(b []byte) (int, error) {
} }
rem := c.readRemaining rem := c.readRemaining
rem -= int64(n) rem -= int64(n)
c.setReadRemaining(rem) _ = c.setReadRemaining(rem) // rem is guaranteed to be >= 0
if c.readRemaining > 0 && c.readErr == io.EOF { if c.readRemaining > 0 && c.readErr == io.EOF {
c.readErr = errUnexpectedEOF c.readErr = errUnexpectedEOF
} }
@ -1135,7 +1142,8 @@ func (c *Conn) SetCloseHandler(h func(code int, text string) error) {
if h == nil { if h == nil {
h = func(code int, text string) error { h = func(code int, text string) error {
message := FormatCloseMessage(code, "") message := FormatCloseMessage(code, "")
c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) // Make a best effor to send the close message.
_ = c.WriteControl(CloseMessage, message, time.Now().Add(writeWait))
return nil return nil
} }
} }

View File

@ -69,9 +69,9 @@ func (b *broadcastBench) makeConns(numConns int) {
select { select {
case msg := <-c.msgCh: case msg := <-c.msgCh:
if msg.prepared != nil { if msg.prepared != nil {
c.conn.WritePreparedMessage(msg.prepared) _ = c.conn.WritePreparedMessage(msg.prepared)
} else { } else {
c.conn.WriteMessage(TextMessage, msg.payload) _ = c.conn.WriteMessage(TextMessage, msg.payload)
} }
val := atomic.AddInt32(&b.count, 1) val := atomic.AddInt32(&b.count, 1)
if val%int32(numConns) == 0 { if val%int32(numConns) == 0 {

View File

@ -157,7 +157,7 @@ func TestControl(t *testing.T) {
wc := newTestConn(nil, &connBuf, isServer) wc := newTestConn(nil, &connBuf, isServer)
rc := newTestConn(&connBuf, nil, !isServer) rc := newTestConn(&connBuf, nil, !isServer)
if isWriteControl { if isWriteControl {
wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)) _ = wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
} else { } else {
w, err := wc.NextWriter(PongMessage) w, err := wc.NextWriter(PongMessage)
if err != nil { if err != nil {
@ -174,7 +174,7 @@ func TestControl(t *testing.T) {
} }
var actualMessage string var actualMessage string
rc.SetPongHandler(func(s string) error { actualMessage = s; return nil }) rc.SetPongHandler(func(s string) error { actualMessage = s; return nil })
rc.NextReader() _, _, _ = rc.NextReader()
if actualMessage != message { if actualMessage != message {
t.Errorf("%s: pong=%q, want %q", name, actualMessage, message) t.Errorf("%s: pong=%q, want %q", name, actualMessage, message)
continue continue
@ -358,8 +358,8 @@ func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
rc := newTestConn(&b1, &b2, true) rc := newTestConn(&b1, &b2, true)
w, _ := wc.NextWriter(BinaryMessage) w, _ := wc.NextWriter(BinaryMessage)
w.Write(make([]byte, bufSize+bufSize/2)) _, _ = w.Write(make([]byte, bufSize+bufSize/2))
wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second)) _ = wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second))
w.Close() w.Close()
op, r, err := rc.NextReader() op, r, err := rc.NextReader()
@ -385,7 +385,7 @@ func TestEOFWithinFrame(t *testing.T) {
rc := newTestConn(&b, nil, true) rc := newTestConn(&b, nil, true)
w, _ := wc.NextWriter(BinaryMessage) w, _ := wc.NextWriter(BinaryMessage)
w.Write(make([]byte, bufSize)) _, _ = w.Write(make([]byte, bufSize))
w.Close() w.Close()
if n >= b.Len() { if n >= b.Len() {
@ -419,7 +419,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
rc := newTestConn(&b1, &b2, true) rc := newTestConn(&b1, &b2, true)
w, _ := wc.NextWriter(BinaryMessage) w, _ := wc.NextWriter(BinaryMessage)
w.Write(make([]byte, bufSize+bufSize/2)) _, _ = w.Write(make([]byte, bufSize+bufSize/2))
op, r, err := rc.NextReader() op, r, err := rc.NextReader()
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
@ -438,7 +438,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
func TestWriteAfterMessageWriterClose(t *testing.T) { func TestWriteAfterMessageWriterClose(t *testing.T) {
wc := newTestConn(nil, &bytes.Buffer{}, false) wc := newTestConn(nil, &bytes.Buffer{}, false)
w, _ := wc.NextWriter(BinaryMessage) w, _ := wc.NextWriter(BinaryMessage)
io.WriteString(w, "hello") _, _ = io.WriteString(w, "hello")
if err := w.Close(); err != nil { if err := w.Close(); err != nil {
t.Fatalf("unxpected error closing message writer, %v", err) t.Fatalf("unxpected error closing message writer, %v", err)
} }
@ -448,7 +448,7 @@ func TestWriteAfterMessageWriterClose(t *testing.T) {
} }
w, _ = wc.NextWriter(BinaryMessage) w, _ = wc.NextWriter(BinaryMessage)
io.WriteString(w, "hello") _, _ = io.WriteString(w, "hello")
// close w by getting next writer // close w by getting next writer
_, err := wc.NextWriter(BinaryMessage) _, err := wc.NextWriter(BinaryMessage)
@ -473,13 +473,13 @@ func TestReadLimit(t *testing.T) {
// Send message at the limit with interleaved pong. // Send message at the limit with interleaved pong.
w, _ := wc.NextWriter(BinaryMessage) w, _ := wc.NextWriter(BinaryMessage)
w.Write(message[:readLimit-1]) _, _ = w.Write(message[:readLimit-1])
wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)) _ = wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
w.Write(message[:1]) _, _ = w.Write(message[:1])
w.Close() w.Close()
// Send message larger than the limit. // Send message larger than the limit.
wc.WriteMessage(BinaryMessage, message[:readLimit+1]) _ = wc.WriteMessage(BinaryMessage, message[:readLimit+1])
op, _, err := rc.NextReader() op, _, err := rc.NextReader()
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
@ -592,7 +592,7 @@ func TestBufioReadBytes(t *testing.T) {
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil) rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil)
w, _ := wc.NextWriter(BinaryMessage) w, _ := wc.NextWriter(BinaryMessage)
w.Write(m) _, _ = w.Write(m)
w.Close() w.Close()
op, r, err := rc.NextReader() op, r, err := rc.NextReader()
@ -666,7 +666,7 @@ func TestConcurrentWritePanic(t *testing.T) {
w := blockingWriter{make(chan struct{}), make(chan struct{})} w := blockingWriter{make(chan struct{}), make(chan struct{})}
c := newTestConn(nil, w, false) c := newTestConn(nil, w, false)
go func() { go func() {
c.WriteMessage(TextMessage, []byte{}) _ = c.WriteMessage(TextMessage, []byte{})
}() }()
// wait for goroutine to block in write. // wait for goroutine to block in write.
@ -679,7 +679,7 @@ func TestConcurrentWritePanic(t *testing.T) {
} }
}() }()
c.WriteMessage(TextMessage, []byte{}) _ = c.WriteMessage(TextMessage, []byte{})
t.Fatal("should not get here") t.Fatal("should not get here")
} }
@ -699,7 +699,7 @@ func TestFailedConnectionReadPanic(t *testing.T) {
}() }()
for i := 0; i < 20000; i++ { for i := 0; i < 20000; i++ {
c.ReadMessage() _, _, _ = c.ReadMessage()
} }
t.Fatal("should not get here") t.Fatal("should not get here")
} }

View File

@ -19,7 +19,7 @@ func TestJoinMessages(t *testing.T) {
wc := newTestConn(nil, &connBuf, true) wc := newTestConn(nil, &connBuf, true)
rc := newTestConn(&connBuf, nil, false) rc := newTestConn(&connBuf, nil, false)
for _, m := range messages { for _, m := range messages {
wc.WriteMessage(BinaryMessage, []byte(m)) _ = wc.WriteMessage(BinaryMessage, []byte(m))
} }
var result bytes.Buffer var result bytes.Buffer

View File

@ -40,7 +40,9 @@ func TestPreparedMessage(t *testing.T) {
if tt.enableWriteCompression { if tt.enableWriteCompression {
c.newCompressionWriter = compressNoContextTakeover c.newCompressionWriter = compressNoContextTakeover
} }
c.SetCompressionLevel(tt.compressionLevel) if err := c.SetCompressionLevel(tt.compressionLevel); err != nil {
t.Fatal(err)
}
// Seed random number generator for consistent frame mask. // Seed random number generator for consistent frame mask.
rand.Seed(1234) rand.Seed(1234)

View File

@ -178,8 +178,19 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
"websocket: hijack: "+err.Error()) "websocket: hijack: "+err.Error())
} }
// Close the network connection when returning an error. The variable
// netConn is set to nil before the success return at the end of the
// function.
defer func() {
if netConn != nil {
// It's safe to ignore the error from Close() because this code is
// only executed when returning a more important error to the
// application.
_ = netConn.Close()
}
}()
if brw.Reader.Buffered() > 0 { if brw.Reader.Buffered() > 0 {
netConn.Close()
return nil, errors.New("websocket: client sent data before handshake is complete") return nil, errors.New("websocket: client sent data before handshake is complete")
} }
@ -243,20 +254,30 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
} }
p = append(p, "\r\n"...) p = append(p, "\r\n"...)
// Clear deadlines set by HTTP server.
netConn.SetDeadline(time.Time{})
if u.HandshakeTimeout > 0 { if u.HandshakeTimeout > 0 {
netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)) if err := netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)); err != nil {
return nil, err
}
} else {
// Clear deadlines set by HTTP server.
if err := netConn.SetDeadline(time.Time{}); err != nil {
return nil, err
}
} }
if _, err = netConn.Write(p); err != nil { if _, err = netConn.Write(p); err != nil {
netConn.Close()
return nil, err return nil, err
} }
if u.HandshakeTimeout > 0 { if u.HandshakeTimeout > 0 {
netConn.SetWriteDeadline(time.Time{}) if err := netConn.SetWriteDeadline(time.Time{}); err != nil {
return nil, err
}
} }
// Success! Set netConn to nil to stop the deferred function above from
// closing the network connection.
netConn = nil
return c, nil return c, nil
} }
@ -352,7 +373,7 @@ func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte {
// bufio.Writer's underlying writer. // bufio.Writer's underlying writer.
var wh writeHook var wh writeHook
bw.Reset(&wh) bw.Reset(&wh)
bw.WriteByte(0) _ = bw.WriteByte(0)
bw.Flush() bw.Flush()
bw.Reset(originalWriter) bw.Reset(originalWriter)