Compare commits

...

3 Commits

Author SHA1 Message Date
Canelo Hill 371a9fdbb8
Merge d40797b837 into 1d5465562b 2024-06-19 20:52:16 -06:00
Canelo Hill 1d5465562b Unbundle x/net/proxy and update to recent version
Import golang.org/x/net/proxy instead of using the bundle in
x_net_proxy.go. There's no need to avoid the dependency on
golang.org/x/net/proxy now that Go's module system is in widespread use.

Change Dialer.DialContext to pass contexts as an argument to the dial
function instead of tunneling the context through closures. Tunneling is
no longer needed because the proxy package supports contexts. The
version of the proxy package in the bundle predates contexts!

Simplify the code for calculating the base dial function.

Prevent the HTTP proxy dialer from leaking out of the websocket package
by selecting the HTTP proxy dialer directly in the websocket package.
Previously, the HTTP dialer was registered with the proxy package.
2024-06-19 20:11:25 -04:00
Canelo Hill d40797b837 Handle errcheck warnings
The package ignored errors from net.Conn Set*Deadline in a few places.
Update the package to return these errors to the caller.

Ignore all other errors reported by errcheck. These errors are safe to
ignore because
- The function is making a best effort to cleanup while handling another
  error.
- The function call is guaranteed to succeed.
- The error is ignored in a test.
2024-06-18 15:58:40 -06:00
14 changed files with 133 additions and 571 deletions

View File

@ -52,7 +52,7 @@ func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufS
// It is safe to call Dialer's methods concurrently.
type Dialer struct {
// NetDial specifies the dial function for creating TCP connections. If
// NetDial is nil, net.Dial is used.
// NetDial is nil, net.Dialer DialContext is used.
NetDial func(network, addr string) (net.Conn, error)
// NetDialContext specifies the dial function for creating TCP connections. If
@ -244,46 +244,25 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
defer cancel()
}
// Get network dial function.
var netDial func(network, add string) (net.Conn, error)
switch u.Scheme {
case "http":
if d.NetDialContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialContext(ctx, network, addr)
}
} else if d.NetDial != nil {
netDial = d.NetDial
}
case "https":
if d.NetDialTLSContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialTLSContext(ctx, network, addr)
}
} else if d.NetDialContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialContext(ctx, network, addr)
}
} else if d.NetDial != nil {
netDial = d.NetDial
var netDial netDialerFunc
switch {
case u.Scheme == "https" && d.NetDialTLSContext != nil:
netDial = d.NetDialTLSContext
case d.NetDialContext != nil:
netDial = d.NetDialContext
case d.NetDial != nil:
netDial = func(ctx context.Context, net, addr string) (net.Conn, error) {
return d.NetDial(net, addr)
}
default:
return nil, nil, errMalformedURL
}
if netDial == nil {
netDialer := &net.Dialer{}
netDial = func(network, addr string) (net.Conn, error) {
return netDialer.DialContext(ctx, network, addr)
}
netDial = (&net.Dialer{}).DialContext
}
// If needed, wrap the dial function to set the connection deadline.
if deadline, ok := ctx.Deadline(); ok {
forwardDial := netDial
netDial = func(network, addr string) (net.Conn, error) {
c, err := forwardDial(network, addr)
netDial = func(ctx context.Context, network, addr string) (net.Conn, error) {
c, err := forwardDial(ctx, network, addr)
if err != nil {
return nil, err
}
@ -303,11 +282,10 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
return nil, nil, err
}
if proxyURL != nil {
dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial))
netDial, err = proxyFromURL(proxyURL, netDial)
if err != nil {
return nil, nil, err
}
netDial = dialer.Dial
}
}
@ -317,7 +295,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
trace.GetConn(hostPort)
}
netConn, err := netDial("tcp", hostPort)
netConn, err := netDial(ctx, "tcp", hostPort)
if err != nil {
return nil, nil, err
}
@ -420,8 +398,14 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
netConn.SetDeadline(time.Time{})
netConn = nil // to avoid close in defer.
if err := netConn.SetDeadline(time.Time{}); err != nil {
return nil, nil, err
}
// Set netConn to nil to avoid call to netConn.Close() in
// deferred function call.
netConn = nil
return conn, resp, nil
}

