diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml index caef4b31..707670d0 100644 --- a/.github/FUNDING.yml +++ b/.github/FUNDING.yml @@ -1 +1 @@ -custom: ['https://uptrace.dev'] +custom: ['https://uptrace.dev/sponsor'] diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index fbaa570b..e86d7a66 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -3,6 +3,3 @@ contact_links: - name: Discussions url: https://github.com/go-redis/redis/discussions about: Ask a question via GitHub Discussions - - name: Discord - url: https://discord.gg/rWtp5Aj - about: Ask a question via Discord diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 7ca8d61c..c50b7747 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -2,9 +2,12 @@ name: Go on: push: - branches: [master] + branches: [master, v9] pull_request: - branches: [master] + branches: [master, v9] + +permissions: + contents: read jobs: build: @@ -13,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - go-version: [1.16.x, 1.17.x] + go-version: [1.18.x, 1.19.x] services: redis: @@ -25,12 +28,12 @@ jobs: steps: - name: Set up ${{ matrix.go-version }} - uses: actions/setup-go@v2 + uses: actions/setup-go@v3 with: go-version: ${{ matrix.go-version }} - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Test run: make test diff --git a/.github/workflows/commitlint.yml b/.github/workflows/commitlint.yml index 67e6df3b..af8a615e 100644 --- a/.github/workflows/commitlint.yml +++ b/.github/workflows/commitlint.yml @@ -5,7 +5,7 @@ jobs: commitlint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 with: fetch-depth: 0 - - uses: wagoid/commitlint-github-action@v4 + - uses: wagoid/commitlint-github-action@v5 diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 6c12b83b..d3232ecb 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -7,13 +7,20 @@ on: branches: - master - main + - v9 pull_request: +permissions: + contents: read + jobs: golangci: + permissions: + contents: read # for actions/checkout to fetch code + pull-requests: read # for golangci/golangci-lint-action to fetch pull requests name: lint runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: golangci-lint - uses: golangci/golangci-lint-action@v2 + uses: golangci/golangci-lint-action@v3 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 9168cad2..685693ae 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,7 +9,7 @@ jobs: build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: ncipollo/release-action@v1 with: body: diff --git a/.gitignore b/.gitignore index b975a7b4..dc322f9b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ *.rdb -testdata/*/ +testdata/* .idea/ diff --git a/CHANGELOG.md b/CHANGELOG.md index c575c568..7b117894 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,149 +1,55 @@ -## [8.11.4](https://github.com/go-redis/redis/compare/v8.11.3...v8.11.4) (2021-10-04) +# [9.0.0-rc.2](https://github.com/go-redis/redis/compare/v9.0.0-rc.1...v9.0.0-rc.2) (2022-11-26) + + +### Bug Fixes + +* capture error correctly in withConn ([d1bfaba](https://github.com/go-redis/redis/commit/d1bfaba549fe380d269c26cea0a0183ed1520a85)) +* fixes ring.SetAddrs and rebalance race ([#2283](https://github.com/go-redis/redis/issues/2283)) ([d83436b](https://github.com/go-redis/redis/commit/d83436b321cd9ed52ba33c3edbe8f63bb0444c59)) +* read in route_randomly query param correctly ([f236053](https://github.com/go-redis/redis/commit/f236053735d10aec5e6e31fc3ced1b2e53292554)) +* reduce `SetAddrs` shards lock contention ([6c05a9f](https://github.com/go-redis/redis/commit/6c05a9f6b17f8e32593d3f7d594f82ba3dbcafb1)), closes [/github.com/go-redis/redis/pull/2190#discussion_r953040289](https://github.com//github.com/go-redis/redis/pull/2190/issues/discussion_r953040289) [#2077](https://github.com/go-redis/redis/issues/2077) +* wrap cmds in Conn.TxPipeline ([5053db2](https://github.com/go-redis/redis/commit/5053db2f9c8b3ca25f497a75f70012c7ad6cd775)) ### Features -* add acl auth support for sentinels ([f66582f](https://github.com/go-redis/redis/commit/f66582f44f3dc3a4705a5260f982043fde4aa634)) -* add Cmd.{String,Int,Float,Bool}Slice helpers and an example ([5d3d293](https://github.com/go-redis/redis/commit/5d3d293cc9c60b90871e2420602001463708ce24)) -* add SetVal method for each command ([168981d](https://github.com/go-redis/redis/commit/168981da2d84ee9e07d15d3e74d738c162e264c4)) +* add HasErrorPrefix ([d3d8002](https://github.com/go-redis/redis/commit/d3d8002e894a1eab5bab2c9fff13439527e330d8)) +* add support for SINTERCARD command ([bc51c61](https://github.com/go-redis/redis/commit/bc51c61a458d1bc4fb4424c7c3e912325ef980cc)) -## v8.11 +## v9 UNRELEASED -- Remove OpenTelemetry metrics. -- Supports more redis commands and options. +### Added -## v8.10 +- Added support for [RESP3](https://github.com/antirez/RESP3/blob/master/spec.md) protocol. + Contributed by @monkey92t who has done a lot of work recently. +- Added `ContextTimeoutEnabled` option that controls whether the client respects context timeouts + and deadlines. See + [Redis Timeouts](https://redis.uptrace.dev/guide/go-redis-debugging.html#timeouts) for details. +- Added `ParseClusterURL` to parse URLs into `ClusterOptions`, for example, + `redis://user:password@localhost:6789?dial_timeout=3&read_timeout=6s&addr=localhost:6790&addr=localhost:6791`. +- Added metrics instrumentation using `redisotel.IstrumentMetrics`. See + [documentation](https://redis.uptrace.dev/guide/go-redis-monitoring.html) -- Removed extra OpenTelemetry spans from go-redis core. Now go-redis instrumentation only adds a - single span with a Redis command (instead of 4 spans). There are multiple reasons behind this - decision: +### Changed - - Traces become smaller and less noisy. - - It may be costly to process those 3 extra spans for each query. - - go-redis no longer depends on OpenTelemetry. +- Removed asynchronous cancellation based on the context timeout. It was racy in v8 and is + completely gone in v9. +- Reworked hook interface and added `DialHook`. +- Replaced `redisotel.NewTracingHook` with `redisotel.InstrumentTracing`. See + [example](example/otel) and + [documentation](https://redis.uptrace.dev/guide/go-redis-monitoring.html). +- Replaced `*redis.Z` with `redis.Z` since it is small enough to be passed as value without making + an allocation. +- Renamed the option `MaxConnAge` to `ConnMaxLifetime`. +- Renamed the option `IdleTimeout` to `ConnMaxIdleTime`. +- Removed connection reaper in favor of `MaxIdleConns`. +- Removed `WithContext` since `context.Context` can be passed directly as an arg. +- Removed `Pipeline.Close` since there is no real need to explicitly manage pipeline resources and + it can be safely reused via `sync.Pool` etc. `Pipeline.Discard` is still available if you want to + reset commands for some reason. - Eventually we hope to replace the information that we no longer collect with OpenTelemetry - Metrics. +### Fixed -## v8.9 - -- Changed `PubSub.Channel` to only rely on `Ping` result. You can now use `WithChannelSize`, - `WithChannelHealthCheckInterval`, and `WithChannelSendTimeout` to override default settings. - -## v8.8 - -- To make updating easier, extra modules now have the same version as go-redis does. That means that - you need to update your imports: - -``` -github.com/go-redis/redis/extra/redisotel -> github.com/go-redis/redis/extra/redisotel/v8 -github.com/go-redis/redis/extra/rediscensus -> github.com/go-redis/redis/extra/rediscensus/v8 -``` - -## v8.5 - -- [knadh](https://github.com/knadh) contributed long-awaited ability to scan Redis Hash into a - struct: - -```go -err := rdb.HGetAll(ctx, "hash").Scan(&data) - -err := rdb.MGet(ctx, "key1", "key2").Scan(&data) -``` - -- Please check [redismock](https://github.com/go-redis/redismock) by - [monkey92t](https://github.com/monkey92t) if you are looking for mocking Redis Client. - -## v8 - -- All commands require `context.Context` as a first argument, e.g. `rdb.Ping(ctx)`. If you are not - using `context.Context` yet, the simplest option is to define global package variable - `var ctx = context.TODO()` and use it when `ctx` is required. - -- Full support for `context.Context` canceling. - -- Added `redis.NewFailoverClusterClient` that supports routing read-only commands to a slave node. - -- Added `redisext.OpenTemetryHook` that adds - [Redis OpenTelemetry instrumentation](https://redis.uptrace.dev/tracing/). - -- Redis slow log support. - -- Ring uses Rendezvous Hashing by default which provides better distribution. You need to move - existing keys to a new location or keys will be inaccessible / lost. To use old hashing scheme: - -```go -import "github.com/golang/groupcache/consistenthash" - -ring := redis.NewRing(&redis.RingOptions{ - NewConsistentHash: func() { - return consistenthash.New(100, crc32.ChecksumIEEE) - }, -}) -``` - -- `ClusterOptions.MaxRedirects` default value is changed from 8 to 3. -- `Options.MaxRetries` default value is changed from 0 to 3. - -- `Cluster.ForEachNode` is renamed to `ForEachShard` for consistency with `Ring`. - -## v7.3 - -- New option `Options.Username` which causes client to use `AuthACL`. Be aware if your connection - URL contains username. - -## v7.2 - -- Existing `HMSet` is renamed to `HSet` and old deprecated `HMSet` is restored for Redis 3 users. - -## v7.1 - -- Existing `Cmd.String` is renamed to `Cmd.Text`. New `Cmd.String` implements `fmt.Stringer` - interface. - -## v7 - -- _Important_. Tx.Pipeline now returns a non-transactional pipeline. Use Tx.TxPipeline for a - transactional pipeline. -- WrapProcess is replaced with more convenient AddHook that has access to context.Context. -- WithContext now can not be used to create a shallow copy of the client. -- New methods ProcessContext, DoContext, and ExecContext. -- Client respects Context.Deadline when setting net.Conn deadline. -- Client listens on Context.Done while waiting for a connection from the pool and returns an error - when context context is cancelled. -- Add PubSub.ChannelWithSubscriptions that sends `*Subscription` in addition to `*Message` to allow - detecting reconnections. -- `time.Time` is now marshalled in RFC3339 format. `rdb.Get("foo").Time()` helper is added to parse - the time. -- `SetLimiter` is removed and added `Options.Limiter` instead. -- `HMSet` is deprecated as of Redis v4. - -## v6.15 - -- Cluster and Ring pipelines process commands for each node in its own goroutine. - -## 6.14 - -- Added Options.MinIdleConns. -- Added Options.MaxConnAge. -- PoolStats.FreeConns is renamed to PoolStats.IdleConns. -- Add Client.Do to simplify creating custom commands. -- Add Cmd.String, Cmd.Int, Cmd.Int64, Cmd.Uint64, Cmd.Float64, and Cmd.Bool helpers. -- Lower memory usage. - -## v6.13 - -- Ring got new options called `HashReplicas` and `Hash`. It is recommended to set - `HashReplicas = 1000` for better keys distribution between shards. -- Cluster client was optimized to use much less memory when reloading cluster state. -- PubSub.ReceiveMessage is re-worked to not use ReceiveTimeout so it does not lose data when timeout - occurres. In most cases it is recommended to use PubSub.Channel instead. -- Dialer.KeepAlive is set to 5 minutes by default. - -## v6.12 - -- ClusterClient got new option called `ClusterSlots` which allows to build cluster of normal Redis - Servers that don't have cluster mode enabled. See - https://godoc.org/github.com/go-redis/redis#example-NewClusterClient--ManualSetup +- Improved and fixed pipeline retries. +- As usual, added more commands and fixed some bugs. diff --git a/Makefile b/Makefile index a4cfe057..d660763b 100644 --- a/Makefile +++ b/Makefile @@ -6,6 +6,8 @@ test: testdeps go test ./... -run=NONE -bench=. -benchmem env GOOS=linux GOARCH=386 go test ./... go vet + cd internal/customvet && go build . + go vet -vettool ./internal/customvet/customvet testdeps: testdata/redis/src/redis-server @@ -16,7 +18,7 @@ bench: testdeps testdata/redis: mkdir -p $@ - wget -qO- https://download.redis.io/releases/redis-6.2.5.tar.gz | tar xvz --strip-components=1 -C $@ + wget -qO- https://download.redis.io/releases/redis-7.0.7.tar.gz | tar xvz --strip-components=1 -C $@ testdata/redis/src/redis-server: testdata/redis cd $< && make all @@ -26,10 +28,9 @@ fmt: goimports -w -local github.com/go-redis/redis ./ go_mod_tidy: - go get -u && go mod tidy set -e; for dir in $(PACKAGE_DIRS); do \ echo "go mod tidy in $${dir}"; \ (cd "$${dir}" && \ - go get -u && \ - go mod tidy); \ + go get -u ./... && \ + go mod tidy -compat=1.17); \ done diff --git a/README.md b/README.md index 0419f35b..d7bec205 100644 --- a/README.md +++ b/README.md @@ -1,53 +1,47 @@ -

- - All-in-one tool to optimize performance and monitor errors & logs - -

+# Redis client for Go -# Redis client for Golang - -![build workflow](https://github.com/go-redis/redis/actions/workflows/build.yml/badge.svg) +[![build workflow](https://github.com/go-redis/redis/actions/workflows/build.yml/badge.svg)](https://github.com/go-redis/redis/actions) [![PkgGoDev](https://pkg.go.dev/badge/github.com/go-redis/redis/v8)](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc) [![Documentation](https://img.shields.io/badge/redis-documentation-informational)](https://redis.uptrace.dev/) [![Chat](https://discordapp.com/api/guilds/752070105847955518/widget.png)](https://discord.gg/rWtp5Aj) -- To ask questions, join [Discord](https://discord.gg/rWtp5Aj) or use - [Discussions](https://github.com/go-redis/redis/discussions). -- [Newsletter](https://blog.uptrace.dev/pages/newsletter.html) to get latest updates. +> go-redis is brought to you by :star: [**uptrace/uptrace**](https://github.com/uptrace/uptrace). +> Uptrace is an open-source APM tool that supports distributed tracing, metrics, and logs. You can +> use it to monitor applications and set up automatic alerts to receive notifications via email, +> Slack, Telegram, and others. +> +> See [OpenTelemetry](example/otel) example which demonstrates how you can use Uptrace to monitor +> go-redis. + +## Resources + - [Documentation](https://redis.uptrace.dev) -- [Reference](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc) -- [Examples](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#pkg-examples) -- [RealWorld example app](https://github.com/uptrace/go-treemux-realworld-example-app) - -Other projects you may like: - -- [Bun](https://bun.uptrace.dev/) - fast and simple SQL client for PostgreSQL, MySQL, and SQLite. -- [BunRouter](https://bunrouter.uptrace.dev/) - fast and flexible HTTP router for Go. +- [Discussions](https://github.com/go-redis/redis/discussions) +- [Chat](https://discord.gg/rWtp5Aj) +- [Reference](https://pkg.go.dev/github.com/go-redis/redis/v9) +- [Examples](https://pkg.go.dev/github.com/go-redis/redis/v9#pkg-examples) ## Ecosystem -- [Redis Mock](https://github.com/go-redis/redismock). -- [Distributed Locks](https://github.com/bsm/redislock). -- [Redis Cache](https://github.com/go-redis/cache). -- [Rate limiting](https://github.com/go-redis/redis_rate). +- [Redis Mock](https://github.com/go-redis/redismock) +- [Distributed Locks](https://github.com/bsm/redislock) +- [Redis Cache](https://github.com/go-redis/cache) +- [Rate limiting](https://github.com/go-redis/redis_rate) + +This client also works with [Kvrocks](https://github.com/apache/incubator-kvrocks), a distributed +key value NoSQL database that uses RocksDB as storage engine and is compatible with Redis protocol. ## Features - Redis 3 commands except QUIT, MONITOR, and SYNC. - Automatic connection pooling with - [circuit breaker](https://en.wikipedia.org/wiki/Circuit_breaker_design_pattern) support. -- [Pub/Sub](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#PubSub). -- [Transactions](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#example-Client-TxPipeline). -- [Pipeline](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#example-Client.Pipeline) and - [TxPipeline](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#example-Client.TxPipeline). -- [Scripting](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#Script). -- [Timeouts](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#Options). -- [Redis Sentinel](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#NewFailoverClient). -- [Redis Cluster](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#NewClusterClient). -- [Cluster of Redis Servers](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#example-NewClusterClient-ManualSetup) - without using cluster mode and Redis Sentinel. -- [Ring](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#NewRing). -- [Instrumentation](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#example-package-Instrumentation). +- [Pub/Sub](https://redis.uptrace.dev/guide/go-redis-pubsub.html). +- [Pipelines and transactions](https://redis.uptrace.dev/guide/go-redis-pipelines.html). +- [Scripting](https://redis.uptrace.dev/guide/lua-scripting.html). +- [Redis Sentinel](https://redis.uptrace.dev/guide/go-redis-sentinel.html). +- [Redis Cluster](https://redis.uptrace.dev/guide/go-redis-cluster.html). +- [Redis Ring](https://redis.uptrace.dev/guide/ring.html). +- [Redis Performance Monitoring](https://redis.uptrace.dev/guide/redis-performance-monitoring.html). ## Installation @@ -59,18 +53,25 @@ module: go mod init github.com/my/repo ``` -And then install go-redis/v8 (note _v8_ in the import; omitting it is a popular mistake): +If you are using **Redis 6**, install go-redis/**v8**: ```shell go get github.com/go-redis/redis/v8 ``` +If you are using **Redis 7**, install go-redis/**v9**: + +```shell +go get github.com/go-redis/redis/v9 +``` + ## Quickstart ```go import ( "context" "github.com/go-redis/redis/v8" + "fmt" ) var ctx = context.Background() @@ -147,7 +148,7 @@ go-redis will start a redis-server and run the test cases. The paths of redis-server bin file and redis config file are defined in `main_test.go`: -``` +```go var ( redisServerBin, _ = filepath.Abs(filepath.Join("testdata", "redis", "src", "redis-server")) redisServerConf, _ = filepath.Abs(filepath.Join("testdata", "redis", "redis.conf")) @@ -157,17 +158,24 @@ var ( For local testing, you can change the variables to refer to your local files, or create a soft link to the corresponding folder for redis-server and copy the config file to `testdata/redis/`: -``` +```shell ln -s /usr/bin/redis-server ./go-redis/testdata/redis/src cp ./go-redis/testdata/redis.conf ./go-redis/testdata/redis/ ``` Lastly, run: -``` +```shell go test ``` +## See also + +- [Golang ORM](https://bun.uptrace.dev) for PostgreSQL, MySQL, MSSQL, and SQLite +- [Golang PostgreSQL](https://bun.uptrace.dev/postgres/) +- [Golang HTTP router](https://bunrouter.uptrace.dev/) +- [Golang ClickHouse ORM](https://github.com/uptrace/go-clickhouse) + ## Contributors Thanks to all the people who already contributed! diff --git a/bench_decode_test.go b/bench_decode_test.go index 83828064..fc929e52 100644 --- a/bench_decode_test.go +++ b/bench_decode_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/go-redis/redis/v8/internal/proto" + "github.com/go-redis/redis/v9/internal/proto" ) var ctx = context.TODO() @@ -18,14 +18,17 @@ type ClientStub struct { resp []byte } +var initHello = []byte("%1\r\n+proto\r\n:3\r\n") + func NewClientStub(resp []byte) *ClientStub { stub := &ClientStub{ resp: resp, } + stub.Cmdable = NewClient(&Options{ PoolSize: 128, Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) { - return stub.stubConn(), nil + return stub.stubConn(initHello), nil }, }) return stub @@ -38,9 +41,9 @@ func NewClusterClientStub(resp []byte) *ClientStub { client := NewClusterClient(&ClusterOptions{ PoolSize: 128, - Addrs: []string{"127.0.0.1:6379"}, + Addrs: []string{":6379"}, Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) { - return stub.stubConn(), nil + return stub.stubConn(initHello), nil }, ClusterSlots: func(_ context.Context) ([]ClusterSlot, error) { return []ClusterSlot{ @@ -65,18 +68,27 @@ func NewClusterClientStub(resp []byte) *ClientStub { return stub } -func (c *ClientStub) stubConn() *ConnStub { +func (c *ClientStub) stubConn(init []byte) *ConnStub { return &ConnStub{ + init: init, resp: c.resp, } } type ConnStub struct { + init []byte resp []byte pos int } func (c *ConnStub) Read(b []byte) (n int, err error) { + // Return conn.init() + if len(c.init) > 0 { + n = copy(b, c.init) + c.init = c.init[n:] + return n, nil + } + if len(c.resp) == 0 { return 0, io.EOF } @@ -106,7 +118,7 @@ func BenchmarkDecode(b *testing.B) { } benchmarks := []Benchmark{ - {"single", NewClientStub}, + {"server", NewClientStub}, {"cluster", NewClusterClientStub}, } diff --git a/bench_test.go b/bench_test.go index 5644f50c..cb29a7a8 100644 --- a/bench_test.go +++ b/bench_test.go @@ -10,7 +10,7 @@ import ( "testing" "time" - "github.com/go-redis/redis/v8" + "github.com/go-redis/redis/v9" ) func benchmarkRedisClient(ctx context.Context, poolSize int) *redis.Client { @@ -223,7 +223,7 @@ func BenchmarkZAdd(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - err := client.ZAdd(ctx, "key", &redis.Z{ + err := client.ZAdd(ctx, "key", redis.Z{ Score: float64(1), Member: "hello", }).Err() @@ -273,36 +273,6 @@ func BenchmarkXRead(b *testing.B) { }) } -var clientSink *redis.Client - -func BenchmarkWithContext(b *testing.B) { - ctx := context.Background() - rdb := benchmarkRedisClient(ctx, 10) - defer rdb.Close() - - b.ResetTimer() - b.ReportAllocs() - - for i := 0; i < b.N; i++ { - clientSink = rdb.WithContext(ctx) - } -} - -var ringSink *redis.Ring - -func BenchmarkRingWithContext(b *testing.B) { - ctx := context.Background() - rdb := redis.NewRing(&redis.RingOptions{}) - defer rdb.Close() - - b.ResetTimer() - b.ReportAllocs() - - for i := 0; i < b.N; i++ { - ringSink = rdb.WithContext(ctx) - } -} - //------------------------------------------------------------------------------ func newClusterScenario() *clusterScenario { @@ -341,6 +311,32 @@ func BenchmarkClusterPing(b *testing.B) { }) } +func BenchmarkClusterDoInt(b *testing.B) { + if testing.Short() { + b.Skip("skipping in short mode") + } + + ctx := context.Background() + cluster := newClusterScenario() + if err := startCluster(ctx, cluster); err != nil { + b.Fatal(err) + } + defer cluster.Close() + + client := cluster.newClusterClient(ctx, redisClusterOptions()) + defer client.Close() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + err := client.Do(ctx, "SET", 10, 10).Err() + if err != nil { + b.Fatal(err) + } + } + }) +} + func BenchmarkClusterSetString(b *testing.B) { if testing.Short() { b.Skip("skipping in short mode") @@ -370,17 +366,77 @@ func BenchmarkClusterSetString(b *testing.B) { }) } -var clusterSink *redis.ClusterClient +func BenchmarkExecRingSetAddrsCmd(b *testing.B) { + const ( + ringShard1Name = "ringShardOne" + ringShard2Name = "ringShardTwo" + ) -func BenchmarkClusterWithContext(b *testing.B) { - ctx := context.Background() - rdb := redis.NewClusterClient(&redis.ClusterOptions{}) - defer rdb.Close() + for _, port := range []string{ringShard1Port, ringShard2Port} { + if _, err := startRedis(port); err != nil { + b.Fatal(err) + } + } + + b.Cleanup(func() { + for _, p := range processes { + if err := p.Close(); err != nil { + b.Errorf("Failed to stop redis process: %v", err) + } + } + processes = nil + }) + + ring := redis.NewRing(&redis.RingOptions{ + Addrs: map[string]string{ + "ringShardOne": ":" + ringShard1Port, + }, + NewClient: func(opt *redis.Options) *redis.Client { + // Simulate slow shard creation + time.Sleep(100 * time.Millisecond) + return redis.NewClient(opt) + }, + }) + defer ring.Close() + + if _, err := ring.Ping(context.Background()).Result(); err != nil { + b.Fatal(err) + } + + // Continuously update addresses by adding and removing one address + updatesDone := make(chan struct{}) + defer func() { close(updatesDone) }() + go func() { + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + for i := 0; ; i++ { + select { + case <-ticker.C: + if i%2 == 0 { + ring.SetAddrs(map[string]string{ + ringShard1Name: ":" + ringShard1Port, + }) + } else { + ring.SetAddrs(map[string]string{ + ringShard1Name: ":" + ringShard1Port, + ringShard2Name: ":" + ringShard2Port, + }) + } + case <-updatesDone: + return + } + } + }() b.ResetTimer() - b.ReportAllocs() - for i := 0; i < b.N; i++ { - clusterSink = rdb.WithContext(ctx) + if _, err := ring.Ping(context.Background()).Result(); err != nil { + if err == redis.ErrClosed { + // The shard client could be closed while ping command is in progress + continue + } else { + b.Fatal(err) + } + } } } diff --git a/cluster.go b/cluster.go index a54f2f37..872baad4 100644 --- a/cluster.go +++ b/cluster.go @@ -6,17 +6,19 @@ import ( "fmt" "math" "net" + "net/url" "runtime" "sort" + "strings" "sync" "sync/atomic" "time" - "github.com/go-redis/redis/v8/internal" - "github.com/go-redis/redis/v8/internal/hashtag" - "github.com/go-redis/redis/v8/internal/pool" - "github.com/go-redis/redis/v8/internal/proto" - "github.com/go-redis/redis/v8/internal/rand" + "github.com/go-redis/redis/v9/internal" + "github.com/go-redis/redis/v9/internal/hashtag" + "github.com/go-redis/redis/v9/internal/pool" + "github.com/go-redis/redis/v9/internal/proto" + "github.com/go-redis/redis/v9/internal/rand" ) var errClusterNoNodes = fmt.Errorf("redis: cluster has no nodes") @@ -27,6 +29,9 @@ type ClusterOptions struct { // A seed list of host:port addresses of cluster nodes. Addrs []string + // ClientName will execute the `CLIENT SETNAME ClientName` command for each conn. + ClientName string + // NewClient creates a cluster node client with provided name and options. NewClient func(opt *Options) *Client @@ -64,20 +69,18 @@ type ClusterOptions struct { MinRetryBackoff time.Duration MaxRetryBackoff time.Duration - DialTimeout time.Duration - ReadTimeout time.Duration - WriteTimeout time.Duration + DialTimeout time.Duration + ReadTimeout time.Duration + WriteTimeout time.Duration + ContextTimeoutEnabled bool - // PoolFIFO uses FIFO mode for each node connection pool GET/PUT (default LIFO). - PoolFIFO bool - - // PoolSize applies per cluster node and not for the whole cluster. - PoolSize int - MinIdleConns int - MaxConnAge time.Duration - PoolTimeout time.Duration - IdleTimeout time.Duration - IdleCheckFrequency time.Duration + PoolFIFO bool + PoolSize int // applies per cluster node and not for the whole cluster + PoolTimeout time.Duration + MinIdleConns int + MaxIdleConns int + ConnMaxIdleTime time.Duration + ConnMaxLifetime time.Duration TLSConfig *tls.Config } @@ -131,12 +134,134 @@ func (opt *ClusterOptions) init() { } } -func (opt *ClusterOptions) clientOptions() *Options { - const disableIdleCheck = -1 +// ParseClusterURL parses a URL into ClusterOptions that can be used to connect to Redis. +// The URL must be in the form: +// +// redis://:@: +// or +// rediss://:@: +// +// To add additional addresses, specify the query parameter, "addr" one or more times. e.g: +// +// redis://:@:?addr=:&addr=: +// or +// rediss://:@:?addr=:&addr=: +// +// Most Option fields can be set using query parameters, with the following restrictions: +// - field names are mapped using snake-case conversion: to set MaxRetries, use max_retries +// - only scalar type fields are supported (bool, int, time.Duration) +// - for time.Duration fields, values must be a valid input for time.ParseDuration(); +// additionally a plain integer as value (i.e. without unit) is intepreted as seconds +// - to disable a duration field, use value less than or equal to 0; to use the default +// value, leave the value blank or remove the parameter +// - only the last value is interpreted if a parameter is given multiple times +// - fields "network", "addr", "username" and "password" can only be set using other +// URL attributes (scheme, host, userinfo, resp.), query paremeters using these +// names will be treated as unknown parameters +// - unknown parameter names will result in an error +// +// Example: +// +// redis://user:password@localhost:6789?dial_timeout=3&read_timeout=6s&addr=localhost:6790&addr=localhost:6791 +// is equivalent to: +// &ClusterOptions{ +// Addr: ["localhost:6789", "localhost:6790", "localhost:6791"] +// DialTimeout: 3 * time.Second, // no time unit = seconds +// ReadTimeout: 6 * time.Second, +// } +func ParseClusterURL(redisURL string) (*ClusterOptions, error) { + o := &ClusterOptions{} + u, err := url.Parse(redisURL) + if err != nil { + return nil, err + } + + // add base URL to the array of addresses + // more addresses may be added through the URL params + h, p := getHostPortWithDefaults(u) + o.Addrs = append(o.Addrs, net.JoinHostPort(h, p)) + + // setup username, password, and other configurations + o, err = setupClusterConn(u, h, o) + if err != nil { + return nil, err + } + + return o, nil +} + +// setupClusterConn gets the username and password from the URL and the query parameters. +func setupClusterConn(u *url.URL, host string, o *ClusterOptions) (*ClusterOptions, error) { + switch u.Scheme { + case "rediss": + o.TLSConfig = &tls.Config{ServerName: host} + fallthrough + case "redis": + o.Username, o.Password = getUserPassword(u) + default: + return nil, fmt.Errorf("redis: invalid URL scheme: %s", u.Scheme) + } + + // retrieve the configuration from the query parameters + o, err := setupClusterQueryParams(u, o) + if err != nil { + return nil, err + } + + return o, nil +} + +// setupClusterQueryParams converts query parameters in u to option value in o. +func setupClusterQueryParams(u *url.URL, o *ClusterOptions) (*ClusterOptions, error) { + q := queryOptions{q: u.Query()} + + o.ClientName = q.string("client_name") + o.MaxRedirects = q.int("max_redirects") + o.ReadOnly = q.bool("read_only") + o.RouteByLatency = q.bool("route_by_latency") + o.RouteRandomly = q.bool("route_randomly") + o.MaxRetries = q.int("max_retries") + o.MinRetryBackoff = q.duration("min_retry_backoff") + o.MaxRetryBackoff = q.duration("max_retry_backoff") + o.DialTimeout = q.duration("dial_timeout") + o.ReadTimeout = q.duration("read_timeout") + o.WriteTimeout = q.duration("write_timeout") + o.PoolFIFO = q.bool("pool_fifo") + o.PoolSize = q.int("pool_size") + o.MinIdleConns = q.int("min_idle_conns") + o.PoolTimeout = q.duration("pool_timeout") + o.ConnMaxLifetime = q.duration("conn_max_lifetime") + o.ConnMaxIdleTime = q.duration("conn_max_idle_time") + + if q.err != nil { + return nil, q.err + } + + // addr can be specified as many times as needed + addrs := q.strings("addr") + for _, addr := range addrs { + h, p, err := net.SplitHostPort(addr) + if err != nil || h == "" || p == "" { + return nil, fmt.Errorf("redis: unable to parse addr param: %s", addr) + } + + o.Addrs = append(o.Addrs, net.JoinHostPort(h, p)) + } + + // any parameters left? + if r := q.remaining(); len(r) > 0 { + return nil, fmt.Errorf("redis: unexpected option: %s", strings.Join(r, ", ")) + } + + return o, nil +} + +func (opt *ClusterOptions) clientOptions() *Options { return &Options{ - Dialer: opt.Dialer, - OnConnect: opt.OnConnect, + ClientName: opt.ClientName, + Dialer: opt.Dialer, + OnConnect: opt.OnConnect, Username: opt.Username, Password: opt.Password, @@ -149,13 +274,13 @@ func (opt *ClusterOptions) clientOptions() *Options { ReadTimeout: opt.ReadTimeout, WriteTimeout: opt.WriteTimeout, - PoolFIFO: opt.PoolFIFO, - PoolSize: opt.PoolSize, - MinIdleConns: opt.MinIdleConns, - MaxConnAge: opt.MaxConnAge, - PoolTimeout: opt.PoolTimeout, - IdleTimeout: opt.IdleTimeout, - IdleCheckFrequency: disableIdleCheck, + PoolFIFO: opt.PoolFIFO, + PoolSize: opt.PoolSize, + PoolTimeout: opt.PoolTimeout, + MinIdleConns: opt.MinIdleConns, + MaxIdleConns: opt.MaxIdleConns, + ConnMaxIdleTime: opt.ConnMaxIdleTime, + ConnMaxLifetime: opt.ConnMaxLifetime, TLSConfig: opt.TLSConfig, // If ClusterSlots is populated, then we probably have an artificial @@ -204,15 +329,26 @@ func (n *clusterNode) updateLatency() { const numProbe = 10 var dur uint64 + successes := 0 for i := 0; i < numProbe; i++ { time.Sleep(time.Duration(10+rand.Intn(10)) * time.Millisecond) start := time.Now() - n.Client.Ping(context.TODO()) - dur += uint64(time.Since(start) / time.Microsecond) + err := n.Client.Ping(context.TODO()).Err() + if err == nil { + dur += uint64(time.Since(start) / time.Microsecond) + successes++ + } } - latency := float64(dur) / float64(numProbe) + var latency float64 + if successes == 0 { + // If none of the pings worked, set latency to some arbitrarily high value so this node gets + // least priority. + latency = float64((1 * time.Minute) / time.Microsecond) + } else { + latency = float64(dur) / float64(successes) + } atomic.StoreUint32(&n.latency, uint32(latency+0.5)) } @@ -262,6 +398,7 @@ type clusterNodes struct { nodes map[string]*clusterNode activeAddrs []string closed bool + onNewNode []func(rdb *Client) _generation uint32 // atomic } @@ -297,6 +434,12 @@ func (c *clusterNodes) Close() error { return firstErr } +func (c *clusterNodes) OnNewNode(fn func(rdb *Client)) { + c.mu.Lock() + c.onNewNode = append(c.onNewNode, fn) + c.mu.Unlock() +} + func (c *clusterNodes) Addrs() ([]string, error) { var addrs []string @@ -374,6 +517,9 @@ func (c *clusterNodes) GetOrCreate(addr string) (*clusterNode, error) { } node = newClusterNode(c.opt, addr) + for _, fn := range c.onNewNode { + fn(node.Client) + } c.addrs = appendIfNotExists(c.addrs, addr) c.nodes[addr] = node @@ -683,21 +829,16 @@ func (c *clusterStateHolder) ReloadOrGet(ctx context.Context) (*clusterState, er //------------------------------------------------------------------------------ -type clusterClient struct { - opt *ClusterOptions - nodes *clusterNodes - state *clusterStateHolder //nolint:structcheck - cmdsInfoCache *cmdsInfoCache //nolint:structcheck -} - // ClusterClient is a Redis Cluster client representing a pool of zero // or more underlying connections. It's safe for concurrent use by // multiple goroutines. type ClusterClient struct { - *clusterClient + opt *ClusterOptions + nodes *clusterNodes + state *clusterStateHolder + cmdsInfoCache *cmdsInfoCache cmdable hooks - ctx context.Context } // NewClusterClient returns a Redis Cluster client as described in @@ -706,38 +847,21 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { opt.init() c := &ClusterClient{ - clusterClient: &clusterClient{ - opt: opt, - nodes: newClusterNodes(opt), - }, - ctx: context.Background(), + opt: opt, + nodes: newClusterNodes(opt), } + c.state = newClusterStateHolder(c.loadState) c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo) c.cmdable = c.Process - if opt.IdleCheckFrequency > 0 { - go c.reaper(opt.IdleCheckFrequency) - } + c.hooks.setProcess(c.process) + c.hooks.setProcessPipeline(c.processPipeline) + c.hooks.setProcessTxPipeline(c.processTxPipeline) return c } -func (c *ClusterClient) Context() context.Context { - return c.ctx -} - -func (c *ClusterClient) WithContext(ctx context.Context) *ClusterClient { - if ctx == nil { - panic("nil context") - } - clone := *c - clone.cmdable = clone.Process - clone.hooks.lock() - clone.ctx = ctx - return &clone -} - // Options returns read-only Options that were used to create the client. func (c *ClusterClient) Options() *ClusterOptions { return c.opt @@ -757,7 +881,7 @@ func (c *ClusterClient) Close() error { return c.nodes.Close() } -// Do creates a Cmd from the args and processes the cmd. +// Do create a Cmd from the args and processes the cmd. func (c *ClusterClient) Do(ctx context.Context, args ...interface{}) *Cmd { cmd := NewCmd(ctx, args...) _ = c.Process(ctx, cmd) @@ -765,13 +889,14 @@ func (c *ClusterClient) Do(ctx context.Context, args ...interface{}) *Cmd { } func (c *ClusterClient) Process(ctx context.Context, cmd Cmder) error { - return c.hooks.process(ctx, cmd, c.process) + err := c.hooks.process(ctx, cmd) + cmd.SetErr(err) + return err } func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { - cmdInfo := c.cmdInfo(cmd.Name()) - slot := c.cmdSlot(cmd) - + cmdInfo := c.cmdInfo(ctx, cmd.Name()) + slot := c.cmdSlot(ctx, cmd) var node *clusterNode var ask bool var lastErr error @@ -791,12 +916,12 @@ func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { } if ask { + ask = false + pipe := node.Client.Pipeline() _ = pipe.Process(ctx, NewCmd(ctx, "asking")) _ = pipe.Process(ctx, cmd) _, lastErr = pipe.Exec(ctx) - _ = pipe.Close() - ask = false } else { lastErr = node.Client.Process(ctx, cmd) } @@ -851,6 +976,10 @@ func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { return lastErr } +func (c *ClusterClient) OnNewNode(fn func(rdb *Client)) { + c.nodes.OnNewNode(fn) +} + // ForEachMaster concurrently calls the fn on each master node in the cluster. // It returns the first error if any. func (c *ClusterClient) ForEachMaster( @@ -1056,30 +1185,9 @@ func (c *ClusterClient) loadState(ctx context.Context) (*clusterState, error) { return nil, firstErr } -// reaper closes idle connections to the cluster. -func (c *ClusterClient) reaper(idleCheckFrequency time.Duration) { - ticker := time.NewTicker(idleCheckFrequency) - defer ticker.Stop() - - for range ticker.C { - nodes, err := c.nodes.All() - if err != nil { - break - } - - for _, node := range nodes { - _, err := node.Client.connPool.(*pool.ConnPool).ReapStaleConns() - if err != nil { - internal.Logger.Printf(c.Context(), "ReapStaleConns failed: %s", err) - } - } - } -} - func (c *ClusterClient) Pipeline() Pipeliner { pipe := Pipeline{ - ctx: c.ctx, - exec: c.processPipeline, + exec: pipelineExecer(c.hooks.processPipeline), } pipe.init() return &pipe @@ -1090,13 +1198,9 @@ func (c *ClusterClient) Pipelined(ctx context.Context, fn func(Pipeliner) error) } func (c *ClusterClient) processPipeline(ctx context.Context, cmds []Cmder) error { - return c.hooks.processPipeline(ctx, cmds, c._processPipeline) -} - -func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) error { cmdsMap := newCmdsMap() - err := c.mapCmdsByNode(ctx, cmdsMap, cmds) - if err != nil { + + if err := c.mapCmdsByNode(ctx, cmdsMap, cmds); err != nil { setCmdsErr(cmds, err) return err } @@ -1116,18 +1220,7 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro wg.Add(1) go func(node *clusterNode, cmds []Cmder) { defer wg.Done() - - err := c._processPipelineNode(ctx, node, cmds, failedCmds) - if err == nil { - return - } - if attempt < c.opt.MaxRedirects { - if err := c.mapCmdsByNode(ctx, failedCmds, cmds); err != nil { - setCmdsErr(cmds, err) - } - } else { - setCmdsErr(cmds, err) - } + c.processPipelineNode(ctx, node, cmds, failedCmds) }(node, cmds) } @@ -1147,9 +1240,9 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd return err } - if c.opt.ReadOnly && c.cmdsAreReadOnly(cmds) { + if c.opt.ReadOnly && c.cmdsAreReadOnly(ctx, cmds) { for _, cmd := range cmds { - slot := c.cmdSlot(cmd) + slot := c.cmdSlot(ctx, cmd) node, err := c.slotReadOnlyNode(state, slot) if err != nil { return err @@ -1160,7 +1253,7 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd } for _, cmd := range cmds { - slot := c.cmdSlot(cmd) + slot := c.cmdSlot(ctx, cmd) node, err := state.slotMasterNode(slot) if err != nil { return err @@ -1170,9 +1263,9 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd return nil } -func (c *ClusterClient) cmdsAreReadOnly(cmds []Cmder) bool { +func (c *ClusterClient) cmdsAreReadOnly(ctx context.Context, cmds []Cmder) bool { for _, cmd := range cmds { - cmdInfo := c.cmdInfo(cmd.Name()) + cmdInfo := c.cmdInfo(ctx, cmd.Name()) if cmdInfo == nil || !cmdInfo.ReadOnly { return false } @@ -1180,22 +1273,38 @@ func (c *ClusterClient) cmdsAreReadOnly(cmds []Cmder) bool { return true } -func (c *ClusterClient) _processPipelineNode( +func (c *ClusterClient) processPipelineNode( ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap, -) error { - return node.Client.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { - return node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { - err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { - return writeCmds(wr, cmds) - }) - if err != nil { - return err - } +) { + _ = node.Client.hooks.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { + cn, err := node.Client.getConn(ctx) + if err != nil { + _ = c.mapCmdsByNode(ctx, failedCmds, cmds) + setCmdsErr(cmds, err) + return err + } - return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { - return c.pipelineReadCmds(ctx, node, rd, cmds, failedCmds) - }) - }) + err = c.processPipelineNodeConn(ctx, node, cn, cmds, failedCmds) + node.Client.releaseConn(ctx, cn, err) + return err + }) +} + +func (c *ClusterClient) processPipelineNodeConn( + ctx context.Context, node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap, +) error { + if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { + return writeCmds(wr, cmds) + }); err != nil { + if shouldRetry(err, true) { + _ = c.mapCmdsByNode(ctx, failedCmds, cmds) + } + setCmdsErr(cmds, err) + return err + } + + return cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error { + return c.pipelineReadCmds(ctx, node, rd, cmds, failedCmds) }) } @@ -1206,7 +1315,7 @@ func (c *ClusterClient) pipelineReadCmds( cmds []Cmder, failedCmds *cmdsMap, ) error { - for _, cmd := range cmds { + for i, cmd := range cmds { err := cmd.readReply(rd) cmd.SetErr(err) @@ -1218,15 +1327,24 @@ func (c *ClusterClient) pipelineReadCmds( continue } - if c.opt.ReadOnly && isLoadingError(err) { + if c.opt.ReadOnly { node.MarkAsFailing() + } + + if !isRedisError(err) { + if shouldRetry(err, true) { + _ = c.mapCmdsByNode(ctx, failedCmds, cmds) + } + setCmdsErr(cmds[i+1:], err) return err } - if isRedisError(err) { - continue - } + } + + if err := cmds[0].Err(); err != nil && shouldRetry(err, true) { + _ = c.mapCmdsByNode(ctx, failedCmds, cmds) return err } + return nil } @@ -1260,8 +1378,10 @@ func (c *ClusterClient) checkMovedErr( // TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC. func (c *ClusterClient) TxPipeline() Pipeliner { pipe := Pipeline{ - ctx: c.ctx, - exec: c.processTxPipeline, + exec: func(ctx context.Context, cmds []Cmder) error { + cmds = wrapMultiExec(ctx, cmds) + return c.hooks.processTxPipeline(ctx, cmds) + }, } pipe.init() return &pipe @@ -1272,10 +1392,6 @@ func (c *ClusterClient) TxPipelined(ctx context.Context, fn func(Pipeliner) erro } func (c *ClusterClient) processTxPipeline(ctx context.Context, cmds []Cmder) error { - return c.hooks.processTxPipeline(ctx, cmds, c._processTxPipeline) -} - -func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) error { // Trim multi .. exec. cmds = cmds[1 : len(cmds)-1] @@ -1285,7 +1401,7 @@ func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) er return err } - cmdsMap := c.mapCmdsBySlot(cmds) + cmdsMap := c.mapCmdsBySlot(ctx, cmds) for slot, cmds := range cmdsMap { node, err := state.slotMasterNode(slot) if err != nil { @@ -1309,19 +1425,7 @@ func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) er wg.Add(1) go func(node *clusterNode, cmds []Cmder) { defer wg.Done() - - err := c._processTxPipelineNode(ctx, node, cmds, failedCmds) - if err == nil { - return - } - - if attempt < c.opt.MaxRedirects { - if err := c.mapCmdsByNode(ctx, failedCmds, cmds); err != nil { - setCmdsErr(cmds, err) - } - } else { - setCmdsErr(cmds, err) - } + c.processTxPipelineNode(ctx, node, cmds, failedCmds) }(node, cmds) } @@ -1336,44 +1440,65 @@ func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) er return cmdsFirstErr(cmds) } -func (c *ClusterClient) mapCmdsBySlot(cmds []Cmder) map[int][]Cmder { +func (c *ClusterClient) mapCmdsBySlot(ctx context.Context, cmds []Cmder) map[int][]Cmder { cmdsMap := make(map[int][]Cmder) for _, cmd := range cmds { - slot := c.cmdSlot(cmd) + slot := c.cmdSlot(ctx, cmd) cmdsMap[slot] = append(cmdsMap[slot], cmd) } return cmdsMap } -func (c *ClusterClient) _processTxPipelineNode( +func (c *ClusterClient) processTxPipelineNode( ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap, +) { + cmds = wrapMultiExec(ctx, cmds) + _ = node.Client.hooks.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { + cn, err := node.Client.getConn(ctx) + if err != nil { + _ = c.mapCmdsByNode(ctx, failedCmds, cmds) + setCmdsErr(cmds, err) + return err + } + + err = c.processTxPipelineNodeConn(ctx, node, cn, cmds, failedCmds) + node.Client.releaseConn(ctx, cn, err) + return err + }) +} + +func (c *ClusterClient) processTxPipelineNodeConn( + ctx context.Context, node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap, ) error { - return node.Client.hooks.processTxPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { - return node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { - err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { - return writeCmds(wr, cmds) - }) - if err != nil { - return err + if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { + return writeCmds(wr, cmds) + }); err != nil { + if shouldRetry(err, true) { + _ = c.mapCmdsByNode(ctx, failedCmds, cmds) + } + setCmdsErr(cmds, err) + return err + } + + return cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error { + statusCmd := cmds[0].(*StatusCmd) + // Trim multi and exec. + trimmedCmds := cmds[1 : len(cmds)-1] + + if err := c.txPipelineReadQueued( + ctx, rd, statusCmd, trimmedCmds, failedCmds, + ); err != nil { + setCmdsErr(cmds, err) + + moved, ask, addr := isMovedError(err) + if moved || ask { + return c.cmdsMoved(ctx, trimmedCmds, moved, ask, addr, failedCmds) } - return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { - statusCmd := cmds[0].(*StatusCmd) - // Trim multi and exec. - cmds = cmds[1 : len(cmds)-1] + return err + } - err := c.txPipelineReadQueued(ctx, rd, statusCmd, cmds, failedCmds) - if err != nil { - moved, ask, addr := isMovedError(err) - if moved || ask { - return c.cmdsMoved(ctx, cmds, moved, ask, addr, failedCmds) - } - return err - } - - return pipelineReadCmds(rd, cmds) - }) - }) + return pipelineReadCmds(rd, trimmedCmds) }) } @@ -1406,12 +1531,7 @@ func (c *ClusterClient) txPipelineReadQueued( return err } - switch line[0] { - case proto.ErrorReply: - return proto.ParseErrorReply(line) - case proto.ArrayReply: - // ok - default: + if line[0] != proto.RespArray { return fmt.Errorf("redis: expected '*', but got line %q", line) } @@ -1568,6 +1688,15 @@ func (c *ClusterClient) PSubscribe(ctx context.Context, channels ...string) *Pub return pubsub } +// SSubscribe Subscribes the client to the specified shard channels. +func (c *ClusterClient) SSubscribe(ctx context.Context, channels ...string) *PubSub { + pubsub := c.pubSub() + if len(channels) > 0 { + _ = pubsub.SSubscribe(ctx, channels...) + } + return pubsub +} + func (c *ClusterClient) retryBackoff(attempt int) time.Duration { return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff) } @@ -1614,26 +1743,27 @@ func (c *ClusterClient) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, return nil, firstErr } -func (c *ClusterClient) cmdInfo(name string) *CommandInfo { - cmdsInfo, err := c.cmdsInfoCache.Get(c.ctx) +func (c *ClusterClient) cmdInfo(ctx context.Context, name string) *CommandInfo { + cmdsInfo, err := c.cmdsInfoCache.Get(ctx) if err != nil { + internal.Logger.Printf(context.TODO(), "getting command info: %s", err) return nil } info := cmdsInfo[name] if info == nil { - internal.Logger.Printf(c.Context(), "info for cmd=%s not found", name) + internal.Logger.Printf(context.TODO(), "info for cmd=%s not found", name) } return info } -func (c *ClusterClient) cmdSlot(cmd Cmder) int { +func (c *ClusterClient) cmdSlot(ctx context.Context, cmd Cmder) int { args := cmd.Args() if args[0] == "cluster" && args[1] == "getkeysinslot" { return args[2].(int) } - cmdInfo := c.cmdInfo(cmd.Name()) + cmdInfo := c.cmdInfo(ctx, cmd.Name()) return cmdSlot(cmd, cmdFirstKeyPos(cmd, cmdInfo)) } @@ -1661,7 +1791,7 @@ func (c *ClusterClient) cmdNode( return state.slotMasterNode(slot) } -func (c *clusterClient) slotReadOnlyNode(state *clusterState, slot int) (*clusterNode, error) { +func (c *ClusterClient) slotReadOnlyNode(state *clusterState, slot int) (*clusterNode, error) { if c.opt.RouteByLatency { return state.slotClosestNode(slot) } @@ -1708,6 +1838,13 @@ func (c *ClusterClient) MasterForKey(ctx context.Context, key string) (*Client, return node.Client, err } +func (c *ClusterClient) context(ctx context.Context) context.Context { + if c.opt.ContextTimeoutEnabled { + return ctx + } + return context.Background() +} + func appendUniqueNode(nodes []*clusterNode, node *clusterNode) []*clusterNode { for _, n := range nodes { if n == node { diff --git a/cluster_commands.go b/cluster_commands.go index 085bce83..fc0a9cd4 100644 --- a/cluster_commands.go +++ b/cluster_commands.go @@ -8,7 +8,7 @@ import ( func (c *ClusterClient) DBSize(ctx context.Context) *IntCmd { cmd := NewIntCmd(ctx, "dbsize") - _ = c.hooks.process(ctx, cmd, func(ctx context.Context, _ Cmder) error { + _ = c.hooks.withProcessHook(ctx, cmd, func(ctx context.Context, _ Cmder) error { var size int64 err := c.ForEachMaster(ctx, func(ctx context.Context, master *Client) error { n, err := master.DBSize(ctx).Result() @@ -30,7 +30,7 @@ func (c *ClusterClient) DBSize(ctx context.Context) *IntCmd { func (c *ClusterClient) ScriptLoad(ctx context.Context, script string) *StringCmd { cmd := NewStringCmd(ctx, "script", "load", script) - _ = c.hooks.process(ctx, cmd, func(ctx context.Context, _ Cmder) error { + _ = c.hooks.withProcessHook(ctx, cmd, func(ctx context.Context, _ Cmder) error { mu := &sync.Mutex{} err := c.ForEachShard(ctx, func(ctx context.Context, shard *Client) error { val, err := shard.ScriptLoad(ctx, script).Result() @@ -56,7 +56,7 @@ func (c *ClusterClient) ScriptLoad(ctx context.Context, script string) *StringCm func (c *ClusterClient) ScriptFlush(ctx context.Context) *StatusCmd { cmd := NewStatusCmd(ctx, "script", "flush") - _ = c.hooks.process(ctx, cmd, func(ctx context.Context, _ Cmder) error { + _ = c.hooks.withProcessHook(ctx, cmd, func(ctx context.Context, _ Cmder) error { err := c.ForEachShard(ctx, func(ctx context.Context, shard *Client) error { return shard.ScriptFlush(ctx).Err() }) @@ -82,8 +82,8 @@ func (c *ClusterClient) ScriptExists(ctx context.Context, hashes ...string) *Boo result[i] = true } - _ = c.hooks.process(ctx, cmd, func(ctx context.Context, _ Cmder) error { - mu := &sync.Mutex{} + _ = c.hooks.withProcessHook(ctx, cmd, func(ctx context.Context, _ Cmder) error { + var mu sync.Mutex err := c.ForEachShard(ctx, func(ctx context.Context, shard *Client) error { val, err := shard.ScriptExists(ctx, hashes...).Result() if err != nil { diff --git a/cluster_test.go b/cluster_test.go index 6ee7364e..2827d3fc 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -2,18 +2,22 @@ package redis_test import ( "context" + "crypto/tls" + "errors" "fmt" "net" "strconv" "strings" "sync" + "testing" "time" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" + "github.com/stretchr/testify/assert" - "github.com/go-redis/redis/v8" - "github.com/go-redis/redis/v8/internal/hashtag" + "github.com/go-redis/redis/v9" + "github.com/go-redis/redis/v9/internal/hashtag" ) type clusterScenario struct { @@ -82,8 +86,10 @@ func (s *clusterScenario) newClusterClient( func (s *clusterScenario) Close() error { for _, port := range s.ports { - processes[port].Close() - delete(processes, port) + if process, ok := processes[port]; ok { + process.Close() + delete(processes, port) + } } return nil } @@ -237,14 +243,6 @@ var _ = Describe("ClusterClient", func() { var client *redis.ClusterClient assertClusterClient := func() { - It("supports WithContext", func() { - ctx, cancel := context.WithCancel(ctx) - cancel() - - err := client.Ping(ctx).Err() - Expect(err).To(MatchError("context canceled")) - }) - It("should GET/SET/DEL", func() { err := client.Get(ctx, "A").Err() Expect(err).To(Equal(redis.Nil)) @@ -515,9 +513,7 @@ var _ = Describe("ClusterClient", func() { pipe = client.Pipeline().(*redis.Pipeline) }) - AfterEach(func() { - Expect(pipe.Close()).NotTo(HaveOccurred()) - }) + AfterEach(func() {}) assertPipeline() }) @@ -527,9 +523,7 @@ var _ = Describe("ClusterClient", func() { pipe = client.TxPipeline().(*redis.Pipeline) }) - AfterEach(func() { - Expect(pipe.Close()).NotTo(HaveOccurred()) - }) + AfterEach(func() {}) assertPipeline() }) @@ -559,6 +553,30 @@ var _ = Describe("ClusterClient", func() { }, 30*time.Second).ShouldNot(HaveOccurred()) }) + It("supports sharded PubSub", func() { + pubsub := client.SSubscribe(ctx, "mychannel") + defer pubsub.Close() + + Eventually(func() error { + _, err := client.SPublish(ctx, "mychannel", "hello").Result() + if err != nil { + return err + } + + msg, err := pubsub.ReceiveTimeout(ctx, time.Second) + if err != nil { + return err + } + + _, ok := msg.(*redis.Message) + if !ok { + return fmt.Errorf("got %T, wanted *redis.Message", msg) + } + + return nil + }, 30*time.Second).ShouldNot(HaveOccurred()) + }) + It("supports PubSub.Ping without channels", func() { pubsub := client.Subscribe(ctx) defer pubsub.Close() @@ -571,6 +589,7 @@ var _ = Describe("ClusterClient", func() { Describe("ClusterClient", func() { BeforeEach(func() { opt = redisClusterOptions() + opt.ClientName = "cluster_hi" client = cluster.newClusterClient(ctx, opt) err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error { @@ -661,6 +680,20 @@ var _ = Describe("ClusterClient", func() { Expect(assertSlotsEqual(res, wanted)).NotTo(HaveOccurred()) }) + It("should cluster client setname", func() { + err := client.ForEachShard(ctx, func(ctx context.Context, c *redis.Client) error { + return c.Ping(ctx).Err() + }) + Expect(err).NotTo(HaveOccurred()) + + _ = client.ForEachShard(ctx, func(ctx context.Context, c *redis.Client) error { + val, err := c.ClientList(ctx).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).Should(ContainSubstring("name=cluster_hi")) + return nil + }) + }) + It("should CLUSTER NODES", func() { res, err := client.ClusterNodes(ctx).Result() Expect(err).NotTo(HaveOccurred()) @@ -737,6 +770,9 @@ var _ = Describe("ClusterClient", func() { }) It("supports Process hook", func() { + testCtx, cancel := context.WithCancel(ctx) + defer cancel() + err := client.Ping(ctx).Err() Expect(err).NotTo(HaveOccurred()) @@ -748,29 +784,47 @@ var _ = Describe("ClusterClient", func() { var stack []string clusterHook := &hook{ - beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) { - Expect(cmd.String()).To(Equal("ping: ")) - stack = append(stack, "cluster.BeforeProcess") - return ctx, nil - }, - afterProcess: func(ctx context.Context, cmd redis.Cmder) error { - Expect(cmd.String()).To(Equal("ping: PONG")) - stack = append(stack, "cluster.AfterProcess") - return nil + processHook: func(hook redis.ProcessHook) redis.ProcessHook { + return func(ctx context.Context, cmd redis.Cmder) error { + select { + case <-testCtx.Done(): + return hook(ctx, cmd) + default: + } + + Expect(cmd.String()).To(Equal("ping: ")) + stack = append(stack, "cluster.BeforeProcess") + + err := hook(ctx, cmd) + + Expect(cmd.String()).To(Equal("ping: PONG")) + stack = append(stack, "cluster.AfterProcess") + + return err + } }, } client.AddHook(clusterHook) nodeHook := &hook{ - beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) { - Expect(cmd.String()).To(Equal("ping: ")) - stack = append(stack, "shard.BeforeProcess") - return ctx, nil - }, - afterProcess: func(ctx context.Context, cmd redis.Cmder) error { - Expect(cmd.String()).To(Equal("ping: PONG")) - stack = append(stack, "shard.AfterProcess") - return nil + processHook: func(hook redis.ProcessHook) redis.ProcessHook { + return func(ctx context.Context, cmd redis.Cmder) error { + select { + case <-testCtx.Done(): + return hook(ctx, cmd) + default: + } + + Expect(cmd.String()).To(Equal("ping: ")) + stack = append(stack, "shard.BeforeProcess") + + err := hook(ctx, cmd) + + Expect(cmd.String()).To(Equal("ping: PONG")) + stack = append(stack, "shard.AfterProcess") + + return err + } }, } @@ -787,11 +841,6 @@ var _ = Describe("ClusterClient", func() { "shard.AfterProcess", "cluster.AfterProcess", })) - - clusterHook.beforeProcess = nil - clusterHook.afterProcess = nil - nodeHook.beforeProcess = nil - nodeHook.afterProcess = nil }) It("supports Pipeline hook", func() { @@ -806,33 +855,39 @@ var _ = Describe("ClusterClient", func() { var stack []string client.AddHook(&hook{ - beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: ")) - stack = append(stack, "cluster.BeforeProcessPipeline") - return ctx, nil - }, - afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: PONG")) - stack = append(stack, "cluster.AfterProcessPipeline") - return nil + processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { + return func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: ")) + stack = append(stack, "cluster.BeforeProcessPipeline") + + err := hook(ctx, cmds) + + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: PONG")) + stack = append(stack, "cluster.AfterProcessPipeline") + + return err + } }, }) _ = client.ForEachShard(ctx, func(ctx context.Context, node *redis.Client) error { node.AddHook(&hook{ - beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: ")) - stack = append(stack, "shard.BeforeProcessPipeline") - return ctx, nil - }, - afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: PONG")) - stack = append(stack, "shard.AfterProcessPipeline") - return nil + processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { + return func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: ")) + stack = append(stack, "shard.BeforeProcessPipeline") + + err := hook(ctx, cmds) + + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: PONG")) + stack = append(stack, "shard.AfterProcessPipeline") + + return err + } }, }) return nil @@ -863,33 +918,39 @@ var _ = Describe("ClusterClient", func() { var stack []string client.AddHook(&hook{ - beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { - Expect(cmds).To(HaveLen(3)) - Expect(cmds[1].String()).To(Equal("ping: ")) - stack = append(stack, "cluster.BeforeProcessPipeline") - return ctx, nil - }, - afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { - Expect(cmds).To(HaveLen(3)) - Expect(cmds[1].String()).To(Equal("ping: PONG")) - stack = append(stack, "cluster.AfterProcessPipeline") - return nil + processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { + return func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(3)) + Expect(cmds[1].String()).To(Equal("ping: ")) + stack = append(stack, "cluster.BeforeProcessPipeline") + + err := hook(ctx, cmds) + + Expect(cmds).To(HaveLen(3)) + Expect(cmds[1].String()).To(Equal("ping: PONG")) + stack = append(stack, "cluster.AfterProcessPipeline") + + return err + } }, }) _ = client.ForEachShard(ctx, func(ctx context.Context, node *redis.Client) error { node.AddHook(&hook{ - beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { - Expect(cmds).To(HaveLen(3)) - Expect(cmds[1].String()).To(Equal("ping: ")) - stack = append(stack, "shard.BeforeProcessPipeline") - return ctx, nil - }, - afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { - Expect(cmds).To(HaveLen(3)) - Expect(cmds[1].String()).To(Equal("ping: PONG")) - stack = append(stack, "shard.AfterProcessPipeline") - return nil + processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { + return func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(3)) + Expect(cmds[1].String()).To(Equal("ping: ")) + stack = append(stack, "shard.BeforeProcessPipeline") + + err := hook(ctx, cmds) + + Expect(cmds).To(HaveLen(3)) + Expect(cmds[1].String()).To(Equal("ping: PONG")) + stack = append(stack, "shard.AfterProcessPipeline") + + return err + } }, }) return nil @@ -1182,16 +1243,17 @@ var _ = Describe("ClusterClient with unavailable Cluster", func() { var client *redis.ClusterClient BeforeEach(func() { - for _, node := range cluster.clients { - err := node.ClientPause(ctx, 5*time.Second).Err() - Expect(err).NotTo(HaveOccurred()) - } - opt := redisClusterOptions() opt.ReadTimeout = 250 * time.Millisecond opt.WriteTimeout = 250 * time.Millisecond opt.MaxRedirects = 1 client = cluster.newClusterClientUnstable(opt) + Expect(client.Ping(ctx).Err()).NotTo(HaveOccurred()) + + for _, node := range cluster.clients { + err := node.ClientPause(ctx, 5*time.Second).Err() + Expect(err).NotTo(HaveOccurred()) + } }) AfterEach(func() { @@ -1257,27 +1319,182 @@ var _ = Describe("ClusterClient timeout", func() { Context("read/write timeout", func() { BeforeEach(func() { opt := redisClusterOptions() - opt.ReadTimeout = 250 * time.Millisecond - opt.WriteTimeout = 250 * time.Millisecond - opt.MaxRedirects = 1 client = cluster.newClusterClient(ctx, opt) err := client.ForEachShard(ctx, func(ctx context.Context, client *redis.Client) error { - return client.ClientPause(ctx, pause).Err() + err := client.ClientPause(ctx, pause).Err() + + opt := client.Options() + opt.ReadTimeout = time.Nanosecond + opt.WriteTimeout = time.Nanosecond + + return err }) Expect(err).NotTo(HaveOccurred()) + + // Overwrite timeouts after the client is initialized. + opt.ReadTimeout = time.Nanosecond + opt.WriteTimeout = time.Nanosecond + opt.MaxRedirects = 0 }) AfterEach(func() { _ = client.ForEachShard(ctx, func(ctx context.Context, client *redis.Client) error { defer GinkgoRecover() + + opt := client.Options() + opt.ReadTimeout = time.Second + opt.WriteTimeout = time.Second + Eventually(func() error { return client.Ping(ctx).Err() }, 2*pause).ShouldNot(HaveOccurred()) return nil }) + + err := client.Close() + Expect(err).NotTo(HaveOccurred()) }) testTimeout() }) }) + +func TestParseClusterURL(t *testing.T) { + cases := []struct { + test string + url string + o *redis.ClusterOptions // expected value + err error + }{ + { + test: "ParseRedisURL", + url: "redis://localhost:123", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}}, + }, { + test: "ParseRedissURL", + url: "rediss://localhost:123", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, TLSConfig: &tls.Config{ServerName: "localhost"}}, + }, { + test: "MissingRedisPort", + url: "redis://localhost", + o: &redis.ClusterOptions{Addrs: []string{"localhost:6379"}}, + }, { + test: "MissingRedissPort", + url: "rediss://localhost", + o: &redis.ClusterOptions{Addrs: []string{"localhost:6379"}, TLSConfig: &tls.Config{ServerName: "localhost"}}, + }, { + test: "MultipleRedisURLs", + url: "redis://localhost:123?addr=localhost:1234&addr=localhost:12345", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123", "localhost:1234", "localhost:12345"}}, + }, { + test: "MultipleRedissURLs", + url: "rediss://localhost:123?addr=localhost:1234&addr=localhost:12345", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123", "localhost:1234", "localhost:12345"}, TLSConfig: &tls.Config{ServerName: "localhost"}}, + }, { + test: "OnlyPassword", + url: "redis://:bar@localhost:123", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, Password: "bar"}, + }, { + test: "OnlyUser", + url: "redis://foo@localhost:123", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, Username: "foo"}, + }, { + test: "RedisUsernamePassword", + url: "redis://foo:bar@localhost:123", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, Username: "foo", Password: "bar"}, + }, { + test: "RedissUsernamePassword", + url: "rediss://foo:bar@localhost:123?addr=localhost:1234", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123", "localhost:1234"}, Username: "foo", Password: "bar", TLSConfig: &tls.Config{ServerName: "localhost"}}, + }, { + test: "QueryParameters", + url: "redis://localhost:123?read_timeout=2&pool_fifo=true&addr=localhost:1234", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123", "localhost:1234"}, ReadTimeout: 2 * time.Second, PoolFIFO: true}, + }, { + test: "DisabledTimeout", + url: "redis://localhost:123?conn_max_idle_time=0", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, ConnMaxIdleTime: -1}, + }, { + test: "DisabledTimeoutNeg", + url: "redis://localhost:123?conn_max_idle_time=-1", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, ConnMaxIdleTime: -1}, + }, { + test: "UseDefault", + url: "redis://localhost:123?conn_max_idle_time=", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, ConnMaxIdleTime: 0}, + }, { + test: "ClientName", + url: "redis://localhost:123?client_name=cluster_hi", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, ClientName: "cluster_hi"}, + }, { + test: "UseDefaultMissing=", + url: "redis://localhost:123?conn_max_idle_time", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, ConnMaxIdleTime: 0}, + }, { + test: "InvalidQueryAddr", + url: "rediss://foo:bar@localhost:123?addr=rediss://foo:barr@localhost:1234", + err: errors.New(`redis: unable to parse addr param: rediss://foo:barr@localhost:1234`), + }, { + test: "InvalidInt", + url: "redis://localhost?pool_size=five", + err: errors.New(`redis: invalid pool_size number: strconv.Atoi: parsing "five": invalid syntax`), + }, { + test: "InvalidBool", + url: "redis://localhost?pool_fifo=yes", + err: errors.New(`redis: invalid pool_fifo boolean: expected true/false/1/0 or an empty string, got "yes"`), + }, { + test: "UnknownParam", + url: "redis://localhost?abc=123", + err: errors.New("redis: unexpected option: abc"), + }, { + test: "InvalidScheme", + url: "https://google.com", + err: errors.New("redis: invalid URL scheme: https"), + }, + } + + for i := range cases { + tc := cases[i] + t.Run(tc.test, func(t *testing.T) { + t.Parallel() + + actual, err := redis.ParseClusterURL(tc.url) + if tc.err == nil && err != nil { + t.Fatalf("unexpected error: %q", err) + return + } + if tc.err != nil && err == nil { + t.Fatalf("expected error: got %+v", actual) + return + } + if tc.err != nil && err != nil { + if tc.err.Error() != err.Error() { + t.Fatalf("got %q, expected %q", err, tc.err) + } + return + } + comprareOptions(t, actual, tc.o) + }) + } +} + +func comprareOptions(t *testing.T, actual, expected *redis.ClusterOptions) { + t.Helper() + assert.Equal(t, expected.Addrs, actual.Addrs) + assert.Equal(t, expected.TLSConfig, actual.TLSConfig) + assert.Equal(t, expected.Username, actual.Username) + assert.Equal(t, expected.Password, actual.Password) + assert.Equal(t, expected.MaxRetries, actual.MaxRetries) + assert.Equal(t, expected.MinRetryBackoff, actual.MinRetryBackoff) + assert.Equal(t, expected.MaxRetryBackoff, actual.MaxRetryBackoff) + assert.Equal(t, expected.DialTimeout, actual.DialTimeout) + assert.Equal(t, expected.ReadTimeout, actual.ReadTimeout) + assert.Equal(t, expected.WriteTimeout, actual.WriteTimeout) + assert.Equal(t, expected.PoolFIFO, actual.PoolFIFO) + assert.Equal(t, expected.PoolSize, actual.PoolSize) + assert.Equal(t, expected.MinIdleConns, actual.MinIdleConns) + assert.Equal(t, expected.ConnMaxLifetime, actual.ConnMaxLifetime) + assert.Equal(t, expected.ConnMaxIdleTime, actual.ConnMaxIdleTime) + assert.Equal(t, expected.PoolTimeout, actual.PoolTimeout) +} diff --git a/command.go b/command.go index 0079b596..8c77c24d 100644 --- a/command.go +++ b/command.go @@ -7,10 +7,10 @@ import ( "strconv" "time" - "github.com/go-redis/redis/v8/internal" - "github.com/go-redis/redis/v8/internal/hscan" - "github.com/go-redis/redis/v8/internal/proto" - "github.com/go-redis/redis/v8/internal/util" + "github.com/go-redis/redis/v9/internal" + "github.com/go-redis/redis/v9/internal/hscan" + "github.com/go-redis/redis/v9/internal/proto" + "github.com/go-redis/redis/v9/internal/util" ) type Cmder interface { @@ -20,7 +20,7 @@ type Cmder interface { String() string stringArg(int) string firstKeyPos() int8 - setFirstKeyPos(int8) + SetFirstKeyPos(int8) readTimeout() *time.Duration readReply(rd *proto.Reader) error @@ -65,7 +65,7 @@ func cmdFirstKeyPos(cmd Cmder, info *CommandInfo) int { } switch cmd.Name() { - case "eval", "evalsha": + case "eval", "evalsha", "eval_ro", "evalsha_ro": if cmd.stringArg(2) != "0" { return 3 } @@ -83,7 +83,7 @@ func cmdFirstKeyPos(cmd Cmder, info *CommandInfo) int { if info != nil { return int(info.FirstKeyPos) } - return 0 + return 1 } func cmdString(cmd Cmder, val interface{}) string { @@ -104,7 +104,7 @@ func cmdString(cmd Cmder, val interface{}) string { b = internal.AppendArg(b, val) } - return internal.String(b) + return util.BytesToString(b) } //------------------------------------------------------------------------------ @@ -151,15 +151,21 @@ func (cmd *baseCmd) stringArg(pos int) string { if pos < 0 || pos >= len(cmd.args) { return "" } - s, _ := cmd.args[pos].(string) - return s + arg := cmd.args[pos] + switch v := arg.(type) { + case string: + return v + default: + // TODO: consider using appendArg + return fmt.Sprint(v) + } } func (cmd *baseCmd) firstKeyPos() int8 { return cmd.keyPos } -func (cmd *baseCmd) setFirstKeyPos(keyPos int8) { +func (cmd *baseCmd) SetFirstKeyPos(keyPos int8) { cmd.keyPos = keyPos } @@ -458,31 +464,10 @@ func (cmd *Cmd) BoolSlice() ([]bool, error) { } func (cmd *Cmd) readReply(rd *proto.Reader) (err error) { - cmd.val, err = rd.ReadReply(sliceParser) + cmd.val, err = rd.ReadReply() return err } -// sliceParser implements proto.MultiBulkParse. -func sliceParser(rd *proto.Reader, n int64) (interface{}, error) { - vals := make([]interface{}, n) - for i := 0; i < len(vals); i++ { - v, err := rd.ReadReply(sliceParser) - if err != nil { - if err == Nil { - vals[i] = nil - continue - } - if err, ok := err.(proto.RedisError); ok { - vals[i] = err - continue - } - return nil, err - } - vals[i] = v - } - return vals, nil -} - //------------------------------------------------------------------------------ type SliceCmd struct { @@ -538,13 +523,9 @@ func (cmd *SliceCmd) Scan(dst interface{}) error { return hscan.Scan(dst, args, cmd.val) } -func (cmd *SliceCmd) readReply(rd *proto.Reader) error { - v, err := rd.ReadArrayReply(sliceParser) - if err != nil { - return err - } - cmd.val = v.([]interface{}) - return nil +func (cmd *SliceCmd) readReply(rd *proto.Reader) (err error) { + cmd.val, err = rd.ReadSlice() + return err } //------------------------------------------------------------------------------ @@ -627,7 +608,7 @@ func (cmd *IntCmd) String() string { } func (cmd *IntCmd) readReply(rd *proto.Reader) (err error) { - cmd.val, err = rd.ReadIntReply() + cmd.val, err = rd.ReadInt() return err } @@ -667,18 +648,17 @@ func (cmd *IntSliceCmd) String() string { } func (cmd *IntSliceCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.val = make([]int64, n) - for i := 0; i < len(cmd.val); i++ { - num, err := rd.ReadIntReply() - if err != nil { - return nil, err - } - cmd.val[i] = num + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + cmd.val = make([]int64, n) + for i := 0; i < len(cmd.val); i++ { + if cmd.val[i], err = rd.ReadInt(); err != nil { + return err } - return nil, nil - }) - return err + } + return nil } //------------------------------------------------------------------------------ @@ -719,7 +699,7 @@ func (cmd *DurationCmd) String() string { } func (cmd *DurationCmd) readReply(rd *proto.Reader) error { - n, err := rd.ReadIntReply() + n, err := rd.ReadInt() if err != nil { return err } @@ -770,25 +750,19 @@ func (cmd *TimeCmd) String() string { } func (cmd *TimeCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - if n != 2 { - return nil, fmt.Errorf("got %d elements, expected 2", n) - } - - sec, err := rd.ReadInt() - if err != nil { - return nil, err - } - - microsec, err := rd.ReadInt() - if err != nil { - return nil, err - } - - cmd.val = time.Unix(sec, microsec*1000) - return nil, nil - }) - return err + if err := rd.ReadFixedArrayLen(2); err != nil { + return err + } + second, err := rd.ReadInt() + if err != nil { + return err + } + microsecond, err := rd.ReadInt() + if err != nil { + return err + } + cmd.val = time.Unix(second, microsecond*1000) + return nil } //------------------------------------------------------------------------------ @@ -826,27 +800,16 @@ func (cmd *BoolCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *BoolCmd) readReply(rd *proto.Reader) error { - v, err := rd.ReadReply(nil) +func (cmd *BoolCmd) readReply(rd *proto.Reader) (err error) { + cmd.val, err = rd.ReadBool() + // `SET key value NX` returns nil when key already exists. But // `SETNX key value` returns bool (0/1). So convert nil to bool. if err == Nil { cmd.val = false - return nil - } - if err != nil { - return err - } - switch v := v.(type) { - case int64: - cmd.val = v == 1 - return nil - case string: - cmd.val = v == "OK" - return nil - default: - return fmt.Errorf("got %T, wanted int64 or string", v) + err = nil } + return err } //------------------------------------------------------------------------------ @@ -989,7 +952,7 @@ func (cmd *FloatCmd) String() string { } func (cmd *FloatCmd) readReply(rd *proto.Reader) (err error) { - cmd.val, err = rd.ReadFloatReply() + cmd.val, err = rd.ReadFloat() return err } @@ -1029,21 +992,23 @@ func (cmd *FloatSliceCmd) String() string { } func (cmd *FloatSliceCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.val = make([]float64, n) - for i := 0; i < len(cmd.val); i++ { - switch num, err := rd.ReadFloatReply(); { - case err == Nil: - cmd.val[i] = 0 - case err != nil: - return nil, err - default: - cmd.val[i] = num - } + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + + cmd.val = make([]float64, n) + for i := 0; i < len(cmd.val); i++ { + switch num, err := rd.ReadFloat(); { + case err == Nil: + cmd.val[i] = 0 + case err != nil: + return err + default: + cmd.val[i] = num } - return nil, nil - }) - return err + } + return nil } //------------------------------------------------------------------------------ @@ -1086,21 +1051,116 @@ func (cmd *StringSliceCmd) ScanSlice(container interface{}) error { } func (cmd *StringSliceCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.val = make([]string, n) - for i := 0; i < len(cmd.val); i++ { - switch s, err := rd.ReadString(); { - case err == Nil: - cmd.val[i] = "" - case err != nil: - return nil, err - default: - cmd.val[i] = s + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + cmd.val = make([]string, n) + for i := 0; i < len(cmd.val); i++ { + switch s, err := rd.ReadString(); { + case err == Nil: + cmd.val[i] = "" + case err != nil: + return err + default: + cmd.val[i] = s + } + } + return nil +} + +//------------------------------------------------------------------------------ + +type KeyValue struct { + Key string + Value string +} + +type KeyValueSliceCmd struct { + baseCmd + + val []KeyValue +} + +var _ Cmder = (*KeyValueSliceCmd)(nil) + +func NewKeyValueSliceCmd(ctx context.Context, args ...interface{}) *KeyValueSliceCmd { + return &KeyValueSliceCmd{ + baseCmd: baseCmd{ + ctx: ctx, + args: args, + }, + } +} + +func (cmd *KeyValueSliceCmd) SetVal(val []KeyValue) { + cmd.val = val +} + +func (cmd *KeyValueSliceCmd) Val() []KeyValue { + return cmd.val +} + +func (cmd *KeyValueSliceCmd) Result() ([]KeyValue, error) { + return cmd.val, cmd.err +} + +func (cmd *KeyValueSliceCmd) String() string { + return cmdString(cmd, cmd.val) +} + +// Many commands will respond to two formats: +// 1. 1) "one" +// 2. (double) 1 +// 2. 1) "two" +// 2. (double) 2 +// +// OR: +// 1. "two" +// 2. (double) 2 +// 3. "one" +// 4. (double) 1 +func (cmd *KeyValueSliceCmd) readReply(rd *proto.Reader) error { // nolint:dupl + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + + // If the n is 0, can't continue reading. + if n == 0 { + cmd.val = make([]KeyValue, 0) + return nil + } + + typ, err := rd.PeekReplyType() + if err != nil { + return err + } + array := typ == proto.RespArray + + if array { + cmd.val = make([]KeyValue, n) + } else { + cmd.val = make([]KeyValue, n/2) + } + + for i := 0; i < len(cmd.val); i++ { + if array { + if err = rd.ReadFixedArrayLen(2); err != nil { + return err } } - return nil, nil - }) - return err + + if cmd.val[i].Key, err = rd.ReadString(); err != nil { + return err + } + + if cmd.val[i].Value, err = rd.ReadString(); err != nil { + return err + } + } + + return nil } //------------------------------------------------------------------------------ @@ -1139,32 +1199,31 @@ func (cmd *BoolSliceCmd) String() string { } func (cmd *BoolSliceCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.val = make([]bool, n) - for i := 0; i < len(cmd.val); i++ { - n, err := rd.ReadIntReply() - if err != nil { - return nil, err - } - cmd.val[i] = n == 1 + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + cmd.val = make([]bool, n) + for i := 0; i < len(cmd.val); i++ { + if cmd.val[i], err = rd.ReadBool(); err != nil { + return err } - return nil, nil - }) - return err + } + return nil } //------------------------------------------------------------------------------ -type StringStringMapCmd struct { +type MapStringStringCmd struct { baseCmd val map[string]string } -var _ Cmder = (*StringStringMapCmd)(nil) +var _ Cmder = (*MapStringStringCmd)(nil) -func NewStringStringMapCmd(ctx context.Context, args ...interface{}) *StringStringMapCmd { - return &StringStringMapCmd{ +func NewMapStringStringCmd(ctx context.Context, args ...interface{}) *MapStringStringCmd { + return &MapStringStringCmd{ baseCmd: baseCmd{ ctx: ctx, args: args, @@ -1172,25 +1231,25 @@ func NewStringStringMapCmd(ctx context.Context, args ...interface{}) *StringStri } } -func (cmd *StringStringMapCmd) SetVal(val map[string]string) { - cmd.val = val -} - -func (cmd *StringStringMapCmd) Val() map[string]string { +func (cmd *MapStringStringCmd) Val() map[string]string { return cmd.val } -func (cmd *StringStringMapCmd) Result() (map[string]string, error) { +func (cmd *MapStringStringCmd) SetVal(val map[string]string) { + cmd.val = val +} + +func (cmd *MapStringStringCmd) Result() (map[string]string, error) { return cmd.val, cmd.err } -func (cmd *StringStringMapCmd) String() string { +func (cmd *MapStringStringCmd) String() string { return cmdString(cmd, cmd.val) } // Scan scans the results from the map into a destination struct. The map keys // are matched in the Redis struct fields by the `redis:"field"` tag. -func (cmd *StringStringMapCmd) Scan(dest interface{}) error { +func (cmd *MapStringStringCmd) Scan(dest interface{}) error { if cmd.err != nil { return cmd.err } @@ -1209,39 +1268,41 @@ func (cmd *StringStringMapCmd) Scan(dest interface{}) error { return nil } -func (cmd *StringStringMapCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.val = make(map[string]string, n/2) - for i := int64(0); i < n; i += 2 { - key, err := rd.ReadString() - if err != nil { - return nil, err - } +func (cmd *MapStringStringCmd) readReply(rd *proto.Reader) error { + n, err := rd.ReadMapLen() + if err != nil { + return err + } - value, err := rd.ReadString() - if err != nil { - return nil, err - } - - cmd.val[key] = value + cmd.val = make(map[string]string, n) + for i := 0; i < n; i++ { + key, err := rd.ReadString() + if err != nil { + return err } - return nil, nil - }) - return err + + value, err := rd.ReadString() + if err != nil { + return err + } + + cmd.val[key] = value + } + return nil } //------------------------------------------------------------------------------ -type StringIntMapCmd struct { +type MapStringIntCmd struct { baseCmd val map[string]int64 } -var _ Cmder = (*StringIntMapCmd)(nil) +var _ Cmder = (*MapStringIntCmd)(nil) -func NewStringIntMapCmd(ctx context.Context, args ...interface{}) *StringIntMapCmd { - return &StringIntMapCmd{ +func NewMapStringIntCmd(ctx context.Context, args ...interface{}) *MapStringIntCmd { + return &MapStringIntCmd{ baseCmd: baseCmd{ ctx: ctx, args: args, @@ -1249,41 +1310,42 @@ func NewStringIntMapCmd(ctx context.Context, args ...interface{}) *StringIntMapC } } -func (cmd *StringIntMapCmd) SetVal(val map[string]int64) { +func (cmd *MapStringIntCmd) SetVal(val map[string]int64) { cmd.val = val } -func (cmd *StringIntMapCmd) Val() map[string]int64 { +func (cmd *MapStringIntCmd) Val() map[string]int64 { return cmd.val } -func (cmd *StringIntMapCmd) Result() (map[string]int64, error) { +func (cmd *MapStringIntCmd) Result() (map[string]int64, error) { return cmd.val, cmd.err } -func (cmd *StringIntMapCmd) String() string { +func (cmd *MapStringIntCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *StringIntMapCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.val = make(map[string]int64, n/2) - for i := int64(0); i < n; i += 2 { - key, err := rd.ReadString() - if err != nil { - return nil, err - } +func (cmd *MapStringIntCmd) readReply(rd *proto.Reader) error { + n, err := rd.ReadMapLen() + if err != nil { + return err + } - n, err := rd.ReadIntReply() - if err != nil { - return nil, err - } - - cmd.val[key] = n + cmd.val = make(map[string]int64, n) + for i := 0; i < n; i++ { + key, err := rd.ReadString() + if err != nil { + return err } - return nil, nil - }) - return err + + nn, err := rd.ReadInt() + if err != nil { + return err + } + cmd.val[key] = nn + } + return nil } //------------------------------------------------------------------------------ @@ -1322,18 +1384,20 @@ func (cmd *StringStructMapCmd) String() string { } func (cmd *StringStructMapCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.val = make(map[string]struct{}, n) - for i := int64(0); i < n; i++ { - key, err := rd.ReadString() - if err != nil { - return nil, err - } - cmd.val[key] = struct{}{} + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + + cmd.val = make(map[string]struct{}, n) + for i := 0; i < n; i++ { + key, err := rd.ReadString() + if err != nil { + return err } - return nil, nil - }) - return err + cmd.val[key] = struct{}{} + } + return nil } //------------------------------------------------------------------------------ @@ -1376,8 +1440,7 @@ func (cmd *XMessageSliceCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *XMessageSliceCmd) readReply(rd *proto.Reader) error { - var err error +func (cmd *XMessageSliceCmd) readReply(rd *proto.Reader) (err error) { cmd.val, err = readXMessageSlice(rd) return err } @@ -1389,10 +1452,8 @@ func readXMessageSlice(rd *proto.Reader) ([]XMessage, error) { } msgs := make([]XMessage, n) - for i := 0; i < n; i++ { - var err error - msgs[i], err = readXMessage(rd) - if err != nil { + for i := 0; i < len(msgs); i++ { + if msgs[i], err = readXMessage(rd); err != nil { return nil, err } } @@ -1400,40 +1461,36 @@ func readXMessageSlice(rd *proto.Reader) ([]XMessage, error) { } func readXMessage(rd *proto.Reader) (XMessage, error) { - n, err := rd.ReadArrayLen() - if err != nil { + if err := rd.ReadFixedArrayLen(2); err != nil { return XMessage{}, err } - if n != 2 { - return XMessage{}, fmt.Errorf("got %d, wanted 2", n) - } id, err := rd.ReadString() if err != nil { return XMessage{}, err } - var values map[string]interface{} - - v, err := rd.ReadArrayReply(stringInterfaceMapParser) + v, err := stringInterfaceMapParser(rd) if err != nil { if err != proto.Nil { return XMessage{}, err } - } else { - values = v.(map[string]interface{}) } return XMessage{ ID: id, - Values: values, + Values: v, }, nil } -// stringInterfaceMapParser implements proto.MultiBulkParse. -func stringInterfaceMapParser(rd *proto.Reader, n int64) (interface{}, error) { - m := make(map[string]interface{}, n/2) - for i := int64(0); i < n; i += 2 { +func stringInterfaceMapParser(rd *proto.Reader) (map[string]interface{}, error) { + n, err := rd.ReadMapLen() + if err != nil { + return nil, err + } + + m := make(map[string]interface{}, n) + for i := 0; i < n; i++ { key, err := rd.ReadString() if err != nil { return nil, err @@ -1490,38 +1547,35 @@ func (cmd *XStreamSliceCmd) String() string { } func (cmd *XStreamSliceCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.val = make([]XStream, n) - for i := 0; i < len(cmd.val); i++ { - i := i - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - if n != 2 { - return nil, fmt.Errorf("got %d, wanted 2", n) - } + typ, err := rd.PeekReplyType() + if err != nil { + return err + } - stream, err := rd.ReadString() - if err != nil { - return nil, err - } - - msgs, err := readXMessageSlice(rd) - if err != nil { - return nil, err - } - - cmd.val[i] = XStream{ - Stream: stream, - Messages: msgs, - } - return nil, nil - }) - if err != nil { - return nil, err + var n int + if typ == proto.RespMap { + n, err = rd.ReadMapLen() + } else { + n, err = rd.ReadArrayLen() + } + if err != nil { + return err + } + cmd.val = make([]XStream, n) + for i := 0; i < len(cmd.val); i++ { + if typ != proto.RespMap { + if err = rd.ReadFixedArrayLen(2); err != nil { + return err } } - return nil, nil - }) - return err + if cmd.val[i].Stream, err = rd.ReadString(); err != nil { + return err + } + if cmd.val[i].Messages, err = readXMessageSlice(rd); err != nil { + return err + } + } + return nil } //------------------------------------------------------------------------------ @@ -1566,68 +1620,45 @@ func (cmd *XPendingCmd) String() string { } func (cmd *XPendingCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - if n != 4 { - return nil, fmt.Errorf("got %d, wanted 4", n) + var err error + if err = rd.ReadFixedArrayLen(4); err != nil { + return err + } + cmd.val = &XPending{} + + if cmd.val.Count, err = rd.ReadInt(); err != nil { + return err + } + + if cmd.val.Lower, err = rd.ReadString(); err != nil && err != Nil { + return err + } + + if cmd.val.Higher, err = rd.ReadString(); err != nil && err != Nil { + return err + } + + n, err := rd.ReadArrayLen() + if err != nil && err != Nil { + return err + } + cmd.val.Consumers = make(map[string]int64, n) + for i := 0; i < n; i++ { + if err = rd.ReadFixedArrayLen(2); err != nil { + return err } - count, err := rd.ReadIntReply() + consumerName, err := rd.ReadString() if err != nil { - return nil, err + return err } - - lower, err := rd.ReadString() - if err != nil && err != Nil { - return nil, err + consumerPending, err := rd.ReadInt() + if err != nil { + return err } - - higher, err := rd.ReadString() - if err != nil && err != Nil { - return nil, err - } - - cmd.val = &XPending{ - Count: count, - Lower: lower, - Higher: higher, - } - _, err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - for i := int64(0); i < n; i++ { - _, err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - if n != 2 { - return nil, fmt.Errorf("got %d, wanted 2", n) - } - - consumerName, err := rd.ReadString() - if err != nil { - return nil, err - } - - consumerPending, err := rd.ReadInt() - if err != nil { - return nil, err - } - - if cmd.val.Consumers == nil { - cmd.val.Consumers = make(map[string]int64) - } - cmd.val.Consumers[consumerName] = consumerPending - - return nil, nil - }) - if err != nil { - return nil, err - } - } - return nil, nil - }) - if err != nil && err != Nil { - return nil, err - } - - return nil, nil - }) - return err + cmd.val.Consumers[consumerName] = consumerPending + } + return nil } //------------------------------------------------------------------------------ @@ -1672,49 +1703,37 @@ func (cmd *XPendingExtCmd) String() string { } func (cmd *XPendingExtCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.val = make([]XPendingExt, 0, n) - for i := int64(0); i < n; i++ { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - if n != 4 { - return nil, fmt.Errorf("got %d, wanted 4", n) - } + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + cmd.val = make([]XPendingExt, n) - id, err := rd.ReadString() - if err != nil { - return nil, err - } - - consumer, err := rd.ReadString() - if err != nil && err != Nil { - return nil, err - } - - idle, err := rd.ReadIntReply() - if err != nil && err != Nil { - return nil, err - } - - retryCount, err := rd.ReadIntReply() - if err != nil && err != Nil { - return nil, err - } - - cmd.val = append(cmd.val, XPendingExt{ - ID: id, - Consumer: consumer, - Idle: time.Duration(idle) * time.Millisecond, - RetryCount: retryCount, - }) - return nil, nil - }) - if err != nil { - return nil, err - } + for i := 0; i < len(cmd.val); i++ { + if err = rd.ReadFixedArrayLen(4); err != nil { + return err } - return nil, nil - }) - return err + + if cmd.val[i].ID, err = rd.ReadString(); err != nil { + return err + } + + if cmd.val[i].Consumer, err = rd.ReadString(); err != nil && err != Nil { + return err + } + + idle, err := rd.ReadInt() + if err != nil && err != Nil { + return err + } + cmd.val[i].Idle = time.Duration(idle) * time.Millisecond + + if cmd.val[i].RetryCount, err = rd.ReadInt(); err != nil && err != Nil { + return err + } + } + + return nil } //------------------------------------------------------------------------------ @@ -1755,25 +1774,36 @@ func (cmd *XAutoClaimCmd) String() string { } func (cmd *XAutoClaimCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - if n != 2 { - return nil, fmt.Errorf("got %d, wanted 2", n) - } - var err error + n, err := rd.ReadArrayLen() + if err != nil { + return err + } - cmd.start, err = rd.ReadString() - if err != nil { - return nil, err - } + switch n { + case 2, // Redis 6 + 3: // Redis 7: + // ok + default: + return fmt.Errorf("redis: got %d elements in XAutoClaim reply, wanted 2/3", n) + } - cmd.val, err = readXMessageSlice(rd) - if err != nil { - return nil, err - } + cmd.start, err = rd.ReadString() + if err != nil { + return err + } - return nil, nil - }) - return err + cmd.val, err = readXMessageSlice(rd) + if err != nil { + return err + } + + if n >= 3 { + if err := rd.DiscardNext(); err != nil { + return err + } + } + + return nil } //------------------------------------------------------------------------------ @@ -1814,33 +1844,44 @@ func (cmd *XAutoClaimJustIDCmd) String() string { } func (cmd *XAutoClaimJustIDCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - if n != 2 { - return nil, fmt.Errorf("got %d, wanted 2", n) - } - var err error + n, err := rd.ReadArrayLen() + if err != nil { + return err + } - cmd.start, err = rd.ReadString() + switch n { + case 2, // Redis 6 + 3: // Redis 7: + // ok + default: + return fmt.Errorf("redis: got %d elements in XAutoClaimJustID reply, wanted 2/3", n) + } + + cmd.start, err = rd.ReadString() + if err != nil { + return err + } + + nn, err := rd.ReadArrayLen() + if err != nil { + return err + } + + cmd.val = make([]string, nn) + for i := 0; i < nn; i++ { + cmd.val[i], err = rd.ReadString() if err != nil { - return nil, err + return err } + } - nn, err := rd.ReadArrayLen() - if err != nil { - return nil, err + if n >= 3 { + if err := rd.DiscardNext(); err != nil { + return err } + } - cmd.val = make([]string, nn) - for i := 0; i < nn; i++ { - cmd.val[i], err = rd.ReadString() - if err != nil { - return nil, err - } - } - - return nil, nil - }) - return err + return nil } //------------------------------------------------------------------------------ @@ -1853,7 +1894,7 @@ type XInfoConsumersCmd struct { type XInfoConsumer struct { Name string Pending int64 - Idle int64 + Idle time.Duration } var _ Cmder = (*XInfoConsumersCmd)(nil) @@ -1888,62 +1929,41 @@ func (cmd *XInfoConsumersCmd) readReply(rd *proto.Reader) error { if err != nil { return err } - cmd.val = make([]XInfoConsumer, n) - for i := 0; i < n; i++ { - cmd.val[i], err = readXConsumerInfo(rd) - if err != nil { + for i := 0; i < len(cmd.val); i++ { + if err = rd.ReadFixedMapLen(3); err != nil { return err } + + var key string + for f := 0; f < 3; f++ { + key, err = rd.ReadString() + if err != nil { + return err + } + + switch key { + case "name": + cmd.val[i].Name, err = rd.ReadString() + case "pending": + cmd.val[i].Pending, err = rd.ReadInt() + case "idle": + var idle int64 + idle, err = rd.ReadInt() + cmd.val[i].Idle = time.Duration(idle) * time.Millisecond + default: + return fmt.Errorf("redis: unexpected content %s in XINFO CONSUMERS reply", key) + } + if err != nil { + return err + } + } } return nil } -func readXConsumerInfo(rd *proto.Reader) (XInfoConsumer, error) { - var consumer XInfoConsumer - - n, err := rd.ReadArrayLen() - if err != nil { - return consumer, err - } - if n != 6 { - return consumer, fmt.Errorf("redis: got %d elements in XINFO CONSUMERS reply, wanted 6", n) - } - - for i := 0; i < 3; i++ { - key, err := rd.ReadString() - if err != nil { - return consumer, err - } - - val, err := rd.ReadString() - if err != nil { - return consumer, err - } - - switch key { - case "name": - consumer.Name = val - case "pending": - consumer.Pending, err = strconv.ParseInt(val, 0, 64) - if err != nil { - return consumer, err - } - case "idle": - consumer.Idle, err = strconv.ParseInt(val, 0, 64) - if err != nil { - return consumer, err - } - default: - return consumer, fmt.Errorf("redis: unexpected content %s in XINFO CONSUMERS reply", key) - } - } - - return consumer, nil -} - //------------------------------------------------------------------------------ type XInfoGroupsCmd struct { @@ -1956,6 +1976,8 @@ type XInfoGroup struct { Consumers int64 Pending int64 LastDeliveredID string + EntriesRead int64 + Lag int64 } var _ Cmder = (*XInfoGroupsCmd)(nil) @@ -1990,64 +2012,63 @@ func (cmd *XInfoGroupsCmd) readReply(rd *proto.Reader) error { if err != nil { return err } - cmd.val = make([]XInfoGroup, n) - for i := 0; i < n; i++ { - cmd.val[i], err = readXGroupInfo(rd) + for i := 0; i < len(cmd.val); i++ { + group := &cmd.val[i] + + nn, err := rd.ReadMapLen() if err != nil { return err } + + var key string + for j := 0; j < nn; j++ { + key, err = rd.ReadString() + if err != nil { + return err + } + + switch key { + case "name": + group.Name, err = rd.ReadString() + if err != nil { + return err + } + case "consumers": + group.Consumers, err = rd.ReadInt() + if err != nil { + return err + } + case "pending": + group.Pending, err = rd.ReadInt() + if err != nil { + return err + } + case "last-delivered-id": + group.LastDeliveredID, err = rd.ReadString() + if err != nil { + return err + } + case "entries-read": + group.EntriesRead, err = rd.ReadInt() + if err != nil && err != Nil { + return err + } + case "lag": + group.Lag, err = rd.ReadInt() + if err != nil { + return err + } + default: + return fmt.Errorf("redis: unexpected key %q in XINFO GROUPS reply", key) + } + } } return nil } -func readXGroupInfo(rd *proto.Reader) (XInfoGroup, error) { - var group XInfoGroup - - n, err := rd.ReadArrayLen() - if err != nil { - return group, err - } - if n != 8 { - return group, fmt.Errorf("redis: got %d elements in XINFO GROUPS reply, wanted 8", n) - } - - for i := 0; i < 4; i++ { - key, err := rd.ReadString() - if err != nil { - return group, err - } - - val, err := rd.ReadString() - if err != nil { - return group, err - } - - switch key { - case "name": - group.Name = val - case "consumers": - group.Consumers, err = strconv.ParseInt(val, 0, 64) - if err != nil { - return group, err - } - case "pending": - group.Pending, err = strconv.ParseInt(val, 0, 64) - if err != nil { - return group, err - } - case "last-delivered-id": - group.LastDeliveredID = val - default: - return group, fmt.Errorf("redis: unexpected content %s in XINFO GROUPS reply", key) - } - } - - return group, nil -} - //------------------------------------------------------------------------------ type XInfoStreamCmd struct { @@ -2056,13 +2077,16 @@ type XInfoStreamCmd struct { } type XInfoStream struct { - Length int64 - RadixTreeKeys int64 - RadixTreeNodes int64 - Groups int64 - LastGeneratedID string - FirstEntry XMessage - LastEntry XMessage + Length int64 + RadixTreeKeys int64 + RadixTreeNodes int64 + Groups int64 + LastGeneratedID string + MaxDeletedEntryID string + EntriesAdded int64 + FirstEntry XMessage + LastEntry XMessage + RecordedFirstEntryID string } var _ Cmder = (*XInfoStreamCmd)(nil) @@ -2093,55 +2117,73 @@ func (cmd *XInfoStreamCmd) String() string { } func (cmd *XInfoStreamCmd) readReply(rd *proto.Reader) error { - v, err := rd.ReadReply(xStreamInfoParser) + n, err := rd.ReadMapLen() if err != nil { return err } - cmd.val = v.(*XInfoStream) - return nil -} + cmd.val = &XInfoStream{} -func xStreamInfoParser(rd *proto.Reader, n int64) (interface{}, error) { - if n != 14 { - return nil, fmt.Errorf("redis: got %d elements in XINFO STREAM reply,"+ - "wanted 14", n) - } - var info XInfoStream - for i := 0; i < 7; i++ { + for i := 0; i < n; i++ { key, err := rd.ReadString() if err != nil { - return nil, err + return err } switch key { case "length": - info.Length, err = rd.ReadIntReply() + cmd.val.Length, err = rd.ReadInt() + if err != nil { + return err + } case "radix-tree-keys": - info.RadixTreeKeys, err = rd.ReadIntReply() + cmd.val.RadixTreeKeys, err = rd.ReadInt() + if err != nil { + return err + } case "radix-tree-nodes": - info.RadixTreeNodes, err = rd.ReadIntReply() + cmd.val.RadixTreeNodes, err = rd.ReadInt() + if err != nil { + return err + } case "groups": - info.Groups, err = rd.ReadIntReply() + cmd.val.Groups, err = rd.ReadInt() + if err != nil { + return err + } case "last-generated-id": - info.LastGeneratedID, err = rd.ReadString() + cmd.val.LastGeneratedID, err = rd.ReadString() + if err != nil { + return err + } + case "max-deleted-entry-id": + cmd.val.MaxDeletedEntryID, err = rd.ReadString() + if err != nil { + return err + } + case "entries-added": + cmd.val.EntriesAdded, err = rd.ReadInt() + if err != nil { + return err + } case "first-entry": - info.FirstEntry, err = readXMessage(rd) - if err == Nil { - err = nil + cmd.val.FirstEntry, err = readXMessage(rd) + if err != nil && err != Nil { + return err } case "last-entry": - info.LastEntry, err = readXMessage(rd) - if err == Nil { - err = nil + cmd.val.LastEntry, err = readXMessage(rd) + if err != nil && err != Nil { + return err + } + case "recorded-first-entry-id": + cmd.val.RecordedFirstEntryID, err = rd.ReadString() + if err != nil { + return err } default: - return nil, fmt.Errorf("redis: unexpected content %s "+ - "in XINFO STREAM reply", key) - } - if err != nil { - return nil, err + return fmt.Errorf("redis: unexpected key %q in XINFO STREAM reply", key) } } - return &info, nil + return nil } //------------------------------------------------------------------------------ @@ -2152,17 +2194,22 @@ type XInfoStreamFullCmd struct { } type XInfoStreamFull struct { - Length int64 - RadixTreeKeys int64 - RadixTreeNodes int64 - LastGeneratedID string - Entries []XMessage - Groups []XInfoStreamGroup + Length int64 + RadixTreeKeys int64 + RadixTreeNodes int64 + LastGeneratedID string + MaxDeletedEntryID string + EntriesAdded int64 + Entries []XMessage + Groups []XInfoStreamGroup + RecordedFirstEntryID string } type XInfoStreamGroup struct { Name string LastDeliveredID string + EntriesRead int64 + Lag int64 PelCount int64 Pending []XInfoStreamGroupPending Consumers []XInfoStreamConsumer @@ -2216,18 +2263,14 @@ func (cmd *XInfoStreamFullCmd) String() string { } func (cmd *XInfoStreamFullCmd) readReply(rd *proto.Reader) error { - n, err := rd.ReadArrayLen() + n, err := rd.ReadMapLen() if err != nil { return err } - if n != 12 { - return fmt.Errorf("redis: got %d elements in XINFO STREAM FULL reply,"+ - "wanted 12", n) - } cmd.val = &XInfoStreamFull{} - for i := 0; i < 6; i++ { + for i := 0; i < n; i++ { key, err := rd.ReadString() if err != nil { return err @@ -2235,23 +2278,52 @@ func (cmd *XInfoStreamFullCmd) readReply(rd *proto.Reader) error { switch key { case "length": - cmd.val.Length, err = rd.ReadIntReply() + cmd.val.Length, err = rd.ReadInt() + if err != nil { + return err + } case "radix-tree-keys": - cmd.val.RadixTreeKeys, err = rd.ReadIntReply() + cmd.val.RadixTreeKeys, err = rd.ReadInt() + if err != nil { + return err + } case "radix-tree-nodes": - cmd.val.RadixTreeNodes, err = rd.ReadIntReply() + cmd.val.RadixTreeNodes, err = rd.ReadInt() + if err != nil { + return err + } case "last-generated-id": cmd.val.LastGeneratedID, err = rd.ReadString() + if err != nil { + return err + } + case "entries-added": + cmd.val.EntriesAdded, err = rd.ReadInt() + if err != nil { + return err + } case "entries": cmd.val.Entries, err = readXMessageSlice(rd) + if err != nil { + return err + } case "groups": cmd.val.Groups, err = readStreamGroups(rd) + if err != nil { + return err + } + case "max-deleted-entry-id": + cmd.val.MaxDeletedEntryID, err = rd.ReadString() + if err != nil { + return err + } + case "recorded-first-entry-id": + cmd.val.RecordedFirstEntryID, err = rd.ReadString() + if err != nil { + return err + } default: - return fmt.Errorf("redis: unexpected content %s "+ - "in XINFO STREAM reply", key) - } - if err != nil { - return err + return fmt.Errorf("redis: unexpected key %q in XINFO STREAM FULL reply", key) } } return nil @@ -2264,18 +2336,14 @@ func readStreamGroups(rd *proto.Reader) ([]XInfoStreamGroup, error) { } groups := make([]XInfoStreamGroup, 0, n) for i := 0; i < n; i++ { - nn, err := rd.ReadArrayLen() + nn, err := rd.ReadMapLen() if err != nil { return nil, err } - if nn != 10 { - return nil, fmt.Errorf("redis: got %d elements in XINFO STREAM FULL reply,"+ - "wanted 10", nn) - } group := XInfoStreamGroup{} - for f := 0; f < 5; f++ { + for j := 0; j < nn; j++ { key, err := rd.ReadString() if err != nil { return nil, err @@ -2284,21 +2352,41 @@ func readStreamGroups(rd *proto.Reader) ([]XInfoStreamGroup, error) { switch key { case "name": group.Name, err = rd.ReadString() + if err != nil { + return nil, err + } case "last-delivered-id": group.LastDeliveredID, err = rd.ReadString() + if err != nil { + return nil, err + } + case "entries-read": + group.EntriesRead, err = rd.ReadInt() + if err != nil { + return nil, err + } + case "lag": + group.Lag, err = rd.ReadInt() + if err != nil { + return nil, err + } case "pel-count": - group.PelCount, err = rd.ReadIntReply() + group.PelCount, err = rd.ReadInt() + if err != nil { + return nil, err + } case "pending": group.Pending, err = readXInfoStreamGroupPending(rd) + if err != nil { + return nil, err + } case "consumers": group.Consumers, err = readXInfoStreamConsumers(rd) + if err != nil { + return nil, err + } default: - return nil, fmt.Errorf("redis: unexpected content %s "+ - "in XINFO STREAM reply", key) - } - - if err != nil { - return nil, err + return nil, fmt.Errorf("redis: unexpected key %q in XINFO STREAM FULL reply", key) } } @@ -2317,14 +2405,9 @@ func readXInfoStreamGroupPending(rd *proto.Reader) ([]XInfoStreamGroupPending, e pending := make([]XInfoStreamGroupPending, 0, n) for i := 0; i < n; i++ { - nn, err := rd.ReadArrayLen() - if err != nil { + if err = rd.ReadFixedArrayLen(4); err != nil { return nil, err } - if nn != 4 { - return nil, fmt.Errorf("redis: got %d elements in XINFO STREAM FULL reply,"+ - "wanted 4", nn) - } p := XInfoStreamGroupPending{} @@ -2338,13 +2421,13 @@ func readXInfoStreamGroupPending(rd *proto.Reader) ([]XInfoStreamGroupPending, e return nil, err } - delivery, err := rd.ReadIntReply() + delivery, err := rd.ReadInt() if err != nil { return nil, err } p.DeliveryTime = time.Unix(delivery/1000, delivery%1000*int64(time.Millisecond)) - p.DeliveryCount, err = rd.ReadIntReply() + p.DeliveryCount, err = rd.ReadInt() if err != nil { return nil, err } @@ -2364,14 +2447,9 @@ func readXInfoStreamConsumers(rd *proto.Reader) ([]XInfoStreamConsumer, error) { consumers := make([]XInfoStreamConsumer, 0, n) for i := 0; i < n; i++ { - nn, err := rd.ReadArrayLen() - if err != nil { + if err = rd.ReadFixedMapLen(4); err != nil { return nil, err } - if nn != 8 { - return nil, fmt.Errorf("redis: got %d elements in XINFO STREAM FULL reply,"+ - "wanted 8", nn) - } c := XInfoStreamConsumer{} @@ -2385,13 +2463,13 @@ func readXInfoStreamConsumers(rd *proto.Reader) ([]XInfoStreamConsumer, error) { case "name": c.Name, err = rd.ReadString() case "seen-time": - seen, err := rd.ReadIntReply() + seen, err := rd.ReadInt() if err != nil { return nil, err } c.SeenTime = time.Unix(seen/1000, seen%1000*int64(time.Millisecond)) case "pel-count": - c.PelCount, err = rd.ReadIntReply() + c.PelCount, err = rd.ReadInt() case "pending": pendingNumber, err := rd.ReadArrayLen() if err != nil { @@ -2401,14 +2479,9 @@ func readXInfoStreamConsumers(rd *proto.Reader) ([]XInfoStreamConsumer, error) { c.Pending = make([]XInfoStreamConsumerPending, 0, pendingNumber) for pn := 0; pn < pendingNumber; pn++ { - nn, err := rd.ReadArrayLen() - if err != nil { + if err = rd.ReadFixedArrayLen(3); err != nil { return nil, err } - if nn != 3 { - return nil, fmt.Errorf("redis: got %d elements in XINFO STREAM reply,"+ - "wanted 3", nn) - } p := XInfoStreamConsumerPending{} @@ -2417,13 +2490,13 @@ func readXInfoStreamConsumers(rd *proto.Reader) ([]XInfoStreamConsumer, error) { return nil, err } - delivery, err := rd.ReadIntReply() + delivery, err := rd.ReadInt() if err != nil { return nil, err } p.DeliveryTime = time.Unix(delivery/1000, delivery%1000*int64(time.Millisecond)) - p.DeliveryCount, err = rd.ReadIntReply() + p.DeliveryCount, err = rd.ReadInt() if err != nil { return nil, err } @@ -2432,7 +2505,7 @@ func readXInfoStreamConsumers(rd *proto.Reader) ([]XInfoStreamConsumer, error) { } default: return nil, fmt.Errorf("redis: unexpected content %s "+ - "in XINFO STREAM reply", cKey) + "in XINFO STREAM FULL reply", cKey) } if err != nil { return nil, err @@ -2479,28 +2552,47 @@ func (cmd *ZSliceCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *ZSliceCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { +func (cmd *ZSliceCmd) readReply(rd *proto.Reader) error { // nolint:dupl + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + + // If the n is 0, can't continue reading. + if n == 0 { + cmd.val = make([]Z, 0) + return nil + } + + typ, err := rd.PeekReplyType() + if err != nil { + return err + } + array := typ == proto.RespArray + + if array { + cmd.val = make([]Z, n) + } else { cmd.val = make([]Z, n/2) - for i := 0; i < len(cmd.val); i++ { - member, err := rd.ReadString() - if err != nil { - return nil, err - } + } - score, err := rd.ReadFloatReply() - if err != nil { - return nil, err - } - - cmd.val[i] = Z{ - Member: member, - Score: score, + for i := 0; i < len(cmd.val); i++ { + if array { + if err = rd.ReadFixedArrayLen(2); err != nil { + return err } } - return nil, nil - }) - return err + + if cmd.val[i].Member, err = rd.ReadString(); err != nil { + return err + } + + if cmd.val[i].Score, err = rd.ReadFloat(); err != nil { + return err + } + } + + return nil } //------------------------------------------------------------------------------ @@ -2538,33 +2630,23 @@ func (cmd *ZWithKeyCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *ZWithKeyCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - if n != 3 { - return nil, fmt.Errorf("got %d elements, expected 3", n) - } +func (cmd *ZWithKeyCmd) readReply(rd *proto.Reader) (err error) { + if err = rd.ReadFixedArrayLen(3); err != nil { + return err + } + cmd.val = &ZWithKey{} - cmd.val = &ZWithKey{} - var err error + if cmd.val.Key, err = rd.ReadString(); err != nil { + return err + } + if cmd.val.Member, err = rd.ReadString(); err != nil { + return err + } + if cmd.val.Score, err = rd.ReadFloat(); err != nil { + return err + } - cmd.val.Key, err = rd.ReadString() - if err != nil { - return nil, err - } - - cmd.val.Member, err = rd.ReadString() - if err != nil { - return nil, err - } - - cmd.val.Score, err = rd.ReadFloatReply() - if err != nil { - return nil, err - } - - return nil, nil - }) - return err + return nil } //------------------------------------------------------------------------------ @@ -2607,9 +2689,29 @@ func (cmd *ScanCmd) String() string { return cmdString(cmd, cmd.page) } -func (cmd *ScanCmd) readReply(rd *proto.Reader) (err error) { - cmd.page, cmd.cursor, err = rd.ReadScanReply() - return err +func (cmd *ScanCmd) readReply(rd *proto.Reader) error { + if err := rd.ReadFixedArrayLen(2); err != nil { + return err + } + + cursor, err := rd.ReadInt() + if err != nil { + return err + } + cmd.cursor = uint64(cursor) + + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + cmd.page = make([]string, n) + + for i := 0; i < len(cmd.page); i++ { + if cmd.page[i], err = rd.ReadString(); err != nil { + return err + } + } + return nil } // Iterator creates a new ScanIterator. @@ -2622,8 +2724,9 @@ func (cmd *ScanCmd) Iterator() *ScanIterator { //------------------------------------------------------------------------------ type ClusterNode struct { - ID string - Addr string + ID string + Addr string + NetworkingMetadata map[string]string } type ClusterSlot struct { @@ -2666,69 +2769,95 @@ func (cmd *ClusterSlotsCmd) String() string { } func (cmd *ClusterSlotsCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.val = make([]ClusterSlot, n) - for i := 0; i < len(cmd.val); i++ { - n, err := rd.ReadArrayLen() + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + cmd.val = make([]ClusterSlot, n) + + for i := 0; i < len(cmd.val); i++ { + n, err = rd.ReadArrayLen() + if err != nil { + return err + } + if n < 2 { + return fmt.Errorf("redis: got %d elements in cluster info, expected at least 2", n) + } + + start, err := rd.ReadInt() + if err != nil { + return err + } + + end, err := rd.ReadInt() + if err != nil { + return err + } + + // subtract start and end. + nodes := make([]ClusterNode, n-2) + + for j := 0; j < len(nodes); j++ { + nn, err := rd.ReadArrayLen() if err != nil { - return nil, err + return err } - if n < 2 { - err := fmt.Errorf("redis: got %d elements in cluster info, expected at least 2", n) - return nil, err + if nn < 2 || nn > 4 { + return fmt.Errorf("got %d elements in cluster info address, expected 2, 3, or 4", n) } - start, err := rd.ReadIntReply() + ip, err := rd.ReadString() if err != nil { - return nil, err + return err } - end, err := rd.ReadIntReply() + port, err := rd.ReadString() if err != nil { - return nil, err + return err } - nodes := make([]ClusterNode, n-2) - for j := 0; j < len(nodes); j++ { - n, err := rd.ReadArrayLen() + nodes[j].Addr = net.JoinHostPort(ip, port) + + if nn >= 3 { + id, err := rd.ReadString() if err != nil { - return nil, err - } - if n != 2 && n != 3 { - err := fmt.Errorf("got %d elements in cluster info address, expected 2 or 3", n) - return nil, err + return err } + nodes[j].ID = id + } - ip, err := rd.ReadString() + if nn >= 4 { + metadataLength, err := rd.ReadMapLen() if err != nil { - return nil, err + return err } - port, err := rd.ReadString() - if err != nil { - return nil, err - } + networkingMetadata := make(map[string]string, metadataLength) - nodes[j].Addr = net.JoinHostPort(ip, port) - - if n == 3 { - id, err := rd.ReadString() + for i := 0; i < metadataLength; i++ { + key, err := rd.ReadString() if err != nil { - return nil, err + return err } - nodes[j].ID = id + value, err := rd.ReadString() + if err != nil { + return err + } + networkingMetadata[key] = value } - } - cmd.val[i] = ClusterSlot{ - Start: int(start), - End: int(end), - Nodes: nodes, + nodes[j].NetworkingMetadata = networkingMetadata } } - return nil, nil - }) - return err + + cmd.val[i] = ClusterSlot{ + Start: int(start), + End: int(end), + Nodes: nodes, + } + } + + return nil } //------------------------------------------------------------------------------ @@ -2753,6 +2882,9 @@ type GeoRadiusQuery struct { Sort string Store string StoreDist string + + // WithCoord+WithDist+WithGeoHash + withLen int } type GeoLocationCmd struct { @@ -2783,12 +2915,15 @@ func geoLocationArgs(q *GeoRadiusQuery, args ...interface{}) []interface{} { } if q.WithCoord { args = append(args, "withcoord") + q.withLen++ } if q.WithDist { args = append(args, "withdist") + q.withLen++ } if q.WithGeoHash { args = append(args, "withhash") + q.withLen++ } if q.Count > 0 { args = append(args, "count", q.Count) @@ -2824,82 +2959,55 @@ func (cmd *GeoLocationCmd) String() string { } func (cmd *GeoLocationCmd) readReply(rd *proto.Reader) error { - v, err := rd.ReadArrayReply(newGeoLocationSliceParser(cmd.q)) + n, err := rd.ReadArrayLen() if err != nil { return err } - cmd.locations = v.([]GeoLocation) + cmd.locations = make([]GeoLocation, n) + + for i := 0; i < len(cmd.locations); i++ { + // only name + if cmd.q.withLen == 0 { + if cmd.locations[i].Name, err = rd.ReadString(); err != nil { + return err + } + continue + } + + // +name + if err = rd.ReadFixedArrayLen(cmd.q.withLen + 1); err != nil { + return err + } + + if cmd.locations[i].Name, err = rd.ReadString(); err != nil { + return err + } + if cmd.q.WithDist { + if cmd.locations[i].Dist, err = rd.ReadFloat(); err != nil { + return err + } + } + if cmd.q.WithGeoHash { + if cmd.locations[i].GeoHash, err = rd.ReadInt(); err != nil { + return err + } + } + if cmd.q.WithCoord { + if err = rd.ReadFixedArrayLen(2); err != nil { + return err + } + if cmd.locations[i].Longitude, err = rd.ReadFloat(); err != nil { + return err + } + if cmd.locations[i].Latitude, err = rd.ReadFloat(); err != nil { + return err + } + } + } + return nil } -func newGeoLocationSliceParser(q *GeoRadiusQuery) proto.MultiBulkParse { - return func(rd *proto.Reader, n int64) (interface{}, error) { - locs := make([]GeoLocation, 0, n) - for i := int64(0); i < n; i++ { - v, err := rd.ReadReply(newGeoLocationParser(q)) - if err != nil { - return nil, err - } - switch vv := v.(type) { - case string: - locs = append(locs, GeoLocation{ - Name: vv, - }) - case *GeoLocation: - // TODO: avoid copying - locs = append(locs, *vv) - default: - return nil, fmt.Errorf("got %T, expected string or *GeoLocation", v) - } - } - return locs, nil - } -} - -func newGeoLocationParser(q *GeoRadiusQuery) proto.MultiBulkParse { - return func(rd *proto.Reader, n int64) (interface{}, error) { - var loc GeoLocation - var err error - - loc.Name, err = rd.ReadString() - if err != nil { - return nil, err - } - if q.WithDist { - loc.Dist, err = rd.ReadFloatReply() - if err != nil { - return nil, err - } - } - if q.WithGeoHash { - loc.GeoHash, err = rd.ReadIntReply() - if err != nil { - return nil, err - } - } - if q.WithCoord { - n, err := rd.ReadArrayLen() - if err != nil { - return nil, err - } - if n != 2 { - return nil, fmt.Errorf("got %d coordinates, expected 2", n) - } - - loc.Longitude, err = rd.ReadFloatReply() - if err != nil { - return nil, err - } - loc.Latitude, err = rd.ReadFloatReply() - if err != nil { - return nil, err - } - } - - return &loc, nil - } -} - //------------------------------------------------------------------------------ // GeoSearchQuery is used for GEOSearch/GEOSearchStore command query. @@ -3050,31 +3158,26 @@ func (cmd *GeoSearchLocationCmd) readReply(rd *proto.Reader) error { return err } if cmd.opt.WithDist { - loc.Dist, err = rd.ReadFloatReply() + loc.Dist, err = rd.ReadFloat() if err != nil { return err } } if cmd.opt.WithHash { - loc.GeoHash, err = rd.ReadIntReply() + loc.GeoHash, err = rd.ReadInt() if err != nil { return err } } if cmd.opt.WithCoord { - nn, err := rd.ReadArrayLen() + if err = rd.ReadFixedArrayLen(2); err != nil { + return err + } + loc.Longitude, err = rd.ReadFloat() if err != nil { return err } - if nn != 2 { - return fmt.Errorf("got %d coordinates, expected 2", nn) - } - - loc.Longitude, err = rd.ReadFloatReply() - if err != nil { - return err - } - loc.Latitude, err = rd.ReadFloatReply() + loc.Latitude, err = rd.ReadFloat() if err != nil { return err } @@ -3126,38 +3229,38 @@ func (cmd *GeoPosCmd) String() string { } func (cmd *GeoPosCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.val = make([]*GeoPos, n) - for i := 0; i < len(cmd.val); i++ { - i := i - _, err := rd.ReadReply(func(rd *proto.Reader, n int64) (interface{}, error) { - longitude, err := rd.ReadFloatReply() - if err != nil { - return nil, err - } + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + cmd.val = make([]*GeoPos, n) - latitude, err := rd.ReadFloatReply() - if err != nil { - return nil, err - } - - cmd.val[i] = &GeoPos{ - Longitude: longitude, - Latitude: latitude, - } - return nil, nil - }) - if err != nil { - if err == Nil { - cmd.val[i] = nil - continue - } - return nil, err + for i := 0; i < len(cmd.val); i++ { + err = rd.ReadFixedArrayLen(2) + if err != nil { + if err == Nil { + cmd.val[i] = nil + continue } + return err } - return nil, nil - }) - return err + + longitude, err := rd.ReadFloat() + if err != nil { + return err + } + latitude, err := rd.ReadFloat() + if err != nil { + return err + } + + cmd.val[i] = &GeoPos{ + Longitude: longitude, + Latitude: latitude, + } + } + + return nil } //------------------------------------------------------------------------------ @@ -3207,112 +3310,111 @@ func (cmd *CommandsInfoCmd) String() string { } func (cmd *CommandsInfoCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.val = make(map[string]*CommandInfo, n) - for i := int64(0); i < n; i++ { - v, err := rd.ReadReply(commandInfoParser) - if err != nil { - return nil, err - } - vv := v.(*CommandInfo) - cmd.val[vv.Name] = vv - } - return nil, nil - }) - return err -} - -func commandInfoParser(rd *proto.Reader, n int64) (interface{}, error) { const numArgRedis5 = 6 const numArgRedis6 = 7 + const numArgRedis7 = 10 - switch n { - case numArgRedis5, numArgRedis6: - // continue - default: - return nil, fmt.Errorf("redis: got %d elements in COMMAND reply, wanted 7", n) - } - - var cmd CommandInfo - var err error - - cmd.Name, err = rd.ReadString() + n, err := rd.ReadArrayLen() if err != nil { - return nil, err + return err } + cmd.val = make(map[string]*CommandInfo, n) - arity, err := rd.ReadIntReply() - if err != nil { - return nil, err - } - cmd.Arity = int8(arity) + for i := 0; i < n; i++ { + nn, err := rd.ReadArrayLen() + if err != nil { + return err + } - _, err = rd.ReadReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.Flags = make([]string, n) - for i := 0; i < len(cmd.Flags); i++ { + switch nn { + case numArgRedis5, numArgRedis6, numArgRedis7: + // ok + default: + return fmt.Errorf("redis: got %d elements in COMMAND reply, wanted 6/7/10", nn) + } + + cmdInfo := &CommandInfo{} + if cmdInfo.Name, err = rd.ReadString(); err != nil { + return err + } + + arity, err := rd.ReadInt() + if err != nil { + return err + } + cmdInfo.Arity = int8(arity) + + flagLen, err := rd.ReadArrayLen() + if err != nil { + return err + } + cmdInfo.Flags = make([]string, flagLen) + for f := 0; f < len(cmdInfo.Flags); f++ { switch s, err := rd.ReadString(); { case err == Nil: - cmd.Flags[i] = "" + cmdInfo.Flags[f] = "" case err != nil: - return nil, err + return err default: - cmd.Flags[i] = s + if !cmdInfo.ReadOnly && s == "readonly" { + cmdInfo.ReadOnly = true + } + cmdInfo.Flags[f] = s } } - return nil, nil - }) - if err != nil { - return nil, err - } - firstKeyPos, err := rd.ReadIntReply() - if err != nil { - return nil, err - } - cmd.FirstKeyPos = int8(firstKeyPos) - - lastKeyPos, err := rd.ReadIntReply() - if err != nil { - return nil, err - } - cmd.LastKeyPos = int8(lastKeyPos) - - stepCount, err := rd.ReadIntReply() - if err != nil { - return nil, err - } - cmd.StepCount = int8(stepCount) - - for _, flag := range cmd.Flags { - if flag == "readonly" { - cmd.ReadOnly = true - break + firstKeyPos, err := rd.ReadInt() + if err != nil { + return err } - } + cmdInfo.FirstKeyPos = int8(firstKeyPos) - if n == numArgRedis5 { - return &cmd, nil - } + lastKeyPos, err := rd.ReadInt() + if err != nil { + return err + } + cmdInfo.LastKeyPos = int8(lastKeyPos) - _, err = rd.ReadReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.ACLFlags = make([]string, n) - for i := 0; i < len(cmd.ACLFlags); i++ { - switch s, err := rd.ReadString(); { - case err == Nil: - cmd.ACLFlags[i] = "" - case err != nil: - return nil, err - default: - cmd.ACLFlags[i] = s + stepCount, err := rd.ReadInt() + if err != nil { + return err + } + cmdInfo.StepCount = int8(stepCount) + + if nn >= numArgRedis6 { + aclFlagLen, err := rd.ReadArrayLen() + if err != nil { + return err + } + cmdInfo.ACLFlags = make([]string, aclFlagLen) + for f := 0; f < len(cmdInfo.ACLFlags); f++ { + switch s, err := rd.ReadString(); { + case err == Nil: + cmdInfo.ACLFlags[f] = "" + case err != nil: + return err + default: + cmdInfo.ACLFlags[f] = s + } } } - return nil, nil - }) - if err != nil { - return nil, err + + if nn >= numArgRedis7 { + if err := rd.DiscardNext(); err != nil { + return err + } + if err := rd.DiscardNext(); err != nil { + return err + } + if err := rd.DiscardNext(); err != nil { + return err + } + } + + cmd.val[cmdInfo.Name] = cmdInfo } - return &cmd, nil + return nil } //------------------------------------------------------------------------------ @@ -3398,75 +3500,193 @@ func (cmd *SlowLogCmd) String() string { } func (cmd *SlowLogCmd) readReply(rd *proto.Reader) error { - _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { - cmd.val = make([]SlowLog, n) - for i := 0; i < len(cmd.val); i++ { - n, err := rd.ReadArrayLen() + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + cmd.val = make([]SlowLog, n) + + for i := 0; i < len(cmd.val); i++ { + nn, err := rd.ReadArrayLen() + if err != nil { + return err + } + if nn < 4 { + return fmt.Errorf("redis: got %d elements in slowlog get, expected at least 4", nn) + } + + if cmd.val[i].ID, err = rd.ReadInt(); err != nil { + return err + } + + createdAt, err := rd.ReadInt() + if err != nil { + return err + } + cmd.val[i].Time = time.Unix(createdAt, 0) + + costs, err := rd.ReadInt() + if err != nil { + return err + } + cmd.val[i].Duration = time.Duration(costs) * time.Microsecond + + cmdLen, err := rd.ReadArrayLen() + if err != nil { + return err + } + if cmdLen < 1 { + return fmt.Errorf("redis: got %d elements commands reply in slowlog get, expected at least 1", cmdLen) + } + + cmd.val[i].Args = make([]string, cmdLen) + for f := 0; f < len(cmd.val[i].Args); f++ { + cmd.val[i].Args[f], err = rd.ReadString() if err != nil { - return nil, err - } - if n < 4 { - err := fmt.Errorf("redis: got %d elements in slowlog get, expected at least 4", n) - return nil, err - } - - id, err := rd.ReadIntReply() - if err != nil { - return nil, err - } - - createdAt, err := rd.ReadIntReply() - if err != nil { - return nil, err - } - createdAtTime := time.Unix(createdAt, 0) - - costs, err := rd.ReadIntReply() - if err != nil { - return nil, err - } - costsDuration := time.Duration(costs) * time.Microsecond - - cmdLen, err := rd.ReadArrayLen() - if err != nil { - return nil, err - } - if cmdLen < 1 { - err := fmt.Errorf("redis: got %d elements commands reply in slowlog get, expected at least 1", cmdLen) - return nil, err - } - - cmdString := make([]string, cmdLen) - for i := 0; i < cmdLen; i++ { - cmdString[i], err = rd.ReadString() - if err != nil { - return nil, err - } - } - - var address, name string - for i := 4; i < n; i++ { - str, err := rd.ReadString() - if err != nil { - return nil, err - } - if i == 4 { - address = str - } else if i == 5 { - name = str - } - } - - cmd.val[i] = SlowLog{ - ID: id, - Time: createdAtTime, - Duration: costsDuration, - Args: cmdString, - ClientAddr: address, - ClientName: name, + return err } } - return nil, nil - }) - return err + + if nn >= 5 { + if cmd.val[i].ClientAddr, err = rd.ReadString(); err != nil { + return err + } + } + + if nn >= 6 { + if cmd.val[i].ClientName, err = rd.ReadString(); err != nil { + return err + } + } + } + + return nil +} + +//----------------------------------------------------------------------- + +type MapStringInterfaceCmd struct { + baseCmd + + val map[string]interface{} +} + +var _ Cmder = (*MapStringInterfaceCmd)(nil) + +func NewMapStringInterfaceCmd(ctx context.Context, args ...interface{}) *MapStringInterfaceCmd { + return &MapStringInterfaceCmd{ + baseCmd: baseCmd{ + ctx: ctx, + args: args, + }, + } +} + +func (cmd *MapStringInterfaceCmd) SetVal(val map[string]interface{}) { + cmd.val = val +} + +func (cmd *MapStringInterfaceCmd) Val() map[string]interface{} { + return cmd.val +} + +func (cmd *MapStringInterfaceCmd) Result() (map[string]interface{}, error) { + return cmd.Val(), cmd.Err() +} + +func (cmd *MapStringInterfaceCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *MapStringInterfaceCmd) readReply(rd *proto.Reader) error { + n, err := rd.ReadMapLen() + if err != nil { + return err + } + + cmd.val = make(map[string]interface{}, n) + for i := 0; i < n; i++ { + k, err := rd.ReadString() + if err != nil { + return err + } + v, err := rd.ReadReply() + if err != nil { + if err == Nil { + cmd.val[k] = Nil + continue + } + if err, ok := err.(proto.RedisError); ok { + cmd.val[k] = err + continue + } + return err + } + cmd.val[k] = v + } + return nil +} + +//----------------------------------------------------------------------- + +type MapStringStringSliceCmd struct { + baseCmd + + val []map[string]string +} + +var _ Cmder = (*MapStringStringSliceCmd)(nil) + +func NewMapStringStringSliceCmd(ctx context.Context, args ...interface{}) *MapStringStringSliceCmd { + return &MapStringStringSliceCmd{ + baseCmd: baseCmd{ + ctx: ctx, + args: args, + }, + } +} + +func (cmd *MapStringStringSliceCmd) SetVal(val []map[string]string) { + cmd.val = val +} + +func (cmd *MapStringStringSliceCmd) Val() []map[string]string { + return cmd.val +} + +func (cmd *MapStringStringSliceCmd) Result() ([]map[string]string, error) { + return cmd.Val(), cmd.Err() +} + +func (cmd *MapStringStringSliceCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *MapStringStringSliceCmd) readReply(rd *proto.Reader) error { + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + + cmd.val = make([]map[string]string, n) + for i := 0; i < n; i++ { + nn, err := rd.ReadMapLen() + if err != nil { + return err + } + cmd.val[i] = make(map[string]string, nn) + for f := 0; f < nn; f++ { + k, err := rd.ReadString() + if err != nil { + return err + } + + v, err := rd.ReadString() + if err != nil { + return err + } + cmd.val[i][k] = v + } + } + return nil } diff --git a/command_test.go b/command_test.go index 168f9f69..9af156c8 100644 --- a/command_test.go +++ b/command_test.go @@ -4,10 +4,10 @@ import ( "errors" "time" + "github.com/go-redis/redis/v9" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" - - redis "github.com/go-redis/redis/v8" ) var _ = Describe("Cmd", func() { diff --git a/commands.go b/commands.go index 2947d0ff..f58b9a36 100644 --- a/commands.go +++ b/commands.go @@ -7,14 +7,14 @@ import ( "reflect" "time" - "github.com/go-redis/redis/v8/internal" + "github.com/go-redis/redis/v9/internal" ) // KeepTTL is a Redis KEEPTTL option to keep existing TTL, it requires your redis-server version >= 6.0, // otherwise you will receive an error: (error) ERR syntax error. // For example: // -// rdb.Set(ctx, key, value, redis.KeepTTL) +// rdb.Set(ctx, key, value, redis.KeepTTL) const KeepTTL = -1 func usePrecise(dur time.Duration) bool { @@ -120,6 +120,10 @@ type Cmdable interface { Exists(ctx context.Context, keys ...string) *IntCmd Expire(ctx context.Context, key string, expiration time.Duration) *BoolCmd ExpireAt(ctx context.Context, key string, tm time.Time) *BoolCmd + ExpireNX(ctx context.Context, key string, expiration time.Duration) *BoolCmd + ExpireXX(ctx context.Context, key string, expiration time.Duration) *BoolCmd + ExpireGT(ctx context.Context, key string, expiration time.Duration) *BoolCmd + ExpireLT(ctx context.Context, key string, expiration time.Duration) *BoolCmd Keys(ctx context.Context, pattern string) *StringSliceCmd Migrate(ctx context.Context, host, port, key string, db int, timeout time.Duration) *StatusCmd Move(ctx context.Context, key string, db int) *BoolCmd @@ -136,6 +140,7 @@ type Cmdable interface { Restore(ctx context.Context, key string, ttl time.Duration, value string) *StatusCmd RestoreReplace(ctx context.Context, key string, ttl time.Duration, value string) *StatusCmd Sort(ctx context.Context, key string, sort *Sort) *StringSliceCmd + SortRO(ctx context.Context, key string, sort *Sort) *StringSliceCmd SortStore(ctx context.Context, key, store string, sort *Sort) *IntCmd SortInterfaces(ctx context.Context, key string, sort *Sort) *SliceCmd Touch(ctx context.Context, keys ...string) *IntCmd @@ -157,12 +162,12 @@ type Cmdable interface { MSetNX(ctx context.Context, values ...interface{}) *BoolCmd Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *StatusCmd SetArgs(ctx context.Context, key string, value interface{}, a SetArgs) *StatusCmd - // TODO: rename to SetEx - SetEX(ctx context.Context, key string, value interface{}, expiration time.Duration) *StatusCmd + SetEx(ctx context.Context, key string, value interface{}, expiration time.Duration) *StatusCmd SetNX(ctx context.Context, key string, value interface{}, expiration time.Duration) *BoolCmd SetXX(ctx context.Context, key string, value interface{}, expiration time.Duration) *BoolCmd SetRange(ctx context.Context, key string, offset int64, value string) *IntCmd StrLen(ctx context.Context, key string) *IntCmd + Copy(ctx context.Context, sourceKey string, destKey string, db int, replace bool) *IntCmd GetBit(ctx context.Context, key string, offset int64) *IntCmd SetBit(ctx context.Context, key string, offset int64, value int) *IntCmd @@ -183,7 +188,7 @@ type Cmdable interface { HDel(ctx context.Context, key string, fields ...string) *IntCmd HExists(ctx context.Context, key, field string) *BoolCmd HGet(ctx context.Context, key, field string) *StringCmd - HGetAll(ctx context.Context, key string) *StringStringMapCmd + HGetAll(ctx context.Context, key string) *MapStringStringCmd HIncrBy(ctx context.Context, key, field string, incr int64) *IntCmd HIncrByFloat(ctx context.Context, key, field string, incr float64) *FloatCmd HKeys(ctx context.Context, key string) *StringSliceCmd @@ -193,7 +198,8 @@ type Cmdable interface { HMSet(ctx context.Context, key string, values ...interface{}) *BoolCmd HSetNX(ctx context.Context, key, field string, value interface{}) *BoolCmd HVals(ctx context.Context, key string) *StringSliceCmd - HRandField(ctx context.Context, key string, count int, withValues bool) *StringSliceCmd + HRandField(ctx context.Context, key string, count int) *StringSliceCmd + HRandFieldWithValues(ctx context.Context, key string, count int) *KeyValueSliceCmd BLPop(ctx context.Context, timeout time.Duration, keys ...string) *StringSliceCmd BRPop(ctx context.Context, timeout time.Duration, keys ...string) *StringSliceCmd @@ -226,6 +232,7 @@ type Cmdable interface { SDiff(ctx context.Context, keys ...string) *StringSliceCmd SDiffStore(ctx context.Context, destination string, keys ...string) *IntCmd SInter(ctx context.Context, keys ...string) *StringSliceCmd + SInterCard(ctx context.Context, limit int64, keys ...string) *IntCmd SInterStore(ctx context.Context, destination string, keys ...string) *IntCmd SIsMember(ctx context.Context, key string, member interface{}) *BoolCmd SMIsMember(ctx context.Context, key string, members ...interface{}) *BoolSliceCmd @@ -263,10 +270,6 @@ type Cmdable interface { XClaimJustID(ctx context.Context, a *XClaimArgs) *StringSliceCmd XAutoClaim(ctx context.Context, a *XAutoClaimArgs) *XAutoClaimCmd XAutoClaimJustID(ctx context.Context, a *XAutoClaimArgs) *XAutoClaimJustIDCmd - - // TODO: XTrim and XTrimApprox remove in v9. - XTrim(ctx context.Context, key string, maxLen int64) *IntCmd - XTrimApprox(ctx context.Context, key string, maxLen int64) *IntCmd XTrimMaxLen(ctx context.Context, key string, maxLen int64) *IntCmd XTrimMaxLenApprox(ctx context.Context, key string, maxLen, limit int64) *IntCmd XTrimMinID(ctx context.Context, key string, minID string) *IntCmd @@ -279,33 +282,18 @@ type Cmdable interface { BZPopMax(ctx context.Context, timeout time.Duration, keys ...string) *ZWithKeyCmd BZPopMin(ctx context.Context, timeout time.Duration, keys ...string) *ZWithKeyCmd - // TODO: remove - // ZAddCh - // ZIncr - // ZAddNXCh - // ZAddXXCh - // ZIncrNX - // ZIncrXX - // in v9. - // use ZAddArgs and ZAddArgsIncr. - - ZAdd(ctx context.Context, key string, members ...*Z) *IntCmd - ZAddNX(ctx context.Context, key string, members ...*Z) *IntCmd - ZAddXX(ctx context.Context, key string, members ...*Z) *IntCmd - ZAddCh(ctx context.Context, key string, members ...*Z) *IntCmd - ZAddNXCh(ctx context.Context, key string, members ...*Z) *IntCmd - ZAddXXCh(ctx context.Context, key string, members ...*Z) *IntCmd + ZAdd(ctx context.Context, key string, members ...Z) *IntCmd + ZAddNX(ctx context.Context, key string, members ...Z) *IntCmd + ZAddXX(ctx context.Context, key string, members ...Z) *IntCmd ZAddArgs(ctx context.Context, key string, args ZAddArgs) *IntCmd ZAddArgsIncr(ctx context.Context, key string, args ZAddArgs) *FloatCmd - ZIncr(ctx context.Context, key string, member *Z) *FloatCmd - ZIncrNX(ctx context.Context, key string, member *Z) *FloatCmd - ZIncrXX(ctx context.Context, key string, member *Z) *FloatCmd ZCard(ctx context.Context, key string) *IntCmd ZCount(ctx context.Context, key, min, max string) *IntCmd ZLexCount(ctx context.Context, key, min, max string) *IntCmd ZIncrBy(ctx context.Context, key string, increment float64, member string) *FloatCmd ZInter(ctx context.Context, store *ZStore) *StringSliceCmd ZInterWithScores(ctx context.Context, store *ZStore) *ZSliceCmd + ZInterCard(ctx context.Context, limit int64, keys ...string) *IntCmd ZInterStore(ctx context.Context, destination string, store *ZStore) *IntCmd ZMScore(ctx context.Context, key string, members ...string) *FloatSliceCmd ZPopMax(ctx context.Context, key string, count ...int64) *ZSliceCmd @@ -331,9 +319,10 @@ type Cmdable interface { ZRevRank(ctx context.Context, key, member string) *IntCmd ZScore(ctx context.Context, key, member string) *FloatCmd ZUnionStore(ctx context.Context, dest string, store *ZStore) *IntCmd + ZRandMember(ctx context.Context, key string, count int) *StringSliceCmd + ZRandMemberWithScores(ctx context.Context, key string, count int) *ZSliceCmd ZUnion(ctx context.Context, store ZStore) *StringSliceCmd ZUnionWithScores(ctx context.Context, store ZStore) *ZSliceCmd - ZRandMember(ctx context.Context, key string, count int, withScores bool) *StringSliceCmd ZDiff(ctx context.Context, keys ...string) *StringSliceCmd ZDiffWithScores(ctx context.Context, keys ...string) *ZSliceCmd ZDiffStore(ctx context.Context, destination string, keys ...string) *IntCmd @@ -348,8 +337,11 @@ type Cmdable interface { ClientKillByFilter(ctx context.Context, keys ...string) *IntCmd ClientList(ctx context.Context) *StringCmd ClientPause(ctx context.Context, dur time.Duration) *BoolCmd + ClientUnpause(ctx context.Context) *BoolCmd ClientID(ctx context.Context) *IntCmd - ConfigGet(ctx context.Context, parameter string) *SliceCmd + ClientUnblock(ctx context.Context, id int64) *IntCmd + ClientUnblockWithError(ctx context.Context, id int64) *IntCmd + ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd ConfigResetStat(ctx context.Context) *StatusCmd ConfigSet(ctx context.Context, parameter, value string) *StatusCmd ConfigRewrite(ctx context.Context) *StatusCmd @@ -365,6 +357,7 @@ type Cmdable interface { ShutdownSave(ctx context.Context) *StatusCmd ShutdownNoSave(ctx context.Context) *StatusCmd SlaveOf(ctx context.Context, host, port string) *StatusCmd + SlowLogGet(ctx context.Context, num int64) *SlowLogCmd Time(ctx context.Context) *TimeCmd DebugObject(ctx context.Context, key string) *StringCmd ReadOnly(ctx context.Context) *StatusCmd @@ -373,15 +366,20 @@ type Cmdable interface { Eval(ctx context.Context, script string, keys []string, args ...interface{}) *Cmd EvalSha(ctx context.Context, sha1 string, keys []string, args ...interface{}) *Cmd + EvalRO(ctx context.Context, script string, keys []string, args ...interface{}) *Cmd + EvalShaRO(ctx context.Context, sha1 string, keys []string, args ...interface{}) *Cmd ScriptExists(ctx context.Context, hashes ...string) *BoolSliceCmd ScriptFlush(ctx context.Context) *StatusCmd ScriptKill(ctx context.Context) *StatusCmd ScriptLoad(ctx context.Context, script string) *StringCmd Publish(ctx context.Context, channel string, message interface{}) *IntCmd + SPublish(ctx context.Context, channel string, message interface{}) *IntCmd PubSubChannels(ctx context.Context, pattern string) *StringSliceCmd - PubSubNumSub(ctx context.Context, channels ...string) *StringIntMapCmd + PubSubNumSub(ctx context.Context, channels ...string) *MapStringIntCmd PubSubNumPat(ctx context.Context) *IntCmd + PubSubShardChannels(ctx context.Context, pattern string) *StringSliceCmd + PubSubShardNumSub(ctx context.Context, channels ...string) *MapStringIntCmd ClusterSlots(ctx context.Context) *ClusterSlotsCmd ClusterNodes(ctx context.Context) *StringCmd @@ -423,6 +421,7 @@ type StatefulCmdable interface { Select(ctx context.Context, index int) *StatusCmd SwapDB(ctx context.Context, index1, index2 int) *StatusCmd ClientSetName(ctx context.Context, name string) *BoolCmd + Hello(ctx context.Context, ver int, username, password, clientName string) *MapStringInterfaceCmd } var ( @@ -455,6 +454,7 @@ func (c statefulCmdable) AuthACL(ctx context.Context, username, password string) func (c cmdable) Wait(ctx context.Context, numSlaves int, timeout time.Duration) *IntCmd { cmd := NewIntCmd(ctx, "wait", numSlaves, int(timeout/time.Millisecond)) + cmd.setReadTimeout(timeout) _ = c(ctx, cmd) return cmd } @@ -478,6 +478,26 @@ func (c statefulCmdable) ClientSetName(ctx context.Context, name string) *BoolCm return cmd } +// Hello Set the resp protocol used. +func (c statefulCmdable) Hello(ctx context.Context, + ver int, username, password, clientName string) *MapStringInterfaceCmd { + args := make([]interface{}, 0, 7) + args = append(args, "hello", ver) + if password != "" { + if username != "" { + args = append(args, "auth", username, password) + } else { + args = append(args, "auth", "default", password) + } + } + if clientName != "" { + args = append(args, "setname", clientName) + } + cmd := NewMapStringInterfaceCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + //------------------------------------------------------------------------------ func (c cmdable) Command(ctx context.Context) *CommandsInfoCmd { @@ -715,8 +735,9 @@ type Sort struct { Alpha bool } -func (sort *Sort) args(key string) []interface{} { - args := []interface{}{"sort", key} +func (sort *Sort) args(command, key string) []interface{} { + args := []interface{}{command, key} + if sort.By != "" { args = append(args, "by", sort.By) } @@ -735,14 +756,20 @@ func (sort *Sort) args(key string) []interface{} { return args } +func (c cmdable) SortRO(ctx context.Context, key string, sort *Sort) *StringSliceCmd { + cmd := NewStringSliceCmd(ctx, sort.args("sort_ro", key)...) + _ = c(ctx, cmd) + return cmd +} + func (c cmdable) Sort(ctx context.Context, key string, sort *Sort) *StringSliceCmd { - cmd := NewStringSliceCmd(ctx, sort.args(key)...) + cmd := NewStringSliceCmd(ctx, sort.args("sort", key)...) _ = c(ctx, cmd) return cmd } func (c cmdable) SortStore(ctx context.Context, key, store string, sort *Sort) *IntCmd { - args := sort.args(key) + args := sort.args("sort", key) if store != "" { args = append(args, "store", store) } @@ -752,7 +779,7 @@ func (c cmdable) SortStore(ctx context.Context, key, store string, sort *Sort) * } func (c cmdable) SortInterfaces(ctx context.Context, key string, sort *Sort) *SliceCmd { - cmd := NewSliceCmd(ctx, sort.args(key)...) + cmd := NewSliceCmd(ctx, sort.args("sort", key)...) _ = c(ctx, cmd) return cmd } @@ -900,7 +927,7 @@ func (c cmdable) MSetNX(ctx context.Context, values ...interface{}) *BoolCmd { } // Set Redis `SET key value [expiration]` command. -// Use expiration for `SETEX`-like behavior. +// Use expiration for `SETEx`-like behavior. // // Zero expiration means the key has no expiration time. // KeepTTL is a Redis KEEPTTL option to keep existing TTL, it requires your redis-server version >= 6.0, @@ -976,8 +1003,8 @@ func (c cmdable) SetArgs(ctx context.Context, key string, value interface{}, a S return cmd } -// SetEX Redis `SETEX key expiration value` command. -func (c cmdable) SetEX(ctx context.Context, key string, value interface{}, expiration time.Duration) *StatusCmd { +// SetEx Redis `SETEx key expiration value` command. +func (c cmdable) SetEx(ctx context.Context, key string, value interface{}, expiration time.Duration) *StatusCmd { cmd := NewStatusCmd(ctx, "setex", key, formatSec(ctx, expiration), value) _ = c(ctx, cmd) return cmd @@ -1044,6 +1071,16 @@ func (c cmdable) StrLen(ctx context.Context, key string) *IntCmd { return cmd } +func (c cmdable) Copy(ctx context.Context, sourceKey, destKey string, db int, replace bool) *IntCmd { + args := []interface{}{"copy", sourceKey, destKey, "DB", db} + if replace { + args = append(args, "REPLACE") + } + cmd := NewIntCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + //------------------------------------------------------------------------------ func (c cmdable) GetBit(ctx context.Context, key string, offset int64) *IntCmd { @@ -1237,8 +1274,8 @@ func (c cmdable) HGet(ctx context.Context, key, field string) *StringCmd { return cmd } -func (c cmdable) HGetAll(ctx context.Context, key string) *StringStringMapCmd { - cmd := NewStringStringMapCmd(ctx, "hgetall", key) +func (c cmdable) HGetAll(ctx context.Context, key string) *MapStringStringCmd { + cmd := NewMapStringStringCmd(ctx, "hgetall", key) _ = c(ctx, cmd) return cmd } @@ -1330,16 +1367,15 @@ func (c cmdable) HVals(ctx context.Context, key string) *StringSliceCmd { } // HRandField redis-server version >= 6.2.0. -func (c cmdable) HRandField(ctx context.Context, key string, count int, withValues bool) *StringSliceCmd { - args := make([]interface{}, 0, 4) +func (c cmdable) HRandField(ctx context.Context, key string, count int) *StringSliceCmd { + cmd := NewStringSliceCmd(ctx, "hrandfield", key, count) + _ = c(ctx, cmd) + return cmd +} - // Although count=0 is meaningless, redis accepts count=0. - args = append(args, "hrandfield", key, count) - if withValues { - args = append(args, "withvalues") - } - - cmd := NewStringSliceCmd(ctx, args...) +// HRandFieldWithValues redis-server version >= 6.2.0. +func (c cmdable) HRandFieldWithValues(ctx context.Context, key string, count int) *KeyValueSliceCmd { + cmd := NewKeyValueSliceCmd(ctx, "hrandfield", key, count, "withvalues") _ = c(ctx, cmd) return cmd } @@ -1619,6 +1655,22 @@ func (c cmdable) SInter(ctx context.Context, keys ...string) *StringSliceCmd { return cmd } +func (c cmdable) SInterCard(ctx context.Context, limit int64, keys ...string) *IntCmd { + args := make([]interface{}, 4+len(keys)) + args[0] = "sintercard" + numkeys := int64(0) + for i, key := range keys { + args[2+i] = key + numkeys++ + } + args[1] = numkeys + args[2+numkeys] = "limit" + args[3+numkeys] = limit + cmd := NewIntCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + func (c cmdable) SInterStore(ctx context.Context, destination string, keys ...string) *IntCmd { args := make([]interface{}, 2+len(keys)) args[0] = "sinterstore" @@ -1742,11 +1794,7 @@ type XAddArgs struct { Stream string NoMkStream bool MaxLen int64 // MAXLEN N - - // Deprecated: use MaxLen+Approx, remove in v9. - MaxLenApprox int64 // MAXLEN ~ N - - MinID string + MinID string // Approx causes MaxLen and MinID to use "~" matcher (instead of "="). Approx bool Limit int64 @@ -1754,8 +1802,6 @@ type XAddArgs struct { Values interface{} } -// XAdd a.Limit has a bug, please confirm it and use it. -// issue: https://github.com/redis/redis/issues/9046 func (c cmdable) XAdd(ctx context.Context, a *XAddArgs) *StringCmd { args := make([]interface{}, 0, 11) args = append(args, "xadd", a.Stream) @@ -1769,9 +1815,6 @@ func (c cmdable) XAdd(ctx context.Context, a *XAddArgs) *StringCmd { } else { args = append(args, "maxlen", a.MaxLen) } - case a.MaxLenApprox > 0: - // TODO remove in v9. - args = append(args, "maxlen", "~", a.MaxLenApprox) case a.MinID != "": if a.Approx { args = append(args, "minid", "~", a.MinID) @@ -1865,7 +1908,7 @@ func (c cmdable) XRead(ctx context.Context, a *XReadArgs) *XStreamSliceCmd { if a.Block >= 0 { cmd.setReadTimeout(a.Block) } - cmd.setFirstKeyPos(keyPos) + cmd.SetFirstKeyPos(keyPos) _ = c(ctx, cmd) return cmd } @@ -1949,7 +1992,7 @@ func (c cmdable) XReadGroup(ctx context.Context, a *XReadGroupArgs) *XStreamSlic if a.Block >= 0 { cmd.setReadTimeout(a.Block) } - cmd.setFirstKeyPos(keyPos) + cmd.SetFirstKeyPos(keyPos) _ = c(ctx, cmd) return cmd } @@ -2066,8 +2109,10 @@ func xClaimArgs(a *XClaimArgs) []interface{} { // xTrim If approx is true, add the "~" parameter, otherwise it is the default "=" (redis default). // example: -// XTRIM key MAXLEN/MINID threshold LIMIT limit. -// XTRIM key MAXLEN/MINID ~ threshold LIMIT limit. +// +// XTRIM key MAXLEN/MINID threshold LIMIT limit. +// XTRIM key MAXLEN/MINID ~ threshold LIMIT limit. +// // The redis-server version is lower than 6.2, please set limit to 0. func (c cmdable) xTrim( ctx context.Context, key, strategy string, @@ -2087,38 +2132,20 @@ func (c cmdable) xTrim( return cmd } -// Deprecated: use XTrimMaxLen, remove in v9. -func (c cmdable) XTrim(ctx context.Context, key string, maxLen int64) *IntCmd { - return c.xTrim(ctx, key, "maxlen", false, maxLen, 0) -} - -// Deprecated: use XTrimMaxLenApprox, remove in v9. -func (c cmdable) XTrimApprox(ctx context.Context, key string, maxLen int64) *IntCmd { - return c.xTrim(ctx, key, "maxlen", true, maxLen, 0) -} - // XTrimMaxLen No `~` rules are used, `limit` cannot be used. // cmd: XTRIM key MAXLEN maxLen func (c cmdable) XTrimMaxLen(ctx context.Context, key string, maxLen int64) *IntCmd { return c.xTrim(ctx, key, "maxlen", false, maxLen, 0) } -// XTrimMaxLenApprox LIMIT has a bug, please confirm it and use it. -// issue: https://github.com/redis/redis/issues/9046 -// cmd: XTRIM key MAXLEN ~ maxLen LIMIT limit func (c cmdable) XTrimMaxLenApprox(ctx context.Context, key string, maxLen, limit int64) *IntCmd { return c.xTrim(ctx, key, "maxlen", true, maxLen, limit) } -// XTrimMinID No `~` rules are used, `limit` cannot be used. -// cmd: XTRIM key MINID minID func (c cmdable) XTrimMinID(ctx context.Context, key string, minID string) *IntCmd { return c.xTrim(ctx, key, "minid", false, minID, 0) } -// XTrimMinIDApprox LIMIT has a bug, please confirm it and use it. -// issue: https://github.com/redis/redis/issues/9046 -// cmd: XTRIM key MINID ~ minID LIMIT limit func (c cmdable) XTrimMinIDApprox(ctx context.Context, key string, minID string, limit int64) *IntCmd { return c.xTrim(ctx, key, "minid", true, minID, limit) } @@ -2283,116 +2310,26 @@ func (c cmdable) ZAddArgsIncr(ctx context.Context, key string, args ZAddArgs) *F return cmd } -// TODO: Compatible with v8 api, will be removed in v9. -func (c cmdable) zAdd(ctx context.Context, key string, args ZAddArgs, members ...*Z) *IntCmd { - args.Members = make([]Z, len(members)) - for i, m := range members { - args.Members[i] = *m - } - cmd := NewIntCmd(ctx, c.zAddArgs(key, args, false)...) - _ = c(ctx, cmd) - return cmd -} - // ZAdd Redis `ZADD key score member [score member ...]` command. -func (c cmdable) ZAdd(ctx context.Context, key string, members ...*Z) *IntCmd { - return c.zAdd(ctx, key, ZAddArgs{}, members...) +func (c cmdable) ZAdd(ctx context.Context, key string, members ...Z) *IntCmd { + return c.ZAddArgs(ctx, key, ZAddArgs{ + Members: members, + }) } // ZAddNX Redis `ZADD key NX score member [score member ...]` command. -func (c cmdable) ZAddNX(ctx context.Context, key string, members ...*Z) *IntCmd { - return c.zAdd(ctx, key, ZAddArgs{ - NX: true, - }, members...) +func (c cmdable) ZAddNX(ctx context.Context, key string, members ...Z) *IntCmd { + return c.ZAddArgs(ctx, key, ZAddArgs{ + NX: true, + Members: members, + }) } // ZAddXX Redis `ZADD key XX score member [score member ...]` command. -func (c cmdable) ZAddXX(ctx context.Context, key string, members ...*Z) *IntCmd { - return c.zAdd(ctx, key, ZAddArgs{ - XX: true, - }, members...) -} - -// ZAddCh Redis `ZADD key CH score member [score member ...]` command. -// Deprecated: Use -// client.ZAddArgs(ctx, ZAddArgs{ -// Ch: true, -// Members: []Z, -// }) -// remove in v9. -func (c cmdable) ZAddCh(ctx context.Context, key string, members ...*Z) *IntCmd { - return c.zAdd(ctx, key, ZAddArgs{ - Ch: true, - }, members...) -} - -// ZAddNXCh Redis `ZADD key NX CH score member [score member ...]` command. -// Deprecated: Use -// client.ZAddArgs(ctx, ZAddArgs{ -// NX: true, -// Ch: true, -// Members: []Z, -// }) -// remove in v9. -func (c cmdable) ZAddNXCh(ctx context.Context, key string, members ...*Z) *IntCmd { - return c.zAdd(ctx, key, ZAddArgs{ - NX: true, - Ch: true, - }, members...) -} - -// ZAddXXCh Redis `ZADD key XX CH score member [score member ...]` command. -// Deprecated: Use -// client.ZAddArgs(ctx, ZAddArgs{ -// XX: true, -// Ch: true, -// Members: []Z, -// }) -// remove in v9. -func (c cmdable) ZAddXXCh(ctx context.Context, key string, members ...*Z) *IntCmd { - return c.zAdd(ctx, key, ZAddArgs{ - XX: true, - Ch: true, - }, members...) -} - -// ZIncr Redis `ZADD key INCR score member` command. -// Deprecated: Use -// client.ZAddArgsIncr(ctx, ZAddArgs{ -// Members: []Z, -// }) -// remove in v9. -func (c cmdable) ZIncr(ctx context.Context, key string, member *Z) *FloatCmd { - return c.ZAddArgsIncr(ctx, key, ZAddArgs{ - Members: []Z{*member}, - }) -} - -// ZIncrNX Redis `ZADD key NX INCR score member` command. -// Deprecated: Use -// client.ZAddArgsIncr(ctx, ZAddArgs{ -// NX: true, -// Members: []Z, -// }) -// remove in v9. -func (c cmdable) ZIncrNX(ctx context.Context, key string, member *Z) *FloatCmd { - return c.ZAddArgsIncr(ctx, key, ZAddArgs{ - NX: true, - Members: []Z{*member}, - }) -} - -// ZIncrXX Redis `ZADD key XX INCR score member` command. -// Deprecated: Use -// client.ZAddArgsIncr(ctx, ZAddArgs{ -// XX: true, -// Members: []Z, -// }) -// remove in v9. -func (c cmdable) ZIncrXX(ctx context.Context, key string, member *Z) *FloatCmd { - return c.ZAddArgsIncr(ctx, key, ZAddArgs{ +func (c cmdable) ZAddXX(ctx context.Context, key string, members ...Z) *IntCmd { + return c.ZAddArgs(ctx, key, ZAddArgs{ XX: true, - Members: []Z{*member}, + Members: members, }) } @@ -2425,7 +2362,7 @@ func (c cmdable) ZInterStore(ctx context.Context, destination string, store *ZSt args = append(args, "zinterstore", destination, len(store.Keys)) args = store.appendArgs(args) cmd := NewIntCmd(ctx, args...) - cmd.setFirstKeyPos(3) + cmd.SetFirstKeyPos(3) _ = c(ctx, cmd) return cmd } @@ -2435,7 +2372,7 @@ func (c cmdable) ZInter(ctx context.Context, store *ZStore) *StringSliceCmd { args = append(args, "zinter", len(store.Keys)) args = store.appendArgs(args) cmd := NewStringSliceCmd(ctx, args...) - cmd.setFirstKeyPos(2) + cmd.SetFirstKeyPos(2) _ = c(ctx, cmd) return cmd } @@ -2446,7 +2383,23 @@ func (c cmdable) ZInterWithScores(ctx context.Context, store *ZStore) *ZSliceCmd args = store.appendArgs(args) args = append(args, "withscores") cmd := NewZSliceCmd(ctx, args...) - cmd.setFirstKeyPos(2) + cmd.SetFirstKeyPos(2) + _ = c(ctx, cmd) + return cmd +} + +func (c cmdable) ZInterCard(ctx context.Context, limit int64, keys ...string) *IntCmd { + args := make([]interface{}, 4+len(keys)) + args[0] = "zintercard" + numkeys := int64(0) + for i, key := range keys { + args[2+i] = key + numkeys++ + } + args[1] = numkeys + args[2+numkeys] = "limit" + args[3+numkeys] = limit + cmd := NewIntCmd(ctx, args...) _ = c(ctx, cmd) return cmd } @@ -2505,11 +2458,13 @@ func (c cmdable) ZPopMin(ctx context.Context, key string, count ...int64) *ZSlic // ZRangeArgs is all the options of the ZRange command. // In version> 6.2.0, you can replace the(cmd): -// ZREVRANGE, -// ZRANGEBYSCORE, -// ZREVRANGEBYSCORE, -// ZRANGEBYLEX, -// ZREVRANGEBYLEX. +// +// ZREVRANGE, +// ZRANGEBYSCORE, +// ZREVRANGEBYSCORE, +// ZRANGEBYLEX, +// ZREVRANGEBYLEX. +// // Please pay attention to your redis-server version. // // Rev, ByScore, ByLex and Offset+Count options require redis-server 6.2.0 and higher. @@ -2773,7 +2728,7 @@ func (c cmdable) ZUnion(ctx context.Context, store ZStore) *StringSliceCmd { args = append(args, "zunion", len(store.Keys)) args = store.appendArgs(args) cmd := NewStringSliceCmd(ctx, args...) - cmd.setFirstKeyPos(2) + cmd.SetFirstKeyPos(2) _ = c(ctx, cmd) return cmd } @@ -2784,7 +2739,7 @@ func (c cmdable) ZUnionWithScores(ctx context.Context, store ZStore) *ZSliceCmd args = store.appendArgs(args) args = append(args, "withscores") cmd := NewZSliceCmd(ctx, args...) - cmd.setFirstKeyPos(2) + cmd.SetFirstKeyPos(2) _ = c(ctx, cmd) return cmd } @@ -2794,22 +2749,21 @@ func (c cmdable) ZUnionStore(ctx context.Context, dest string, store *ZStore) *I args = append(args, "zunionstore", dest, len(store.Keys)) args = store.appendArgs(args) cmd := NewIntCmd(ctx, args...) - cmd.setFirstKeyPos(3) + cmd.SetFirstKeyPos(3) _ = c(ctx, cmd) return cmd } // ZRandMember redis-server version >= 6.2.0. -func (c cmdable) ZRandMember(ctx context.Context, key string, count int, withScores bool) *StringSliceCmd { - args := make([]interface{}, 0, 4) +func (c cmdable) ZRandMember(ctx context.Context, key string, count int) *StringSliceCmd { + cmd := NewStringSliceCmd(ctx, "zrandmember", key, count) + _ = c(ctx, cmd) + return cmd +} - // Although count=0 is meaningless, redis accepts count=0. - args = append(args, "zrandmember", key, count) - if withScores { - args = append(args, "withscores") - } - - cmd := NewStringSliceCmd(ctx, args...) +// ZRandMemberWithScores redis-server version >= 6.2.0. +func (c cmdable) ZRandMemberWithScores(ctx context.Context, key string, count int) *ZSliceCmd { + cmd := NewZSliceCmd(ctx, "zrandmember", key, count, "withscores") _ = c(ctx, cmd) return cmd } @@ -2824,7 +2778,7 @@ func (c cmdable) ZDiff(ctx context.Context, keys ...string) *StringSliceCmd { } cmd := NewStringSliceCmd(ctx, args...) - cmd.setFirstKeyPos(2) + cmd.SetFirstKeyPos(2) _ = c(ctx, cmd) return cmd } @@ -2840,7 +2794,7 @@ func (c cmdable) ZDiffWithScores(ctx context.Context, keys ...string) *ZSliceCmd args[len(keys)+2] = "withscores" cmd := NewZSliceCmd(ctx, args...) - cmd.setFirstKeyPos(2) + cmd.SetFirstKeyPos(2) _ = c(ctx, cmd) return cmd } @@ -2914,7 +2868,7 @@ func (c cmdable) ClientKill(ctx context.Context, ipPort string) *StatusCmd { // ClientKillByFilter is new style syntax, while the ClientKill is old // -// CLIENT KILL