diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c18618..302f2ec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,70 +1,34 @@ -# [9.0.0-beta.2](https://github.com/go-redis/redis/compare/v9.0.0-beta.1...v9.0.0-beta.2) (2022-07-28) - - -### Bug Fixes - -* [#2114](https://github.com/go-redis/redis/issues/2114) for redis-server not support Hello ([b6d2a92](https://github.com/go-redis/redis/commit/b6d2a925297e3e516eb5c76c114c1c9fcd5b68c5)) -* additional node failures in clustered pipelined reads ([03376a5](https://github.com/go-redis/redis/commit/03376a5d9c7dfd7197b14ce13b24a0431a07a663)) -* disregard failed pings in updateLatency() for cluster nodes ([64f972f](https://github.com/go-redis/redis/commit/64f972fbeae401e52a2c066a0e1c922af617e15c)) -* don't panic when test cannot start ([9e16c79](https://github.com/go-redis/redis/commit/9e16c79951e7769621b7320f1ecdf04baf539b82)) -* handle panic in ringShards Hash function when Ring got closed ([a80b84f](https://github.com/go-redis/redis/commit/a80b84f01f9fc0d3e6f08445ba21f7e07880775e)), closes [#2126](https://github.com/go-redis/redis/issues/2126) -* ignore Nil error when reading EntriesRead ([89d6dfe](https://github.com/go-redis/redis/commit/89d6dfe09a88321d445858c1c5b24d2757b95a3e)) -* log errors from cmdsInfoCache ([fa4d1ea](https://github.com/go-redis/redis/commit/fa4d1ea8398cd729ad5cbaaff88e4b8805393945)) -* provide a signal channel to end heartbeat goroutine ([f032c12](https://github.com/go-redis/redis/commit/f032c126db3e2c1a239ce1790b0ab81994df75cf)) -* remove conn reaper from the pool and uptrace option names ([f6a8adc](https://github.com/go-redis/redis/commit/f6a8adc50cdaec30527f50d06468f9176ee674fe)) -* replace heartbeat signal channel with context.WithCancel ([20d0ca2](https://github.com/go-redis/redis/commit/20d0ca235efff48ad48cc05b98790b825d4ba979)) - - - -# [9.0.0-beta.1](https://github.com/go-redis/redis/compare/v8.11.5...v9.0.0-beta.1) (2022-06-04) - -### Bug Fixes - -- **#1943:** xInfoConsumer.Idle should be time.Duration instead of int64 - ([#2052](https://github.com/go-redis/redis/issues/2052)) - ([997ab5e](https://github.com/go-redis/redis/commit/997ab5e7e3ddf53837917013a4babbded73e944f)), - closes [#1943](https://github.com/go-redis/redis/issues/1943) -- add XInfoConsumers test - ([6f1a1ac](https://github.com/go-redis/redis/commit/6f1a1ac284ea3f683eeb3b06a59969e8424b6376)) -- fix tests - ([3a722be](https://github.com/go-redis/redis/commit/3a722be81180e4d2a9cf0a29dc9a1ee1421f5859)) -- remove test(XInfoConsumer.idle), not a stable return value when tested. - ([f5fbb36](https://github.com/go-redis/redis/commit/f5fbb367e7d9dfd7f391fc535a7387002232fa8a)) -- update ChannelWithSubscriptions to accept options - ([c98c5f0](https://github.com/go-redis/redis/commit/c98c5f0eebf8d254307183c2ce702a48256b718d)) -- update COMMAND parser for Redis 7 - ([b0bb514](https://github.com/go-redis/redis/commit/b0bb514059249e01ed7328c9094e5b8a439dfb12)) -- use redis over ssh channel([#2057](https://github.com/go-redis/redis/issues/2057)) - ([#2060](https://github.com/go-redis/redis/issues/2060)) - ([3961b95](https://github.com/go-redis/redis/commit/3961b9577f622a3079fe74f8fc8da12ba67a77ff)) - -### Features - -- add ClientUnpause - ([91171f5](https://github.com/go-redis/redis/commit/91171f5e19a261dc4cfbf8706626d461b6ba03e4)) -- add NewXPendingResult for unit testing XPending - ([#2066](https://github.com/go-redis/redis/issues/2066)) - ([b7fd09e](https://github.com/go-redis/redis/commit/b7fd09e59479bc6ed5b3b13c4645a3620fd448a3)) -- add WriteArg and Scan net.IP([#2062](https://github.com/go-redis/redis/issues/2062)) - ([7d5167e](https://github.com/go-redis/redis/commit/7d5167e8624ac1515e146ed183becb97dadb3d1a)) -- **pool:** add check for badConnection - ([a8a7665](https://github.com/go-redis/redis/commit/a8a7665ddf8cc657c5226b1826a8ee83dab4b8c1)), - closes [#2053](https://github.com/go-redis/redis/issues/2053) -- provide a username and password callback method, so that the plaintext username and password will - not be stored in the memory, and the username and password will only be generated once when the - CredentialsProvider is called. After the method is executed, the username and password strings on - the stack will be released. ([#2097](https://github.com/go-redis/redis/issues/2097)) - ([56a3dbc](https://github.com/go-redis/redis/commit/56a3dbc7b656525eb88e0735e239d56e04a23bee)) -- upgrade to Redis 7 - ([d09c27e](https://github.com/go-redis/redis/commit/d09c27e6046129fd27b1d275e5a13a477bd7f778)) - ## v9 UNRELEASED +### Added + - Added support for [RESP3](https://github.com/antirez/RESP3/blob/master/spec.md) protocol. -- Removed `Pipeline.Close` since there is no real need to explicitly manage pipeline resources. - `Pipeline.Discard` is still available if you want to reset commands for some reason. -- Replaced `*redis.Z` with `redis.Z` since it is small enough to be passed as value. -- Renamed `MaxConnAge` to `ConnMaxLifetime`. -- Renamed `IdleTimeout` to `ConnMaxIdleTime`. + Contributed by @monkey92t who has done a lot of work recently. +- Added `ContextTimeoutEnabled` option that controls whether the client respects context timeouts + and deadlines. See + [Redis Timeouts](https://redis.uptrace.dev/guide/go-redis-debugging.html#timeouts) for details. +- Added `ParseClusterURL` to parse URLs into `ClusterOptions`, for example, + `redis://user:password@localhost:6789?dial_timeout=3&read_timeout=6s&addr=localhost:6790&addr=localhost:6791`. +- Added metrics instrumentation using `redisotel.IstrumentMetrics`. See + [documentation](https://redis.uptrace.dev/guide/go-redis-monitoring.html) + +### Changed + +- Reworked hook interface and added `DialHook`. +- Replaced `redisotel.NewTracingHook` with `redisotel.InstrumentTracing`. See + [example](example/otel) and + [documentation](https://redis.uptrace.dev/guide/go-redis-monitoring.html). +- Replaced `*redis.Z` with `redis.Z` since it is small enough to be passed as value without making + an allocation. +- Renamed the option `MaxConnAge` to `ConnMaxLifetime`. +- Renamed the option `IdleTimeout` to `ConnMaxIdleTime`. - Removed connection reaper in favor of `MaxIdleConns`. -- Removed `WithContext`. +- Removed `WithContext` since `context.Context` can be passed directly as an arg. +- Removed `Pipeline.Close` since there is no real need to explicitly manage pipeline resources and + it can be safely reused via `sync.Pool` etc. `Pipeline.Discard` is still available if you want to + reset commands for some reason. + +### Fixed + +- Improved and fixed pipeline retries. +- As usual, added more commands and fixed some bugs. diff --git a/cluster.go b/cluster.go index c81bad5..88a0b0e 100644 --- a/cluster.go +++ b/cluster.go @@ -71,11 +71,8 @@ type ClusterOptions struct { WriteTimeout time.Duration ContextTimeoutEnabled bool - // PoolFIFO uses FIFO mode for each node connection pool GET/PUT (default LIFO). - PoolFIFO bool - - // PoolSize applies per cluster node and not for the whole cluster. - PoolSize int + PoolFIFO bool + PoolSize int // applies per cluster node and not for the whole cluster PoolTimeout time.Duration MinIdleConns int MaxIdleConns int @@ -391,6 +388,7 @@ type clusterNodes struct { nodes map[string]*clusterNode activeAddrs []string closed bool + onNewNode []func(rdb *Client) _generation uint32 // atomic } @@ -426,6 +424,12 @@ func (c *clusterNodes) Close() error { return firstErr } +func (c *clusterNodes) OnNewNode(fn func(rdb *Client)) { + c.mu.Lock() + c.onNewNode = append(c.onNewNode, fn) + c.mu.Unlock() +} + func (c *clusterNodes) Addrs() ([]string, error) { var addrs []string @@ -503,6 +507,9 @@ func (c *clusterNodes) GetOrCreate(addr string) (*clusterNode, error) { } node = newClusterNode(c.opt, addr) + for _, fn := range c.onNewNode { + fn(node.Client) + } c.addrs = appendIfNotExists(c.addrs, addr) c.nodes[addr] = node @@ -812,18 +819,14 @@ func (c *clusterStateHolder) ReloadOrGet(ctx context.Context) (*clusterState, er //------------------------------------------------------------------------------ -type clusterClient struct { - opt *ClusterOptions - nodes *clusterNodes - state *clusterStateHolder //nolint:structcheck - cmdsInfoCache *cmdsInfoCache //nolint:structcheck -} - // ClusterClient is a Redis Cluster client representing a pool of zero // or more underlying connections. It's safe for concurrent use by // multiple goroutines. type ClusterClient struct { - *clusterClient + opt *ClusterOptions + nodes *clusterNodes + state *clusterStateHolder + cmdsInfoCache *cmdsInfoCache cmdable hooks } @@ -834,15 +837,18 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { opt.init() c := &ClusterClient{ - clusterClient: &clusterClient{ - opt: opt, - nodes: newClusterNodes(opt), - }, + opt: opt, + nodes: newClusterNodes(opt), } + c.state = newClusterStateHolder(c.loadState) c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo) c.cmdable = c.Process + c.hooks.process = c.process + c.hooks.processPipeline = c._processPipeline + c.hooks.processTxPipeline = c._processTxPipeline + return c } @@ -873,13 +879,14 @@ func (c *ClusterClient) Do(ctx context.Context, args ...interface{}) *Cmd { } func (c *ClusterClient) Process(ctx context.Context, cmd Cmder) error { - return c.hooks.process(ctx, cmd, c.process) + err := c.hooks.process(ctx, cmd) + cmd.SetErr(err) + return err } func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { cmdInfo := c.cmdInfo(ctx, cmd.Name()) slot := c.cmdSlot(ctx, cmd) - var node *clusterNode var ask bool var lastErr error @@ -899,11 +906,12 @@ func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { } if ask { + ask = false + pipe := node.Client.Pipeline() _ = pipe.Process(ctx, NewCmd(ctx, "asking")) _ = pipe.Process(ctx, cmd) _, lastErr = pipe.Exec(ctx) - ask = false } else { lastErr = node.Client.Process(ctx, cmd) } @@ -958,6 +966,10 @@ func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { return lastErr } +func (c *ClusterClient) OnNewNode(fn func(rdb *Client)) { + c.nodes.OnNewNode(fn) +} + // ForEachMaster concurrently calls the fn on each master node in the cluster. // It returns the first error if any. func (c *ClusterClient) ForEachMaster( @@ -1165,7 +1177,7 @@ func (c *ClusterClient) loadState(ctx context.Context) (*clusterState, error) { func (c *ClusterClient) Pipeline() Pipeliner { pipe := Pipeline{ - exec: c.processPipeline, + exec: pipelineExecer(c.hooks.processPipeline), } pipe.init() return &pipe @@ -1175,10 +1187,6 @@ func (c *ClusterClient) Pipelined(ctx context.Context, fn func(Pipeliner) error) return c.Pipeline().Pipelined(ctx, fn) } -func (c *ClusterClient) processPipeline(ctx context.Context, cmds []Cmder) error { - return c.hooks.processPipeline(ctx, cmds, c._processPipeline) -} - func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) error { cmdsMap := newCmdsMap() @@ -1258,7 +1266,7 @@ func (c *ClusterClient) cmdsAreReadOnly(ctx context.Context, cmds []Cmder) bool func (c *ClusterClient) _processPipelineNode( ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap, ) { - _ = node.Client.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { + _ = node.Client.hooks.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { return node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { return writeCmds(wr, cmds) @@ -1344,7 +1352,10 @@ func (c *ClusterClient) checkMovedErr( // TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC. func (c *ClusterClient) TxPipeline() Pipeliner { pipe := Pipeline{ - exec: c.processTxPipeline, + exec: func(ctx context.Context, cmds []Cmder) error { + cmds = wrapMultiExec(ctx, cmds) + return c.hooks.processTxPipeline(ctx, cmds) + }, } pipe.init() return &pipe @@ -1354,10 +1365,6 @@ func (c *ClusterClient) TxPipelined(ctx context.Context, fn func(Pipeliner) erro return c.TxPipeline().Pipelined(ctx, fn) } -func (c *ClusterClient) processTxPipeline(ctx context.Context, cmds []Cmder) error { - return c.hooks.processTxPipeline(ctx, cmds, c._processTxPipeline) -} - func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) error { // Trim multi .. exec. cmds = cmds[1 : len(cmds)-1] @@ -1419,38 +1426,38 @@ func (c *ClusterClient) mapCmdsBySlot(ctx context.Context, cmds []Cmder) map[int func (c *ClusterClient) _processTxPipelineNode( ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap, ) { - _ = node.Client.hooks.processTxPipeline( - ctx, cmds, func(ctx context.Context, cmds []Cmder) error { - return node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { - if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { - return writeCmds(wr, cmds) - }); err != nil { + cmds = wrapMultiExec(ctx, cmds) + _ = node.Client.hooks.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { + return node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { + if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { + return writeCmds(wr, cmds) + }); err != nil { + setCmdsErr(cmds, err) + return err + } + + return cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error { + statusCmd := cmds[0].(*StatusCmd) + // Trim multi and exec. + trimmedCmds := cmds[1 : len(cmds)-1] + + if err := c.txPipelineReadQueued( + ctx, rd, statusCmd, trimmedCmds, failedCmds, + ); err != nil { setCmdsErr(cmds, err) + + moved, ask, addr := isMovedError(err) + if moved || ask { + return c.cmdsMoved(ctx, trimmedCmds, moved, ask, addr, failedCmds) + } + return err } - return cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error { - statusCmd := cmds[0].(*StatusCmd) - // Trim multi and exec. - trimmedCmds := cmds[1 : len(cmds)-1] - - if err := c.txPipelineReadQueued( - ctx, rd, statusCmd, trimmedCmds, failedCmds, - ); err != nil { - setCmdsErr(cmds, err) - - moved, ask, addr := isMovedError(err) - if moved || ask { - return c.cmdsMoved(ctx, trimmedCmds, moved, ask, addr, failedCmds) - } - - return err - } - - return pipelineReadCmds(rd, trimmedCmds) - }) + return pipelineReadCmds(rd, trimmedCmds) }) }) + }) } func (c *ClusterClient) txPipelineReadQueued( @@ -1742,7 +1749,7 @@ func (c *ClusterClient) cmdNode( return state.slotMasterNode(slot) } -func (c *clusterClient) slotReadOnlyNode(state *clusterState, slot int) (*clusterNode, error) { +func (c *ClusterClient) slotReadOnlyNode(state *clusterState, slot int) (*clusterNode, error) { if c.opt.RouteByLatency { return state.slotClosestNode(slot) } diff --git a/cluster_commands.go b/cluster_commands.go index 085bce8..fc0a9cd 100644 --- a/cluster_commands.go +++ b/cluster_commands.go @@ -8,7 +8,7 @@ import ( func (c *ClusterClient) DBSize(ctx context.Context) *IntCmd { cmd := NewIntCmd(ctx, "dbsize") - _ = c.hooks.process(ctx, cmd, func(ctx context.Context, _ Cmder) error { + _ = c.hooks.withProcessHook(ctx, cmd, func(ctx context.Context, _ Cmder) error { var size int64 err := c.ForEachMaster(ctx, func(ctx context.Context, master *Client) error { n, err := master.DBSize(ctx).Result() @@ -30,7 +30,7 @@ func (c *ClusterClient) DBSize(ctx context.Context) *IntCmd { func (c *ClusterClient) ScriptLoad(ctx context.Context, script string) *StringCmd { cmd := NewStringCmd(ctx, "script", "load", script) - _ = c.hooks.process(ctx, cmd, func(ctx context.Context, _ Cmder) error { + _ = c.hooks.withProcessHook(ctx, cmd, func(ctx context.Context, _ Cmder) error { mu := &sync.Mutex{} err := c.ForEachShard(ctx, func(ctx context.Context, shard *Client) error { val, err := shard.ScriptLoad(ctx, script).Result() @@ -56,7 +56,7 @@ func (c *ClusterClient) ScriptLoad(ctx context.Context, script string) *StringCm func (c *ClusterClient) ScriptFlush(ctx context.Context) *StatusCmd { cmd := NewStatusCmd(ctx, "script", "flush") - _ = c.hooks.process(ctx, cmd, func(ctx context.Context, _ Cmder) error { + _ = c.hooks.withProcessHook(ctx, cmd, func(ctx context.Context, _ Cmder) error { err := c.ForEachShard(ctx, func(ctx context.Context, shard *Client) error { return shard.ScriptFlush(ctx).Err() }) @@ -82,8 +82,8 @@ func (c *ClusterClient) ScriptExists(ctx context.Context, hashes ...string) *Boo result[i] = true } - _ = c.hooks.process(ctx, cmd, func(ctx context.Context, _ Cmder) error { - mu := &sync.Mutex{} + _ = c.hooks.withProcessHook(ctx, cmd, func(ctx context.Context, _ Cmder) error { + var mu sync.Mutex err := c.ForEachShard(ctx, func(ctx context.Context, shard *Client) error { val, err := shard.ScriptExists(ctx, hashes...).Result() if err != nil { diff --git a/cluster_test.go b/cluster_test.go index 92844eb..15c6a94 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -755,6 +755,9 @@ var _ = Describe("ClusterClient", func() { }) It("supports Process hook", func() { + testCtx, cancel := context.WithCancel(ctx) + defer cancel() + err := client.Ping(ctx).Err() Expect(err).NotTo(HaveOccurred()) @@ -766,29 +769,47 @@ var _ = Describe("ClusterClient", func() { var stack []string clusterHook := &hook{ - beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) { - Expect(cmd.String()).To(Equal("ping: ")) - stack = append(stack, "cluster.BeforeProcess") - return ctx, nil - }, - afterProcess: func(ctx context.Context, cmd redis.Cmder) error { - Expect(cmd.String()).To(Equal("ping: PONG")) - stack = append(stack, "cluster.AfterProcess") - return nil + processHook: func(hook redis.ProcessHook) redis.ProcessHook { + return func(ctx context.Context, cmd redis.Cmder) error { + select { + case <-testCtx.Done(): + return hook(ctx, cmd) + default: + } + + Expect(cmd.String()).To(Equal("ping: ")) + stack = append(stack, "cluster.BeforeProcess") + + err := hook(ctx, cmd) + + Expect(cmd.String()).To(Equal("ping: PONG")) + stack = append(stack, "cluster.AfterProcess") + + return err + } }, } client.AddHook(clusterHook) nodeHook := &hook{ - beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) { - Expect(cmd.String()).To(Equal("ping: ")) - stack = append(stack, "shard.BeforeProcess") - return ctx, nil - }, - afterProcess: func(ctx context.Context, cmd redis.Cmder) error { - Expect(cmd.String()).To(Equal("ping: PONG")) - stack = append(stack, "shard.AfterProcess") - return nil + processHook: func(hook redis.ProcessHook) redis.ProcessHook { + return func(ctx context.Context, cmd redis.Cmder) error { + select { + case <-testCtx.Done(): + return hook(ctx, cmd) + default: + } + + Expect(cmd.String()).To(Equal("ping: ")) + stack = append(stack, "shard.BeforeProcess") + + err := hook(ctx, cmd) + + Expect(cmd.String()).To(Equal("ping: PONG")) + stack = append(stack, "shard.AfterProcess") + + return err + } }, } @@ -805,11 +826,6 @@ var _ = Describe("ClusterClient", func() { "shard.AfterProcess", "cluster.AfterProcess", })) - - clusterHook.beforeProcess = nil - clusterHook.afterProcess = nil - nodeHook.beforeProcess = nil - nodeHook.afterProcess = nil }) It("supports Pipeline hook", func() { @@ -824,33 +840,39 @@ var _ = Describe("ClusterClient", func() { var stack []string client.AddHook(&hook{ - beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: ")) - stack = append(stack, "cluster.BeforeProcessPipeline") - return ctx, nil - }, - afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: PONG")) - stack = append(stack, "cluster.AfterProcessPipeline") - return nil + processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { + return func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: ")) + stack = append(stack, "cluster.BeforeProcessPipeline") + + err := hook(ctx, cmds) + + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: PONG")) + stack = append(stack, "cluster.AfterProcessPipeline") + + return err + } }, }) _ = client.ForEachShard(ctx, func(ctx context.Context, node *redis.Client) error { node.AddHook(&hook{ - beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: ")) - stack = append(stack, "shard.BeforeProcessPipeline") - return ctx, nil - }, - afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: PONG")) - stack = append(stack, "shard.AfterProcessPipeline") - return nil + processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { + return func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: ")) + stack = append(stack, "shard.BeforeProcessPipeline") + + err := hook(ctx, cmds) + + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: PONG")) + stack = append(stack, "shard.AfterProcessPipeline") + + return err + } }, }) return nil @@ -881,33 +903,39 @@ var _ = Describe("ClusterClient", func() { var stack []string client.AddHook(&hook{ - beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { - Expect(cmds).To(HaveLen(3)) - Expect(cmds[1].String()).To(Equal("ping: ")) - stack = append(stack, "cluster.BeforeProcessPipeline") - return ctx, nil - }, - afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { - Expect(cmds).To(HaveLen(3)) - Expect(cmds[1].String()).To(Equal("ping: PONG")) - stack = append(stack, "cluster.AfterProcessPipeline") - return nil + processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { + return func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(3)) + Expect(cmds[1].String()).To(Equal("ping: ")) + stack = append(stack, "cluster.BeforeProcessPipeline") + + err := hook(ctx, cmds) + + Expect(cmds).To(HaveLen(3)) + Expect(cmds[1].String()).To(Equal("ping: PONG")) + stack = append(stack, "cluster.AfterProcessPipeline") + + return err + } }, }) _ = client.ForEachShard(ctx, func(ctx context.Context, node *redis.Client) error { node.AddHook(&hook{ - beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { - Expect(cmds).To(HaveLen(3)) - Expect(cmds[1].String()).To(Equal("ping: ")) - stack = append(stack, "shard.BeforeProcessPipeline") - return ctx, nil - }, - afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { - Expect(cmds).To(HaveLen(3)) - Expect(cmds[1].String()).To(Equal("ping: PONG")) - stack = append(stack, "shard.AfterProcessPipeline") - return nil + processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { + return func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(3)) + Expect(cmds[1].String()).To(Equal("ping: ")) + stack = append(stack, "shard.BeforeProcessPipeline") + + err := hook(ctx, cmds) + + Expect(cmds).To(HaveLen(3)) + Expect(cmds[1].String()).To(Equal("ping: PONG")) + stack = append(stack, "shard.AfterProcessPipeline") + + return err + } }, }) return nil diff --git a/example/otel/README.md b/example/otel/README.md index e0451ec..4c68009 100644 --- a/example/otel/README.md +++ b/example/otel/README.md @@ -1,7 +1,7 @@ # Example for go-redis OpenTelemetry instrumentation See -[Redis Monitoring Performance and Errors](https://redis.uptrace.dev/guide/redis-performance-monitoring.html) +[Monitoring Go Redis Performance and Errors](https://redis.uptrace.dev/guide/go-redis-monitoring.html) for details. This example requires Redis Server on port `:6379`. You can start Redis Server using Docker: @@ -18,19 +18,23 @@ You can run this example with different OpenTelemetry exporters by providing env go run . ``` +[Uptrace](https://github.com/uptrace/uptrace) exporter: + +```shell +UPTRACE_DSN="https://@uptrace.dev/" go run . +``` + **Jaeger** exporter: ```shell OTEL_EXPORTER_JAEGER_ENDPOINT=http://localhost:14268/api/traces go run . ``` -**Uptrace** exporter: - -```shell -UPTRACE_DSN="https://@uptrace.dev/" go run . -``` +To instrument Redis Cluster client, see +[go-redis-cluster](https://github.com/uptrace/opentelemetry-go-extra/tree/main/example/go-redis-cluster) +example. ## Links -- [Find instrumentations](https://opentelemetry.uptrace.dev/instrumentations/?lang=go) -- [OpenTelemetry Tracing API](https://opentelemetry.uptrace.dev/guide/go-tracing.html) +- [OpenTelemetry Go instrumentations](https://uptrace.dev/opentelemetry/instrumentations/?lang=go) +- [OpenTelemetry Go Tracing API](https://uptrace.dev/opentelemetry/go-tracing.html) diff --git a/example/otel/go.mod b/example/otel/go.mod index 9d7af37..1499bfb 100644 --- a/example/otel/go.mod +++ b/example/otel/go.mod @@ -12,7 +12,7 @@ require ( github.com/go-redis/redis/extra/redisotel/v9 v9.0.0-beta.2 github.com/go-redis/redis/v9 v9.0.0-beta.2 github.com/uptrace/opentelemetry-go-extra/otelplay v0.1.15 - go.opentelemetry.io/otel v1.8.0 + go.opentelemetry.io/otel v1.10.0 golang.org/x/net v0.0.0-20220728030405-41545e8bf201 // indirect google.golang.org/genproto v0.0.0-20220725144611-272f38e5d71b // indirect ) diff --git a/example/otel/go.sum b/example/otel/go.sum index 1466ca3..e70b848 100644 --- a/example/otel/go.sum +++ b/example/otel/go.sum @@ -218,14 +218,17 @@ github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vv github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= github.com/onsi/ginkgo/v2 v2.1.3/go.mod h1:vw5CSIxN1JObi/U8gcbwft7ZxR2dgaR70JSE3/PpL4c= -github.com/onsi/ginkgo/v2 v2.1.4 h1:GNapqRSid3zijZ9H77KrgVG4/8KqiyRsxcSxe+7ApXY= github.com/onsi/ginkgo/v2 v2.1.4/go.mod h1:um6tUpWM/cxCK3/FK8BXqEiUMUwRgSM4JXG47RKZmLU= +github.com/onsi/ginkgo/v2 v2.1.6 h1:Fx2POJZfKRQcM1pH49qSZiYeu319wji004qX+GDovrU= +github.com/onsi/ginkgo/v2 v2.1.6/go.mod h1:MEH45j8TBi6u9BMogfbp0stKC5cdGjumZj5Y7AG4VIk= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= github.com/onsi/gomega v1.17.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= github.com/onsi/gomega v1.19.0/go.mod h1:LY+I3pBVzYsTBU1AnDwOSxaYi9WoWiqgwooUqq9yPro= -github.com/onsi/gomega v1.20.0 h1:8W0cWlwFkflGPLltQvLRB7ZVD5HuP6ng320w2IS245Q= github.com/onsi/gomega v1.20.0/go.mod h1:DtrZpjmvpn2mPm4YWQa0/ALMDj9v4YxLgojwPeREyVo= +github.com/onsi/gomega v1.20.1/go.mod h1:DtrZpjmvpn2mPm4YWQa0/ALMDj9v4YxLgojwPeREyVo= +github.com/onsi/gomega v1.21.1 h1:OB/euWYIExnPBohllTicTHmGTrMaqJ67nIu80j0/uEM= +github.com/onsi/gomega v1.21.1/go.mod h1:iYAIXgPSaDHak0LCMA+AWBpIKBr8WZicMxnE8luStNc= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= @@ -252,6 +255,7 @@ github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= @@ -261,9 +265,10 @@ go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= go.opentelemetry.io/contrib/instrumentation/runtime v0.33.0 h1:YhjJQGhEoxXXKwH16MQEW37KgdZSACk+HyHnFtv/NcI= go.opentelemetry.io/contrib/instrumentation/runtime v0.33.0/go.mod h1:cu2qiP1YaeuJtMVDUuMJZPfCBqJr4GB9kkwXbg5knMA= -go.opentelemetry.io/otel v1.4.1/go.mod h1:StM6F/0fSwpd8dKWDCdRr7uRvEPYdW0hBSlbdTiUde4= -go.opentelemetry.io/otel v1.8.0 h1:zcvBFizPbpa1q7FehvFiHbQwGzmPILebO0tyqIR5Djg= go.opentelemetry.io/otel v1.8.0/go.mod h1:2pkj+iMj0o03Y+cW6/m8Y4WkRdYN3AvCXCnzRMp9yvM= +go.opentelemetry.io/otel v1.9.0/go.mod h1:np4EoPGzoPs3O67xUVNoPPcmSvsfOxNlNA4F4AC+0Eo= +go.opentelemetry.io/otel v1.10.0 h1:Y7DTJMR6zs1xkS/upamJYk0SxxN4C9AqRd77jmZnyY4= +go.opentelemetry.io/otel v1.10.0/go.mod h1:NbvWjCthWHKBEUMpf0/v8ZRZlni86PpGFEMA9pnQSnQ= go.opentelemetry.io/otel/exporters/jaeger v1.8.0 h1:TLLqD6kDhLPziEC7pgPrMvP9lAqdk3n1gf8DiFSnfW8= go.opentelemetry.io/otel/exporters/jaeger v1.8.0/go.mod h1:GbWg+ng88rDtx+id26C34QLqw2erqJeAjsCx9AFeHfE= go.opentelemetry.io/otel/exporters/otlp/internal/retry v1.8.0 h1:ao8CJIShCaIbaMsGxy+jp2YHSudketpDgDRcbirov78= @@ -278,16 +283,18 @@ go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.8.0 h1:00hCSG go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.8.0/go.mod h1:twhIvtDQW2sWP1O2cT1N8nkSBgKCRZv2z6COTTBrf8Q= go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.8.0 h1:FVy7BZCjoA2Nk+fHqIdoTmm554J9wTX+YcrDp+mc368= go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.8.0/go.mod h1:ztncjvKpotSUQq7rlgPibGt8kZfSI3/jI8EO7JjuY2c= -go.opentelemetry.io/otel/metric v0.31.0 h1:6SiklT+gfWAwWUR0meEMxQBtihpiEs4c+vL9spDTqUs= go.opentelemetry.io/otel/metric v0.31.0/go.mod h1:ohmwj9KTSIeBnDBm/ZwH2PSZxZzoOaG2xZeekTRzL5A= -go.opentelemetry.io/otel/sdk v1.4.1/go.mod h1:NBwHDgDIBYjwK2WNu1OPgsIc2IJzmBXNnvIJxJc8BpE= -go.opentelemetry.io/otel/sdk v1.8.0 h1:xwu69/fNuwbSHWe/0PGS888RmjWY181OmcXDQKu7ZQk= +go.opentelemetry.io/otel/metric v0.32.1 h1:ftff5LSBCIDwL0UkhBuDg8j9NNxx2IusvJ18q9h6RC4= +go.opentelemetry.io/otel/metric v0.32.1/go.mod h1:iLPP7FaKMAD5BIxJ2VX7f2KTuz//0QK2hEUyti5psqQ= go.opentelemetry.io/otel/sdk v1.8.0/go.mod h1:uPSfc+yfDH2StDM/Rm35WE8gXSNdvCg023J6HeGNO0c= +go.opentelemetry.io/otel/sdk v1.9.0 h1:LNXp1vrr83fNXTHgU8eO89mhzxb/bbWAsHG6fNf3qWo= +go.opentelemetry.io/otel/sdk v1.9.0/go.mod h1:AEZc8nt5bd2F7BC24J5R0mrjYnpEgYHyTcM/vrSple4= go.opentelemetry.io/otel/sdk/metric v0.31.0 h1:2sZx4R43ZMhJdteKAlKoHvRgrMp53V1aRxvEf5lCq8Q= go.opentelemetry.io/otel/sdk/metric v0.31.0/go.mod h1:fl0SmNnX9mN9xgU6OLYLMBMrNAsaZQi7qBwprwO3abk= -go.opentelemetry.io/otel/trace v1.4.1/go.mod h1:iYEVbroFCNut9QkwEczV9vMRPHNKSSwYZjulEtsmhFc= -go.opentelemetry.io/otel/trace v1.8.0 h1:cSy0DF9eGI5WIfNwZ1q2iUyGj00tGzP24dE1lOlHrfY= go.opentelemetry.io/otel/trace v1.8.0/go.mod h1:0Bt3PXY8w+3pheS3hQUt+wow8b1ojPaTBoTCh2zIFI4= +go.opentelemetry.io/otel/trace v1.9.0/go.mod h1:2737Q0MuG8q1uILYm2YYVkAyLtOofiTNGg6VODnOiPo= +go.opentelemetry.io/otel/trace v1.10.0 h1:npQMbR8o7mum8uF95yFbOEJffhs1sbCOfDh8zAJiH5E= +go.opentelemetry.io/otel/trace v1.10.0/go.mod h1:Sij3YYczqAdz+EhmGhE6TpTxUO5/F/AzrK+kxfGqySM= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= go.opentelemetry.io/proto/otlp v0.18.0 h1:W5hyXNComRa23tGpKwG+FRAc4rfF6ZUg1JReK+QHS80= go.opentelemetry.io/proto/otlp v0.18.0/go.mod h1:H7XAot3MsfNsj7EXtrA2q5xSNQ10UqI405h3+duxN4U= @@ -335,8 +342,9 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 h1:kQgndtyPBW/JIYERgdxfwMYh3AVStj88WQTlNDi2a+o= golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -419,6 +427,7 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -556,15 +565,15 @@ golang.org/x/tools v0.1.2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.3/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/tools v0.1.10 h1:QjFRCZxdOhBJ/UNgnBZLbNV13DlbnK0quyivTnXJM20= golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= +golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= -golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f h1:uF6paiQQebLeSXkrTqHqz0MXhXXS1KgF41eUdBNvxK0= golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= diff --git a/example/otel/main.go b/example/otel/main.go index 75033eb..15de7aa 100644 --- a/example/otel/main.go +++ b/example/otel/main.go @@ -8,7 +8,6 @@ import ( "github.com/uptrace/opentelemetry-go-extra/otelplay" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/codes" - semconv "go.opentelemetry.io/otel/semconv/v1.7.0" "github.com/go-redis/redis/extra/redisotel/v9" "github.com/go-redis/redis/v9" @@ -25,7 +24,9 @@ func main() { rdb := redis.NewClient(&redis.Options{ Addr: ":6379", }) - rdb.AddHook(redisotel.NewTracingHook(redisotel.WithAttributes(semconv.NetPeerNameKey.String("localhost"), semconv.NetPeerPortKey.String("6379")))) + if err := redisotel.InstrumentTracing(rdb); err != nil { + panic(err) + } ctx, span := tracer.Start(ctx, "handleRequest") defer span.End() diff --git a/example_instrumentation_test.go b/example_instrumentation_test.go index 639f9f0..9e8179a 100644 --- a/example_instrumentation_test.go +++ b/example_instrumentation_test.go @@ -3,6 +3,7 @@ package redis_test import ( "context" "fmt" + "net" "github.com/go-redis/redis/v9" ) @@ -11,24 +12,28 @@ type redisHook struct{} var _ redis.Hook = redisHook{} -func (redisHook) BeforeProcess(ctx context.Context, cmd redis.Cmder) (context.Context, error) { - fmt.Printf("starting processing: <%s>\n", cmd) - return ctx, nil +func (redisHook) DialHook(hook redis.DialHook) redis.DialHook { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + return hook(ctx, network, addr) + } } -func (redisHook) AfterProcess(ctx context.Context, cmd redis.Cmder) error { - fmt.Printf("finished processing: <%s>\n", cmd) - return nil +func (redisHook) ProcessHook(hook redis.ProcessHook) redis.ProcessHook { + return func(ctx context.Context, cmd redis.Cmder) error { + fmt.Printf("starting processing: <%s>\n", cmd) + err := hook(ctx, cmd) + fmt.Printf("finished processing: <%s>\n", cmd) + return err + } } -func (redisHook) BeforeProcessPipeline(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { - fmt.Printf("pipeline starting processing: %v\n", cmds) - return ctx, nil -} - -func (redisHook) AfterProcessPipeline(ctx context.Context, cmds []redis.Cmder) error { - fmt.Printf("pipeline finished processing: %v\n", cmds) - return nil +func (redisHook) ProcessPipelineHook(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { + return func(ctx context.Context, cmds []redis.Cmder) error { + fmt.Printf("pipeline starting processing: %v\n", cmds) + err := hook(ctx, cmds) + fmt.Printf("pipeline finished processing: %v\n", cmds) + return err + } } func Example_instrumentation() { diff --git a/example_test.go b/example_test.go index f633397..1974c55 100644 --- a/example_test.go +++ b/example_test.go @@ -433,7 +433,7 @@ func ExampleClient_TxPipeline() { } func ExampleClient_Watch() { - const maxRetries = 1000 + const maxRetries = 10000 // Increment transactionally increments key using GET and SET commands. increment := func(key string) error { diff --git a/extra/redisotel/config.go b/extra/redisotel/config.go new file mode 100644 index 0000000..7c29e0c --- /dev/null +++ b/extra/redisotel/config.go @@ -0,0 +1,128 @@ +package redisotel + +import ( + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/metric/global" + semconv "go.opentelemetry.io/otel/semconv/v1.4.0" + "go.opentelemetry.io/otel/trace" +) + +type config struct { + // Common options. + + attrs []attribute.KeyValue + + // Tracing options. + + tp trace.TracerProvider + tracer trace.Tracer + + dbStmtEnabled bool + + // Metrics options. + + mp metric.MeterProvider + meter metric.Meter + + poolName string +} + +type baseOption interface { + apply(conf *config) +} + +type Option interface { + baseOption + tracing() + metrics() +} + +type option func(conf *config) + +func (fn option) apply(conf *config) { + fn(conf) +} + +func (fn option) tracing() {} + +func (fn option) metrics() {} + +func newConfig(opts ...baseOption) *config { + conf := &config{ + tp: otel.GetTracerProvider(), + mp: global.MeterProvider(), + attrs: []attribute.KeyValue{ + semconv.DBSystemRedis, + }, + dbStmtEnabled: true, + } + for _, opt := range opts { + opt.apply(conf) + } + return conf +} + +// WithAttributes specifies additional attributes to be added to the span. +func WithAttributes(attrs ...attribute.KeyValue) Option { + return option(func(conf *config) { + conf.attrs = append(conf.attrs, attrs...) + }) +} + +//------------------------------------------------------------------------------ + +type TracingOption interface { + baseOption + tracing() +} + +type tracingOption func(conf *config) + +var _ TracingOption = (*tracingOption)(nil) + +func (fn tracingOption) apply(conf *config) { + fn(conf) +} + +func (fn tracingOption) tracing() {} + +// WithTracerProvider specifies a tracer provider to use for creating a tracer. +// If none is specified, the global provider is used. +func WithTracerProvider(provider trace.TracerProvider) TracingOption { + return tracingOption(func(conf *config) { + conf.tp = provider + }) +} + +// WithDBStatement tells the tracing hook not to log raw redis commands. +func WithDBStatement(on bool) TracingOption { + return tracingOption(func(conf *config) { + conf.dbStmtEnabled = on + }) +} + +//------------------------------------------------------------------------------ + +type MetricsOption interface { + baseOption + metrics() +} + +type metricsOption func(conf *config) + +var _ MetricsOption = (*metricsOption)(nil) + +func (fn metricsOption) apply(conf *config) { + fn(conf) +} + +func (fn metricsOption) metrics() {} + +// WithMeterProvider configures a metric.Meter used to create instruments. +func WithMeterProvider(mp metric.MeterProvider) MetricsOption { + return metricsOption(func(conf *config) { + conf.mp = mp + }) +} diff --git a/extra/redisotel/go.mod b/extra/redisotel/go.mod index c3cbdc9..78caa7c 100644 --- a/extra/redisotel/go.mod +++ b/extra/redisotel/go.mod @@ -9,7 +9,8 @@ replace github.com/go-redis/redis/extra/rediscmd/v9 => ../rediscmd require ( github.com/go-redis/redis/extra/rediscmd/v9 v9.0.0-beta.2 github.com/go-redis/redis/v9 v9.0.0-beta.2 - go.opentelemetry.io/otel v1.9.0 + go.opentelemetry.io/otel v1.10.0 + go.opentelemetry.io/otel/metric v0.32.1 go.opentelemetry.io/otel/sdk v1.9.0 - go.opentelemetry.io/otel/trace v1.9.0 + go.opentelemetry.io/otel/trace v1.10.0 ) diff --git a/extra/redisotel/go.sum b/extra/redisotel/go.sum index 6e7704c..68a1f25 100644 --- a/extra/redisotel/go.sum +++ b/extra/redisotel/go.sum @@ -67,12 +67,16 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.opentelemetry.io/otel v1.9.0 h1:8WZNQFIB2a71LnANS9JeyidJKKGOOremcUtb/OtHISw= go.opentelemetry.io/otel v1.9.0/go.mod h1:np4EoPGzoPs3O67xUVNoPPcmSvsfOxNlNA4F4AC+0Eo= +go.opentelemetry.io/otel v1.10.0 h1:Y7DTJMR6zs1xkS/upamJYk0SxxN4C9AqRd77jmZnyY4= +go.opentelemetry.io/otel v1.10.0/go.mod h1:NbvWjCthWHKBEUMpf0/v8ZRZlni86PpGFEMA9pnQSnQ= +go.opentelemetry.io/otel/metric v0.32.1 h1:ftff5LSBCIDwL0UkhBuDg8j9NNxx2IusvJ18q9h6RC4= +go.opentelemetry.io/otel/metric v0.32.1/go.mod h1:iLPP7FaKMAD5BIxJ2VX7f2KTuz//0QK2hEUyti5psqQ= go.opentelemetry.io/otel/sdk v1.9.0 h1:LNXp1vrr83fNXTHgU8eO89mhzxb/bbWAsHG6fNf3qWo= go.opentelemetry.io/otel/sdk v1.9.0/go.mod h1:AEZc8nt5bd2F7BC24J5R0mrjYnpEgYHyTcM/vrSple4= -go.opentelemetry.io/otel/trace v1.9.0 h1:oZaCNJUjWcg60VXWee8lJKlqhPbXAPB51URuR47pQYc= go.opentelemetry.io/otel/trace v1.9.0/go.mod h1:2737Q0MuG8q1uILYm2YYVkAyLtOofiTNGg6VODnOiPo= +go.opentelemetry.io/otel/trace v1.10.0 h1:npQMbR8o7mum8uF95yFbOEJffhs1sbCOfDh8zAJiH5E= +go.opentelemetry.io/otel/trace v1.10.0/go.mod h1:Sij3YYczqAdz+EhmGhE6TpTxUO5/F/AzrK+kxfGqySM= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= diff --git a/extra/redisotel/metrics.go b/extra/redisotel/metrics.go new file mode 100644 index 0000000..6d41d32 --- /dev/null +++ b/extra/redisotel/metrics.go @@ -0,0 +1,194 @@ +package redisotel + +import ( + "context" + "fmt" + "net" + "time" + + "github.com/go-redis/redis/v9" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/metric/instrument" + "go.opentelemetry.io/otel/metric/instrument/syncint64" +) + +// InstrumentMetrics starts reporting OpenTelemetry Metrics. +// +// Based on https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/metrics/semantic_conventions/database-metrics.md +func InstrumentMetrics(rdb redis.UniversalClient, opts ...MetricsOption) error { + baseOpts := make([]baseOption, len(opts)) + for i, opt := range opts { + baseOpts[i] = opt + } + conf := newConfig(baseOpts...) + + if conf.meter == nil { + conf.meter = conf.mp.Meter( + instrumName, + metric.WithInstrumentationVersion("semver:"+redis.Version()), + ) + } + + switch rdb := rdb.(type) { + case *redis.Client: + if conf.poolName == "" { + opt := rdb.Options() + conf.poolName = opt.Addr + } + conf.attrs = append(conf.attrs, attribute.String("pool.name", conf.poolName)) + + if err := reportPoolStats(rdb, conf); err != nil { + return err + } + if err := addMetricsHook(rdb, conf); err != nil { + return err + } + return nil + case *redis.ClusterClient: + rdb.OnNewNode(func(rdb *redis.Client) { + if conf.poolName == "" { + opt := rdb.Options() + conf.poolName = opt.Addr + } + conf.attrs = append(conf.attrs, attribute.String("pool.name", conf.poolName)) + + if err := reportPoolStats(rdb, conf); err != nil { + otel.Handle(err) + } + if err := addMetricsHook(rdb, conf); err != nil { + otel.Handle(err) + } + }) + return nil + case *redis.Ring: + rdb.OnNewNode(func(rdb *redis.Client) { + if err := reportPoolStats(rdb, conf); err != nil { + otel.Handle(err) + } + if err := addMetricsHook(rdb, conf); err != nil { + otel.Handle(err) + } + }) + return nil + default: + return fmt.Errorf("redisotel: %T not supported", rdb) + } +} + +func reportPoolStats(rdb *redis.Client, conf *config) error { + labels := conf.attrs + idleAttrs := append(labels, attribute.String("state", "idle")) + usedAttrs := append(labels, attribute.String("state", "used")) + + usage, err := conf.meter.AsyncInt64().UpDownCounter( + "db.client.connections.usage", + instrument.WithDescription("The number of connections that are currently in state described by the state attribute"), + ) + if err != nil { + return err + } + + timeouts, err := conf.meter.AsyncInt64().UpDownCounter( + "db.client.connections.timeouts", + instrument.WithDescription("The number of connection timeouts that have occurred trying to obtain a connection from the pool"), + ) + if err != nil { + return err + } + + return conf.meter.RegisterCallback( + []instrument.Asynchronous{ + usage, + timeouts, + }, + func(ctx context.Context) { + stats := rdb.PoolStats() + + usage.Observe(ctx, int64(stats.IdleConns), idleAttrs...) + usage.Observe(ctx, int64(stats.TotalConns-stats.IdleConns), usedAttrs...) + + timeouts.Observe(ctx, int64(stats.Timeouts), labels...) + }, + ) +} + +func addMetricsHook(rdb *redis.Client, conf *config) error { + createTime, err := conf.meter.SyncInt64().Histogram( + "db.client.connections.create_time", + instrument.WithDescription("The time it took to create a new connection."), + instrument.WithUnit("ms"), + ) + if err != nil { + return err + } + + useTime, err := conf.meter.SyncInt64().Histogram( + "db.client.connections.use_time", + instrument.WithDescription("The time between borrowing a connection and returning it to the pool."), + instrument.WithUnit("ms"), + ) + if err != nil { + return err + } + + rdb.AddHook(&metricsHook{ + createTime: createTime, + useTime: useTime, + }) + return nil +} + +type metricsHook struct { + createTime syncint64.Histogram + useTime syncint64.Histogram +} + +var _ redis.Hook = (*metricsHook)(nil) + +func (mh *metricsHook) DialHook(hook redis.DialHook) redis.DialHook { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + start := time.Now() + + conn, err := hook(ctx, network, addr) + + mh.createTime.Record(ctx, time.Since(start).Milliseconds()) + return conn, err + } +} + +func (mh *metricsHook) ProcessHook(hook redis.ProcessHook) redis.ProcessHook { + return func(ctx context.Context, cmd redis.Cmder) error { + start := time.Now() + + err := hook(ctx, cmd) + + dur := time.Since(start).Milliseconds() + mh.useTime.Record(ctx, dur, attribute.String("type", "command"), statusAttr(err)) + + return err + } +} + +func (mh *metricsHook) ProcessPipelineHook( + hook redis.ProcessPipelineHook, +) redis.ProcessPipelineHook { + return func(ctx context.Context, cmds []redis.Cmder) error { + start := time.Now() + + err := hook(ctx, cmds) + + dur := time.Since(start).Milliseconds() + mh.useTime.Record(ctx, dur, attribute.String("type", "pipeline"), statusAttr(err)) + + return err + } +} + +func statusAttr(err error) attribute.KeyValue { + if err != nil { + return attribute.String("status", "error") + } + return attribute.String("status", "ok") +} diff --git a/extra/redisotel/redisotel.go b/extra/redisotel/redisotel.go deleted file mode 100644 index 26ab654..0000000 --- a/extra/redisotel/redisotel.go +++ /dev/null @@ -1,153 +0,0 @@ -package redisotel - -import ( - "context" - - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/codes" - semconv "go.opentelemetry.io/otel/semconv/v1.10.0" - "go.opentelemetry.io/otel/trace" - - "github.com/go-redis/redis/extra/rediscmd/v9" - "github.com/go-redis/redis/v9" -) - -const ( - defaultTracerName = "github.com/go-redis/redis/extra/redisotel" -) - -type TracingHook struct { - tracer trace.Tracer - attrs []attribute.KeyValue - dbStmtEnabled bool -} - -func NewTracingHook(opts ...Option) *TracingHook { - cfg := &config{ - tp: otel.GetTracerProvider(), - attrs: []attribute.KeyValue{ - semconv.DBSystemRedis, - }, - dbStmtEnabled: true, - } - for _, opt := range opts { - opt.apply(cfg) - } - - tracer := cfg.tp.Tracer( - defaultTracerName, - trace.WithInstrumentationVersion("semver:"+redis.Version()), - ) - return &TracingHook{tracer: tracer, attrs: cfg.attrs, dbStmtEnabled: cfg.dbStmtEnabled} -} - -func (th *TracingHook) BeforeProcess(ctx context.Context, cmd redis.Cmder) (context.Context, error) { - if !trace.SpanFromContext(ctx).IsRecording() { - return ctx, nil - } - - opts := []trace.SpanStartOption{ - trace.WithSpanKind(trace.SpanKindClient), - trace.WithAttributes(th.attrs...), - } - - if th.dbStmtEnabled { - opts = append(opts, trace.WithAttributes(semconv.DBStatementKey.String(rediscmd.CmdString(cmd)))) - } - - ctx, _ = th.tracer.Start(ctx, cmd.FullName(), opts...) - - return ctx, nil -} - -func (th *TracingHook) AfterProcess(ctx context.Context, cmd redis.Cmder) error { - span := trace.SpanFromContext(ctx) - if err := cmd.Err(); err != nil { - recordError(ctx, span, err) - } - span.End() - return nil -} - -func (th *TracingHook) BeforeProcessPipeline( - ctx context.Context, cmds []redis.Cmder, -) (context.Context, error) { - if !trace.SpanFromContext(ctx).IsRecording() { - return ctx, nil - } - - opts := []trace.SpanStartOption{ - trace.WithSpanKind(trace.SpanKindClient), - trace.WithAttributes(th.attrs...), - trace.WithAttributes( - attribute.Int("db.redis.num_cmd", len(cmds)), - ), - } - - summary, cmdsString := rediscmd.CmdsString(cmds) - if th.dbStmtEnabled { - opts = append(opts, trace.WithAttributes(semconv.DBStatementKey.String(cmdsString))) - } - - ctx, _ = th.tracer.Start(ctx, "pipeline "+summary, opts...) - - return ctx, nil -} - -func (th *TracingHook) AfterProcessPipeline(ctx context.Context, cmds []redis.Cmder) error { - span := trace.SpanFromContext(ctx) - if err := cmds[0].Err(); err != nil { - recordError(ctx, span, err) - } - span.End() - return nil -} - -func recordError(ctx context.Context, span trace.Span, err error) { - if err != redis.Nil { - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) - } -} - -type config struct { - tp trace.TracerProvider - attrs []attribute.KeyValue - dbStmtEnabled bool -} - -// Option specifies instrumentation configuration options. -type Option interface { - apply(*config) -} - -type optionFunc func(*config) - -func (o optionFunc) apply(c *config) { - o(c) -} - -// WithTracerProvider specifies a tracer provider to use for creating a tracer. -// If none is specified, the global provider is used. -func WithTracerProvider(provider trace.TracerProvider) Option { - return optionFunc(func(cfg *config) { - if provider != nil { - cfg.tp = provider - } - }) -} - -// WithAttributes specifies additional attributes to be added to the span. -func WithAttributes(attrs ...attribute.KeyValue) Option { - return optionFunc(func(cfg *config) { - cfg.attrs = append(cfg.attrs, attrs...) - }) -} - -// WithDBStatement tells the tracing hook not to log raw redis commands. -func WithDBStatement(on bool) Option { - return optionFunc(func(cfg *config) { - cfg.dbStmtEnabled = on - }) -} diff --git a/extra/redisotel/redisotel_test.go b/extra/redisotel/redisotel_test.go index 8c0e487..de6e9bf 100644 --- a/extra/redisotel/redisotel_test.go +++ b/extra/redisotel/redisotel_test.go @@ -1,4 +1,4 @@ -package redisotel_test +package redisotel import ( "context" @@ -10,16 +10,9 @@ import ( sdktrace "go.opentelemetry.io/otel/sdk/trace" "go.opentelemetry.io/otel/trace" - "github.com/go-redis/redis/extra/redisotel/v9" "github.com/go-redis/redis/v9" ) -func TestNew(t *testing.T) { - // this also functions as a compile-time test that the - // TracingHook conforms to the Hook interface - var _ redis.Hook = redisotel.NewTracingHook() -} - type providerFunc func(name string, opts ...trace.TracerOption) trace.Tracer func (fn providerFunc) Tracer(name string, opts ...trace.TracerOption) trace.Tracer { @@ -34,7 +27,7 @@ func TestNewWithTracerProvider(t *testing.T) { return otel.GetTracerProvider().Tracer(name, opts...) }) - _ = redisotel.NewTracingHook(redisotel.WithTracerProvider(tp)) + _ = newTracingHook("", WithTracerProvider(tp)) if !invoked { t.Fatalf("did not call custom TraceProvider") @@ -43,55 +36,52 @@ func TestNewWithTracerProvider(t *testing.T) { func TestNewWithAttributes(t *testing.T) { provider := sdktrace.NewTracerProvider() - hook := redisotel.NewTracingHook(redisotel.WithTracerProvider(provider), redisotel.WithAttributes(semconv.NetPeerNameKey.String("localhost"))) + hook := newTracingHook("", WithTracerProvider(provider), WithAttributes(semconv.NetPeerNameKey.String("localhost"))) ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test") cmd := redis.NewCmd(ctx, "ping") defer span.End() - ctx, err := hook.BeforeProcess(ctx, cmd) + processHook := hook.ProcessHook(func(ctx context.Context, cmd redis.Cmder) error { + attrs := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan).Attributes() + if !(attrs[0] == semconv.DBSystemRedis) { + t.Fatalf("expected attrs[0] to be semconv.DBSystemRedis, got: %v", attrs[0]) + } + if !(attrs[1] == semconv.NetPeerNameKey.String("localhost")) { + t.Fatalf("expected attrs[1] to be semconv.NetPeerNameKey.String(\"localhost\"), got: %v", attrs[1]) + } + if !(attrs[2] == semconv.DBStatementKey.String("ping")) { + t.Fatalf("expected attrs[2] to be semconv.DBStatementKey.String(\"ping\"), got: %v", attrs[2]) + } + return nil + }) + err := processHook(ctx, cmd) if err != nil { t.Fatal(err) } - err = hook.AfterProcess(ctx, cmd) - if err != nil { - t.Fatal(err) - } - - attrs := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan).Attributes() - if !(attrs[0] == semconv.DBSystemRedis) { - t.Fatalf("expected attrs[0] to be semconv.DBSystemRedis, got: %v", attrs[0]) - } - if !(attrs[1] == semconv.NetPeerNameKey.String("localhost")) { - t.Fatalf("expected attrs[1] to be semconv.NetPeerNameKey.String(\"localhost\"), got: %v", attrs[1]) - } - if !(attrs[2] == semconv.DBStatementKey.String("ping")) { - t.Fatalf("expected attrs[2] to be semconv.DBStatementKey.String(\"ping\"), got: %v", attrs[2]) - } } func TestWithDBStatement(t *testing.T) { provider := sdktrace.NewTracerProvider() - hook := redisotel.NewTracingHook( - redisotel.WithTracerProvider(provider), - redisotel.WithDBStatement(false), + hook := newTracingHook( + "", + WithTracerProvider(provider), + WithDBStatement(false), ) ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test") cmd := redis.NewCmd(ctx, "ping") defer span.End() - ctx, err := hook.BeforeProcess(ctx, cmd) - if err != nil { - t.Fatal(err) - } - err = hook.AfterProcess(ctx, cmd) - if err != nil { - t.Fatal(err) - } - - attrs := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan).Attributes() - for _, attr := range attrs { - if attr.Key == semconv.DBStatementKey { - t.Fatal("Attribute with db statement should not exist") + processHook := hook.ProcessHook(func(ctx context.Context, cmd redis.Cmder) error { + attrs := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan).Attributes() + for _, attr := range attrs { + if attr.Key == semconv.DBStatementKey { + t.Fatal("Attribute with db statement should not exist") + } } + return nil + }) + err := processHook(ctx, cmd) + if err != nil { + t.Fatal(err) } } diff --git a/extra/redisotel/tracing.go b/extra/redisotel/tracing.go new file mode 100644 index 0000000..db74c72 --- /dev/null +++ b/extra/redisotel/tracing.go @@ -0,0 +1,175 @@ +package redisotel + +import ( + "context" + "fmt" + "net" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + semconv "go.opentelemetry.io/otel/semconv/v1.10.0" + "go.opentelemetry.io/otel/trace" + + "github.com/go-redis/redis/extra/rediscmd/v9" + "github.com/go-redis/redis/v9" +) + +const ( + instrumName = "github.com/go-redis/redis/extra/redisotel" +) + +func InstrumentTracing(rdb redis.UniversalClient, opts ...TracingOption) error { + switch rdb := rdb.(type) { + case *redis.Client: + opt := rdb.Options() + connString := formatDBConnString(opt.Network, opt.Addr) + rdb.AddHook(newTracingHook(connString, opts...)) + return nil + case *redis.ClusterClient: + rdb.AddHook(newTracingHook("", opts...)) + + rdb.OnNewNode(func(rdb *redis.Client) { + opt := rdb.Options() + connString := formatDBConnString(opt.Network, opt.Addr) + rdb.AddHook(newTracingHook(connString, opts...)) + }) + return nil + case *redis.Ring: + rdb.AddHook(newTracingHook("", opts...)) + + rdb.OnNewNode(func(rdb *redis.Client) { + opt := rdb.Options() + connString := formatDBConnString(opt.Network, opt.Addr) + rdb.AddHook(newTracingHook(connString, opts...)) + }) + return nil + default: + return fmt.Errorf("redisotel: %T not supported", rdb) + } +} + +type tracingHook struct { + conf *config + + spanOpts []trace.SpanStartOption +} + +var _ redis.Hook = (*tracingHook)(nil) + +func newTracingHook(connString string, opts ...TracingOption) *tracingHook { + baseOpts := make([]baseOption, len(opts)) + for i, opt := range opts { + baseOpts[i] = opt + } + conf := newConfig(baseOpts...) + + if conf.tracer == nil { + conf.tracer = conf.tp.Tracer( + instrumName, + trace.WithInstrumentationVersion("semver:"+redis.Version()), + ) + } + if connString != "" { + conf.attrs = append(conf.attrs, semconv.DBConnectionStringKey.String(connString)) + } + + return &tracingHook{ + conf: conf, + + spanOpts: []trace.SpanStartOption{ + trace.WithSpanKind(trace.SpanKindClient), + trace.WithAttributes(conf.attrs...), + }, + } +} + +func (th *tracingHook) DialHook(hook redis.DialHook) redis.DialHook { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + if !trace.SpanFromContext(ctx).IsRecording() { + return hook(ctx, network, addr) + } + + spanOpts := th.spanOpts + spanOpts = append(spanOpts, trace.WithAttributes( + attribute.String("network", network), + attribute.String("addr", addr), + )) + + ctx, span := th.conf.tracer.Start(ctx, "redis.dial", spanOpts...) + defer span.End() + + conn, err := hook(ctx, network, addr) + if err != nil { + recordError(ctx, span, err) + return nil, err + } + return conn, nil + } +} + +func (th *tracingHook) ProcessHook(hook redis.ProcessHook) redis.ProcessHook { + return func(ctx context.Context, cmd redis.Cmder) error { + if !trace.SpanFromContext(ctx).IsRecording() { + return hook(ctx, cmd) + } + + opts := th.spanOpts + if th.conf.dbStmtEnabled { + opts = append(opts, trace.WithAttributes( + semconv.DBStatementKey.String(rediscmd.CmdString(cmd))), + ) + } + + ctx, span := th.conf.tracer.Start(ctx, cmd.FullName(), opts...) + defer span.End() + + if err := hook(ctx, cmd); err != nil { + recordError(ctx, span, err) + return err + } + return nil + } +} + +func (th *tracingHook) ProcessPipelineHook( + hook redis.ProcessPipelineHook, +) redis.ProcessPipelineHook { + return func(ctx context.Context, cmds []redis.Cmder) error { + if !trace.SpanFromContext(ctx).IsRecording() { + return hook(ctx, cmds) + } + + opts := th.spanOpts + opts = append(opts, trace.WithAttributes( + attribute.Int("db.redis.num_cmd", len(cmds)), + )) + + summary, cmdsString := rediscmd.CmdsString(cmds) + if th.conf.dbStmtEnabled { + opts = append(opts, trace.WithAttributes(semconv.DBStatementKey.String(cmdsString))) + } + + ctx, span := th.conf.tracer.Start(ctx, "redis.pipeline "+summary, opts...) + defer span.End() + + if err := hook(ctx, cmds); err != nil { + recordError(ctx, span, err) + return err + } + return nil + } +} + +func recordError(ctx context.Context, span trace.Span, err error) { + if err != redis.Nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + } +} + +func formatDBConnString(network, addr string) string { + if network == "tcp" { + network = "redis" + } + return fmt.Sprintf("%s://%s", network, addr) +} diff --git a/main_test.go b/main_test.go index e35b17c..ec2439f 100644 --- a/main_test.go +++ b/main_test.go @@ -1,7 +1,6 @@ package redis_test import ( - "context" "fmt" "net" "os" @@ -416,37 +415,28 @@ func (cn *badConn) Write([]byte) (int, error) { //------------------------------------------------------------------------------ type hook struct { - beforeProcess func(ctx context.Context, cmd redis.Cmder) (context.Context, error) - afterProcess func(ctx context.Context, cmd redis.Cmder) error - - beforeProcessPipeline func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) - afterProcessPipeline func(ctx context.Context, cmds []redis.Cmder) error + dialHook func(hook redis.DialHook) redis.DialHook + processHook func(hook redis.ProcessHook) redis.ProcessHook + processPipelineHook func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook } -func (h *hook) BeforeProcess(ctx context.Context, cmd redis.Cmder) (context.Context, error) { - if h.beforeProcess != nil { - return h.beforeProcess(ctx, cmd) +func (h *hook) DialHook(hook redis.DialHook) redis.DialHook { + if h.dialHook != nil { + return h.dialHook(hook) } - return ctx, nil + return hook } -func (h *hook) AfterProcess(ctx context.Context, cmd redis.Cmder) error { - if h.afterProcess != nil { - return h.afterProcess(ctx, cmd) +func (h *hook) ProcessHook(hook redis.ProcessHook) redis.ProcessHook { + if h.processHook != nil { + return h.processHook(hook) } - return nil + return hook } -func (h *hook) BeforeProcessPipeline(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { - if h.beforeProcessPipeline != nil { - return h.beforeProcessPipeline(ctx, cmds) +func (h *hook) ProcessPipelineHook(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { + if h.processPipelineHook != nil { + return h.processPipelineHook(hook) } - return ctx, nil -} - -func (h *hook) AfterProcessPipeline(ctx context.Context, cmds []redis.Cmder) error { - if h.afterProcessPipeline != nil { - return h.afterProcessPipeline(ctx, cmds) - } - return nil + return hook } diff --git a/options.go b/options.go index 6c2f853..56eab53 100644 --- a/options.go +++ b/options.go @@ -470,10 +470,13 @@ func getUserPassword(u *url.URL) (string, string) { return user, password } -func newConnPool(opt *Options) *pool.ConnPool { +func newConnPool( + opt *Options, + dialer func(ctx context.Context, network, addr string) (net.Conn, error), +) *pool.ConnPool { return pool.NewConnPool(&pool.Options{ Dialer: func(ctx context.Context) (net.Conn, error) { - return opt.Dialer(ctx, opt.Network, opt.Addr) + return dialer(ctx, opt.Network, opt.Addr) }, PoolFIFO: opt.PoolFIFO, PoolSize: opt.PoolSize, diff --git a/redis.go b/redis.go index a800f45..2f3585e 100644 --- a/redis.go +++ b/redis.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "net" "strings" "sync/atomic" "time" @@ -24,102 +25,104 @@ func SetLogger(logger internal.Logging) { //------------------------------------------------------------------------------ type Hook interface { - BeforeProcess(ctx context.Context, cmd Cmder) (context.Context, error) - AfterProcess(ctx context.Context, cmd Cmder) error - - BeforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error) - AfterProcessPipeline(ctx context.Context, cmds []Cmder) error + DialHook(hook DialHook) DialHook + ProcessHook(hook ProcessHook) ProcessHook + ProcessPipelineHook(hook ProcessPipelineHook) ProcessPipelineHook } +type ( + DialHook func(ctx context.Context, network, addr string) (net.Conn, error) + ProcessHook func(ctx context.Context, cmd Cmder) error + ProcessPipelineHook func(ctx context.Context, cmds []Cmder) error +) + type hooks struct { - hooks []Hook -} - -func (hs *hooks) lock() { - hs.hooks = hs.hooks[:len(hs.hooks):len(hs.hooks)] -} - -func (hs hooks) clone() hooks { - clone := hs - clone.lock() - return clone + slice []Hook + dial DialHook + process ProcessHook + processPipeline ProcessPipelineHook + processTxPipeline ProcessPipelineHook } func (hs *hooks) AddHook(hook Hook) { - hs.hooks = append(hs.hooks, hook) + if hs.process == nil { + panic("hs.process == nil") + } + if hs.processPipeline == nil { + panic("hs.processPipeline == nil") + } + if hs.processTxPipeline == nil { + panic("hs.processTxPipeline == nil") + } + + hs.slice = append(hs.slice, hook) + hs.dial = hook.DialHook(hs.dial) + hs.process = hook.ProcessHook(hs.process) + hs.processPipeline = hook.ProcessPipelineHook(hs.processPipeline) + hs.processTxPipeline = hook.ProcessPipelineHook(hs.processTxPipeline) } -func (hs hooks) process( - ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error, -) error { - if len(hs.hooks) == 0 { - err := fn(ctx, cmd) - cmd.SetErr(err) - return err - } - - var hookIndex int - var retErr error - - for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ { - ctx, retErr = hs.hooks[hookIndex].BeforeProcess(ctx, cmd) - if retErr != nil { - cmd.SetErr(retErr) - } - } - - if retErr == nil { - retErr = fn(ctx, cmd) - cmd.SetErr(retErr) - } - - for hookIndex--; hookIndex >= 0; hookIndex-- { - if err := hs.hooks[hookIndex].AfterProcess(ctx, cmd); err != nil { - retErr = err - cmd.SetErr(retErr) - } - } - - return retErr +func (hs *hooks) clone() hooks { + clone := *hs + l := len(clone.slice) + clone.slice = clone.slice[:l:l] + return clone } -func (hs hooks) processPipeline( - ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error, -) error { - if len(hs.hooks) == 0 { - err := fn(ctx, cmds) - return err - } - - var hookIndex int - var retErr error - - for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ { - ctx, retErr = hs.hooks[hookIndex].BeforeProcessPipeline(ctx, cmds) - if retErr != nil { - setCmdsErr(cmds, retErr) +func (hs *hooks) setDial(dial DialHook) { + hs.dial = dial + for _, h := range hs.slice { + if wrapped := h.DialHook(hs.dial); wrapped != nil { + hs.dial = wrapped } } - - if retErr == nil { - retErr = fn(ctx, cmds) - } - - for hookIndex--; hookIndex >= 0; hookIndex-- { - if err := hs.hooks[hookIndex].AfterProcessPipeline(ctx, cmds); err != nil { - retErr = err - setCmdsErr(cmds, retErr) - } - } - - return retErr } -func (hs hooks) processTxPipeline( - ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error, +func (hs *hooks) setProcess(process ProcessHook) { + hs.process = process + for _, h := range hs.slice { + if wrapped := h.ProcessHook(hs.process); wrapped != nil { + hs.process = wrapped + } + } +} + +func (hs *hooks) setProcessPipeline(processPipeline ProcessPipelineHook) { + hs.processPipeline = processPipeline + for _, h := range hs.slice { + if wrapped := h.ProcessPipelineHook(hs.processPipeline); wrapped != nil { + hs.processPipeline = wrapped + } + } +} + +func (hs *hooks) setProcessTxPipeline(processTxPipeline ProcessPipelineHook) { + hs.processTxPipeline = processTxPipeline + for _, h := range hs.slice { + if wrapped := h.ProcessPipelineHook(hs.processTxPipeline); wrapped != nil { + hs.processTxPipeline = wrapped + } + } +} + +func (hs *hooks) withProcessHook(ctx context.Context, cmd Cmder, hook ProcessHook) error { + for _, h := range hs.slice { + if wrapped := h.ProcessHook(hook); wrapped != nil { + hook = wrapped + } + } + return hook(ctx, cmd) +} + +func (hs *hooks) withProcessPipelineHook( + ctx context.Context, cmds []Cmder, hook ProcessPipelineHook, ) error { - cmds = wrapMultiExec(ctx, cmds) - return hs.processPipeline(ctx, cmds, fn) + for _, h := range hs.slice { + if wrapped := h.ProcessPipelineHook(hook); wrapped != nil { + hook = wrapped + } + } + return hook(ctx, cmds) } //------------------------------------------------------------------------------ @@ -131,13 +134,6 @@ type baseClient struct { onClose func() error // hook called when client is closed } -func newBaseClient(opt *Options, connPool pool.Pooler) *baseClient { - return &baseClient{ - opt: opt, - connPool: connPool, - } -} - func (c *baseClient) clone() *baseClient { clone := *c return &clone @@ -293,6 +289,10 @@ func (c *baseClient) withConn( return fn(ctx, cn) } +func (c *baseClient) dial(ctx context.Context, network, addr string) (net.Conn, error) { + return c.opt.Dialer(ctx, network, addr) +} + func (c *baseClient) process(ctx context.Context, cmd Cmder) error { var lastErr error for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { @@ -379,26 +379,22 @@ func (c *baseClient) getAddr() string { } func (c *baseClient) processPipeline(ctx context.Context, cmds []Cmder) error { - return c.generalProcessPipeline(ctx, cmds, c.pipelineProcessCmds) -} - -func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error { - return c.generalProcessPipeline(ctx, cmds, c.txPipelineProcessCmds) -} - -type pipelineProcessor func(context.Context, *pool.Conn, []Cmder) (bool, error) - -func (c *baseClient) generalProcessPipeline( - ctx context.Context, cmds []Cmder, p pipelineProcessor, -) error { - err := c._generalProcessPipeline(ctx, cmds, p) - if err != nil { + if err := c.generalProcessPipeline(ctx, cmds, c.pipelineProcessCmds); err != nil { return err } return cmdsFirstErr(cmds) } -func (c *baseClient) _generalProcessPipeline( +func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error { + if err := c.generalProcessPipeline(ctx, cmds, c.txPipelineProcessCmds); err != nil { + return err + } + return cmdsFirstErr(cmds) +} + +type pipelineProcessor func(context.Context, *pool.Conn, []Cmder) (bool, error) + +func (c *baseClient) generalProcessPipeline( ctx context.Context, cmds []Cmder, p pipelineProcessor, ) error { var lastErr error @@ -484,17 +480,6 @@ func (c *baseClient) txPipelineProcessCmds( return false, nil } -func wrapMultiExec(ctx context.Context, cmds []Cmder) []Cmder { - if len(cmds) == 0 { - panic("not reached") - } - cmdCopy := make([]Cmder, len(cmds)+2) - cmdCopy[0] = NewStatusCmd(ctx, "multi") - copy(cmdCopy[1:], cmds) - cmdCopy[len(cmdCopy)-1] = NewSliceCmd(ctx, "exec") - return cmdCopy -} - func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error { // Parse +OK. if err := statusCmd.readReply(rd); err != nil { @@ -549,24 +534,29 @@ func NewClient(opt *Options) *Client { opt.init() c := Client{ - baseClient: newBaseClient(opt, newConnPool(opt)), + baseClient: &baseClient{ + opt: opt, + }, } - c.cmdable = c.Process + c.connPool = newConnPool(opt, c.baseClient.dial) + c.init() return &c } -func (c *Client) clone() *Client { - clone := *c - clone.cmdable = clone.Process - clone.hooks.lock() - return &clone +func (c *Client) init() { + c.cmdable = c.Process + c.hooks.setDial(c.baseClient.dial) + c.hooks.setProcess(c.baseClient.process) + c.hooks.setProcessPipeline(c.baseClient.processPipeline) + c.hooks.setProcessTxPipeline(c.baseClient.processTxPipeline) } func (c *Client) WithTimeout(timeout time.Duration) *Client { - clone := c.clone() + clone := *c clone.baseClient = c.baseClient.withTimeout(timeout) - return clone + clone.init() + return &clone } func (c *Client) Conn() *Conn { @@ -581,15 +571,9 @@ func (c *Client) Do(ctx context.Context, args ...interface{}) *Cmd { } func (c *Client) Process(ctx context.Context, cmd Cmder) error { - return c.hooks.process(ctx, cmd, c.baseClient.process) -} - -func (c *Client) processPipeline(ctx context.Context, cmds []Cmder) error { - return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline) -} - -func (c *Client) processTxPipeline(ctx context.Context, cmds []Cmder) error { - return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline) + err := c.hooks.process(ctx, cmd) + cmd.SetErr(err) + return err } // Options returns read-only Options that were used to create the client. @@ -611,7 +595,7 @@ func (c *Client) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmd func (c *Client) Pipeline() Pipeliner { pipe := Pipeline{ - exec: c.processPipeline, + exec: pipelineExecer(c.hooks.processPipeline), } pipe.init() return &pipe @@ -624,7 +608,10 @@ func (c *Client) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]C // TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC. func (c *Client) TxPipeline() Pipeliner { pipe := Pipeline{ - exec: c.processTxPipeline, + exec: func(ctx context.Context, cmds []Cmder) error { + cmds = wrapMultiExec(ctx, cmds) + return c.hooks.processTxPipeline(ctx, cmds) + }, } pipe.init() return &pipe @@ -699,44 +686,47 @@ func (c *Client) SSubscribe(ctx context.Context, channels ...string) *PubSub { //------------------------------------------------------------------------------ -type conn struct { - baseClient - cmdable - statefulCmdable - hooks // TODO: inherit hooks -} - // Conn represents a single Redis connection rather than a pool of connections. // Prefer running commands from Client unless there is a specific need // for a continuous single Redis connection. type Conn struct { - *conn + baseClient + cmdable + statefulCmdable + hooks } func newConn(opt *Options, connPool pool.Pooler) *Conn { c := Conn{ - conn: &conn{ - baseClient: baseClient{ - opt: opt, - connPool: connPool, - }, + baseClient: baseClient{ + opt: opt, + connPool: connPool, }, } + c.cmdable = c.Process c.statefulCmdable = c.Process + + c.hooks.setDial(c.baseClient.dial) + c.hooks.setProcess(c.baseClient.process) + c.hooks.setProcessPipeline(c.baseClient.processPipeline) + c.hooks.setProcessTxPipeline(c.baseClient.processTxPipeline) + return &c } func (c *Conn) Process(ctx context.Context, cmd Cmder) error { - return c.hooks.process(ctx, cmd, c.baseClient.process) + err := c.hooks.process(ctx, cmd) + cmd.SetErr(err) + return err } func (c *Conn) processPipeline(ctx context.Context, cmds []Cmder) error { - return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline) + return c.hooks.processPipeline(ctx, cmds) } func (c *Conn) processTxPipeline(ctx context.Context, cmds []Cmder) error { - return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline) + return c.hooks.processTxPipeline(ctx, cmds) } func (c *Conn) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { diff --git a/redis_test.go b/redis_test.go index a7029b8..d0ccb45 100644 --- a/redis_test.go +++ b/redis_test.go @@ -14,18 +14,22 @@ import ( "github.com/go-redis/redis/v9" ) -type redisHookError struct { - redis.Hook -} +type redisHookError struct{} var _ redis.Hook = redisHookError{} -func (redisHookError) BeforeProcess(ctx context.Context, cmd redis.Cmder) (context.Context, error) { - return ctx, nil +func (redisHookError) DialHook(hook redis.DialHook) redis.DialHook { + return hook } -func (redisHookError) AfterProcess(ctx context.Context, cmd redis.Cmder) error { - return errors.New("hook error") +func (redisHookError) ProcessHook(hook redis.ProcessHook) redis.ProcessHook { + return func(ctx context.Context, cmd redis.Cmder) error { + return errors.New("hook error") + } +} + +func (redisHookError) ProcessPipelineHook(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { + return hook } func TestHookError(t *testing.T) { diff --git a/ring.go b/ring.go index fa916c6..a4f0b06 100644 --- a/ring.go +++ b/ring.go @@ -213,11 +213,12 @@ func (shard *ringShard) Vote(up bool) bool { type ringSharding struct { opt *RingOptions - mu sync.RWMutex - shards *ringShards - closed bool - hash ConsistentHash - numShard int + mu sync.RWMutex + shards *ringShards + closed bool + hash ConsistentHash + numShard int + onNewNode []func(rdb *Client) } type ringShards struct { @@ -234,6 +235,12 @@ func newRingSharding(opt *RingOptions) *ringSharding { return c } +func (c *ringSharding) OnNewNode(fn func(rdb *Client)) { + c.mu.Lock() + c.onNewNode = append(c.onNewNode, fn) + c.mu.Unlock() +} + // SetAddrs replaces the shards in use, such that you can increase and // decrease number of shards, that you use. It will reuse shards that // existed before and close the ones that will not be used anymore. @@ -245,7 +252,7 @@ func (c *ringSharding) SetAddrs(addrs map[string]string) { return } - shards, cleanup := newRingShards(c.opt, addrs, c.shards) + shards, cleanup := c.newRingShards(addrs, c.shards) c.shards = shards c.mu.Unlock() @@ -253,8 +260,8 @@ func (c *ringSharding) SetAddrs(addrs map[string]string) { cleanup() } -func newRingShards( - opt *RingOptions, addrs map[string]string, existingShards *ringShards, +func (c *ringSharding) newRingShards( + addrs map[string]string, existingShards *ringShards, ) (*ringShards, func()) { shardMap := make(map[string]*ringShard) // indexed by addr unusedShards := make(map[string]*ringShard) // indexed by addr @@ -276,7 +283,12 @@ func newRingShards( shards.m[name] = shard delete(unusedShards, addr) } else { - shards.m[name] = newRingShard(opt, addr) + shard := newRingShard(c.opt, addr) + shards.m[name] = shard + + for _, fn := range c.onNewNode { + fn(shard.Client) + } } } @@ -460,13 +472,13 @@ func (c *ringSharding) Close() error { // and can tolerate losing data when one of the servers dies. // Otherwise you should use Redis Cluster. type Ring struct { + cmdable + hooks + opt *RingOptions sharding *ringSharding cmdsInfoCache *cmdsInfoCache heartbeatCancelFn context.CancelFunc - - cmdable - hooks } func NewRing(opt *RingOptions) *Ring { @@ -483,6 +495,14 @@ func NewRing(opt *RingOptions) *Ring { ring.cmdsInfoCache = newCmdsInfoCache(ring.cmdsInfo) ring.cmdable = ring.Process + ring.hooks.process = ring.process + ring.hooks.processPipeline = func(ctx context.Context, cmds []Cmder) error { + return ring.generalProcessPipeline(ctx, cmds, false) + } + ring.hooks.processTxPipeline = func(ctx context.Context, cmds []Cmder) error { + return ring.generalProcessPipeline(ctx, cmds, true) + } + go ring.sharding.Heartbeat(hbCtx, opt.HeartbeatFrequency) return &ring @@ -500,7 +520,9 @@ func (c *Ring) Do(ctx context.Context, args ...interface{}) *Cmd { } func (c *Ring) Process(ctx context.Context, cmd Cmder) error { - return c.hooks.process(ctx, cmd, c.process) + err := c.hooks.process(ctx, cmd) + cmd.SetErr(err) + return err } // Options returns read-only Options that were used to create the client. @@ -573,6 +595,10 @@ func (c *Ring) SSubscribe(ctx context.Context, channels ...string) *PubSub { return shard.Client.SSubscribe(ctx, channels...) } +func (c *Ring) OnNewNode(fn func(rdb *Client)) { + c.sharding.OnNewNode(fn) +} + // ForEachShard concurrently calls the fn on each live shard in the ring. // It returns the first error if any. func (c *Ring) ForEachShard( @@ -677,40 +703,37 @@ func (c *Ring) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder func (c *Ring) Pipeline() Pipeliner { pipe := Pipeline{ - exec: c.processPipeline, + exec: pipelineExecer(c.hooks.processPipeline), } pipe.init() return &pipe } -func (c *Ring) processPipeline(ctx context.Context, cmds []Cmder) error { - return c.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { - return c.generalProcessPipeline(ctx, cmds, false) - }) -} - func (c *Ring) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { return c.TxPipeline().Pipelined(ctx, fn) } func (c *Ring) TxPipeline() Pipeliner { pipe := Pipeline{ - exec: c.processTxPipeline, + exec: func(ctx context.Context, cmds []Cmder) error { + cmds = wrapMultiExec(ctx, cmds) + return c.hooks.processTxPipeline(ctx, cmds) + }, } pipe.init() return &pipe } -func (c *Ring) processTxPipeline(ctx context.Context, cmds []Cmder) error { - return c.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { - return c.generalProcessPipeline(ctx, cmds, true) - }) -} - func (c *Ring) generalProcessPipeline( ctx context.Context, cmds []Cmder, tx bool, ) error { + if tx { + // Trim multi .. exec. + cmds = cmds[1 : len(cmds)-1] + } + cmdsMap := make(map[string][]Cmder) + for _, cmd := range cmds { cmdInfo := c.cmdInfo(ctx, cmd.Name()) hash := cmd.stringArg(cmdFirstKeyPos(cmd, cmdInfo)) @@ -726,7 +749,19 @@ func (c *Ring) generalProcessPipeline( go func(hash string, cmds []Cmder) { defer wg.Done() - _ = c.processShardPipeline(ctx, hash, cmds, tx) + // TODO: retry? + shard, err := c.sharding.GetByName(hash) + if err != nil { + setCmdsErr(cmds, err) + return + } + + if tx { + cmds = wrapMultiExec(ctx, cmds) + shard.Client.hooks.processTxPipeline(ctx, cmds) + } else { + shard.Client.hooks.processPipeline(ctx, cmds) + } }(hash, cmds) } @@ -734,28 +769,13 @@ func (c *Ring) generalProcessPipeline( return cmdsFirstErr(cmds) } -func (c *Ring) processShardPipeline( - ctx context.Context, hash string, cmds []Cmder, tx bool, -) error { - // TODO: retry? - shard, err := c.sharding.GetByName(hash) - if err != nil { - setCmdsErr(cmds, err) - return err - } - - if tx { - return shard.Client.processTxPipeline(ctx, cmds) - } - return shard.Client.processPipeline(ctx, cmds) -} - func (c *Ring) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error { if len(keys) == 0 { return fmt.Errorf("redis: Watch requires at least one key") } var shards []*ringShard + for _, key := range keys { if key != "" { shard, err := c.sharding.GetByKey(hashtag.Key(key)) diff --git a/ring_test.go b/ring_test.go index 9610066..c64e107 100644 --- a/ring_test.go +++ b/ring_test.go @@ -277,29 +277,35 @@ var _ = Describe("Redis Ring", func() { var stack []string ring.AddHook(&hook{ - beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) { - Expect(cmd.String()).To(Equal("ping: ")) - stack = append(stack, "ring.BeforeProcess") - return ctx, nil - }, - afterProcess: func(ctx context.Context, cmd redis.Cmder) error { - Expect(cmd.String()).To(Equal("ping: PONG")) - stack = append(stack, "ring.AfterProcess") - return nil + processHook: func(hook redis.ProcessHook) redis.ProcessHook { + return func(ctx context.Context, cmd redis.Cmder) error { + Expect(cmd.String()).To(Equal("ping: ")) + stack = append(stack, "ring.BeforeProcess") + + err := hook(ctx, cmd) + + Expect(cmd.String()).To(Equal("ping: PONG")) + stack = append(stack, "ring.AfterProcess") + + return err + } }, }) ring.ForEachShard(ctx, func(ctx context.Context, shard *redis.Client) error { shard.AddHook(&hook{ - beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) { - Expect(cmd.String()).To(Equal("ping: ")) - stack = append(stack, "shard.BeforeProcess") - return ctx, nil - }, - afterProcess: func(ctx context.Context, cmd redis.Cmder) error { - Expect(cmd.String()).To(Equal("ping: PONG")) - stack = append(stack, "shard.AfterProcess") - return nil + processHook: func(hook redis.ProcessHook) redis.ProcessHook { + return func(ctx context.Context, cmd redis.Cmder) error { + Expect(cmd.String()).To(Equal("ping: ")) + stack = append(stack, "shard.BeforeProcess") + + err := hook(ctx, cmd) + + Expect(cmd.String()).To(Equal("ping: PONG")) + stack = append(stack, "shard.AfterProcess") + + return err + } }, }) return nil @@ -322,33 +328,39 @@ var _ = Describe("Redis Ring", func() { var stack []string ring.AddHook(&hook{ - beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: ")) - stack = append(stack, "ring.BeforeProcessPipeline") - return ctx, nil - }, - afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: PONG")) - stack = append(stack, "ring.AfterProcessPipeline") - return nil + processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { + return func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: ")) + stack = append(stack, "ring.BeforeProcessPipeline") + + err := hook(ctx, cmds) + + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: PONG")) + stack = append(stack, "ring.AfterProcessPipeline") + + return err + } }, }) ring.ForEachShard(ctx, func(ctx context.Context, shard *redis.Client) error { shard.AddHook(&hook{ - beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: ")) - stack = append(stack, "shard.BeforeProcessPipeline") - return ctx, nil - }, - afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: PONG")) - stack = append(stack, "shard.AfterProcessPipeline") - return nil + processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { + return func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: ")) + stack = append(stack, "shard.BeforeProcessPipeline") + + err := hook(ctx, cmds) + + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: PONG")) + stack = append(stack, "shard.AfterProcessPipeline") + + return err + } }, }) return nil @@ -374,33 +386,43 @@ var _ = Describe("Redis Ring", func() { var stack []string ring.AddHook(&hook{ - beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: ")) - stack = append(stack, "ring.BeforeProcessPipeline") - return ctx, nil - }, - afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: PONG")) - stack = append(stack, "ring.AfterProcessPipeline") - return nil + processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { + return func(ctx context.Context, cmds []redis.Cmder) error { + defer GinkgoRecover() + + Expect(cmds).To(HaveLen(3)) + Expect(cmds[1].String()).To(Equal("ping: ")) + stack = append(stack, "ring.BeforeProcessPipeline") + + err := hook(ctx, cmds) + + Expect(cmds).To(HaveLen(3)) + Expect(cmds[1].String()).To(Equal("ping: PONG")) + stack = append(stack, "ring.AfterProcessPipeline") + + return err + } }, }) ring.ForEachShard(ctx, func(ctx context.Context, shard *redis.Client) error { shard.AddHook(&hook{ - beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { - Expect(cmds).To(HaveLen(3)) - Expect(cmds[1].String()).To(Equal("ping: ")) - stack = append(stack, "shard.BeforeProcessPipeline") - return ctx, nil - }, - afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { - Expect(cmds).To(HaveLen(3)) - Expect(cmds[1].String()).To(Equal("ping: PONG")) - stack = append(stack, "shard.AfterProcessPipeline") - return nil + processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { + return func(ctx context.Context, cmds []redis.Cmder) error { + defer GinkgoRecover() + + Expect(cmds).To(HaveLen(3)) + Expect(cmds[1].String()).To(Equal("ping: ")) + stack = append(stack, "shard.BeforeProcessPipeline") + + err := hook(ctx, cmds) + + Expect(cmds).To(HaveLen(3)) + Expect(cmds[1].String()).To(Equal("ping: PONG")) + stack = append(stack, "shard.AfterProcessPipeline") + + return err + } }, }) return nil diff --git a/sentinel.go b/sentinel.go index 8b5fe88..41228d4 100644 --- a/sentinel.go +++ b/sentinel.go @@ -198,7 +198,17 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client { opt.Dialer = masterReplicaDialer(failover) opt.init() - connPool := newConnPool(opt) + var connPool *pool.ConnPool + + rdb := &Client{ + baseClient: &baseClient{ + opt: opt, + }, + } + connPool = newConnPool(opt, rdb.baseClient.dial) + rdb.connPool = connPool + rdb.onClose = failover.Close + rdb.init() failover.mu.Lock() failover.onFailover = func(ctx context.Context, addr string) { @@ -208,13 +218,7 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client { } failover.mu.Unlock() - c := Client{ - baseClient: newBaseClient(opt, connPool), - } - c.cmdable = c.Process - c.onClose = failover.Close - - return &c + return rdb } func masterReplicaDialer( @@ -262,15 +266,21 @@ func NewSentinelClient(opt *Options) *SentinelClient { opt.init() c := &SentinelClient{ baseClient: &baseClient{ - opt: opt, - connPool: newConnPool(opt), + opt: opt, }, } + c.connPool = newConnPool(opt, c.baseClient.dial) + + c.hooks.setDial(c.baseClient.dial) + c.hooks.setProcess(c.baseClient.process) + return c } func (c *SentinelClient) Process(ctx context.Context, cmd Cmder) error { - return c.hooks.process(ctx, cmd, c.baseClient.process) + err := c.hooks.process(ctx, cmd) + cmd.SetErr(err) + return err } func (c *SentinelClient) pubSub() *PubSub { diff --git a/tx.go b/tx.go index 61375e0..e720e68 100644 --- a/tx.go +++ b/tx.go @@ -37,10 +37,17 @@ func (c *Client) newTx() *Tx { func (c *Tx) init() { c.cmdable = c.Process c.statefulCmdable = c.Process + + c.hooks.setDial(c.baseClient.dial) + c.hooks.setProcess(c.baseClient.process) + c.hooks.setProcessPipeline(c.baseClient.processPipeline) + c.hooks.setProcessTxPipeline(c.baseClient.processTxPipeline) } func (c *Tx) Process(ctx context.Context, cmd Cmder) error { - return c.hooks.process(ctx, cmd, c.baseClient.process) + err := c.hooks.process(ctx, cmd) + cmd.SetErr(err) + return err } // Watch prepares a transaction and marks the keys to be watched @@ -93,7 +100,7 @@ func (c *Tx) Unwatch(ctx context.Context, keys ...string) *StatusCmd { func (c *Tx) Pipeline() Pipeliner { pipe := Pipeline{ exec: func(ctx context.Context, cmds []Cmder) error { - return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline) + return c.hooks.processPipeline(ctx, cmds) }, } pipe.init() @@ -122,9 +129,21 @@ func (c *Tx) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder func (c *Tx) TxPipeline() Pipeliner { pipe := Pipeline{ exec: func(ctx context.Context, cmds []Cmder) error { - return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline) + cmds = wrapMultiExec(ctx, cmds) + return c.hooks.processTxPipeline(ctx, cmds) }, } pipe.init() return &pipe } + +func wrapMultiExec(ctx context.Context, cmds []Cmder) []Cmder { + if len(cmds) == 0 { + panic("not reached") + } + cmdsCopy := make([]Cmder, len(cmds)+2) + cmdsCopy[0] = NewStatusCmd(ctx, "multi") + copy(cmdsCopy[1:], cmds) + cmdsCopy[len(cmdsCopy)-1] = NewSliceCmd(ctx, "exec") + return cmdsCopy +}