diff --git a/cluster_commands.go b/cluster_commands.go index 1f0bae06..336ea98d 100644 --- a/cluster_commands.go +++ b/cluster_commands.go @@ -2,6 +2,7 @@ package redis import ( "context" + "sync" "sync/atomic" ) @@ -23,3 +24,76 @@ func (c *ClusterClient) DBSize(ctx context.Context) *IntCmd { cmd.val = size return cmd } + +func (c *ClusterClient) ScriptLoad(ctx context.Context, script string) *StringCmd { + cmd := NewStringCmd(ctx, "script", "load", script) + mu := &sync.Mutex{} + err := c.ForEachShard(ctx, func(ctx context.Context, shard *Client) error { + val, err := shard.ScriptLoad(ctx, script).Result() + if err != nil { + return err + } + + mu.Lock() + if cmd.Val() == "" { + cmd.val = val + } + mu.Unlock() + + return nil + }) + if err != nil { + cmd.SetErr(err) + } + + return cmd +} + +func (c *ClusterClient) ScriptFlush(ctx context.Context) *StatusCmd { + cmd := NewStatusCmd(ctx, "script", "flush") + _ = c.ForEachShard(ctx, func(ctx context.Context, shard *Client) error { + shard.ScriptFlush(ctx) + + return nil + }) + + return cmd +} + +func (c *ClusterClient) ScriptExists(ctx context.Context, hashes ...string) *BoolSliceCmd { + args := make([]interface{}, 2+len(hashes)) + args[0] = "script" + args[1] = "exists" + for i, hash := range hashes { + args[2+i] = hash + } + cmd := NewBoolSliceCmd(ctx, args...) + + result := make([]bool, len(hashes)) + for i := range result { + result[i] = true + } + + mu := &sync.Mutex{} + err := c.ForEachShard(ctx, func(ctx context.Context, shard *Client) error { + val, err := shard.ScriptExists(ctx, hashes...).Result() + if err != nil { + return err + } + + mu.Lock() + for i, v := range val { + result[i] = result[i] && v + } + mu.Unlock() + + return nil + }) + if err != nil { + cmd.SetErr(err) + } + + cmd.val = result + + return cmd +} diff --git a/cluster_test.go b/cluster_test.go index 561832a0..3880d437 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -344,6 +344,50 @@ var _ = Describe("ClusterClient", func() { }) }) + It("distributes scripts when using Script Load", func() { + client.ScriptFlush(ctx) + + script := redis.NewScript(`return 'Unique script'`) + + script.Load(ctx, client) + + client.ForEachShard(ctx, func(ctx context.Context, shard *redis.Client) error { + defer GinkgoRecover() + + val, _ := script.Exists(ctx, shard).Result() + Expect(val[0]).To(Equal(true)) + return nil + }) + }) + + It("checks all shards when using Script Exists", func() { + client.ScriptFlush(ctx) + + script := redis.NewScript(`return 'First script'`) + lostScriptSrc := `return 'Lost script'` + lostScript := redis.NewScript(lostScriptSrc) + + script.Load(ctx, client) + client.Do(ctx, "script", "load", lostScriptSrc) + + val, _ := client.ScriptExists(ctx, script.Hash(), lostScript.Hash()).Result() + + Expect(val).To(Equal([]bool{true, false})) + }) + + It("flushes scripts from all shards when using ScriptFlush", func() { + script := redis.NewScript(`return 'Unnecessary script'`) + script.Load(ctx, client) + + val, _ := client.ScriptExists(ctx, script.Hash()).Result() + Expect(val).To(Equal([]bool{true})) + + client.ScriptFlush(ctx) + + val, _ = client.ScriptExists(ctx, script.Hash()).Result() + Expect(val).To(Equal([]bool{false})) + }) + It("supports Watch", func() { var incr func(string) error