mirror of https://github.com/gorilla/websocket.git
Compare commits
3 Commits
5b1a3e6315
...
371a9fdbb8
Author | SHA1 | Date |
---|---|---|
![]() |
371a9fdbb8 | |
![]() |
1d5465562b | |
![]() |
d40797b837 |
62
client.go
62
client.go
|
@ -52,7 +52,7 @@ func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufS
|
||||||
// It is safe to call Dialer's methods concurrently.
|
// It is safe to call Dialer's methods concurrently.
|
||||||
type Dialer struct {
|
type Dialer struct {
|
||||||
// NetDial specifies the dial function for creating TCP connections. If
|
// NetDial specifies the dial function for creating TCP connections. If
|
||||||
// NetDial is nil, net.Dial is used.
|
// NetDial is nil, net.Dialer DialContext is used.
|
||||||
NetDial func(network, addr string) (net.Conn, error)
|
NetDial func(network, addr string) (net.Conn, error)
|
||||||
|
|
||||||
// NetDialContext specifies the dial function for creating TCP connections. If
|
// NetDialContext specifies the dial function for creating TCP connections. If
|
||||||
|
@ -244,46 +244,25 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
defer cancel()
|
defer cancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get network dial function.
|
var netDial netDialerFunc
|
||||||
var netDial func(network, add string) (net.Conn, error)
|
switch {
|
||||||
|
case u.Scheme == "https" && d.NetDialTLSContext != nil:
|
||||||
switch u.Scheme {
|
netDial = d.NetDialTLSContext
|
||||||
case "http":
|
case d.NetDialContext != nil:
|
||||||
if d.NetDialContext != nil {
|
netDial = d.NetDialContext
|
||||||
netDial = func(network, addr string) (net.Conn, error) {
|
case d.NetDial != nil:
|
||||||
return d.NetDialContext(ctx, network, addr)
|
netDial = func(ctx context.Context, net, addr string) (net.Conn, error) {
|
||||||
}
|
return d.NetDial(net, addr)
|
||||||
} else if d.NetDial != nil {
|
|
||||||
netDial = d.NetDial
|
|
||||||
}
|
|
||||||
case "https":
|
|
||||||
if d.NetDialTLSContext != nil {
|
|
||||||
netDial = func(network, addr string) (net.Conn, error) {
|
|
||||||
return d.NetDialTLSContext(ctx, network, addr)
|
|
||||||
}
|
|
||||||
} else if d.NetDialContext != nil {
|
|
||||||
netDial = func(network, addr string) (net.Conn, error) {
|
|
||||||
return d.NetDialContext(ctx, network, addr)
|
|
||||||
}
|
|
||||||
} else if d.NetDial != nil {
|
|
||||||
netDial = d.NetDial
|
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return nil, nil, errMalformedURL
|
netDial = (&net.Dialer{}).DialContext
|
||||||
}
|
|
||||||
|
|
||||||
if netDial == nil {
|
|
||||||
netDialer := &net.Dialer{}
|
|
||||||
netDial = func(network, addr string) (net.Conn, error) {
|
|
||||||
return netDialer.DialContext(ctx, network, addr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If needed, wrap the dial function to set the connection deadline.
|
// If needed, wrap the dial function to set the connection deadline.
|
||||||
if deadline, ok := ctx.Deadline(); ok {
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
forwardDial := netDial
|
forwardDial := netDial
|
||||||
netDial = func(network, addr string) (net.Conn, error) {
|
netDial = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
c, err := forwardDial(network, addr)
|
c, err := forwardDial(ctx, network, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -303,11 +282,10 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
if proxyURL != nil {
|
if proxyURL != nil {
|
||||||
dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial))
|
netDial, err = proxyFromURL(proxyURL, netDial)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
netDial = dialer.Dial
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -317,7 +295,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
trace.GetConn(hostPort)
|
trace.GetConn(hostPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
netConn, err := netDial("tcp", hostPort)
|
netConn, err := netDial(ctx, "tcp", hostPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -420,8 +398,14 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
|
resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
|
||||||
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
|
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
|
||||||
|
|
||||||
netConn.SetDeadline(time.Time{})
|
if err := netConn.SetDeadline(time.Time{}); err != nil {
|
||||||
netConn = nil // to avoid close in defer.
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set netConn to nil to avoid call to netConn.Close() in
|
||||||
|
// deferred function call.
|
||||||
|
netConn = nil
|
||||||
|
|
||||||
return conn, resp, nil
|
return conn, resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -546,7 +546,7 @@ func TestRespOnBadHandshake(t *testing.T) {
|
||||||
|
|
||||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(expectedStatus)
|
w.WriteHeader(expectedStatus)
|
||||||
io.WriteString(w, expectedBody)
|
_, _ = io.WriteString(w, expectedBody)
|
||||||
}))
|
}))
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
|
@ -796,7 +796,7 @@ func TestSocksProxyDial(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer c1.Close()
|
defer c1.Close()
|
||||||
|
|
||||||
c1.SetDeadline(time.Now().Add(30 * time.Second))
|
_ = c1.SetDeadline(time.Now().Add(30 * time.Second))
|
||||||
|
|
||||||
buf := make([]byte, 32)
|
buf := make([]byte, 32)
|
||||||
if _, err := io.ReadFull(c1, buf[:3]); err != nil {
|
if _, err := io.ReadFull(c1, buf[:3]); err != nil {
|
||||||
|
@ -835,10 +835,10 @@ func TestSocksProxyDial(t *testing.T) {
|
||||||
defer c2.Close()
|
defer c2.Close()
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
io.Copy(c1, c2)
|
_, _ = io.Copy(c1, c2)
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
io.Copy(c2, c1)
|
_, _ = io.Copy(c2, c1)
|
||||||
<-done
|
<-done
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,11 @@ func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
|
||||||
"\x01\x00\x00\xff\xff"
|
"\x01\x00\x00\xff\xff"
|
||||||
|
|
||||||
fr, _ := flateReaderPool.Get().(io.ReadCloser)
|
fr, _ := flateReaderPool.Get().(io.ReadCloser)
|
||||||
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
|
mr := io.MultiReader(r, strings.NewReader(tail))
|
||||||
|
if err := fr.(flate.Resetter).Reset(mr, nil); err != nil {
|
||||||
|
// Reset never fails, but handle error in case that changes.
|
||||||
|
fr = flate.NewReader(mr)
|
||||||
|
}
|
||||||
return &flateReadWrapper{fr}
|
return &flateReadWrapper{fr}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ func TestTruncWriter(t *testing.T) {
|
||||||
if m > n {
|
if m > n {
|
||||||
m = n
|
m = n
|
||||||
}
|
}
|
||||||
w.Write(p[:m])
|
_, _ = w.Write(p[:m])
|
||||||
p = p[m:]
|
p = p[m:]
|
||||||
}
|
}
|
||||||
if b.String() != data[:len(data)-len(w.p)] {
|
if b.String() != data[:len(data)-len(w.p)] {
|
||||||
|
@ -46,7 +46,7 @@ func BenchmarkWriteNoCompression(b *testing.B) {
|
||||||
messages := textMessages(100)
|
messages := textMessages(100)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
c.WriteMessage(TextMessage, messages[i%len(messages)])
|
_ = c.WriteMessage(TextMessage, messages[i%len(messages)])
|
||||||
}
|
}
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
}
|
}
|
||||||
|
@ -59,7 +59,7 @@ func BenchmarkWriteWithCompression(b *testing.B) {
|
||||||
c.newCompressionWriter = compressNoContextTakeover
|
c.newCompressionWriter = compressNoContextTakeover
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
c.WriteMessage(TextMessage, messages[i%len(messages)])
|
_ = c.WriteMessage(TextMessage, messages[i%len(messages)])
|
||||||
}
|
}
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
}
|
}
|
||||||
|
|
34
conn.go
34
conn.go
|
@ -371,7 +371,7 @@ func (c *Conn) read(n int) ([]byte, error) {
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
err = errUnexpectedEOF
|
err = errUnexpectedEOF
|
||||||
}
|
}
|
||||||
c.br.Discard(len(p))
|
_, _ = c.br.Discard(len(p)) // guaranteed to succeed
|
||||||
return p, err
|
return p, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -386,7 +386,9 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.conn.SetWriteDeadline(deadline)
|
if err := c.conn.SetWriteDeadline(deadline); err != nil {
|
||||||
|
return c.writeFatal(err)
|
||||||
|
}
|
||||||
if len(buf1) == 0 {
|
if len(buf1) == 0 {
|
||||||
_, err = c.conn.Write(buf0)
|
_, err = c.conn.Write(buf0)
|
||||||
} else {
|
} else {
|
||||||
|
@ -396,7 +398,7 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error
|
||||||
return c.writeFatal(err)
|
return c.writeFatal(err)
|
||||||
}
|
}
|
||||||
if frameType == CloseMessage {
|
if frameType == CloseMessage {
|
||||||
c.writeFatal(ErrCloseSent)
|
_ = c.writeFatal(ErrCloseSent)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -459,13 +461,14 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.conn.SetWriteDeadline(deadline)
|
if err := c.conn.SetWriteDeadline(deadline); err != nil {
|
||||||
_, err = c.conn.Write(buf)
|
return c.writeFatal(err)
|
||||||
if err != nil {
|
}
|
||||||
|
if _, err = c.conn.Write(buf); err != nil {
|
||||||
return c.writeFatal(err)
|
return c.writeFatal(err)
|
||||||
}
|
}
|
||||||
if messageType == CloseMessage {
|
if messageType == CloseMessage {
|
||||||
c.writeFatal(ErrCloseSent)
|
_ = c.writeFatal(ErrCloseSent)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -629,7 +632,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if final {
|
if final {
|
||||||
w.endMessage(errWriteClosed)
|
_ = w.endMessage(errWriteClosed)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -816,7 +819,7 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
rsv2 := p[0]&rsv2Bit != 0
|
rsv2 := p[0]&rsv2Bit != 0
|
||||||
rsv3 := p[0]&rsv3Bit != 0
|
rsv3 := p[0]&rsv3Bit != 0
|
||||||
mask := p[1]&maskBit != 0
|
mask := p[1]&maskBit != 0
|
||||||
c.setReadRemaining(int64(p[1] & 0x7f))
|
_ = c.setReadRemaining(int64(p[1] & 0x7f)) // will not overflow
|
||||||
|
|
||||||
c.readDecompress = false
|
c.readDecompress = false
|
||||||
if rsv1 {
|
if rsv1 {
|
||||||
|
@ -921,7 +924,8 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.readLimit > 0 && c.readLength > c.readLimit {
|
if c.readLimit > 0 && c.readLength > c.readLimit {
|
||||||
c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
|
// Make a best effort to send a close message describing the problem.
|
||||||
|
_ = c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
|
||||||
return noFrame, ErrReadLimit
|
return noFrame, ErrReadLimit
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -933,7 +937,7 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
var payload []byte
|
var payload []byte
|
||||||
if c.readRemaining > 0 {
|
if c.readRemaining > 0 {
|
||||||
payload, err = c.read(int(c.readRemaining))
|
payload, err = c.read(int(c.readRemaining))
|
||||||
c.setReadRemaining(0)
|
_ = c.setReadRemaining(0) // will not overflow
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return noFrame, err
|
return noFrame, err
|
||||||
}
|
}
|
||||||
|
@ -980,7 +984,8 @@ func (c *Conn) handleProtocolError(message string) error {
|
||||||
if len(data) > maxControlFramePayloadSize {
|
if len(data) > maxControlFramePayloadSize {
|
||||||
data = data[:maxControlFramePayloadSize]
|
data = data[:maxControlFramePayloadSize]
|
||||||
}
|
}
|
||||||
c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
|
// Make a best effor to send a close message describing the problem.
|
||||||
|
_ = c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
|
||||||
return errors.New("websocket: " + message)
|
return errors.New("websocket: " + message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1053,7 +1058,7 @@ func (r *messageReader) Read(b []byte) (int, error) {
|
||||||
}
|
}
|
||||||
rem := c.readRemaining
|
rem := c.readRemaining
|
||||||
rem -= int64(n)
|
rem -= int64(n)
|
||||||
c.setReadRemaining(rem)
|
_ = c.setReadRemaining(rem) // will not overflow
|
||||||
if c.readRemaining > 0 && c.readErr == io.EOF {
|
if c.readRemaining > 0 && c.readErr == io.EOF {
|
||||||
c.readErr = errUnexpectedEOF
|
c.readErr = errUnexpectedEOF
|
||||||
}
|
}
|
||||||
|
@ -1135,7 +1140,8 @@ func (c *Conn) SetCloseHandler(h func(code int, text string) error) {
|
||||||
if h == nil {
|
if h == nil {
|
||||||
h = func(code int, text string) error {
|
h = func(code int, text string) error {
|
||||||
message := FormatCloseMessage(code, "")
|
message := FormatCloseMessage(code, "")
|
||||||
c.WriteControl(CloseMessage, message, time.Now().Add(writeWait))
|
// Make a best effor to send the close message.
|
||||||
|
_ = c.WriteControl(CloseMessage, message, time.Now().Add(writeWait))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,9 +69,9 @@ func (b *broadcastBench) makeConns(numConns int) {
|
||||||
select {
|
select {
|
||||||
case msg := <-c.msgCh:
|
case msg := <-c.msgCh:
|
||||||
if msg.prepared != nil {
|
if msg.prepared != nil {
|
||||||
c.conn.WritePreparedMessage(msg.prepared)
|
_ = c.conn.WritePreparedMessage(msg.prepared)
|
||||||
} else {
|
} else {
|
||||||
c.conn.WriteMessage(TextMessage, msg.payload)
|
_ = c.conn.WriteMessage(TextMessage, msg.payload)
|
||||||
}
|
}
|
||||||
val := atomic.AddInt32(&b.count, 1)
|
val := atomic.AddInt32(&b.count, 1)
|
||||||
if val%int32(numConns) == 0 {
|
if val%int32(numConns) == 0 {
|
||||||
|
|
32
conn_test.go
32
conn_test.go
|
@ -157,7 +157,7 @@ func TestControl(t *testing.T) {
|
||||||
wc := newTestConn(nil, &connBuf, isServer)
|
wc := newTestConn(nil, &connBuf, isServer)
|
||||||
rc := newTestConn(&connBuf, nil, !isServer)
|
rc := newTestConn(&connBuf, nil, !isServer)
|
||||||
if isWriteControl {
|
if isWriteControl {
|
||||||
wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
|
_ = wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
|
||||||
} else {
|
} else {
|
||||||
w, err := wc.NextWriter(PongMessage)
|
w, err := wc.NextWriter(PongMessage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -174,7 +174,7 @@ func TestControl(t *testing.T) {
|
||||||
}
|
}
|
||||||
var actualMessage string
|
var actualMessage string
|
||||||
rc.SetPongHandler(func(s string) error { actualMessage = s; return nil })
|
rc.SetPongHandler(func(s string) error { actualMessage = s; return nil })
|
||||||
rc.NextReader()
|
_, _, _ = rc.NextReader()
|
||||||
if actualMessage != message {
|
if actualMessage != message {
|
||||||
t.Errorf("%s: pong=%q, want %q", name, actualMessage, message)
|
t.Errorf("%s: pong=%q, want %q", name, actualMessage, message)
|
||||||
continue
|
continue
|
||||||
|
@ -358,8 +358,8 @@ func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
|
||||||
rc := newTestConn(&b1, &b2, true)
|
rc := newTestConn(&b1, &b2, true)
|
||||||
|
|
||||||
w, _ := wc.NextWriter(BinaryMessage)
|
w, _ := wc.NextWriter(BinaryMessage)
|
||||||
w.Write(make([]byte, bufSize+bufSize/2))
|
_, _ = w.Write(make([]byte, bufSize+bufSize/2))
|
||||||
wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second))
|
_ = wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second))
|
||||||
w.Close()
|
w.Close()
|
||||||
|
|
||||||
op, r, err := rc.NextReader()
|
op, r, err := rc.NextReader()
|
||||||
|
@ -385,7 +385,7 @@ func TestEOFWithinFrame(t *testing.T) {
|
||||||
rc := newTestConn(&b, nil, true)
|
rc := newTestConn(&b, nil, true)
|
||||||
|
|
||||||
w, _ := wc.NextWriter(BinaryMessage)
|
w, _ := wc.NextWriter(BinaryMessage)
|
||||||
w.Write(make([]byte, bufSize))
|
_, _ = w.Write(make([]byte, bufSize))
|
||||||
w.Close()
|
w.Close()
|
||||||
|
|
||||||
if n >= b.Len() {
|
if n >= b.Len() {
|
||||||
|
@ -419,7 +419,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
|
||||||
rc := newTestConn(&b1, &b2, true)
|
rc := newTestConn(&b1, &b2, true)
|
||||||
|
|
||||||
w, _ := wc.NextWriter(BinaryMessage)
|
w, _ := wc.NextWriter(BinaryMessage)
|
||||||
w.Write(make([]byte, bufSize+bufSize/2))
|
_, _ = w.Write(make([]byte, bufSize+bufSize/2))
|
||||||
|
|
||||||
op, r, err := rc.NextReader()
|
op, r, err := rc.NextReader()
|
||||||
if op != BinaryMessage || err != nil {
|
if op != BinaryMessage || err != nil {
|
||||||
|
@ -438,7 +438,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
|
||||||
func TestWriteAfterMessageWriterClose(t *testing.T) {
|
func TestWriteAfterMessageWriterClose(t *testing.T) {
|
||||||
wc := newTestConn(nil, &bytes.Buffer{}, false)
|
wc := newTestConn(nil, &bytes.Buffer{}, false)
|
||||||
w, _ := wc.NextWriter(BinaryMessage)
|
w, _ := wc.NextWriter(BinaryMessage)
|
||||||
io.WriteString(w, "hello")
|
_, _ = io.WriteString(w, "hello")
|
||||||
if err := w.Close(); err != nil {
|
if err := w.Close(); err != nil {
|
||||||
t.Fatalf("unxpected error closing message writer, %v", err)
|
t.Fatalf("unxpected error closing message writer, %v", err)
|
||||||
}
|
}
|
||||||
|
@ -448,7 +448,7 @@ func TestWriteAfterMessageWriterClose(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
w, _ = wc.NextWriter(BinaryMessage)
|
w, _ = wc.NextWriter(BinaryMessage)
|
||||||
io.WriteString(w, "hello")
|
_, _ = io.WriteString(w, "hello")
|
||||||
|
|
||||||
// close w by getting next writer
|
// close w by getting next writer
|
||||||
_, err := wc.NextWriter(BinaryMessage)
|
_, err := wc.NextWriter(BinaryMessage)
|
||||||
|
@ -473,13 +473,13 @@ func TestReadLimit(t *testing.T) {
|
||||||
|
|
||||||
// Send message at the limit with interleaved pong.
|
// Send message at the limit with interleaved pong.
|
||||||
w, _ := wc.NextWriter(BinaryMessage)
|
w, _ := wc.NextWriter(BinaryMessage)
|
||||||
w.Write(message[:readLimit-1])
|
_, _ = w.Write(message[:readLimit-1])
|
||||||
wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
|
_ = wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
|
||||||
w.Write(message[:1])
|
_, _ = w.Write(message[:1])
|
||||||
w.Close()
|
w.Close()
|
||||||
|
|
||||||
// Send message larger than the limit.
|
// Send message larger than the limit.
|
||||||
wc.WriteMessage(BinaryMessage, message[:readLimit+1])
|
_ = wc.WriteMessage(BinaryMessage, message[:readLimit+1])
|
||||||
|
|
||||||
op, _, err := rc.NextReader()
|
op, _, err := rc.NextReader()
|
||||||
if op != BinaryMessage || err != nil {
|
if op != BinaryMessage || err != nil {
|
||||||
|
@ -592,7 +592,7 @@ func TestBufioReadBytes(t *testing.T) {
|
||||||
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil)
|
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil)
|
||||||
|
|
||||||
w, _ := wc.NextWriter(BinaryMessage)
|
w, _ := wc.NextWriter(BinaryMessage)
|
||||||
w.Write(m)
|
_, _ = w.Write(m)
|
||||||
w.Close()
|
w.Close()
|
||||||
|
|
||||||
op, r, err := rc.NextReader()
|
op, r, err := rc.NextReader()
|
||||||
|
@ -666,7 +666,7 @@ func TestConcurrentWritePanic(t *testing.T) {
|
||||||
w := blockingWriter{make(chan struct{}), make(chan struct{})}
|
w := blockingWriter{make(chan struct{}), make(chan struct{})}
|
||||||
c := newTestConn(nil, w, false)
|
c := newTestConn(nil, w, false)
|
||||||
go func() {
|
go func() {
|
||||||
c.WriteMessage(TextMessage, []byte{})
|
_ = c.WriteMessage(TextMessage, []byte{})
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// wait for goroutine to block in write.
|
// wait for goroutine to block in write.
|
||||||
|
@ -679,7 +679,7 @@ func TestConcurrentWritePanic(t *testing.T) {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
c.WriteMessage(TextMessage, []byte{})
|
_ = c.WriteMessage(TextMessage, []byte{})
|
||||||
t.Fatal("should not get here")
|
t.Fatal("should not get here")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -699,7 +699,7 @@ func TestFailedConnectionReadPanic(t *testing.T) {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for i := 0; i < 20000; i++ {
|
for i := 0; i < 20000; i++ {
|
||||||
c.ReadMessage()
|
_, _, _ = c.ReadMessage()
|
||||||
}
|
}
|
||||||
t.Fatal("should not get here")
|
t.Fatal("should not get here")
|
||||||
}
|
}
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -5,3 +5,5 @@ go 1.20
|
||||||
retract (
|
retract (
|
||||||
v1.5.2 // tag accidentally overwritten
|
v1.5.2 // tag accidentally overwritten
|
||||||
)
|
)
|
||||||
|
|
||||||
|
require golang.org/x/net v0.26.0
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -0,0 +1,2 @@
|
||||||
|
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
|
||||||
|
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
|
|
@ -19,7 +19,7 @@ func TestJoinMessages(t *testing.T) {
|
||||||
wc := newTestConn(nil, &connBuf, true)
|
wc := newTestConn(nil, &connBuf, true)
|
||||||
rc := newTestConn(&connBuf, nil, false)
|
rc := newTestConn(&connBuf, nil, false)
|
||||||
for _, m := range messages {
|
for _, m := range messages {
|
||||||
wc.WriteMessage(BinaryMessage, []byte(m))
|
_ = wc.WriteMessage(BinaryMessage, []byte(m))
|
||||||
}
|
}
|
||||||
|
|
||||||
var result bytes.Buffer
|
var result bytes.Buffer
|
||||||
|
|
|
@ -40,7 +40,9 @@ func TestPreparedMessage(t *testing.T) {
|
||||||
if tt.enableWriteCompression {
|
if tt.enableWriteCompression {
|
||||||
c.newCompressionWriter = compressNoContextTakeover
|
c.newCompressionWriter = compressNoContextTakeover
|
||||||
}
|
}
|
||||||
c.SetCompressionLevel(tt.compressionLevel)
|
if err := c.SetCompressionLevel(tt.compressionLevel); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
// Seed random number generator for consistent frame mask.
|
// Seed random number generator for consistent frame mask.
|
||||||
rand.Seed(1234)
|
rand.Seed(1234)
|
||||||
|
|
35
proxy.go
35
proxy.go
|
@ -7,34 +7,51 @@ package websocket
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/net/proxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
type netDialerFunc func(network, addr string) (net.Conn, error)
|
type netDialerFunc func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||||
|
|
||||||
func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
|
func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
|
||||||
return fn(network, addr)
|
return fn(context.Background(), network, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func (fn netDialerFunc) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
|
return fn(ctx, network, addr)
|
||||||
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil
|
}
|
||||||
})
|
|
||||||
|
func proxyFromURL(proxyURL *url.URL, forwardDial netDialerFunc) (netDialerFunc, error) {
|
||||||
|
if proxyURL.Scheme == "http" {
|
||||||
|
return (&httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDial}).DialContext, nil
|
||||||
|
}
|
||||||
|
dialer, err := proxy.FromURL(proxyURL, forwardDial)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if d, ok := dialer.(proxy.ContextDialer); ok {
|
||||||
|
return d.DialContext, nil
|
||||||
|
}
|
||||||
|
return func(ctx context.Context, net, addr string) (net.Conn, error) {
|
||||||
|
return dialer.Dial(net, addr)
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type httpProxyDialer struct {
|
type httpProxyDialer struct {
|
||||||
proxyURL *url.URL
|
proxyURL *url.URL
|
||||||
forwardDial func(network, addr string) (net.Conn, error)
|
forwardDial netDialerFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) {
|
func (hpd *httpProxyDialer) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
|
||||||
hostPort, _ := hostPortNoPort(hpd.proxyURL)
|
hostPort, _ := hostPortNoPort(hpd.proxyURL)
|
||||||
conn, err := hpd.forwardDial(network, hostPort)
|
conn, err := hpd.forwardDial(ctx, network, hostPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
34
server.go
34
server.go
|
@ -178,8 +178,16 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
||||||
"websocket: hijack: "+err.Error())
|
"websocket: hijack: "+err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if netConn != nil {
|
||||||
|
// It's safe to ignore the error from Close() because this code is
|
||||||
|
// only executed when returning a more important to the
|
||||||
|
// application.
|
||||||
|
_ = netConn.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
if brw.Reader.Buffered() > 0 {
|
if brw.Reader.Buffered() > 0 {
|
||||||
netConn.Close()
|
|
||||||
return nil, errors.New("websocket: client sent data before handshake is complete")
|
return nil, errors.New("websocket: client sent data before handshake is complete")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -243,20 +251,30 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
||||||
}
|
}
|
||||||
p = append(p, "\r\n"...)
|
p = append(p, "\r\n"...)
|
||||||
|
|
||||||
// Clear deadlines set by HTTP server.
|
|
||||||
netConn.SetDeadline(time.Time{})
|
|
||||||
|
|
||||||
if u.HandshakeTimeout > 0 {
|
if u.HandshakeTimeout > 0 {
|
||||||
netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout))
|
if err := netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Clear deadlines set by HTTP server.
|
||||||
|
if err := netConn.SetDeadline(time.Time{}); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err = netConn.Write(p); err != nil {
|
if _, err = netConn.Write(p); err != nil {
|
||||||
netConn.Close()
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if u.HandshakeTimeout > 0 {
|
if u.HandshakeTimeout > 0 {
|
||||||
netConn.SetWriteDeadline(time.Time{})
|
if err := netConn.SetWriteDeadline(time.Time{}); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set netConn to nil to avoid call to netConn.Close() in
|
||||||
|
// deferred function call.
|
||||||
|
netConn = nil
|
||||||
|
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -352,7 +370,7 @@ func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte {
|
||||||
// bufio.Writer's underlying writer.
|
// bufio.Writer's underlying writer.
|
||||||
var wh writeHook
|
var wh writeHook
|
||||||
bw.Reset(&wh)
|
bw.Reset(&wh)
|
||||||
bw.WriteByte(0)
|
_ = bw.WriteByte(0)
|
||||||
bw.Flush()
|
bw.Flush()
|
||||||
|
|
||||||
bw.Reset(originalWriter)
|
bw.Reset(originalWriter)
|
||||||
|
|
473
x_net_proxy.go
473
x_net_proxy.go
|
@ -1,473 +0,0 @@
|
||||||
// Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT.
|
|
||||||
//go:generate bundle -o x_net_proxy.go golang.org/x/net/proxy
|
|
||||||
|
|
||||||
// Package proxy provides support for a variety of protocols to proxy network
|
|
||||||
// data.
|
|
||||||
//
|
|
||||||
|
|
||||||
package websocket
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
type proxy_direct struct{}
|
|
||||||
|
|
||||||
// Direct is a direct proxy: one that makes network connections directly.
|
|
||||||
var proxy_Direct = proxy_direct{}
|
|
||||||
|
|
||||||
func (proxy_direct) Dial(network, addr string) (net.Conn, error) {
|
|
||||||
return net.Dial(network, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// A PerHost directs connections to a default Dialer unless the host name
|
|
||||||
// requested matches one of a number of exceptions.
|
|
||||||
type proxy_PerHost struct {
|
|
||||||
def, bypass proxy_Dialer
|
|
||||||
|
|
||||||
bypassNetworks []*net.IPNet
|
|
||||||
bypassIPs []net.IP
|
|
||||||
bypassZones []string
|
|
||||||
bypassHosts []string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewPerHost returns a PerHost Dialer that directs connections to either
|
|
||||||
// defaultDialer or bypass, depending on whether the connection matches one of
|
|
||||||
// the configured rules.
|
|
||||||
func proxy_NewPerHost(defaultDialer, bypass proxy_Dialer) *proxy_PerHost {
|
|
||||||
return &proxy_PerHost{
|
|
||||||
def: defaultDialer,
|
|
||||||
bypass: bypass,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dial connects to the address addr on the given network through either
|
|
||||||
// defaultDialer or bypass.
|
|
||||||
func (p *proxy_PerHost) Dial(network, addr string) (c net.Conn, err error) {
|
|
||||||
host, _, err := net.SplitHostPort(addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return p.dialerForRequest(host).Dial(network, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *proxy_PerHost) dialerForRequest(host string) proxy_Dialer {
|
|
||||||
if ip := net.ParseIP(host); ip != nil {
|
|
||||||
for _, net := range p.bypassNetworks {
|
|
||||||
if net.Contains(ip) {
|
|
||||||
return p.bypass
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, bypassIP := range p.bypassIPs {
|
|
||||||
if bypassIP.Equal(ip) {
|
|
||||||
return p.bypass
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return p.def
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, zone := range p.bypassZones {
|
|
||||||
if strings.HasSuffix(host, zone) {
|
|
||||||
return p.bypass
|
|
||||||
}
|
|
||||||
if host == zone[1:] {
|
|
||||||
// For a zone ".example.com", we match "example.com"
|
|
||||||
// too.
|
|
||||||
return p.bypass
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, bypassHost := range p.bypassHosts {
|
|
||||||
if bypassHost == host {
|
|
||||||
return p.bypass
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return p.def
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddFromString parses a string that contains comma-separated values
|
|
||||||
// specifying hosts that should use the bypass proxy. Each value is either an
|
|
||||||
// IP address, a CIDR range, a zone (*.example.com) or a host name
|
|
||||||
// (localhost). A best effort is made to parse the string and errors are
|
|
||||||
// ignored.
|
|
||||||
func (p *proxy_PerHost) AddFromString(s string) {
|
|
||||||
hosts := strings.Split(s, ",")
|
|
||||||
for _, host := range hosts {
|
|
||||||
host = strings.TrimSpace(host)
|
|
||||||
if len(host) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if strings.Contains(host, "/") {
|
|
||||||
// We assume that it's a CIDR address like 127.0.0.0/8
|
|
||||||
if _, net, err := net.ParseCIDR(host); err == nil {
|
|
||||||
p.AddNetwork(net)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if ip := net.ParseIP(host); ip != nil {
|
|
||||||
p.AddIP(ip)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(host, "*.") {
|
|
||||||
p.AddZone(host[1:])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
p.AddHost(host)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddIP specifies an IP address that will use the bypass proxy. Note that
|
|
||||||
// this will only take effect if a literal IP address is dialed. A connection
|
|
||||||
// to a named host will never match an IP.
|
|
||||||
func (p *proxy_PerHost) AddIP(ip net.IP) {
|
|
||||||
p.bypassIPs = append(p.bypassIPs, ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddNetwork specifies an IP range that will use the bypass proxy. Note that
|
|
||||||
// this will only take effect if a literal IP address is dialed. A connection
|
|
||||||
// to a named host will never match.
|
|
||||||
func (p *proxy_PerHost) AddNetwork(net *net.IPNet) {
|
|
||||||
p.bypassNetworks = append(p.bypassNetworks, net)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of
|
|
||||||
// "example.com" matches "example.com" and all of its subdomains.
|
|
||||||
func (p *proxy_PerHost) AddZone(zone string) {
|
|
||||||
if strings.HasSuffix(zone, ".") {
|
|
||||||
zone = zone[:len(zone)-1]
|
|
||||||
}
|
|
||||||
if !strings.HasPrefix(zone, ".") {
|
|
||||||
zone = "." + zone
|
|
||||||
}
|
|
||||||
p.bypassZones = append(p.bypassZones, zone)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddHost specifies a host name that will use the bypass proxy.
|
|
||||||
func (p *proxy_PerHost) AddHost(host string) {
|
|
||||||
if strings.HasSuffix(host, ".") {
|
|
||||||
host = host[:len(host)-1]
|
|
||||||
}
|
|
||||||
p.bypassHosts = append(p.bypassHosts, host)
|
|
||||||
}
|
|
||||||
|
|
||||||
// A Dialer is a means to establish a connection.
|
|
||||||
type proxy_Dialer interface {
|
|
||||||
// Dial connects to the given address via the proxy.
|
|
||||||
Dial(network, addr string) (c net.Conn, err error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Auth contains authentication parameters that specific Dialers may require.
|
|
||||||
type proxy_Auth struct {
|
|
||||||
User, Password string
|
|
||||||
}
|
|
||||||
|
|
||||||
// FromEnvironment returns the dialer specified by the proxy related variables in
|
|
||||||
// the environment.
|
|
||||||
func proxy_FromEnvironment() proxy_Dialer {
|
|
||||||
allProxy := proxy_allProxyEnv.Get()
|
|
||||||
if len(allProxy) == 0 {
|
|
||||||
return proxy_Direct
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyURL, err := url.Parse(allProxy)
|
|
||||||
if err != nil {
|
|
||||||
return proxy_Direct
|
|
||||||
}
|
|
||||||
proxy, err := proxy_FromURL(proxyURL, proxy_Direct)
|
|
||||||
if err != nil {
|
|
||||||
return proxy_Direct
|
|
||||||
}
|
|
||||||
|
|
||||||
noProxy := proxy_noProxyEnv.Get()
|
|
||||||
if len(noProxy) == 0 {
|
|
||||||
return proxy
|
|
||||||
}
|
|
||||||
|
|
||||||
perHost := proxy_NewPerHost(proxy, proxy_Direct)
|
|
||||||
perHost.AddFromString(noProxy)
|
|
||||||
return perHost
|
|
||||||
}
|
|
||||||
|
|
||||||
// proxySchemes is a map from URL schemes to a function that creates a Dialer
|
|
||||||
// from a URL with such a scheme.
|
|
||||||
var proxy_proxySchemes map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error)
|
|
||||||
|
|
||||||
// RegisterDialerType takes a URL scheme and a function to generate Dialers from
|
|
||||||
// a URL with that scheme and a forwarding Dialer. Registered schemes are used
|
|
||||||
// by FromURL.
|
|
||||||
func proxy_RegisterDialerType(scheme string, f func(*url.URL, proxy_Dialer) (proxy_Dialer, error)) {
|
|
||||||
if proxy_proxySchemes == nil {
|
|
||||||
proxy_proxySchemes = make(map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error))
|
|
||||||
}
|
|
||||||
proxy_proxySchemes[scheme] = f
|
|
||||||
}
|
|
||||||
|
|
||||||
// FromURL returns a Dialer given a URL specification and an underlying
|
|
||||||
// Dialer for it to make network requests.
|
|
||||||
func proxy_FromURL(u *url.URL, forward proxy_Dialer) (proxy_Dialer, error) {
|
|
||||||
var auth *proxy_Auth
|
|
||||||
if u.User != nil {
|
|
||||||
auth = new(proxy_Auth)
|
|
||||||
auth.User = u.User.Username()
|
|
||||||
if p, ok := u.User.Password(); ok {
|
|
||||||
auth.Password = p
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch u.Scheme {
|
|
||||||
case "socks5":
|
|
||||||
return proxy_SOCKS5("tcp", u.Host, auth, forward)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the scheme doesn't match any of the built-in schemes, see if it
|
|
||||||
// was registered by another package.
|
|
||||||
if proxy_proxySchemes != nil {
|
|
||||||
if f, ok := proxy_proxySchemes[u.Scheme]; ok {
|
|
||||||
return f(u, forward)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, errors.New("proxy: unknown scheme: " + u.Scheme)
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
proxy_allProxyEnv = &proxy_envOnce{
|
|
||||||
names: []string{"ALL_PROXY", "all_proxy"},
|
|
||||||
}
|
|
||||||
proxy_noProxyEnv = &proxy_envOnce{
|
|
||||||
names: []string{"NO_PROXY", "no_proxy"},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
// envOnce looks up an environment variable (optionally by multiple
|
|
||||||
// names) once. It mitigates expensive lookups on some platforms
|
|
||||||
// (e.g. Windows).
|
|
||||||
// (Borrowed from net/http/transport.go)
|
|
||||||
type proxy_envOnce struct {
|
|
||||||
names []string
|
|
||||||
once sync.Once
|
|
||||||
val string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *proxy_envOnce) Get() string {
|
|
||||||
e.once.Do(e.init)
|
|
||||||
return e.val
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *proxy_envOnce) init() {
|
|
||||||
for _, n := range e.names {
|
|
||||||
e.val = os.Getenv(n)
|
|
||||||
if e.val != "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given address
|
|
||||||
// with an optional username and password. See RFC 1928 and RFC 1929.
|
|
||||||
func proxy_SOCKS5(network, addr string, auth *proxy_Auth, forward proxy_Dialer) (proxy_Dialer, error) {
|
|
||||||
s := &proxy_socks5{
|
|
||||||
network: network,
|
|
||||||
addr: addr,
|
|
||||||
forward: forward,
|
|
||||||
}
|
|
||||||
if auth != nil {
|
|
||||||
s.user = auth.User
|
|
||||||
s.password = auth.Password
|
|
||||||
}
|
|
||||||
|
|
||||||
return s, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type proxy_socks5 struct {
|
|
||||||
user, password string
|
|
||||||
network, addr string
|
|
||||||
forward proxy_Dialer
|
|
||||||
}
|
|
||||||
|
|
||||||
const proxy_socks5Version = 5
|
|
||||||
|
|
||||||
const (
|
|
||||||
proxy_socks5AuthNone = 0
|
|
||||||
proxy_socks5AuthPassword = 2
|
|
||||||
)
|
|
||||||
|
|
||||||
const proxy_socks5Connect = 1
|
|
||||||
|
|
||||||
const (
|
|
||||||
proxy_socks5IP4 = 1
|
|
||||||
proxy_socks5Domain = 3
|
|
||||||
proxy_socks5IP6 = 4
|
|
||||||
)
|
|
||||||
|
|
||||||
var proxy_socks5Errors = []string{
|
|
||||||
"",
|
|
||||||
"general failure",
|
|
||||||
"connection forbidden",
|
|
||||||
"network unreachable",
|
|
||||||
"host unreachable",
|
|
||||||
"connection refused",
|
|
||||||
"TTL expired",
|
|
||||||
"command not supported",
|
|
||||||
"address type not supported",
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dial connects to the address addr on the given network via the SOCKS5 proxy.
|
|
||||||
func (s *proxy_socks5) Dial(network, addr string) (net.Conn, error) {
|
|
||||||
switch network {
|
|
||||||
case "tcp", "tcp6", "tcp4":
|
|
||||||
default:
|
|
||||||
return nil, errors.New("proxy: no support for SOCKS5 proxy connections of type " + network)
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := s.forward.Dial(s.network, s.addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := s.connect(conn, addr); err != nil {
|
|
||||||
conn.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// connect takes an existing connection to a socks5 proxy server,
|
|
||||||
// and commands the server to extend that connection to target,
|
|
||||||
// which must be a canonical address with a host and port.
|
|
||||||
func (s *proxy_socks5) connect(conn net.Conn, target string) error {
|
|
||||||
host, portStr, err := net.SplitHostPort(target)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
port, err := strconv.Atoi(portStr)
|
|
||||||
if err != nil {
|
|
||||||
return errors.New("proxy: failed to parse port number: " + portStr)
|
|
||||||
}
|
|
||||||
if port < 1 || port > 0xffff {
|
|
||||||
return errors.New("proxy: port number out of range: " + portStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// the size here is just an estimate
|
|
||||||
buf := make([]byte, 0, 6+len(host))
|
|
||||||
|
|
||||||
buf = append(buf, proxy_socks5Version)
|
|
||||||
if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 {
|
|
||||||
buf = append(buf, 2 /* num auth methods */, proxy_socks5AuthNone, proxy_socks5AuthPassword)
|
|
||||||
} else {
|
|
||||||
buf = append(buf, 1 /* num auth methods */, proxy_socks5AuthNone)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := conn.Write(buf); err != nil {
|
|
||||||
return errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
|
|
||||||
return errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
|
||||||
}
|
|
||||||
if buf[0] != 5 {
|
|
||||||
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0])))
|
|
||||||
}
|
|
||||||
if buf[1] == 0xff {
|
|
||||||
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication")
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 1929
|
|
||||||
if buf[1] == proxy_socks5AuthPassword {
|
|
||||||
buf = buf[:0]
|
|
||||||
buf = append(buf, 1 /* password protocol version */)
|
|
||||||
buf = append(buf, uint8(len(s.user)))
|
|
||||||
buf = append(buf, s.user...)
|
|
||||||
buf = append(buf, uint8(len(s.password)))
|
|
||||||
buf = append(buf, s.password...)
|
|
||||||
|
|
||||||
if _, err := conn.Write(buf); err != nil {
|
|
||||||
return errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
|
|
||||||
return errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if buf[1] != 0 {
|
|
||||||
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
buf = buf[:0]
|
|
||||||
buf = append(buf, proxy_socks5Version, proxy_socks5Connect, 0 /* reserved */)
|
|
||||||
|
|
||||||
if ip := net.ParseIP(host); ip != nil {
|
|
||||||
if ip4 := ip.To4(); ip4 != nil {
|
|
||||||
buf = append(buf, proxy_socks5IP4)
|
|
||||||
ip = ip4
|
|
||||||
} else {
|
|
||||||
buf = append(buf, proxy_socks5IP6)
|
|
||||||
}
|
|
||||||
buf = append(buf, ip...)
|
|
||||||
} else {
|
|
||||||
if len(host) > 255 {
|
|
||||||
return errors.New("proxy: destination host name too long: " + host)
|
|
||||||
}
|
|
||||||
buf = append(buf, proxy_socks5Domain)
|
|
||||||
buf = append(buf, byte(len(host)))
|
|
||||||
buf = append(buf, host...)
|
|
||||||
}
|
|
||||||
buf = append(buf, byte(port>>8), byte(port))
|
|
||||||
|
|
||||||
if _, err := conn.Write(buf); err != nil {
|
|
||||||
return errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := io.ReadFull(conn, buf[:4]); err != nil {
|
|
||||||
return errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
failure := "unknown error"
|
|
||||||
if int(buf[1]) < len(proxy_socks5Errors) {
|
|
||||||
failure = proxy_socks5Errors[buf[1]]
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(failure) > 0 {
|
|
||||||
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure)
|
|
||||||
}
|
|
||||||
|
|
||||||
bytesToDiscard := 0
|
|
||||||
switch buf[3] {
|
|
||||||
case proxy_socks5IP4:
|
|
||||||
bytesToDiscard = net.IPv4len
|
|
||||||
case proxy_socks5IP6:
|
|
||||||
bytesToDiscard = net.IPv6len
|
|
||||||
case proxy_socks5Domain:
|
|
||||||
_, err := io.ReadFull(conn, buf[:1])
|
|
||||||
if err != nil {
|
|
||||||
return errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
|
||||||
}
|
|
||||||
bytesToDiscard = int(buf[0])
|
|
||||||
default:
|
|
||||||
return errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cap(buf) < bytesToDiscard {
|
|
||||||
buf = make([]byte, bytesToDiscard)
|
|
||||||
} else {
|
|
||||||
buf = buf[:bytesToDiscard]
|
|
||||||
}
|
|
||||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
|
||||||
return errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Also need to discard the port number
|
|
||||||
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
|
|
||||||
return errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
Loading…
Reference in New Issue