diff --git a/internal/pool/conn.go b/internal/pool/conn.go index b7845304..ee064c9f 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -9,7 +9,6 @@ import ( "github.com/go-redis/redis/v8/internal" "github.com/go-redis/redis/v8/internal/proto" - "go.opentelemetry.io/otel/trace" ) var noDeadline = time.Time{} @@ -66,41 +65,43 @@ func (cn *Conn) RemoteAddr() net.Addr { } func (cn *Conn) WithReader(ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error) error { - return internal.WithSpan(ctx, "redis.with_reader", func(ctx context.Context, span trace.Span) error { - if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil { - return internal.RecordError(ctx, span, err) - } - if err := fn(cn.rd); err != nil { - return internal.RecordError(ctx, span, err) - } - return nil - }) + ctx, span := internal.StartSpan(ctx, "redis.with_reader") + defer span.End() + + if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil { + return internal.RecordError(ctx, span, err) + } + if err := fn(cn.rd); err != nil { + return internal.RecordError(ctx, span, err) + } + return nil } func (cn *Conn) WithWriter( ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error, ) error { - return internal.WithSpan(ctx, "redis.with_writer", func(ctx context.Context, span trace.Span) error { - if err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)); err != nil { - return internal.RecordError(ctx, span, err) - } + ctx, span := internal.StartSpan(ctx, "redis.with_writer") + defer span.End() - if cn.bw.Buffered() > 0 { - cn.bw.Reset(cn.netConn) - } + if err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)); err != nil { + return internal.RecordError(ctx, span, err) + } - if err := fn(cn.wr); err != nil { - return internal.RecordError(ctx, span, err) - } + if cn.bw.Buffered() > 0 { + cn.bw.Reset(cn.netConn) + } - if err := cn.bw.Flush(); err != nil { - return internal.RecordError(ctx, span, err) - } + if err := fn(cn.wr); err != nil { + return internal.RecordError(ctx, span, err) + } - internal.WritesCounter.Add(ctx, 1) + if err := cn.bw.Flush(); err != nil { + return internal.RecordError(ctx, span, err) + } - return nil - }) + internal.WritesCounter.Add(ctx, 1) + + return nil } func (cn *Conn) Close() error { diff --git a/internal/util.go b/internal/util.go index 4d7921bf..1a648fe6 100644 --- a/internal/util.go +++ b/internal/util.go @@ -11,17 +11,18 @@ import ( ) func Sleep(ctx context.Context, dur time.Duration) error { - return WithSpan(ctx, "time.Sleep", func(ctx context.Context, span trace.Span) error { - t := time.NewTimer(dur) - defer t.Stop() + _, span := StartSpan(ctx, "time.Sleep") + defer span.End() - select { - case <-t.C: - return nil - case <-ctx.Done(): - return ctx.Err() - } - }) + t := time.NewTimer(dur) + defer t.Stop() + + select { + case <-t.C: + return nil + case <-ctx.Done(): + return ctx.Err() + } } func ToLower(s string) string { @@ -54,15 +55,11 @@ func isLower(s string) bool { var tracer = otel.Tracer("github.com/go-redis/redis") -func WithSpan(ctx context.Context, name string, fn func(context.Context, trace.Span) error) error { +func StartSpan(ctx context.Context, name string) (context.Context, trace.Span) { if span := trace.SpanFromContext(ctx); !span.IsRecording() { - return fn(ctx, span) + return ctx, span } - - ctx, span := tracer.Start(ctx, name) - defer span.End() - - return fn(ctx, span) + return tracer.Start(ctx, name) } func RecordError(ctx context.Context, span trace.Span, err error) error { diff --git a/options.go b/options.go index 5d24d531..0fd8e880 100644 --- a/options.go +++ b/options.go @@ -15,7 +15,6 @@ import ( "github.com/go-redis/redis/v8/internal" "github.com/go-redis/redis/v8/internal/pool" "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" ) // Limiter is the interface of a rate limiter or a circuit breaker. @@ -292,20 +291,21 @@ func getUserPassword(u *url.URL) (string, string) { func newConnPool(opt *Options) *pool.ConnPool { return pool.NewConnPool(&pool.Options{ Dialer: func(ctx context.Context) (net.Conn, error) { - var conn net.Conn - err := internal.WithSpan(ctx, "redis.dial", func(ctx context.Context, span trace.Span) error { + ctx, span := internal.StartSpan(ctx, "redis.dial") + defer span.End() + + if span.IsRecording() { span.SetAttributes( attribute.String("db.connection_string", opt.Addr), ) + } - var err error - conn, err = opt.Dialer(ctx, opt.Network, opt.Addr) - if err != nil { - _ = internal.RecordError(ctx, span, err) - } - return err - }) - return conn, err + cn, err := opt.Dialer(ctx, opt.Network, opt.Addr) + if err != nil { + return nil, internal.RecordError(ctx, span, err) + } + + return cn, nil }, PoolSize: opt.PoolSize, MinIdleConns: opt.MinIdleConns, diff --git a/redis.go b/redis.go index 3dfdb82a..c3882e01 100644 --- a/redis.go +++ b/redis.go @@ -11,7 +11,6 @@ import ( "github.com/go-redis/redis/v8/internal/pool" "github.com/go-redis/redis/v8/internal/proto" "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" ) // Nil reply returned by Redis when key does not exist. @@ -214,10 +213,7 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) { return cn, nil } - err = internal.WithSpan(ctx, "redis.init_conn", func(ctx context.Context, span trace.Span) error { - return c.initConn(ctx, cn) - }) - if err != nil { + if err := c.initConn(ctx, cn); err != nil { c.connPool.Remove(ctx, cn, err) if err := errors.Unwrap(err); err != nil { return nil, err @@ -241,6 +237,9 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { return nil } + ctx, span := internal.StartSpan(ctx, "redis.init_conn") + defer span.End() + connPool := pool.NewSingleConnPool(c.connPool, cn) conn := newConn(ctx, c.opt, connPool) @@ -288,43 +287,44 @@ func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) func (c *baseClient) withConn( ctx context.Context, fn func(context.Context, *pool.Conn) error, ) error { - return internal.WithSpan(ctx, "redis.with_conn", func(ctx context.Context, span trace.Span) error { - cn, err := c.getConn(ctx) - if err != nil { - return err + ctx, span := internal.StartSpan(ctx, "redis.with_conn") + defer span.End() + + cn, err := c.getConn(ctx) + if err != nil { + return err + } + + if span.IsRecording() { + if remoteAddr := cn.RemoteAddr(); remoteAddr != nil { + span.SetAttributes(attribute.String("net.peer.ip", remoteAddr.String())) } + } - if span.IsRecording() { - if remoteAddr := cn.RemoteAddr(); remoteAddr != nil { - span.SetAttributes(attribute.String("net.peer.ip", remoteAddr.String())) - } - } + defer func() { + c.releaseConn(ctx, cn, err) + }() - defer func() { - c.releaseConn(ctx, cn, err) - }() + done := ctx.Done() + if done == nil { + err = fn(ctx, cn) + return err + } - done := ctx.Done() - if done == nil { - err = fn(ctx, cn) - return err - } + errc := make(chan error, 1) + go func() { errc <- fn(ctx, cn) }() - errc := make(chan error, 1) - go func() { errc <- fn(ctx, cn) }() + select { + case <-done: + _ = cn.Close() + // Wait for the goroutine to finish and send something. + <-errc - select { - case <-done: - _ = cn.Close() - // Wait for the goroutine to finish and send something. - <-errc - - err = ctx.Err() - return err - case err = <-errc: - return err - } - }) + err = ctx.Err() + return err + case err = <-errc: + return err + } } func (c *baseClient) process(ctx context.Context, cmd Cmder) error { @@ -332,47 +332,50 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error { for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { attempt := attempt - var retry bool - err := internal.WithSpan(ctx, "redis.process", func(ctx context.Context, span trace.Span) error { - if attempt > 0 { - if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { - return err - } - } - - retryTimeout := uint32(1) - err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { - err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { - return writeCmd(wr, cmd) - }) - if err != nil { - return err - } - - err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply) - if err != nil { - if cmd.readTimeout() == nil { - atomic.StoreUint32(&retryTimeout, 1) - } - return err - } - - return nil - }) - if err == nil { - return nil - } - retry = shouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1) - return err - }) + retry, err := c._process(ctx, cmd, attempt) if err == nil || !retry { return err } + lastErr = err } return lastErr } +func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool, error) { + if attempt > 0 { + if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { + return false, err + } + } + + retryTimeout := uint32(1) + err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { + err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { + return writeCmd(wr, cmd) + }) + if err != nil { + return err + } + + err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply) + if err != nil { + if cmd.readTimeout() == nil { + atomic.StoreUint32(&retryTimeout, 1) + } + return err + } + + return nil + }) + if err == nil { + return false, nil + } + + retry := shouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1) + return retry, err +} + func (c *baseClient) retryBackoff(attempt int) time.Duration { return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff) }