mirror of https://github.com/gorilla/websocket.git
Handle errcheck warnings
The package ignored errors from net.Conn Set*Deadline in a few places. Update the package to return these errors to the caller. Ignore all other errors reported by errcheck. These errors are safe to ignore because - The function is making a best effort to cleanup while handling another error. - The function call is guaranteed to succeed. - The error is ignored in a test.
This commit is contained in:
parent
a62d9d2a84
commit
3810b2346f
18
client.go
18
client.go
|
@ -305,9 +305,15 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close the network connection when returning an error. The variable
|
||||||
|
// netConn is set to nil before the success return at the end of the
|
||||||
|
// function.
|
||||||
defer func() {
|
defer func() {
|
||||||
if netConn != nil {
|
if netConn != nil {
|
||||||
netConn.Close()
|
// It's safe to ignore the error from Close() because this code is
|
||||||
|
// only executed when returning a more important error to the
|
||||||
|
// application.
|
||||||
|
_ = netConn.Close()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -398,8 +404,14 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
|
resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
|
||||||
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
|
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
|
||||||
|
|
||||||
netConn.SetDeadline(time.Time{})
|
if err := netConn.SetDeadline(time.Time{}); err != nil {
|
||||||
netConn = nil // to avoid close in defer.
|
return nil, resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Success! Set netConn to nil to stop the deferred function above from
|
||||||
|
// closing the network connection.
|
||||||
|
netConn = nil
|
||||||
|
|
||||||
return conn, resp, nil
|
return conn, resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -578,7 +578,7 @@ func TestRespOnBadHandshake(t *testing.T) {
|
||||||
|
|
||||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(expectedStatus)
|
w.WriteHeader(expectedStatus)
|
||||||
io.WriteString(w, expectedBody)
|
_, _ = io.WriteString(w, expectedBody)
|
||||||
}))
|
}))
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
|
@ -828,7 +828,7 @@ func TestSocksProxyDial(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer c1.Close()
|
defer c1.Close()
|
||||||
|
|
||||||
c1.SetDeadline(time.Now().Add(30 * time.Second))
|
_ = c1.SetDeadline(time.Now().Add(30 * time.Second))
|
||||||
|
|
||||||
buf := make([]byte, 32)
|
buf := make([]byte, 32)
|
||||||
if _, err := io.ReadFull(c1, buf[:3]); err != nil {
|
if _, err := io.ReadFull(c1, buf[:3]); err != nil {
|
||||||
|
@ -867,10 +867,10 @@ func TestSocksProxyDial(t *testing.T) {
|
||||||
defer c2.Close()
|
defer c2.Close()
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
io.Copy(c1, c2)
|
_, _ = io.Copy(c1, c2)
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
io.Copy(c2, c1)
|
_, _ = io.Copy(c2, c1)
|
||||||
<-done
|
<-done
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,11 @@ func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
|
||||||
"\x01\x00\x00\xff\xff"
|
"\x01\x00\x00\xff\xff"
|
||||||
|
|
||||||
fr, _ := flateReaderPool.Get().(io.ReadCloser)
|
fr, _ := flateReaderPool.Get().(io.ReadCloser)
|
||||||
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
|
mr := io.MultiReader(r, strings.NewReader(tail))
|
||||||
|
if err := fr.(flate.Resetter).Reset(mr, nil); err != nil {
|
||||||
|
// Reset never fails, but handle error in case that changes.
|
||||||
|
fr = flate.NewReader(mr)
|
||||||
|
}
|
||||||
return &flateReadWrapper{fr}
|
return &flateReadWrapper{fr}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ func TestTruncWriter(t *testing.T) {
|
||||||
if m > n {
|
if m > n {
|
||||||
m = n
|
m = n
|
||||||
}
|
}
|
||||||
w.Write(p[:m])
|
_, _ = w.Write(p[:m])
|
||||||
p = p[m:]
|
p = p[m:]
|
||||||
}
|
}
|
||||||
if b.String() != data[:len(data)-len(w.p)] {
|
if b.String() != data[:len(data)-len(w.p)] {
|
||||||
|
@ -46,7 +46,7 @@ func BenchmarkWriteNoCompression(b *testing.B) {
|
||||||
messages := textMessages(100)
|
messages := textMessages(100)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
c.WriteMessage(TextMessage, messages[i%len(messages)])
|
_ = c.WriteMessage(TextMessage, messages[i%len(messages)])
|
||||||
}
|
}
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
}
|
}
|
||||||
|
@ -59,7 +59,7 @@ func BenchmarkWriteWithCompression(b *testing.B) {
|
||||||
c.newCompressionWriter = compressNoContextTakeover
|
c.newCompressionWriter = compressNoContextTakeover
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
c.WriteMessage(TextMessage, messages[i%len(messages)])
|
_ = c.WriteMessage(TextMessage, messages[i%len(messages)])
|
||||||
}
|
}
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
}
|
}
|
||||||
|
|
36
conn.go
36
conn.go
|
@ -370,7 +370,9 @@ func (c *Conn) read(n int) ([]byte, error) {
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
err = errUnexpectedEOF
|
err = errUnexpectedEOF
|
||||||
}
|
}
|
||||||
c.br.Discard(len(p))
|
// Discard is guaranteed to succeed because the number of bytes to discard
|
||||||
|
// is less than or equal to the number of bytes buffered.
|
||||||
|
_, _ = c.br.Discard(len(p))
|
||||||
return p, err
|
return p, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -385,7 +387,9 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.conn.SetWriteDeadline(deadline)
|
if err := c.conn.SetWriteDeadline(deadline); err != nil {
|
||||||
|
return c.writeFatal(err)
|
||||||
|
}
|
||||||
if len(buf1) == 0 {
|
if len(buf1) == 0 {
|
||||||
_, err = c.conn.Write(buf0)
|
_, err = c.conn.Write(buf0)
|
||||||
} else {
|
} else {
|
||||||
|
@ -395,7 +399,7 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error
|
||||||
return c.writeFatal(err)
|
return c.writeFatal(err)
|
||||||
}
|
}
|
||||||
if frameType == CloseMessage {
|
if frameType == CloseMessage {
|
||||||
c.writeFatal(ErrCloseSent)
|
_ = c.writeFatal(ErrCloseSent)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -458,13 +462,14 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.conn.SetWriteDeadline(deadline)
|
if err := c.conn.SetWriteDeadline(deadline); err != nil {
|
||||||
_, err = c.conn.Write(buf)
|
return c.writeFatal(err)
|
||||||
if err != nil {
|
}
|
||||||
|
if _, err = c.conn.Write(buf); err != nil {
|
||||||
return c.writeFatal(err)
|
return c.writeFatal(err)
|
||||||
}
|
}
|
||||||
if messageType == CloseMessage {
|
if messageType == CloseMessage {
|
||||||
c.writeFatal(ErrCloseSent)
|
_ = c.writeFatal(ErrCloseSent)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -628,7 +633,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if final {
|
if final {
|
||||||
w.endMessage(errWriteClosed)
|
_ = w.endMessage(errWriteClosed)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -815,7 +820,7 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
rsv2 := p[0]&rsv2Bit != 0
|
rsv2 := p[0]&rsv2Bit != 0
|
||||||
rsv3 := p[0]&rsv3Bit != 0
|
rsv3 := p[0]&rsv3Bit != 0
|
||||||
mask := p[1]&maskBit != 0
|
mask := p[1]&maskBit != 0
|
||||||
c.setReadRemaining(int64(p[1] & 0x7f))
|
_ = c.setReadRemaining(int64(p[1] & 0x7f)) // will not fail because argument is >= 0
|
||||||
|
|
||||||
c.readDecompress = false
|
c.readDecompress = false
|
||||||
if rsv1 {
|
if rsv1 {
|
||||||
|
@ -920,7 +925,8 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.readLimit > 0 && c.readLength > c.readLimit {
|
if c.readLimit > 0 && c.readLength > c.readLimit {
|
||||||
c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
|
// Make a best effort to send a close message describing the problem.
|
||||||
|
_ = c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
|
||||||
return noFrame, ErrReadLimit
|
return noFrame, ErrReadLimit
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -932,7 +938,7 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
var payload []byte
|
var payload []byte
|
||||||
if c.readRemaining > 0 {
|
if c.readRemaining > 0 {
|
||||||
payload, err = c.read(int(c.readRemaining))
|
payload, err = c.read(int(c.readRemaining))
|
||||||
c.setReadRemaining(0)
|
_ = c.setReadRemaining(0) // will not fail because argument is >= 0
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return noFrame, err
|
return noFrame, err
|
||||||
}
|
}
|
||||||
|
@ -979,7 +985,8 @@ func (c *Conn) handleProtocolError(message string) error {
|
||||||
if len(data) > maxControlFramePayloadSize {
|
if len(data) > maxControlFramePayloadSize {
|
||||||
data = data[:maxControlFramePayloadSize]
|
data = data[:maxControlFramePayloadSize]
|
||||||
}
|
}
|
||||||
c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
|
// Make a best effor to send a close message describing the problem.
|
||||||
|
_ = c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
|
||||||
return errors.New("websocket: " + message)
|
return errors.New("websocket: " + message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1052,7 +1059,7 @@ func (r *messageReader) Read(b []byte) (int, error) {
|
||||||
}
|
}
|
||||||
rem := c.readRemaining
|
rem := c.readRemaining
|
||||||
rem -= int64(n)
|
rem -= int64(n)
|
||||||
c.setReadRemaining(rem)
|
_ = c.setReadRemaining(rem) // rem is guaranteed to be >= 0
|
||||||
if c.readRemaining > 0 && c.readErr == io.EOF {
|
if c.readRemaining > 0 && c.readErr == io.EOF {
|
||||||
c.readErr = errUnexpectedEOF
|
c.readErr = errUnexpectedEOF
|
||||||
}
|
}
|
||||||
|
@ -1134,7 +1141,8 @@ func (c *Conn) SetCloseHandler(h func(code int, text string) error) {
|
||||||
if h == nil {
|
if h == nil {
|
||||||
h = func(code int, text string) error {
|
h = func(code int, text string) error {
|
||||||
message := FormatCloseMessage(code, "")
|
message := FormatCloseMessage(code, "")
|
||||||
c.WriteControl(CloseMessage, message, time.Now().Add(writeWait))
|
// Make a best effor to send the close message.
|
||||||
|
_ = c.WriteControl(CloseMessage, message, time.Now().Add(writeWait))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,9 +69,9 @@ func (b *broadcastBench) makeConns(numConns int) {
|
||||||
select {
|
select {
|
||||||
case msg := <-c.msgCh:
|
case msg := <-c.msgCh:
|
||||||
if msg.prepared != nil {
|
if msg.prepared != nil {
|
||||||
c.conn.WritePreparedMessage(msg.prepared)
|
_ = c.conn.WritePreparedMessage(msg.prepared)
|
||||||
} else {
|
} else {
|
||||||
c.conn.WriteMessage(TextMessage, msg.payload)
|
_ = c.conn.WriteMessage(TextMessage, msg.payload)
|
||||||
}
|
}
|
||||||
val := atomic.AddInt32(&b.count, 1)
|
val := atomic.AddInt32(&b.count, 1)
|
||||||
if val%int32(numConns) == 0 {
|
if val%int32(numConns) == 0 {
|
||||||
|
|
32
conn_test.go
32
conn_test.go
|
@ -157,7 +157,7 @@ func TestControl(t *testing.T) {
|
||||||
wc := newTestConn(nil, &connBuf, isServer)
|
wc := newTestConn(nil, &connBuf, isServer)
|
||||||
rc := newTestConn(&connBuf, nil, !isServer)
|
rc := newTestConn(&connBuf, nil, !isServer)
|
||||||
if isWriteControl {
|
if isWriteControl {
|
||||||
wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
|
_ = wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
|
||||||
} else {
|
} else {
|
||||||
w, err := wc.NextWriter(PongMessage)
|
w, err := wc.NextWriter(PongMessage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -174,7 +174,7 @@ func TestControl(t *testing.T) {
|
||||||
}
|
}
|
||||||
var actualMessage string
|
var actualMessage string
|
||||||
rc.SetPongHandler(func(s string) error { actualMessage = s; return nil })
|
rc.SetPongHandler(func(s string) error { actualMessage = s; return nil })
|
||||||
rc.NextReader()
|
_, _, _ = rc.NextReader()
|
||||||
if actualMessage != message {
|
if actualMessage != message {
|
||||||
t.Errorf("%s: pong=%q, want %q", name, actualMessage, message)
|
t.Errorf("%s: pong=%q, want %q", name, actualMessage, message)
|
||||||
continue
|
continue
|
||||||
|
@ -358,8 +358,8 @@ func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
|
||||||
rc := newTestConn(&b1, &b2, true)
|
rc := newTestConn(&b1, &b2, true)
|
||||||
|
|
||||||
w, _ := wc.NextWriter(BinaryMessage)
|
w, _ := wc.NextWriter(BinaryMessage)
|
||||||
w.Write(make([]byte, bufSize+bufSize/2))
|
_, _ = w.Write(make([]byte, bufSize+bufSize/2))
|
||||||
wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second))
|
_ = wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second))
|
||||||
w.Close()
|
w.Close()
|
||||||
|
|
||||||
op, r, err := rc.NextReader()
|
op, r, err := rc.NextReader()
|
||||||
|
@ -385,7 +385,7 @@ func TestEOFWithinFrame(t *testing.T) {
|
||||||
rc := newTestConn(&b, nil, true)
|
rc := newTestConn(&b, nil, true)
|
||||||
|
|
||||||
w, _ := wc.NextWriter(BinaryMessage)
|
w, _ := wc.NextWriter(BinaryMessage)
|
||||||
w.Write(make([]byte, bufSize))
|
_, _ = w.Write(make([]byte, bufSize))
|
||||||
w.Close()
|
w.Close()
|
||||||
|
|
||||||
if n >= b.Len() {
|
if n >= b.Len() {
|
||||||
|
@ -419,7 +419,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
|
||||||
rc := newTestConn(&b1, &b2, true)
|
rc := newTestConn(&b1, &b2, true)
|
||||||
|
|
||||||
w, _ := wc.NextWriter(BinaryMessage)
|
w, _ := wc.NextWriter(BinaryMessage)
|
||||||
w.Write(make([]byte, bufSize+bufSize/2))
|
_, _ = w.Write(make([]byte, bufSize+bufSize/2))
|
||||||
|
|
||||||
op, r, err := rc.NextReader()
|
op, r, err := rc.NextReader()
|
||||||
if op != BinaryMessage || err != nil {
|
if op != BinaryMessage || err != nil {
|
||||||
|
@ -438,7 +438,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
|
||||||
func TestWriteAfterMessageWriterClose(t *testing.T) {
|
func TestWriteAfterMessageWriterClose(t *testing.T) {
|
||||||
wc := newTestConn(nil, &bytes.Buffer{}, false)
|
wc := newTestConn(nil, &bytes.Buffer{}, false)
|
||||||
w, _ := wc.NextWriter(BinaryMessage)
|
w, _ := wc.NextWriter(BinaryMessage)
|
||||||
io.WriteString(w, "hello")
|
_, _ = io.WriteString(w, "hello")
|
||||||
if err := w.Close(); err != nil {
|
if err := w.Close(); err != nil {
|
||||||
t.Fatalf("unxpected error closing message writer, %v", err)
|
t.Fatalf("unxpected error closing message writer, %v", err)
|
||||||
}
|
}
|
||||||
|
@ -448,7 +448,7 @@ func TestWriteAfterMessageWriterClose(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
w, _ = wc.NextWriter(BinaryMessage)
|
w, _ = wc.NextWriter(BinaryMessage)
|
||||||
io.WriteString(w, "hello")
|
_, _ = io.WriteString(w, "hello")
|
||||||
|
|
||||||
// close w by getting next writer
|
// close w by getting next writer
|
||||||
_, err := wc.NextWriter(BinaryMessage)
|
_, err := wc.NextWriter(BinaryMessage)
|
||||||
|
@ -473,13 +473,13 @@ func TestReadLimit(t *testing.T) {
|
||||||
|
|
||||||
// Send message at the limit with interleaved pong.
|
// Send message at the limit with interleaved pong.
|
||||||
w, _ := wc.NextWriter(BinaryMessage)
|
w, _ := wc.NextWriter(BinaryMessage)
|
||||||
w.Write(message[:readLimit-1])
|
_, _ = w.Write(message[:readLimit-1])
|
||||||
wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
|
_ = wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
|
||||||
w.Write(message[:1])
|
_, _ = w.Write(message[:1])
|
||||||
w.Close()
|
w.Close()
|
||||||
|
|
||||||
// Send message larger than the limit.
|
// Send message larger than the limit.
|
||||||
wc.WriteMessage(BinaryMessage, message[:readLimit+1])
|
_ = wc.WriteMessage(BinaryMessage, message[:readLimit+1])
|
||||||
|
|
||||||
op, _, err := rc.NextReader()
|
op, _, err := rc.NextReader()
|
||||||
if op != BinaryMessage || err != nil {
|
if op != BinaryMessage || err != nil {
|
||||||
|
@ -592,7 +592,7 @@ func TestBufioReadBytes(t *testing.T) {
|
||||||
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil)
|
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil)
|
||||||
|
|
||||||
w, _ := wc.NextWriter(BinaryMessage)
|
w, _ := wc.NextWriter(BinaryMessage)
|
||||||
w.Write(m)
|
_, _ = w.Write(m)
|
||||||
w.Close()
|
w.Close()
|
||||||
|
|
||||||
op, r, err := rc.NextReader()
|
op, r, err := rc.NextReader()
|
||||||
|
@ -666,7 +666,7 @@ func TestConcurrentWritePanic(t *testing.T) {
|
||||||
w := blockingWriter{make(chan struct{}), make(chan struct{})}
|
w := blockingWriter{make(chan struct{}), make(chan struct{})}
|
||||||
c := newTestConn(nil, w, false)
|
c := newTestConn(nil, w, false)
|
||||||
go func() {
|
go func() {
|
||||||
c.WriteMessage(TextMessage, []byte{})
|
_ = c.WriteMessage(TextMessage, []byte{})
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// wait for goroutine to block in write.
|
// wait for goroutine to block in write.
|
||||||
|
@ -679,7 +679,7 @@ func TestConcurrentWritePanic(t *testing.T) {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
c.WriteMessage(TextMessage, []byte{})
|
_ = c.WriteMessage(TextMessage, []byte{})
|
||||||
t.Fatal("should not get here")
|
t.Fatal("should not get here")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -699,7 +699,7 @@ func TestFailedConnectionReadPanic(t *testing.T) {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for i := 0; i < 20000; i++ {
|
for i := 0; i < 20000; i++ {
|
||||||
c.ReadMessage()
|
_, _, _ = c.ReadMessage()
|
||||||
}
|
}
|
||||||
t.Fatal("should not get here")
|
t.Fatal("should not get here")
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,7 +19,7 @@ func TestJoinMessages(t *testing.T) {
|
||||||
wc := newTestConn(nil, &connBuf, true)
|
wc := newTestConn(nil, &connBuf, true)
|
||||||
rc := newTestConn(&connBuf, nil, false)
|
rc := newTestConn(&connBuf, nil, false)
|
||||||
for _, m := range messages {
|
for _, m := range messages {
|
||||||
wc.WriteMessage(BinaryMessage, []byte(m))
|
_ = wc.WriteMessage(BinaryMessage, []byte(m))
|
||||||
}
|
}
|
||||||
|
|
||||||
var result bytes.Buffer
|
var result bytes.Buffer
|
||||||
|
|
|
@ -45,7 +45,9 @@ func TestPreparedMessage(t *testing.T) {
|
||||||
if tt.enableWriteCompression {
|
if tt.enableWriteCompression {
|
||||||
c.newCompressionWriter = compressNoContextTakeover
|
c.newCompressionWriter = compressNoContextTakeover
|
||||||
}
|
}
|
||||||
c.SetCompressionLevel(tt.compressionLevel)
|
if err := c.SetCompressionLevel(tt.compressionLevel); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
// Seed random number generator for consistent frame mask.
|
// Seed random number generator for consistent frame mask.
|
||||||
testRand.Seed(1234)
|
testRand.Seed(1234)
|
||||||
|
|
34
server.go
34
server.go
|
@ -178,6 +178,18 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
||||||
"websocket: hijack: "+err.Error())
|
"websocket: hijack: "+err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close the network connection when returning an error. The variable
|
||||||
|
// netConn is set to nil before the success return at the end of the
|
||||||
|
// function.
|
||||||
|
defer func() {
|
||||||
|
if netConn != nil {
|
||||||
|
// It's safe to ignore the error from Close() because this code is
|
||||||
|
// only executed when returning a more important error to the
|
||||||
|
// application.
|
||||||
|
_ = netConn.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
var br *bufio.Reader
|
var br *bufio.Reader
|
||||||
if u.ReadBufferSize == 0 && brw.Reader.Size() > 256 {
|
if u.ReadBufferSize == 0 && brw.Reader.Size() > 256 {
|
||||||
// Use hijacked buffered reader as the connection reader.
|
// Use hijacked buffered reader as the connection reader.
|
||||||
|
@ -244,20 +256,30 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
||||||
}
|
}
|
||||||
p = append(p, "\r\n"...)
|
p = append(p, "\r\n"...)
|
||||||
|
|
||||||
// Clear deadlines set by HTTP server.
|
|
||||||
netConn.SetDeadline(time.Time{})
|
|
||||||
|
|
||||||
if u.HandshakeTimeout > 0 {
|
if u.HandshakeTimeout > 0 {
|
||||||
netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout))
|
if err := netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Clear deadlines set by HTTP server.
|
||||||
|
if err := netConn.SetDeadline(time.Time{}); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err = netConn.Write(p); err != nil {
|
if _, err = netConn.Write(p); err != nil {
|
||||||
netConn.Close()
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if u.HandshakeTimeout > 0 {
|
if u.HandshakeTimeout > 0 {
|
||||||
netConn.SetWriteDeadline(time.Time{})
|
if err := netConn.SetWriteDeadline(time.Time{}); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Success! Set netConn to nil to stop the deferred function above from
|
||||||
|
// closing the network connection.
|
||||||
|
netConn = nil
|
||||||
|
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue