This commit is contained in:
Marco 2024-11-22 00:17:39 -05:00 committed by GitHub
commit 54e8a0368d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 111 additions and 0 deletions

View File

@ -409,6 +409,20 @@ func (c *PubSub) newMessage(reply interface{}) (interface{}, error) {
return &Pong{
Payload: reply[1].(string),
}, nil
case "invalidate":
switch payload := reply[1].(type) {
case []interface{}:
s := make([]string, len(payload))
for idx := range payload {
s[idx] = payload[idx].(string)
}
return &Message{
Channel: "invalidate",
PayloadSlice: s,
}, nil
default:
return nil, fmt.Errorf("redis: unsupported invalidate message payload: %#v", payload)
}
default:
return nil, fmt.Errorf("redis: unsupported pubsub message: %q", kind)
}

View File

@ -1,6 +1,8 @@
package redis_test
import (
"context"
"fmt"
"io"
"net"
"sync"
@ -567,4 +569,99 @@ var _ = Describe("PubSub", func() {
Expect(msg.Channel).To(Equal("mychannel"))
Expect(msg.Payload).To(Equal(text))
})
It("supports client-cache invalidation messages", func() {
ch := make(chan []string, 2)
defer close(ch)
client := redis.NewClient(getOptsWithTracking(redisOptions(), func(keys []string) error {
ch <- keys
return nil
}))
defer client.Close()
v1 := client.Get(context.Background(), "foo")
Expect(v1.Val()).To(Equal(""))
s1 := client.Set(context.Background(), "foo", "bar", time.Duration(time.Minute))
Expect(s1.Val()).To(Equal("OK"))
v2 := client.Get(context.Background(), "foo")
Expect(v2.Val()).To(Equal("bar"))
// sleep a little to allow time for the first invalidation message to come through
time.Sleep(time.Second)
s2 := client.Set(context.Background(), "foo", "foobar", time.Duration(time.Minute))
Expect(s2.Val()).To(Equal("OK"))
for i := 0; i < 2; i++ {
select {
case keys := <-ch:
Expect(keys).ToNot(BeEmpty())
Expect(keys[0]).To(Equal("foo"))
case <-time.After(10 * time.Second):
// fail on timeouts
Fail("invalidation message wait timed out")
}
}
})
})
func getOptsWithTracking(opt *redis.Options, processInvalidKeysFunc func([]string) error) *redis.Options {
var mu sync.Mutex
invalidateClientID := int64(-1)
invalidateOpts := *opt
invalidateOpts.OnConnect = func(ctx context.Context, conn *redis.Conn) (err error) {
invalidateClientID, err = conn.ClientID(ctx).Result()
return
}
startBackgroundInvalidationSubscription := func(ctx context.Context) int64 {
mu.Lock()
defer mu.Unlock()
if invalidateClientID != -1 {
return invalidateClientID
}
invalidateClient := redis.NewClient(&invalidateOpts)
invalidations := invalidateClient.Subscribe(ctx, "__redis__:invalidate")
go func() {
defer func() {
invalidations.Close()
invalidateClient.Close()
mu.Lock()
invalidateClientID = -1
mu.Unlock()
}()
for {
msg, err := invalidations.ReceiveMessage(context.Background())
if err == io.EOF || err == context.Canceled {
return
} else if err != nil {
fmt.Printf("warning: subscription on key invalidations aborted: %s\n", err.Error())
// send back empty []string to fail the test
processInvalidKeysFunc([]string{})
return
}
processInvalidKeysFunc(msg.PayloadSlice)
}
}()
return invalidateClientID
}
opt.OnConnect = func(ctx context.Context, conn *redis.Conn) error {
invalidateClientID := startBackgroundInvalidationSubscription(ctx)
return conn.Process(
ctx,
redis.NewBoolCmd(
ctx,
"CLIENT", "TRACKING", "on",
"REDIRECT", fmt.Sprintf("%d", invalidateClientID),
),
)
}
return opt
}