diff --git a/commands.go b/commands.go index 516c1d9..37a0429 100644 --- a/commands.go +++ b/commands.go @@ -254,6 +254,8 @@ type Cmdable interface { ZCount(ctx context.Context, key, min, max string) *IntCmd ZLexCount(ctx context.Context, key, min, max string) *IntCmd ZIncrBy(ctx context.Context, key string, increment float64, member string) *FloatCmd + ZInter(ctx context.Context, store *ZStore) *StringSliceCmd + ZInterWithScores(ctx context.Context, store *ZStore) *ZSliceCmd ZInterStore(ctx context.Context, destination string, store *ZStore) *IntCmd ZMScore(ctx context.Context, key string, members ...string) *FloatSliceCmd ZPopMax(ctx context.Context, key string, count ...int64) *ZSliceCmd @@ -1943,6 +1945,17 @@ type ZStore struct { Aggregate string } +func (z *ZStore) len() (n int) { + n = len(z.Keys) + if len(z.Weights) > 0 { + n += 1 + len(z.Weights) + } + if z.Aggregate != "" { + n += 2 + } + return n +} + // Redis `BZPOPMAX key [key ...] timeout` command. func (c cmdable) BZPopMax(ctx context.Context, timeout time.Duration, keys ...string) *ZWithKeyCmd { args := make([]interface{}, 1+len(keys)+1) @@ -2088,7 +2101,7 @@ func (c cmdable) ZIncrBy(ctx context.Context, key string, increment float64, mem } func (c cmdable) ZInterStore(ctx context.Context, destination string, store *ZStore) *IntCmd { - args := make([]interface{}, 0, 3+len(store.Keys)) + args := make([]interface{}, 0, 3+store.len()) args = append(args, "zinterstore", destination, len(store.Keys)) for _, key := range store.Keys { args = append(args, key) @@ -2108,6 +2121,50 @@ func (c cmdable) ZInterStore(ctx context.Context, destination string, store *ZSt return cmd } +func (c cmdable) ZInter(ctx context.Context, store *ZStore) *StringSliceCmd { + args := make([]interface{}, 0, 2+store.len()) + args = append(args, "zinter", len(store.Keys)) + for _, key := range store.Keys { + args = append(args, key) + } + if len(store.Weights) > 0 { + args = append(args, "weights") + for _, weights := range store.Weights { + args = append(args, weights) + } + } + + if store.Aggregate != "" { + args = append(args, "aggregate", store.Aggregate) + } + cmd := NewStringSliceCmd(ctx, args...) + cmd.setFirstKeyPos(2) + _ = c(ctx, cmd) + return cmd +} + +func (c cmdable) ZInterWithScores(ctx context.Context, store *ZStore) *ZSliceCmd { + args := make([]interface{}, 0, 3+store.len()) + args = append(args, "zinter", len(store.Keys)) + for _, key := range store.Keys { + args = append(args, key) + } + if len(store.Weights) > 0 { + args = append(args, "weights") + for _, weights := range store.Weights { + args = append(args, weights) + } + } + if store.Aggregate != "" { + args = append(args, "aggregate", store.Aggregate) + } + args = append(args, "withscores") + cmd := NewZSliceCmd(ctx, args...) + cmd.setFirstKeyPos(2) + _ = c(ctx, cmd) + return cmd +} + func (c cmdable) ZMScore(ctx context.Context, key string, members ...string) *FloatSliceCmd { args := make([]interface{}, 2+len(members)) args[0] = "zmscore" @@ -2334,7 +2391,7 @@ func (c cmdable) ZScore(ctx context.Context, key, member string) *FloatCmd { } func (c cmdable) ZUnionStore(ctx context.Context, dest string, store *ZStore) *IntCmd { - args := make([]interface{}, 0, 3+len(store.Keys)) + args := make([]interface{}, 0, 3+store.len()) args = append(args, "zunionstore", dest, len(store.Keys)) for _, key := range store.Keys { args = append(args, key) diff --git a/commands_test.go b/commands_test.go index 1ce8872..d2b5b5b 100644 --- a/commands_test.go +++ b/commands_test.go @@ -3999,6 +3999,55 @@ var _ = Describe("Commands", func() { }, })) }) + + It("should ZInter", func() { + err := client.ZAdd(ctx, "zset1", &redis.Z{Score: 1, Member: "one"}).Err() + Expect(err).NotTo(HaveOccurred()) + err = client.ZAdd(ctx, "zset1", &redis.Z{Score: 2, Member: "two"}).Err() + Expect(err).NotTo(HaveOccurred()) + err = client.ZAdd(ctx, "zset2", &redis.Z{Score: 1, Member: "one"}).Err() + Expect(err).NotTo(HaveOccurred()) + err = client.ZAdd(ctx, "zset2", &redis.Z{Score: 2, Member: "two"}).Err() + Expect(err).NotTo(HaveOccurred()) + err = client.ZAdd(ctx, "zset2", &redis.Z{Score: 3, Member: "three"}).Err() + Expect(err).NotTo(HaveOccurred()) + + v, err := client.ZInter(ctx, &redis.ZStore{ + Keys: []string{"zset1", "zset2"}, + }).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(v).To(Equal([]string{"one", "two"})) + }) + + It("should ZInterWithScores", func() { + err := client.ZAdd(ctx, "zset1", &redis.Z{Score: 1, Member: "one"}).Err() + Expect(err).NotTo(HaveOccurred()) + err = client.ZAdd(ctx, "zset1", &redis.Z{Score: 2, Member: "two"}).Err() + Expect(err).NotTo(HaveOccurred()) + err = client.ZAdd(ctx, "zset2", &redis.Z{Score: 1, Member: "one"}).Err() + Expect(err).NotTo(HaveOccurred()) + err = client.ZAdd(ctx, "zset2", &redis.Z{Score: 2, Member: "two"}).Err() + Expect(err).NotTo(HaveOccurred()) + err = client.ZAdd(ctx, "zset2", &redis.Z{Score: 3, Member: "three"}).Err() + Expect(err).NotTo(HaveOccurred()) + + v, err := client.ZInterWithScores(ctx, &redis.ZStore{ + Keys: []string{"zset1", "zset2"}, + Weights: []float64{2, 3}, + Aggregate: "Max", + }).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(v).To(Equal([]redis.Z{ + { + Member: "one", + Score: 3, + }, + { + Member: "two", + Score: 6, + }, + })) + }) }) Describe("streams", func() {