diff --git a/command.go b/command.go index 7a17864..f606acc 100644 --- a/command.go +++ b/command.go @@ -74,13 +74,13 @@ func cmdFirstKeyPos(cmd Cmder, info *CommandInfo) int { } func cmdString(cmd Cmder, val interface{}) string { - b := make([]byte, 0, 32) + b := make([]byte, 0, 64) for i, arg := range cmd.Args() { if i > 0 { b = append(b, ' ') } - b = appendArg(b, arg) + b = internal.AppendArg(b, arg) } if err := cmd.Err(); err != nil { @@ -88,48 +88,10 @@ func cmdString(cmd Cmder, val interface{}) string { b = append(b, err.Error()...) } else if val != nil { b = append(b, ": "...) - - switch val := val.(type) { - case []byte: - b = append(b, val...) - default: - b = appendArg(b, val) - } + b = internal.AppendArg(b, val) } - return string(b) -} - -func appendArg(b []byte, v interface{}) []byte { - switch v := v.(type) { - case nil: - return append(b, ""...) - case string: - return append(b, v...) - case []byte: - return append(b, v...) - case int: - return strconv.AppendInt(b, int64(v), 10) - case int32: - return strconv.AppendInt(b, int64(v), 10) - case int64: - return strconv.AppendInt(b, v, 10) - case uint: - return strconv.AppendUint(b, uint64(v), 10) - case uint32: - return strconv.AppendUint(b, uint64(v), 10) - case uint64: - return strconv.AppendUint(b, v, 10) - case bool: - if v { - return append(b, "true"...) - } - return append(b, "false"...) - case time.Time: - return v.AppendFormat(b, time.RFC3339Nano) - default: - return append(b, fmt.Sprint(v)...) - } + return internal.String(b) } //------------------------------------------------------------------------------ diff --git a/internal/pool/conn.go b/internal/pool/conn.go index 1f856f0..5bb34c3 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -1,6 +1,7 @@ package pool import ( + "bufio" "context" "net" "sync/atomic" @@ -13,15 +14,16 @@ import ( var noDeadline = time.Time{} type Conn struct { + usedAt int64 // atomic netConn net.Conn rd *proto.Reader + bw *bufio.Writer wr *proto.Writer Inited bool pooled bool createdAt time.Time - usedAt int64 // atomic } func NewConn(netConn net.Conn) *Conn { @@ -30,7 +32,8 @@ func NewConn(netConn net.Conn) *Conn { createdAt: time.Now(), } cn.rd = proto.NewReader(netConn) - cn.wr = proto.NewWriter(netConn) + cn.bw = bufio.NewWriter(netConn) + cn.wr = proto.NewWriter(cn.bw) cn.SetUsedAt(time.Now()) return cn } @@ -47,7 +50,7 @@ func (cn *Conn) SetUsedAt(tm time.Time) { func (cn *Conn) SetNetConn(netConn net.Conn) { cn.netConn = netConn cn.rd.Reset(netConn) - cn.wr.Reset(netConn) + cn.bw.Reset(netConn) } func (cn *Conn) Write(b []byte) (int, error) { @@ -77,8 +80,8 @@ func (cn *Conn) WithWriter( return err } - if cn.wr.Buffered() > 0 { - cn.wr.Reset(cn.netConn) + if cn.bw.Buffered() > 0 { + cn.bw.Reset(cn.netConn) } err = fn(cn.wr) @@ -86,7 +89,7 @@ func (cn *Conn) WithWriter( return err } - return cn.wr.Flush() + return cn.bw.Flush() }) } diff --git a/internal/proto/writer.go b/internal/proto/writer.go index ed969e3..aafe99a 100644 --- a/internal/proto/writer.go +++ b/internal/proto/writer.go @@ -1,7 +1,6 @@ package proto import ( - "bufio" "encoding" "fmt" "io" @@ -11,16 +10,22 @@ import ( "github.com/go-redis/redis/v8/internal/util" ) +type writer interface { + io.Writer + io.ByteWriter + io.StringWriter +} + type Writer struct { - wr *bufio.Writer + writer lenBuf []byte numBuf []byte } -func NewWriter(wr io.Writer) *Writer { +func NewWriter(wr writer) *Writer { return &Writer{ - wr: bufio.NewWriter(wr), + writer: wr, lenBuf: make([]byte, 64), numBuf: make([]byte, 64), @@ -28,19 +33,16 @@ func NewWriter(wr io.Writer) *Writer { } func (w *Writer) WriteArgs(args []interface{}) error { - err := w.wr.WriteByte(ArrayReply) - if err != nil { + if err := w.WriteByte(ArrayReply); err != nil { return err } - err = w.writeLen(len(args)) - if err != nil { + if err := w.writeLen(len(args)); err != nil { return err } for _, arg := range args { - err := w.writeArg(arg) - if err != nil { + if err := w.WriteArg(arg); err != nil { return err } } @@ -51,11 +53,11 @@ func (w *Writer) WriteArgs(args []interface{}) error { func (w *Writer) writeLen(n int) error { w.lenBuf = strconv.AppendUint(w.lenBuf[:0], uint64(n), 10) w.lenBuf = append(w.lenBuf, '\r', '\n') - _, err := w.wr.Write(w.lenBuf) + _, err := w.Write(w.lenBuf) return err } -func (w *Writer) writeArg(v interface{}) error { +func (w *Writer) WriteArg(v interface{}) error { switch v := v.(type) { case nil: return w.string("") @@ -108,18 +110,15 @@ func (w *Writer) writeArg(v interface{}) error { } func (w *Writer) bytes(b []byte) error { - err := w.wr.WriteByte(StringReply) - if err != nil { + if err := w.WriteByte(StringReply); err != nil { return err } - err = w.writeLen(len(b)) - if err != nil { + if err := w.writeLen(len(b)); err != nil { return err } - _, err = w.wr.Write(b) - if err != nil { + if _, err := w.Write(b); err != nil { return err } @@ -146,21 +145,8 @@ func (w *Writer) float(f float64) error { } func (w *Writer) crlf() error { - err := w.wr.WriteByte('\r') - if err != nil { + if err := w.WriteByte('\r'); err != nil { return err } - return w.wr.WriteByte('\n') -} - -func (w *Writer) Buffered() int { - return w.wr.Buffered() -} - -func (w *Writer) Reset(wr io.Writer) { - w.wr.Reset(wr) -} - -func (w *Writer) Flush() error { - return w.wr.Flush() + return w.WriteByte('\n') } diff --git a/internal/proto/write_buffer_test.go b/internal/proto/writer_test.go similarity index 83% rename from internal/proto/write_buffer_test.go rename to internal/proto/writer_test.go index abea7cd..c5df9a6 100644 --- a/internal/proto/write_buffer_test.go +++ b/internal/proto/writer_test.go @@ -3,7 +3,6 @@ package proto_test import ( "bytes" "encoding" - "io/ioutil" "testing" "time" @@ -41,9 +40,6 @@ var _ = Describe("WriteBuffer", func() { }) Expect(err).NotTo(HaveOccurred()) - err = wr.Flush() - Expect(err).NotTo(HaveOccurred()) - Expect(buf.Bytes()).To(Equal([]byte("*6\r\n" + "$6\r\nstring\r\n" + "$2\r\n12\r\n" + @@ -59,9 +55,6 @@ var _ = Describe("WriteBuffer", func() { err := wr.WriteArgs([]interface{}{tm}) Expect(err).NotTo(HaveOccurred()) - err = wr.Flush() - Expect(err).NotTo(HaveOccurred()) - Expect(buf.Len()).To(Equal(41)) }) @@ -69,26 +62,32 @@ var _ = Describe("WriteBuffer", func() { err := wr.WriteArgs([]interface{}{&MyType{}}) Expect(err).NotTo(HaveOccurred()) - err = wr.Flush() - Expect(err).NotTo(HaveOccurred()) - Expect(buf.Len()).To(Equal(15)) }) }) +type discard struct{} + +func (discard) Write(b []byte) (int, error) { + return len(b), nil +} + +func (discard) WriteString(s string) (int, error) { + return len(s), nil +} + +func (discard) WriteByte(c byte) error { + return nil +} + func BenchmarkWriteBuffer_Append(b *testing.B) { - buf := proto.NewWriter(ioutil.Discard) + buf := proto.NewWriter(discard{}) args := []interface{}{"hello", "world", "foo", "bar"} for i := 0; i < b.N; i++ { err := buf.WriteArgs(args) if err != nil { - panic(err) - } - - err = buf.Flush() - if err != nil { - panic(err) + b.Fatal(err) } } } diff --git a/internal/safe.go b/internal/safe.go new file mode 100644 index 0000000..862ff0e --- /dev/null +++ b/internal/safe.go @@ -0,0 +1,11 @@ +// +build appengine + +package internal + +func String(b []byte) string { + return string(b) +} + +func Bytes(s string) []byte { + return []byte(s) +} diff --git a/internal/unsafe.go b/internal/unsafe.go new file mode 100644 index 0000000..4bc7970 --- /dev/null +++ b/internal/unsafe.go @@ -0,0 +1,20 @@ +// +build !appengine + +package internal + +import "unsafe" + +// String converts byte slice to string. +func String(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} + +// Bytes converts string to byte slice. +func Bytes(s string) []byte { + return *(*[]byte)(unsafe.Pointer( + &struct { + string + Cap int + }{s, len(s)}, + )) +} diff --git a/internal/util.go b/internal/util.go index f2ae9df..e8cda05 100644 --- a/internal/util.go +++ b/internal/util.go @@ -2,7 +2,10 @@ package internal import ( "context" + "fmt" + "strconv" "time" + "unicode/utf8" "github.com/go-redis/redis/v8/internal/util" "go.opentelemetry.io/otel/api/global" @@ -69,3 +72,74 @@ func WithSpan(ctx context.Context, name string, fn func(context.Context) error) return fn(ctx) } + +func AppendArg(b []byte, v interface{}) []byte { + switch v := v.(type) { + case nil: + return append(b, ""...) + case string: + return appendUTF8String(b, v) + case []byte: + return appendUTF8String(b, String(v)) + case int: + return strconv.AppendInt(b, int64(v), 10) + case int8: + return strconv.AppendInt(b, int64(v), 10) + case int16: + return strconv.AppendInt(b, int64(v), 10) + case int32: + return strconv.AppendInt(b, int64(v), 10) + case int64: + return strconv.AppendInt(b, v, 10) + case uint: + return strconv.AppendUint(b, uint64(v), 10) + case uint8: + return strconv.AppendUint(b, uint64(v), 10) + case uint16: + return strconv.AppendUint(b, uint64(v), 10) + case uint32: + return strconv.AppendUint(b, uint64(v), 10) + case uint64: + return strconv.AppendUint(b, v, 10) + case float32: + return strconv.AppendFloat(b, float64(v), 'f', -1, 64) + case float64: + return strconv.AppendFloat(b, v, 'f', -1, 64) + case bool: + if v { + return append(b, "true"...) + } + return append(b, "false"...) + case time.Time: + return v.AppendFormat(b, time.RFC3339Nano) + default: + return append(b, fmt.Sprint(v)...) + } +} + +func appendUTF8String(b []byte, s string) []byte { + for _, r := range s { + b = appendRune(b, r) + } + return b +} + +func appendRune(b []byte, r rune) []byte { + if r < utf8.RuneSelf { + switch c := byte(r); c { + case '\n': + return append(b, "\\n"...) + case '\r': + return append(b, "\\r"...) + default: + return append(b, c) + } + } + + l := len(b) + b = append(b, make([]byte, utf8.UTFMax)...) + n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r) + b = b[:l+n] + + return b +} diff --git a/redisext/otel.go b/redisext/otel.go new file mode 100644 index 0000000..7e53f0b --- /dev/null +++ b/redisext/otel.go @@ -0,0 +1,112 @@ +package redisext + +import ( + "context" + "strings" + + "github.com/go-redis/redis/v8" + "github.com/go-redis/redis/v8/internal" + "go.opentelemetry.io/otel/api/global" + "go.opentelemetry.io/otel/api/kv" + "go.opentelemetry.io/otel/api/trace" +) + +type OpenTelemetryHook struct{} + +var _ redis.Hook = OpenTelemetryHook{} + +func (OpenTelemetryHook) BeforeProcess(ctx context.Context, cmd redis.Cmder) (context.Context, error) { + if !trace.SpanFromContext(ctx).IsRecording() { + return ctx, nil + } + + b := make([]byte, 32) + b = appendCmd(b, cmd) + + tracer := global.Tracer("github.com/go-redis/redis") + ctx, span := tracer.Start(ctx, cmd.FullName()) + span.SetAttributes( + kv.String("db.system", "redis"), + kv.String("redis.cmd", internal.String(b)), + ) + + return ctx, nil +} + +func (OpenTelemetryHook) AfterProcess(ctx context.Context, cmd redis.Cmder) error { + trace.SpanFromContext(ctx).End() + return nil +} + +func (OpenTelemetryHook) BeforeProcessPipeline(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + if !trace.SpanFromContext(ctx).IsRecording() { + return ctx, nil + } + + const numCmdLimit = 100 + const numNameLimit = 10 + + seen := make(map[string]struct{}, len(cmds)) + unqNames := make([]string, 0, len(cmds)) + + b := make([]byte, 0, 32*len(cmds)) + + for i, cmd := range cmds { + if i > numCmdLimit { + break + } + + if i > 0 { + b = append(b, '\n') + } + b = appendCmd(b, cmd) + + if len(unqNames) >= numNameLimit { + continue + } + + name := cmd.FullName() + if _, ok := seen[name]; !ok { + seen[name] = struct{}{} + unqNames = append(unqNames, name) + } + } + + tracer := global.Tracer("github.com/go-redis/redis") + ctx, span := tracer.Start(ctx, "pipeline "+strings.Join(unqNames, " ")) + span.SetAttributes( + kv.String("db.system", "redis"), + kv.Int("redis.num_cmd", len(cmds)), + kv.String("redis.cmds", internal.String(b)), + ) + + return ctx, nil +} + +func (OpenTelemetryHook) AfterProcessPipeline(ctx context.Context, cmds []redis.Cmder) error { + trace.SpanFromContext(ctx).End() + return nil +} + +func appendCmd(b []byte, cmd redis.Cmder) []byte { + const lenLimit = 64 + + for i, arg := range cmd.Args() { + if i > 0 { + b = append(b, ' ') + } + + start := len(b) + b = internal.AppendArg(b, arg) + if len(b)-start > lenLimit { + b = append(b[:start+lenLimit], "..."...) + } + } + + if err := cmd.Err(); err != nil { + b = append(b, ": "...) + b = append(b, err.Error()...) + } + + return b +}