View File

@ -546,7 +546,7 @@ func TestRespOnBadHandshake(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(expectedStatus)
io.WriteString(w, expectedBody)
_, _ = io.WriteString(w, expectedBody)
}))
defer s.Close()
@ -796,7 +796,7 @@ func TestSocksProxyDial(t *testing.T) {
}
defer c1.Close()
c1.SetDeadline(time.Now().Add(30 * time.Second))
_ = c1.SetDeadline(time.Now().Add(30 * time.Second))
buf := make([]byte, 32)
if _, err := io.ReadFull(c1, buf[:3]); err != nil {
@ -835,10 +835,10 @@ func TestSocksProxyDial(t *testing.T) {
defer c2.Close()
done := make(chan struct{})
go func() {
io.Copy(c1, c2)
_, _ = io.Copy(c1, c2)
close(done)
}()
io.Copy(c2, c1)
_, _ = io.Copy(c2, c1)
<-done
}()

View File

@ -33,7 +33,11 @@ func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
"\x01\x00\x00\xff\xff"
fr, _ := flateReaderPool.Get().(io.ReadCloser)
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
mr := io.MultiReader(r, strings.NewReader(tail))
if err := fr.(flate.Resetter).Reset(mr, nil); err != nil {
// Reset never fails, but handle error in case that changes.
fr = flate.NewReader(mr)
}
return &flateReadWrapper{fr}
}

View File

@ -22,7 +22,7 @@ func TestTruncWriter(t *testing.T) {
if m > n {
m = n
}
w.Write(p[:m])
_, _ = w.Write(p[:m])
p = p[m:]
}
if b.String() != data[:len(data)-len(w.p)] {
@ -46,7 +46,7 @@ func BenchmarkWriteNoCompression(b *testing.B) {
messages := textMessages(100)
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.WriteMessage(TextMessage, messages[i%len(messages)])
_ = c.WriteMessage(TextMessage, messages[i%len(messages)])
}
b.ReportAllocs()
}
@ -59,7 +59,7 @@ func BenchmarkWriteWithCompression(b *testing.B) {
c.newCompressionWriter = compressNoContextTakeover
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.WriteMessage(TextMessage, messages[i%len(messages)])
_ = c.WriteMessage(TextMessage, messages[i%len(messages)])
}
b.ReportAllocs()
}

34
conn.go
View File

@ -371,7 +371,7 @@ func (c *Conn) read(n int) ([]byte, error) {
if err == io.EOF {
err = errUnexpectedEOF
}
c.br.Discard(len(p))
_, _ = c.br.Discard(len(p)) // guaranteed to succeed
return p, err
}
@ -386,7 +386,9 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error
return err
}
c.conn.SetWriteDeadline(deadline)
if err := c.conn.SetWriteDeadline(deadline); err != nil {
return c.writeFatal(err)
}
if len(buf1) == 0 {
_, err = c.conn.Write(buf0)
} else {
@ -396,7 +398,7 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error
return c.writeFatal(err)
}
if frameType == CloseMessage {
c.writeFatal(ErrCloseSent)
_ = c.writeFatal(ErrCloseSent)
}
return nil
}
@ -459,13 +461,14 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
return err
}
c.conn.SetWriteDeadline(deadline)
_, err = c.conn.Write(buf)
if err != nil {
if err := c.conn.SetWriteDeadline(deadline); err != nil {
return c.writeFatal(err)
}
if _, err = c.conn.Write(buf); err != nil {
return c.writeFatal(err)
}
if messageType == CloseMessage {
c.writeFatal(ErrCloseSent)
_ = c.writeFatal(ErrCloseSent)
}
return err
}
@ -629,7 +632,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
}
if final {
w.endMessage(errWriteClosed)
_ = w.endMessage(errWriteClosed)
return nil
}
@ -816,7 +819,7 @@ func (c *Conn) advanceFrame() (int, error) {
rsv2 := p[0]&rsv2Bit != 0
rsv3 := p[0]&rsv3Bit != 0
mask := p[1]&maskBit != 0
c.setReadRemaining(int64(p[1] & 0x7f))
_ = c.setReadRemaining(int64(p[1] & 0x7f)) // will not overflow
c.readDecompress = false
if rsv1 {
@ -921,7 +924,8 @@ func (c *Conn) advanceFrame() (int, error) {
}
if c.readLimit > 0 && c.readLength > c.readLimit {
c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
// Make a best effort to send a close message describing the problem.
_ = c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
return noFrame, ErrReadLimit
}
@ -933,7 +937,7 @@ func (c *Conn) advanceFrame() (int, error) {
var payload []byte
if c.readRemaining > 0 {
payload, err = c.read(int(c.readRemaining))
c.setReadRemaining(0)
_ = c.setReadRemaining(0) // will not overflow
if err != nil {
return noFrame, err
}
@ -980,7 +984,8 @@ func (c *Conn) handleProtocolError(message string) error {
if len(data) > maxControlFramePayloadSize {
data = data[:maxControlFramePayloadSize]
}
c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
// Make a best effor to send a close message describing the problem.
_ = c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
return errors.New("websocket: " + message)
}
@ -1053,7 +1058,7 @@ func (r *messageReader) Read(b []byte) (int, error) {
}
rem := c.readRemaining
rem -= int64(n)
c.setReadRemaining(rem)
_ = c.setReadRemaining(rem) // will not overflow
if c.readRemaining > 0 && c.readErr == io.EOF {
c.readErr = errUnexpectedEOF
}
@ -1135,7 +1140,8 @@ func (c *Conn) SetCloseHandler(h func(code int, text string) error) {
if h == nil {
h = func(code int, text string) error {
message := FormatCloseMessage(code, "")
c.WriteControl(CloseMessage, message, time.Now().Add(writeWait))
// Make a best effor to send the close message.
_ = c.WriteControl(CloseMessage, message, time.Now().Add(writeWait))
return nil
}
}

View File

@ -69,9 +69,9 @@ func (b *broadcastBench) makeConns(numConns int) {
select {
case msg := <-c.msgCh:
if msg.prepared != nil {
c.conn.WritePreparedMessage(msg.prepared)
_ = c.conn.WritePreparedMessage(msg.prepared)
} else {
c.conn.WriteMessage(TextMessage, msg.payload)
_ = c.conn.WriteMessage(TextMessage, msg.payload)
}
val := atomic.AddInt32(&b.count, 1)
if val%int32(numConns) == 0 {

View File

@ -157,7 +157,7 @@ func TestControl(t *testing.T) {
wc := newTestConn(nil, &connBuf, isServer)
rc := newTestConn(&connBuf, nil, !isServer)
if isWriteControl {
wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
_ = wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
} else {
w, err := wc.NextWriter(PongMessage)
if err != nil {
@ -174,7 +174,7 @@ func TestControl(t *testing.T) {
}
var actualMessage string
rc.SetPongHandler(func(s string) error { actualMessage = s; return nil })
rc.NextReader()
_, _, _ = rc.NextReader()
if actualMessage != message {
t.Errorf("%s: pong=%q, want %q", name, actualMessage, message)
continue
@ -358,8 +358,8 @@ func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
rc := newTestConn(&b1, &b2, true)
w, _ := wc.NextWriter(BinaryMessage)
w.Write(make([]byte, bufSize+bufSize/2))
wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second))
_, _ = w.Write(make([]byte, bufSize+bufSize/2))
_ = wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second))
w.Close()
op, r, err := rc.NextReader()
@ -385,7 +385,7 @@ func TestEOFWithinFrame(t *testing.T) {
rc := newTestConn(&b, nil, true)
w, _ := wc.NextWriter(BinaryMessage)
w.Write(make([]byte, bufSize))
_, _ = w.Write(make([]byte, bufSize))
w.Close()
if n >= b.Len() {
@ -419,7 +419,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
rc := newTestConn(&b1, &b2, true)
w, _ := wc.NextWriter(BinaryMessage)
w.Write(make([]byte, bufSize+bufSize/2))
_, _ = w.Write(make([]byte, bufSize+bufSize/2))
op, r, err := rc.NextReader()
if op != BinaryMessage || err != nil {
@ -438,7 +438,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
func TestWriteAfterMessageWriterClose(t *testing.T) {
wc := newTestConn(nil, &bytes.Buffer{}, false)
w, _ := wc.NextWriter(BinaryMessage)
io.WriteString(w, "hello")
_, _ = io.WriteString(w, "hello")
if err := w.Close(); err != nil {
t.Fatalf("unxpected error closing message writer, %v", err)
}
@ -448,7 +448,7 @@ func TestWriteAfterMessageWriterClose(t *testing.T) {
}
w, _ = wc.NextWriter(BinaryMessage)
io.WriteString(w, "hello")
_, _ = io.WriteString(w, "hello")
// close w by getting next writer
_, err := wc.NextWriter(BinaryMessage)
@ -473,13 +473,13 @@ func TestReadLimit(t *testing.T) {
// Send message at the limit with interleaved pong.
w, _ := wc.NextWriter(BinaryMessage)
w.Write(message[:readLimit-1])
wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
w.Write(message[:1])
_, _ = w.Write(message[:readLimit-1])
_ = wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
_, _ = w.Write(message[:1])
w.Close()
// Send message larger than the limit.
wc.WriteMessage(BinaryMessage, message[:readLimit+1])
_ = wc.WriteMessage(BinaryMessage, message[:readLimit+1])
op, _, err := rc.NextReader()
if op != BinaryMessage || err != nil {
@ -592,7 +592,7 @@ func TestBufioReadBytes(t *testing.T) {
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil)
w, _ := wc.NextWriter(BinaryMessage)
w.Write(m)
_, _ = w.Write(m)
w.Close()
op, r, err := rc.NextReader()
@ -666,7 +666,7 @@ func TestConcurrentWritePanic(t *testing.T) {
w := blockingWriter{make(chan struct{}), make(chan struct{})}
c := newTestConn(nil, w, false)
go func() {
c.WriteMessage(TextMessage, []byte{})
_ = c.WriteMessage(TextMessage, []byte{})
}()
// wait for goroutine to block in write.
@ -679,7 +679,7 @@ func TestConcurrentWritePanic(t *testing.T) {
}
}()
c.WriteMessage(TextMessage, []byte{})
_ = c.WriteMessage(TextMessage, []byte{})
t.Fatal("should not get here")
}
@ -699,7 +699,7 @@ func TestFailedConnectionReadPanic(t *testing.T) {
}()
for i := 0; i < 20000; i++ {
c.ReadMessage()
_, _, _ = c.ReadMessage()
}
t.Fatal("should not get here")
}

2
go.mod
View File

@ -5,3 +5,5 @@ go 1.20
retract (
v1.5.2 // tag accidentally overwritten
)
require golang.org/x/net v0.26.0

2
go.sum
View File

@ -0,0 +1,2 @@
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=

View File

@ -19,7 +19,7 @@ func TestJoinMessages(t *testing.T) {
wc := newTestConn(nil, &connBuf, true)
rc := newTestConn(&connBuf, nil, false)
for _, m := range messages {
wc.WriteMessage(BinaryMessage, []byte(m))
_ = wc.WriteMessage(BinaryMessage, []byte(m))
}
var result bytes.Buffer

View File

@ -40,7 +40,9 @@ func TestPreparedMessage(t *testing.T) {
if tt.enableWriteCompression {
c.newCompressionWriter = compressNoContextTakeover
}
c.SetCompressionLevel(tt.compressionLevel)
if err := c.SetCompressionLevel(tt.compressionLevel); err != nil {
t.Fatal(err)
}
// Seed random number generator for consistent frame mask.
rand.Seed(1234)

View File

@ -7,34 +7,51 @@ package websocket
import (
"bufio"
"bytes"
"context"
"encoding/base64"
"errors"
"net"
"net/http"
"net/url"
"strings"
"golang.org/x/net/proxy"
)
type netDialerFunc func(network, addr string) (net.Conn, error)
type netDialerFunc func(ctx context.Context, network, addr string) (net.Conn, error)
func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
return fn(network, addr)
return fn(context.Background(), network, addr)
}
func init() {
proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil
})
func (fn netDialerFunc) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
return fn(ctx, network, addr)
}
func proxyFromURL(proxyURL *url.URL, forwardDial netDialerFunc) (netDialerFunc, error) {
if proxyURL.Scheme == "http" {
return (&httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDial}).DialContext, nil
}
dialer, err := proxy.FromURL(proxyURL, forwardDial)
if err != nil {
return nil, err
}
if d, ok := dialer.(proxy.ContextDialer); ok {
return d.DialContext, nil
}
return func(ctx context.Context, net, addr string) (net.Conn, error) {
return dialer.Dial(net, addr)
}, nil
}
type httpProxyDialer struct {
proxyURL *url.URL
forwardDial func(network, addr string) (net.Conn, error)
forwardDial netDialerFunc
}
func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) {
func (hpd *httpProxyDialer) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
hostPort, _ := hostPortNoPort(hpd.proxyURL)
conn, err := hpd.forwardDial(network, hostPort)
conn, err := hpd.forwardDial(ctx, network, hostPort)
if err != nil {
return nil, err
}

View File

@ -178,8 +178,16 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
"websocket: hijack: "+err.Error())
}
defer func() {
if netConn != nil {
// It's safe to ignore the error from Close() because this code is
// only executed when returning a more important to the
// application.
_ = netConn.Close()
}
}()
if brw.Reader.Buffered() > 0 {
netConn.Close()
return nil, errors.New("websocket: client sent data before handshake is complete")
}
@ -243,20 +251,30 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
}
p = append(p, "\r\n"...)
// Clear deadlines set by HTTP server.
netConn.SetDeadline(time.Time{})
if u.HandshakeTimeout > 0 {
netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout))
if err := netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)); err != nil {
return nil, err
}
} else {
// Clear deadlines set by HTTP server.
if err := netConn.SetDeadline(time.Time{}); err != nil {
return nil, err
}
}
if _, err = netConn.Write(p); err != nil {
netConn.Close()
return nil, err
}
if u.HandshakeTimeout > 0 {
netConn.SetWriteDeadline(time.Time{})
if err := netConn.SetWriteDeadline(time.Time{}); err != nil {
return nil, err
}
}
// Set netConn to nil to avoid call to netConn.Close() in
// deferred function call.
netConn = nil
return c, nil
}
@ -352,7 +370,7 @@ func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte {
// bufio.Writer's underlying writer.
var wh writeHook
bw.Reset(&wh)
bw.WriteByte(0)
_ = bw.WriteByte(0)
bw.Flush()
bw.Reset(originalWriter)

View File

@ -1,473 +0,0 @@
// Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT.
//go:generate bundle -o x_net_proxy.go golang.org/x/net/proxy
// Package proxy provides support for a variety of protocols to proxy network
// data.
//
package websocket
import (
"errors"
"io"
"net"
"net/url"
"os"
"strconv"
"strings"
"sync"
)
type proxy_direct struct{}
// Direct is a direct proxy: one that makes network connections directly.
var proxy_Direct = proxy_direct{}
func (proxy_direct) Dial(network, addr string) (net.Conn, error) {
return net.Dial(network, addr)
}
// A PerHost directs connections to a default Dialer unless the host name
// requested matches one of a number of exceptions.
type proxy_PerHost struct {
def, bypass proxy_Dialer
bypassNetworks []*net.IPNet
bypassIPs []net.IP
bypassZones []string
bypassHosts []string
}
// NewPerHost returns a PerHost Dialer that directs connections to either
// defaultDialer or bypass, depending on whether the connection matches one of
// the configured rules.
func proxy_NewPerHost(defaultDialer, bypass proxy_Dialer) *proxy_PerHost {
return &proxy_PerHost{
def: defaultDialer,
bypass: bypass,
}
}
// Dial connects to the address addr on the given network through either
// defaultDialer or bypass.
func (p *proxy_PerHost) Dial(network, addr string) (c net.Conn, err error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
return p.dialerForRequest(host).Dial(network, addr)
}
func (p *proxy_PerHost) dialerForRequest(host string) proxy_Dialer {
if ip := net.ParseIP(host); ip != nil {
for _, net := range p.bypassNetworks {
if net.Contains(ip) {
return p.bypass
}
}
for _, bypassIP := range p.bypassIPs {
if bypassIP.Equal(ip) {
return p.bypass
}
}
return p.def
}
for _, zone := range p.bypassZones {
if strings.HasSuffix(host, zone) {
return p.bypass
}
if host == zone[1:] {
// For a zone ".example.com", we match "example.com"
// too.
return p.bypass
}
}
for _, bypassHost := range p.bypassHosts {
if bypassHost == host {
return p.bypass
}
}
return p.def
}
// AddFromString parses a string that contains comma-separated values
// specifying hosts that should use the bypass proxy. Each value is either an
// IP address, a CIDR range, a zone (*.example.com) or a host name
// (localhost). A best effort is made to parse the string and errors are
// ignored.
func (p *proxy_PerHost) AddFromString(s string) {
hosts := strings.Split(s, ",")
for _, host := range hosts {
host = strings.TrimSpace(host)
if len(host) == 0 {
continue
}
if strings.Contains(host, "/") {
// We assume that it's a CIDR address like 127.0.0.0/8
if _, net, err := net.ParseCIDR(host); err == nil {
p.AddNetwork(net)
}
continue
}
if ip := net.ParseIP(host); ip != nil {
p.AddIP(ip)
continue
}
if strings.HasPrefix(host, "*.") {
p.AddZone(host[1:])
continue
}
p.AddHost(host)
}
}
// AddIP specifies an IP address that will use the bypass proxy. Note that
// this will only take effect if a literal IP address is dialed. A connection
// to a named host will never match an IP.
func (p *proxy_PerHost) AddIP(ip net.IP) {
p.bypassIPs = append(p.bypassIPs, ip)
}
// AddNetwork specifies an IP range that will use the bypass proxy. Note that
// this will only take effect if a literal IP address is dialed. A connection
// to a named host will never match.
func (p *proxy_PerHost) AddNetwork(net *net.IPNet) {
p.bypassNetworks = append(p.bypassNetworks, net)
}
// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of
// "example.com" matches "example.com" and all of its subdomains.
func (p *proxy_PerHost) AddZone(zone string) {
if strings.HasSuffix(zone, ".") {
zone = zone[:len(zone)-1]
}
if !strings.HasPrefix(zone, ".") {
zone = "." + zone
}
p.bypassZones = append(p.bypassZones, zone)
}
// AddHost specifies a host name that will use the bypass proxy.
func (p *proxy_PerHost) AddHost(host string) {
if strings.HasSuffix(host, ".") {
host = host[:len(host)-1]
}
p.bypassHosts = append(p.bypassHosts, host)
}
// A Dialer is a means to establish a connection.
type proxy_Dialer interface {
// Dial connects to the given address via the proxy.
Dial(network, addr string) (c net.Conn, err error)
}
// Auth contains authentication parameters that specific Dialers may require.
type proxy_Auth struct {
User, Password string
}
// FromEnvironment returns the dialer specified by the proxy related variables in
// the environment.
func proxy_FromEnvironment() proxy_Dialer {
allProxy := proxy_allProxyEnv.Get()
if len(allProxy) == 0 {
return proxy_Direct
}
proxyURL, err := url.Parse(allProxy)
if err != nil {
return proxy_Direct
}
proxy, err := proxy_FromURL(proxyURL, proxy_Direct)
if err != nil {
return proxy_Direct
}
noProxy := proxy_noProxyEnv.Get()
if len(noProxy) == 0 {
return proxy
}
perHost := proxy_NewPerHost(proxy, proxy_Direct)
perHost.AddFromString(noProxy)
return perHost
}
// proxySchemes is a map from URL schemes to a function that creates a Dialer
// from a URL with such a scheme.
var proxy_proxySchemes map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error)
// RegisterDialerType takes a URL scheme and a function to generate Dialers from
// a URL with that scheme and a forwarding Dialer. Registered schemes are used
// by FromURL.
func proxy_RegisterDialerType(scheme string, f func(*url.URL, proxy_Dialer) (proxy_Dialer, error)) {
if proxy_proxySchemes == nil {
proxy_proxySchemes = make(map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error))
}
proxy_proxySchemes[scheme] = f
}
// FromURL returns a Dialer given a URL specification and an underlying
// Dialer for it to make network requests.
func proxy_FromURL(u *url.URL, forward proxy_Dialer) (proxy_Dialer, error) {
var auth *proxy_Auth
if u.User != nil {
auth = new(proxy_Auth)
auth.User = u.User.Username()
if p, ok := u.User.Password(); ok {
auth.Password = p
}
}
switch u.Scheme {
case "socks5":
return proxy_SOCKS5("tcp", u.Host, auth, forward)
}
// If the scheme doesn't match any of the built-in schemes, see if it
// was registered by another package.
if proxy_proxySchemes != nil {
if f, ok := proxy_proxySchemes[u.Scheme]; ok {
return f(u, forward)
}
}
return nil, errors.New("proxy: unknown scheme: " + u.Scheme)
}
var (
proxy_allProxyEnv = &proxy_envOnce{
names: []string{"ALL_PROXY", "all_proxy"},
}
proxy_noProxyEnv = &proxy_envOnce{
names: []string{"NO_PROXY", "no_proxy"},
}
)
// envOnce looks up an environment variable (optionally by multiple
// names) once. It mitigates expensive lookups on some platforms
// (e.g. Windows).
// (Borrowed from net/http/transport.go)
type proxy_envOnce struct {
names []string
once sync.Once
val string
}
func (e *proxy_envOnce) Get() string {
e.once.Do(e.init)
return e.val
}
func (e *proxy_envOnce) init() {
for _, n := range e.names {
e.val = os.Getenv(n)
if e.val != "" {
return
}
}
}
// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given address
// with an optional username and password. See RFC 1928 and RFC 1929.
func proxy_SOCKS5(network, addr string, auth *proxy_Auth, forward proxy_Dialer) (proxy_Dialer, error) {
s := &proxy_socks5{
network: network,
addr: addr,
forward: forward,
}
if auth != nil {
s.user = auth.User
s.password = auth.Password
}
return s, nil
}
type proxy_socks5 struct {
user, password string
network, addr string
forward proxy_Dialer
}
const proxy_socks5Version = 5
const (
proxy_socks5AuthNone = 0
proxy_socks5AuthPassword = 2
)
const proxy_socks5Connect = 1
const (
proxy_socks5IP4 = 1
proxy_socks5Domain = 3
proxy_socks5IP6 = 4
)
var proxy_socks5Errors = []string{
"",
"general failure",
"connection forbidden",
"network unreachable",
"host unreachable",
"connection refused",
"TTL expired",
"command not supported",
"address type not supported",
}
// Dial connects to the address addr on the given network via the SOCKS5 proxy.
func (s *proxy_socks5) Dial(network, addr string) (net.Conn, error) {
switch network {
case "tcp", "tcp6", "tcp4":
default:
return nil, errors.New("proxy: no support for SOCKS5 proxy connections of type " + network)
}
conn, err := s.forward.Dial(s.network, s.addr)
if err != nil {
return nil, err
}
if err := s.connect(conn, addr); err != nil {
conn.Close()
return nil, err
}
return conn, nil
}
// connect takes an existing connection to a socks5 proxy server,
// and commands the server to extend that connection to target,
// which must be a canonical address with a host and port.
func (s *proxy_socks5) connect(conn net.Conn, target string) error {
host, portStr, err := net.SplitHostPort(target)
if err != nil {
return err
}
port, err := strconv.Atoi(portStr)
if err != nil {
return errors.New("proxy: failed to parse port number: " + portStr)
}
if port < 1 || port > 0xffff {
return errors.New("proxy: port number out of range: " + portStr)
}
// the size here is just an estimate
buf := make([]byte, 0, 6+len(host))
buf = append(buf, proxy_socks5Version)
if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 {
buf = append(buf, 2 /* num auth methods */, proxy_socks5AuthNone, proxy_socks5AuthPassword)
} else {
buf = append(buf, 1 /* num auth methods */, proxy_socks5AuthNone)
}
if _, err := conn.Write(buf); err != nil {
return errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
return errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if buf[0] != 5 {
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0])))
}
if buf[1] == 0xff {
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication")
}
// See RFC 1929
if buf[1] == proxy_socks5AuthPassword {
buf = buf[:0]
buf = append(buf, 1 /* password protocol version */)
buf = append(buf, uint8(len(s.user)))
buf = append(buf, s.user...)
buf = append(buf, uint8(len(s.password)))
buf = append(buf, s.password...)
if _, err := conn.Write(buf); err != nil {
return errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
return errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if buf[1] != 0 {
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password")
}
}
buf = buf[:0]
buf = append(buf, proxy_socks5Version, proxy_socks5Connect, 0 /* reserved */)
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
buf = append(buf, proxy_socks5IP4)
ip = ip4
} else {
buf = append(buf, proxy_socks5IP6)
}
buf = append(buf, ip...)
} else {
if len(host) > 255 {
return errors.New("proxy: destination host name too long: " + host)
}
buf = append(buf, proxy_socks5Domain)
buf = append(buf, byte(len(host)))
buf = append(buf, host...)
}
buf = append(buf, byte(port>>8), byte(port))
if _, err := conn.Write(buf); err != nil {
return errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err := io.ReadFull(conn, buf[:4]); err != nil {
return errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
failure := "unknown error"
if int(buf[1]) < len(proxy_socks5Errors) {
failure = proxy_socks5Errors[buf[1]]
}
if len(failure) > 0 {
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure)
}
bytesToDiscard := 0
switch buf[3] {
case proxy_socks5IP4:
bytesToDiscard = net.IPv4len
case proxy_socks5IP6:
bytesToDiscard = net.IPv6len
case proxy_socks5Domain:
_, err := io.ReadFull(conn, buf[:1])
if err != nil {
return errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
bytesToDiscard = int(buf[0])
default:
return errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr)
}
if cap(buf) < bytesToDiscard {
buf = make([]byte, bytesToDiscard)
} else {
buf = buf[:bytesToDiscard]
}
if _, err := io.ReadFull(conn, buf); err != nil {
return errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
// Also need to discard the port number
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
return errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
return nil
}