forked from mirror/redis
170 lines
4.0 KiB
Go
170 lines
4.0 KiB
Go
package redisotel
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
|
|
"go.opentelemetry.io/otel/attribute"
|
|
"go.opentelemetry.io/otel/codes"
|
|
semconv "go.opentelemetry.io/otel/semconv/v1.10.0"
|
|
"go.opentelemetry.io/otel/trace"
|
|
|
|
"github.com/go-redis/redis/extra/rediscmd/v9"
|
|
"github.com/go-redis/redis/v9"
|
|
)
|
|
|
|
const (
|
|
instrumName = "github.com/go-redis/redis/extra/redisotel"
|
|
)
|
|
|
|
func InstrumentTracing(rdb redis.UniversalClient, opts ...TracingOption) error {
|
|
switch rdb := rdb.(type) {
|
|
case *redis.Client:
|
|
opt := rdb.Options()
|
|
connString := formatDBConnString(opt.Network, opt.Addr)
|
|
rdb.AddHook(newTracingHook(connString, opts...))
|
|
return nil
|
|
case *redis.ClusterClient:
|
|
rdb.AddHook(newTracingHook("", opts...))
|
|
|
|
rdb.OnNewNode(func(rdb *redis.Client) {
|
|
opt := rdb.Options()
|
|
connString := formatDBConnString(opt.Network, opt.Addr)
|
|
rdb.AddHook(newTracingHook(connString, opts...))
|
|
})
|
|
return nil
|
|
case *redis.Ring:
|
|
rdb.AddHook(newTracingHook("", opts...))
|
|
|
|
rdb.OnNewNode(func(rdb *redis.Client) {
|
|
opt := rdb.Options()
|
|
connString := formatDBConnString(opt.Network, opt.Addr)
|
|
rdb.AddHook(newTracingHook(connString, opts...))
|
|
})
|
|
return nil
|
|
default:
|
|
return fmt.Errorf("redisotel: %T not supported", rdb)
|
|
}
|
|
}
|
|
|
|
type tracingHook struct {
|
|
conf *config
|
|
|
|
spanOpts []trace.SpanStartOption
|
|
}
|
|
|
|
var _ redis.Hook = (*tracingHook)(nil)
|
|
|
|
func newTracingHook(connString string, opts ...TracingOption) *tracingHook {
|
|
baseOpts := make([]baseOption, len(opts))
|
|
for i, opt := range opts {
|
|
baseOpts[i] = opt
|
|
}
|
|
conf := newConfig(baseOpts...)
|
|
|
|
if conf.tracer == nil {
|
|
conf.tracer = conf.tp.Tracer(
|
|
instrumName,
|
|
trace.WithInstrumentationVersion("semver:"+redis.Version()),
|
|
)
|
|
}
|
|
if connString != "" {
|
|
conf.attrs = append(conf.attrs, semconv.DBConnectionStringKey.String(connString))
|
|
}
|
|
|
|
return &tracingHook{
|
|
conf: conf,
|
|
|
|
spanOpts: []trace.SpanStartOption{
|
|
trace.WithSpanKind(trace.SpanKindClient),
|
|
trace.WithAttributes(conf.attrs...),
|
|
},
|
|
}
|
|
}
|
|
|
|
func (th *tracingHook) DialHook(hook redis.DialHook) redis.DialHook {
|
|
return func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
if !trace.SpanFromContext(ctx).IsRecording() {
|
|
return hook(ctx, network, addr)
|
|
}
|
|
|
|
ctx, span := th.conf.tracer.Start(ctx, "redis.dial", th.spanOpts...)
|
|
defer span.End()
|
|
|
|
conn, err := hook(ctx, network, addr)
|
|
if err != nil {
|
|
recordError(ctx, span, err)
|
|
return nil, err
|
|
}
|
|
return conn, nil
|
|
}
|
|
}
|
|
|
|
func (th *tracingHook) ProcessHook(hook redis.ProcessHook) redis.ProcessHook {
|
|
return func(ctx context.Context, cmd redis.Cmder) error {
|
|
if !trace.SpanFromContext(ctx).IsRecording() {
|
|
return hook(ctx, cmd)
|
|
}
|
|
|
|
opts := th.spanOpts
|
|
if th.conf.dbStmtEnabled {
|
|
opts = append(opts, trace.WithAttributes(
|
|
semconv.DBStatementKey.String(rediscmd.CmdString(cmd))),
|
|
)
|
|
}
|
|
|
|
ctx, span := th.conf.tracer.Start(ctx, cmd.FullName(), opts...)
|
|
defer span.End()
|
|
|
|
if err := hook(ctx, cmd); err != nil {
|
|
recordError(ctx, span, err)
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (th *tracingHook) ProcessPipelineHook(
|
|
hook redis.ProcessPipelineHook,
|
|
) redis.ProcessPipelineHook {
|
|
return func(ctx context.Context, cmds []redis.Cmder) error {
|
|
if !trace.SpanFromContext(ctx).IsRecording() {
|
|
return hook(ctx, cmds)
|
|
}
|
|
|
|
opts := th.spanOpts
|
|
opts = append(opts, trace.WithAttributes(
|
|
attribute.Int("db.redis.num_cmd", len(cmds)),
|
|
))
|
|
|
|
summary, cmdsString := rediscmd.CmdsString(cmds)
|
|
if th.conf.dbStmtEnabled {
|
|
opts = append(opts, trace.WithAttributes(semconv.DBStatementKey.String(cmdsString)))
|
|
}
|
|
|
|
ctx, span := th.conf.tracer.Start(ctx, "redis.pipeline "+summary, opts...)
|
|
defer span.End()
|
|
|
|
if err := hook(ctx, cmds); err != nil {
|
|
recordError(ctx, span, err)
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func recordError(ctx context.Context, span trace.Span, err error) {
|
|
if err != redis.Nil {
|
|
span.RecordError(err)
|
|
span.SetStatus(codes.Error, err.Error())
|
|
}
|
|
}
|
|
|
|
func formatDBConnString(network, addr string) string {
|
|
if network == "tcp" {
|
|
network = "redis"
|
|
}
|
|
return fmt.Sprintf("%s://%s", network, addr)
|
|
}
|