diff --git a/redis.go b/redis.go index 98d6034..3e7cc9d 100644 --- a/redis.go +++ b/redis.go @@ -29,8 +29,22 @@ type Hook interface { AfterProcessPipeline(ctx context.Context, cmds []Cmder) error } +type ConnectHook interface { + BeforeConnect(ctx context.Context, event ConnectEvent) context.Context + AfterConnect(ctx context.Context, event ConnectEvent) +} + +type fullHook interface { + Hook + ConnectHook +} + +type ConnectEvent struct { + Err error +} + type hooks struct { - hooks []Hook + hooks []fullHook } func (hs *hooks) lock() { @@ -44,7 +58,11 @@ func (hs hooks) clone() hooks { } func (hs *hooks) AddHook(hook Hook) { - hs.hooks = append(hs.hooks, hook) + if hook, ok := hook.(fullHook); ok { + hs.hooks = append(hs.hooks, hook) + } else { + hs.hooks = append(hs.hooks, dummyConnectHook{Hook: hook}) + } } func (hs hooks) process( @@ -132,6 +150,16 @@ func (hs hooks) withContext(ctx context.Context, fn func() error) error { return fn() } +type dummyConnectHook struct { + Hook +} + +func (dummyConnectHook) BeforeConnect(ctx context.Context, event ConnectEvent) context.Context { + return ctx +} + +func (dummyConnectHook) AfterConnect(ctx context.Context, event ConnectEvent) {} + //------------------------------------------------------------------------------ type baseClient struct {