diff --git a/bench_test.go b/bench_test.go index 021dba96..cba9ff5b 100644 --- a/bench_test.go +++ b/bench_test.go @@ -188,7 +188,7 @@ func BenchmarkPipeline(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - _, err := client.Pipelined(func(pipe *redis.Pipeline) error { + _, err := client.Pipelined(func(pipe redis.Pipelineable) error { pipe.Set("key", "hello", 0) pipe.Expire("key", time.Second) return nil diff --git a/cluster.go b/cluster.go index 43cbf643..5bb1b5ad 100644 --- a/cluster.go +++ b/cluster.go @@ -674,7 +674,7 @@ func (c *ClusterClient) reaper(idleCheckFrequency time.Duration) { } } -func (c *ClusterClient) Pipeline() *Pipeline { +func (c *ClusterClient) Pipeline() Pipelineable { pipe := Pipeline{ exec: c.pipelineExec, } @@ -683,7 +683,7 @@ func (c *ClusterClient) Pipeline() *Pipeline { return &pipe } -func (c *ClusterClient) Pipelined(fn func(*Pipeline) error) ([]Cmder, error) { +func (c *ClusterClient) Pipelined(fn func(Pipelineable) error) ([]Cmder, error) { return c.Pipeline().pipelined(fn) } @@ -797,7 +797,7 @@ func (c *ClusterClient) checkMovedErr(cmd Cmder, failedCmds map[*clusterNode][]C } // TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC. -func (c *ClusterClient) TxPipeline() *Pipeline { +func (c *ClusterClient) TxPipeline() Pipelineable { pipe := Pipeline{ exec: c.txPipelineExec, } @@ -806,7 +806,7 @@ func (c *ClusterClient) TxPipeline() *Pipeline { return &pipe } -func (c *ClusterClient) TxPipelined(fn func(*Pipeline) error) ([]Cmder, error) { +func (c *ClusterClient) TxPipelined(fn func(Pipelineable) error) ([]Cmder, error) { return c.TxPipeline().pipelined(fn) } diff --git a/cluster_test.go b/cluster_test.go index d9db8975..c68ff92e 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -347,7 +347,7 @@ var _ = Describe("ClusterClient", func() { return err } - _, err = tx.Pipelined(func(pipe *redis.Pipeline) error { + _, err = tx.Pipelined(func(pipe redis.Pipelineable) error { pipe.Set(key, strconv.FormatInt(n+1, 10), 0) return nil }) @@ -449,7 +449,7 @@ var _ = Describe("ClusterClient", func() { Describe("Pipeline", func() { BeforeEach(func() { - pipe = client.Pipeline() + pipe = client.Pipeline().(*redis.Pipeline) }) AfterEach(func() { @@ -461,7 +461,7 @@ var _ = Describe("ClusterClient", func() { Describe("TxPipeline", func() { BeforeEach(func() { - pipe = client.TxPipeline() + pipe = client.TxPipeline().(*redis.Pipeline) }) AfterEach(func() { @@ -544,7 +544,7 @@ var _ = Describe("ClusterClient without nodes", func() { }) It("pipeline returns an error", func() { - _, err := client.Pipelined(func(pipe *redis.Pipeline) error { + _, err := client.Pipelined(func(pipe redis.Pipelineable) error { pipe.Ping() return nil }) @@ -571,7 +571,7 @@ var _ = Describe("ClusterClient without valid nodes", func() { }) It("pipeline returns an error", func() { - _, err := client.Pipelined(func(pipe *redis.Pipeline) error { + _, err := client.Pipelined(func(pipe redis.Pipelineable) error { pipe.Ping() return nil }) @@ -594,7 +594,7 @@ var _ = Describe("ClusterClient timeout", func() { }) It("Pipeline timeouts", func() { - _, err := client.Pipelined(func(pipe *redis.Pipeline) error { + _, err := client.Pipelined(func(pipe redis.Pipelineable) error { pipe.Ping() return nil }) @@ -612,7 +612,7 @@ var _ = Describe("ClusterClient timeout", func() { It("Tx Pipeline timeouts", func() { err := client.Watch(func(tx *redis.Tx) error { - _, err := tx.Pipelined(func(pipe *redis.Pipeline) error { + _, err := tx.Pipelined(func(pipe redis.Pipelineable) error { pipe.Ping() return nil }) diff --git a/commands.go b/commands.go index fb0e9105..d1efd214 100644 --- a/commands.go +++ b/commands.go @@ -39,8 +39,8 @@ func formatSec(dur time.Duration) int64 { } type Cmdable interface { - Pipeline() *Pipeline - Pipelined(fn func(*Pipeline) error) ([]Cmder, error) + Pipeline() Pipelineable + Pipelined(fn func(Pipelineable) error) ([]Cmder, error) Echo(message interface{}) *StringCmd Ping() *StatusCmd @@ -237,6 +237,15 @@ type Cmdable interface { Command() *CommandsInfoCmd } +type StatefulCmdable interface { + Auth(password string) *StatusCmd + Select(index int) *StatusCmd + ClientSetName(name string) *BoolCmd + ClientGetName() *StringCmd + ReadOnly() *StatusCmd + ReadWrite() *StatusCmd +} + var _ Cmdable = (*Client)(nil) var _ Cmdable = (*Tx)(nil) var _ Cmdable = (*Ring)(nil) diff --git a/commands_test.go b/commands_test.go index a1231334..b9b601ab 100644 --- a/commands_test.go +++ b/commands_test.go @@ -27,7 +27,7 @@ var _ = Describe("Commands", func() { Describe("server", func() { It("should Auth", func() { - _, err := client.Pipelined(func(pipe *redis.Pipeline) error { + _, err := client.Pipelined(func(pipe redis.Pipelineable) error { pipe.Auth("password") return nil }) diff --git a/example_test.go b/example_test.go index 6c3d6bdb..098f0aec 100644 --- a/example_test.go +++ b/example_test.go @@ -159,7 +159,7 @@ func ExampleClient_Scan() { func ExampleClient_Pipelined() { var incr *redis.IntCmd - _, err := client.Pipelined(func(pipe *redis.Pipeline) error { + _, err := client.Pipelined(func(pipe redis.Pipelineable) error { incr = pipe.Incr("pipelined_counter") pipe.Expire("pipelined_counter", time.Hour) return nil @@ -187,7 +187,7 @@ func ExampleClient_Pipeline() { func ExampleClient_TxPipelined() { var incr *redis.IntCmd - _, err := client.TxPipelined(func(pipe *redis.Pipeline) error { + _, err := client.TxPipelined(func(pipe redis.Pipelineable) error { incr = pipe.Incr("tx_pipelined_counter") pipe.Expire("tx_pipelined_counter", time.Hour) return nil @@ -226,7 +226,7 @@ func ExampleClient_Watch() { return err } - _, err = tx.Pipelined(func(pipe *redis.Pipeline) error { + _, err = tx.Pipelined(func(pipe redis.Pipelineable) error { pipe.Set(key, strconv.FormatInt(n+1, 10), 0) return nil }) diff --git a/pipeline.go b/pipeline.go index 9e3ba0e6..cb441dc3 100644 --- a/pipeline.go +++ b/pipeline.go @@ -9,6 +9,17 @@ import ( type pipelineExecer func([]Cmder) error +type Pipelineable interface { + Cmdable + StatefulCmdable + Process(cmd Cmder) error + Close() error + Discard() error + discard() error + Exec() ([]Cmder, error) + pipelined(fn func(Pipelineable) error) ([]Cmder, error) +} + // Pipeline implements pipelining as described in // http://redis.io/topics/pipelining. It's safe for concurrent use // by multiple goroutines. @@ -78,7 +89,7 @@ func (c *Pipeline) Exec() ([]Cmder, error) { return cmds, c.exec(cmds) } -func (c *Pipeline) pipelined(fn func(*Pipeline) error) ([]Cmder, error) { +func (c *Pipeline) pipelined(fn func(Pipelineable) error) ([]Cmder, error) { if err := fn(c); err != nil { return nil, err } @@ -86,3 +97,11 @@ func (c *Pipeline) pipelined(fn func(*Pipeline) error) ([]Cmder, error) { _ = c.Close() return cmds, err } + +func (c *Pipeline) Pipelined(fn func(Pipelineable) error) ([]Cmder, error) { + return c.pipelined(fn) +} + +func (c *Pipeline) Pipeline() Pipelineable { + return c +} diff --git a/pipeline_test.go b/pipeline_test.go index 563bb66b..7b79dfe8 100644 --- a/pipeline_test.go +++ b/pipeline_test.go @@ -22,7 +22,7 @@ var _ = Describe("pipelining", func() { It("supports block style", func() { var get *redis.StringCmd - cmds, err := client.Pipelined(func(pipe *redis.Pipeline) error { + cmds, err := client.Pipelined(func(pipe redis.Pipelineable) error { get = pipe.Get("foo") return nil }) @@ -63,7 +63,7 @@ var _ = Describe("pipelining", func() { Describe("Pipeline", func() { BeforeEach(func() { - pipe = client.Pipeline() + pipe = client.Pipeline().(*redis.Pipeline) }) assertPipeline() diff --git a/pool_test.go b/pool_test.go index 5363c400..274b90b0 100644 --- a/pool_test.go +++ b/pool_test.go @@ -39,7 +39,7 @@ var _ = Describe("pool", func() { var ping *redis.StatusCmd err := client.Watch(func(tx *redis.Tx) error { - cmds, err := tx.Pipelined(func(pipe *redis.Pipeline) error { + cmds, err := tx.Pipelined(func(pipe redis.Pipelineable) error { ping = pipe.Ping() return nil }) diff --git a/race_test.go b/race_test.go index 0ec6a140..2198a34b 100644 --- a/race_test.go +++ b/race_test.go @@ -193,7 +193,7 @@ var _ = Describe("races", func() { num, err := strconv.ParseInt(val, 10, 64) Expect(err).NotTo(HaveOccurred()) - cmds, err := tx.Pipelined(func(pipe *redis.Pipeline) error { + cmds, err := tx.Pipelined(func(pipe redis.Pipelineable) error { pipe.Set("key", strconv.FormatInt(num+1, 10), 0) return nil }) diff --git a/redis.go b/redis.go index b71b9fc6..a84a98f3 100644 --- a/redis.go +++ b/redis.go @@ -61,7 +61,7 @@ func (c *baseClient) initConn(cn *pool.Conn) error { // Temp client for Auth and Select. client := newClient(c.opt, pool.NewSingleConnPool(cn)) - _, err := client.Pipelined(func(pipe *Pipeline) error { + _, err := client.Pipelined(func(pipe Pipelineable) error { if c.opt.Password != "" { pipe.Auth(c.opt.Password) } @@ -324,11 +324,11 @@ func (c *Client) PoolStats() *PoolStats { } } -func (c *Client) Pipelined(fn func(*Pipeline) error) ([]Cmder, error) { +func (c *Client) Pipelined(fn func(Pipelineable) error) ([]Cmder, error) { return c.Pipeline().pipelined(fn) } -func (c *Client) Pipeline() *Pipeline { +func (c *Client) Pipeline() Pipelineable { pipe := Pipeline{ exec: c.pipelineExecer(c.pipelineProcessCmds), } @@ -337,7 +337,7 @@ func (c *Client) Pipeline() *Pipeline { return &pipe } -func (c *Client) TxPipelined(fn func(*Pipeline) error) ([]Cmder, error) { +func (c *Client) TxPipelined(fn func(Pipelineable) error) ([]Cmder, error) { return c.TxPipeline().pipelined(fn) } diff --git a/redis_test.go b/redis_test.go index 2847963f..f3d2e7ce 100644 --- a/redis_test.go +++ b/redis_test.go @@ -68,7 +68,7 @@ var _ = Describe("Client", func() { It("should close Tx without closing the client", func() { err := client.Watch(func(tx *redis.Tx) error { - _, err := tx.Pipelined(func(pipe *redis.Pipeline) error { + _, err := tx.Pipelined(func(pipe redis.Pipelineable) error { pipe.Ping() return nil }) @@ -232,7 +232,7 @@ var _ = Describe("Client timeout", func() { }) It("Pipeline timeouts", func() { - _, err := client.Pipelined(func(pipe *redis.Pipeline) error { + _, err := client.Pipelined(func(pipe redis.Pipelineable) error { pipe.Ping() return nil }) @@ -263,7 +263,7 @@ var _ = Describe("Client timeout", func() { It("Tx Pipeline timeouts", func() { err := client.Watch(func(tx *redis.Tx) error { - _, err := tx.Pipelined(func(pipe *redis.Pipeline) error { + _, err := tx.Pipelined(func(pipe redis.Pipelineable) error { pipe.Ping() return nil }) diff --git a/ring.go b/ring.go index d13a33b7..df6649ef 100644 --- a/ring.go +++ b/ring.go @@ -381,7 +381,7 @@ func (c *Ring) Close() error { return firstErr } -func (c *Ring) Pipeline() *Pipeline { +func (c *Ring) Pipeline() Pipelineable { pipe := Pipeline{ exec: c.pipelineExec, } @@ -390,7 +390,7 @@ func (c *Ring) Pipeline() *Pipeline { return &pipe } -func (c *Ring) Pipelined(fn func(*Pipeline) error) ([]Cmder, error) { +func (c *Ring) Pipelined(fn func(Pipelineable) error) ([]Cmder, error) { return c.Pipeline().pipelined(fn) } diff --git a/ring_test.go b/ring_test.go index 21adab2b..deb6e35d 100644 --- a/ring_test.go +++ b/ring_test.go @@ -137,7 +137,7 @@ var _ = Describe("Redis Ring", func() { keys = append(keys, string(key)) } - _, err := ring.Pipelined(func(pipe *redis.Pipeline) error { + _, err := ring.Pipelined(func(pipe redis.Pipelineable) error { for _, key := range keys { pipe.Set(key, "value", 0).Err() } @@ -153,7 +153,7 @@ var _ = Describe("Redis Ring", func() { }) It("supports hash tags", func() { - _, err := ring.Pipelined(func(pipe *redis.Pipeline) error { + _, err := ring.Pipelined(func(pipe redis.Pipelineable) error { for i := 0; i < 100; i++ { pipe.Set(fmt.Sprintf("key%d{tag}", i), "value", 0).Err() } @@ -184,7 +184,7 @@ var _ = Describe("empty Redis Ring", func() { }) It("pipeline returns an error", func() { - _, err := ring.Pipelined(func(pipe *redis.Pipeline) error { + _, err := ring.Pipelined(func(pipe redis.Pipelineable) error { pipe.Ping() return nil }) diff --git a/tx.go b/tx.go index 762bddc9..a0b642c0 100644 --- a/tx.go +++ b/tx.go @@ -76,7 +76,7 @@ func (c *Tx) Unwatch(keys ...string) *StatusCmd { return cmd } -func (c *Tx) Pipeline() *Pipeline { +func (c *Tx) Pipeline() Pipelineable { pipe := Pipeline{ exec: c.pipelineExecer(c.txPipelineProcessCmds), } @@ -94,6 +94,6 @@ func (c *Tx) Pipeline() *Pipeline { // Exec always returns list of commands. If transaction fails // TxFailedErr is returned. Otherwise Exec returns error of the first // failed command or nil. -func (c *Tx) Pipelined(fn func(*Pipeline) error) ([]Cmder, error) { +func (c *Tx) Pipelined(fn func(Pipelineable) error) ([]Cmder, error) { return c.Pipeline().pipelined(fn) } diff --git a/tx_test.go b/tx_test.go index 583ea057..127ceeea 100644 --- a/tx_test.go +++ b/tx_test.go @@ -33,7 +33,7 @@ var _ = Describe("Tx", func() { return err } - _, err = tx.Pipelined(func(pipe *redis.Pipeline) error { + _, err = tx.Pipelined(func(pipe redis.Pipelineable) error { pipe.Set(key, strconv.FormatInt(n+1, 10), 0) return nil }) @@ -65,7 +65,7 @@ var _ = Describe("Tx", func() { It("should discard", func() { err := client.Watch(func(tx *redis.Tx) error { - cmds, err := tx.Pipelined(func(pipe *redis.Pipeline) error { + cmds, err := tx.Pipelined(func(pipe redis.Pipelineable) error { pipe.Set("key1", "hello1", 0) pipe.Discard() pipe.Set("key2", "hello2", 0) @@ -88,7 +88,7 @@ var _ = Describe("Tx", func() { It("returns an error when there are no commands", func() { err := client.Watch(func(tx *redis.Tx) error { - _, err := tx.Pipelined(func(*redis.Pipeline) error { return nil }) + _, err := tx.Pipelined(func(redis.Pipelineable) error { return nil }) return err }) Expect(err).To(MatchError("redis: pipeline is empty")) @@ -102,7 +102,7 @@ var _ = Describe("Tx", func() { const N = 20000 err := client.Watch(func(tx *redis.Tx) error { - cmds, err := tx.Pipelined(func(pipe *redis.Pipeline) error { + cmds, err := tx.Pipelined(func(pipe redis.Pipelineable) error { for i := 0; i < N; i++ { pipe.Incr("key") } @@ -133,7 +133,7 @@ var _ = Describe("Tx", func() { do := func() error { err := client.Watch(func(tx *redis.Tx) error { - _, err := tx.Pipelined(func(pipe *redis.Pipeline) error { + _, err := tx.Pipelined(func(pipe redis.Pipelineable) error { pipe.Ping() return nil })