diff --git a/command.go b/command.go index 4879cfa0..f21e2dbd 100644 --- a/command.go +++ b/command.go @@ -18,6 +18,8 @@ type Cmder interface { Args() []interface{} String() string stringArg(int) string + firstKeyPos() int + addKeyPos(int) readTimeout() *time.Duration readReply(rd *proto.Reader) error @@ -73,6 +75,9 @@ func cmdFirstKeyPos(cmd Cmder, info *CommandInfo) int { } } + if pos := cmd.firstKeyPos(); pos != 0 { + return pos + } if info == nil { return 0 } @@ -106,6 +111,7 @@ type baseCmd struct { ctx context.Context args []interface{} err error + keyPos int _readTimeout *time.Duration } @@ -147,6 +153,14 @@ func (cmd *baseCmd) stringArg(pos int) string { return s } +func (cmd *baseCmd) firstKeyPos() int { + return cmd.keyPos +} + +func (cmd *baseCmd) addKeyPos(offset int) { + cmd.keyPos += offset +} + func (cmd *baseCmd) SetErr(e error) { cmd.err = e } diff --git a/commands.go b/commands.go index 4c7113d2..e70eac4e 100644 --- a/commands.go +++ b/commands.go @@ -1504,16 +1504,20 @@ type XReadArgs struct { func (c cmdable) XRead(ctx context.Context, a *XReadArgs) *XStreamSliceCmd { args := make([]interface{}, 0, 5+len(a.Streams)) args = append(args, "xread") + + offset := 1 if a.Count > 0 { args = append(args, "count") args = append(args, a.Count) + offset += 2 } if a.Block >= 0 { args = append(args, "block") args = append(args, int64(a.Block/time.Millisecond)) + offset += 2 } - args = append(args, "streams") + offset += 1 for _, s := range a.Streams { args = append(args, s) } @@ -1522,6 +1526,7 @@ func (c cmdable) XRead(ctx context.Context, a *XReadArgs) *XStreamSliceCmd { if a.Block >= 0 { cmd.setReadTimeout(a.Block) } + cmd.addKeyPos(offset) _ = c(ctx, cmd) return cmd } @@ -1575,16 +1580,22 @@ type XReadGroupArgs struct { func (c cmdable) XReadGroup(ctx context.Context, a *XReadGroupArgs) *XStreamSliceCmd { args := make([]interface{}, 0, 8+len(a.Streams)) args = append(args, "xreadgroup", "group", a.Group, a.Consumer) + + offset := 1 if a.Count > 0 { args = append(args, "count", a.Count) + offset += 2 } if a.Block >= 0 { args = append(args, "block", int64(a.Block/time.Millisecond)) + offset += 2 } if a.NoAck { args = append(args, "noack") + offset += 1 } args = append(args, "streams") + offset += 1 for _, s := range a.Streams { args = append(args, s) } @@ -1593,6 +1604,7 @@ func (c cmdable) XReadGroup(ctx context.Context, a *XReadGroupArgs) *XStreamSlic if a.Block >= 0 { cmd.setReadTimeout(a.Block) } + cmd.addKeyPos(offset) _ = c(ctx, cmd) return cmd } @@ -1871,6 +1883,7 @@ func (c cmdable) ZInterStore(ctx context.Context, destination string, store *ZSt args = append(args, "aggregate", store.Aggregate) } cmd := NewIntCmd(ctx, args...) + cmd.addKeyPos(3) _ = c(ctx, cmd) return cmd } @@ -2106,6 +2119,7 @@ func (c cmdable) ZUnionStore(ctx context.Context, dest string, store *ZStore) *I args = append(args, "aggregate", store.Aggregate) } cmd := NewIntCmd(ctx, args...) + cmd.addKeyPos(3) _ = c(ctx, cmd) return cmd }