From 092698ecd376299b99184d4636abffb78e55d482 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Mon, 2 May 2016 15:54:15 +0300 Subject: [PATCH] Tweak transaction API. --- cluster.go | 8 +-- cluster_test.go | 23 ++++---- example_test.go | 23 ++++---- pool_test.go | 17 +++--- race_test.go | 30 +++++----- redis_test.go | 24 +++----- tx.go | 27 +++++---- tx_test.go | 147 +++++++++++++++++++++++------------------------- 8 files changed, 138 insertions(+), 161 deletions(-) diff --git a/cluster.go b/cluster.go index 2fb0491..5d56a4f 100644 --- a/cluster.go +++ b/cluster.go @@ -59,15 +59,13 @@ func (c *ClusterClient) getClients() map[string]*Client { return clients } -// Watch creates new transaction and marks the keys to be watched -// for conditional execution of a transaction. -func (c *ClusterClient) Watch(keys ...string) (*Tx, error) { +func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error { addr := c.slotMasterAddr(hashtag.Slot(keys[0])) client, err := c.getClient(addr) if err != nil { - return nil, err + return err } - return client.Watch(keys...) + return client.Watch(fn, keys...) } // PoolStats returns accumulated connection pool stats. diff --git a/cluster_test.go b/cluster_test.go index 3a4ddb5..69ffa7a 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -383,21 +383,18 @@ var _ = Describe("Cluster", func() { // Transactionally increments key using GET and SET commands. incr = func(key string) error { - tx, err := client.Watch(key) - if err != nil { - return err - } - defer tx.Close() + err := client.Watch(func(tx *redis.Tx) error { + n, err := tx.Get(key).Int64() + if err != nil && err != redis.Nil { + return err + } - n, err := tx.Get(key).Int64() - if err != nil && err != redis.Nil { + _, err = tx.MultiExec(func() error { + tx.Set(key, strconv.FormatInt(n+1, 10), 0) + return nil + }) return err - } - - _, err = tx.Exec(func() error { - tx.Set(key, strconv.FormatInt(n+1, 10), 0) - return nil - }) + }, key) if err == redis.TxFailedErr { return incr(key) } diff --git a/example_test.go b/example_test.go index 4ab22f1..71d06c4 100644 --- a/example_test.go +++ b/example_test.go @@ -184,21 +184,18 @@ func ExampleClient_Watch() { // Transactionally increments key using GET and SET commands. incr = func(key string) error { - tx, err := client.Watch(key) - if err != nil { - return err - } - defer tx.Close() + err := client.Watch(func(tx *redis.Tx) error { + n, err := tx.Get(key).Int64() + if err != nil && err != redis.Nil { + return err + } - n, err := tx.Get(key).Int64() - if err != nil && err != redis.Nil { + _, err = tx.MultiExec(func() error { + tx.Set(key, strconv.FormatInt(n+1, 10), 0) + return nil + }) return err - } - - _, err = tx.Exec(func() error { - tx.Set(key, strconv.FormatInt(n+1, 10), 0) - return nil - }) + }, key) if err == redis.TxFailedErr { return incr(key) } diff --git a/pool_test.go b/pool_test.go index fcaf243..31bf968 100644 --- a/pool_test.go +++ b/pool_test.go @@ -37,18 +37,19 @@ var _ = Describe("pool", func() { perform(1000, func(id int) { var ping *redis.StatusCmd - tx, err := client.Watch() - Expect(err).NotTo(HaveOccurred()) - - cmds, err := tx.Exec(func() error { - ping = tx.Ping() - return nil + err := client.Watch(func(tx *redis.Tx) error { + cmds, err := tx.MultiExec(func() error { + ping = tx.Ping() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(cmds).To(HaveLen(1)) + return err }) Expect(err).NotTo(HaveOccurred()) - Expect(cmds).To(HaveLen(1)) + Expect(ping.Err()).NotTo(HaveOccurred()) Expect(ping.Val()).To(Equal("PONG")) - Expect(tx.Close()).NotTo(HaveOccurred()) }) pool := client.Pool() diff --git a/race_test.go b/race_test.go index c2d3dff..f654f87 100644 --- a/race_test.go +++ b/race_test.go @@ -215,30 +215,26 @@ var _ = Describe("races", func() { perform(C, func(id int) { for i := 0; i < N; i++ { - tx, err := client.Watch("key") - Expect(err).NotTo(HaveOccurred()) + err := client.Watch(func(tx *redis.Tx) error { + val, err := tx.Get("key").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).NotTo(Equal(redis.Nil)) - val, err := tx.Get("key").Result() - Expect(err).NotTo(HaveOccurred()) - Expect(val).NotTo(Equal(redis.Nil)) + num, err := strconv.ParseInt(val, 10, 64) + Expect(err).NotTo(HaveOccurred()) - num, err := strconv.ParseInt(val, 10, 64) - Expect(err).NotTo(HaveOccurred()) - - cmds, err := tx.Exec(func() error { - tx.Set("key", strconv.FormatInt(num+1, 10), 0) - return nil - }) + cmds, err := tx.MultiExec(func() error { + tx.Set("key", strconv.FormatInt(num+1, 10), 0) + return nil + }) + Expect(cmds).To(HaveLen(1)) + return err + }, "key") if err == redis.TxFailedErr { i-- continue } Expect(err).NotTo(HaveOccurred()) - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].Err()).NotTo(HaveOccurred()) - - err = tx.Close() - Expect(err).NotTo(HaveOccurred()) } }) diff --git a/redis_test.go b/redis_test.go index a11b103..bbb00e0 100644 --- a/redis_test.go +++ b/redis_test.go @@ -65,16 +65,15 @@ var _ = Describe("Client", func() { Expect(client.Ping().Err()).NotTo(HaveOccurred()) }) - It("should close multi without closing the client", func() { - tx, err := client.Watch() - Expect(err).NotTo(HaveOccurred()) - Expect(tx.Close()).NotTo(HaveOccurred()) - - _, err = tx.Exec(func() error { - tx.Ping() - return nil + It("should close Tx without closing the client", func() { + err := client.Watch(func(tx *redis.Tx) error { + _, err := tx.MultiExec(func() error { + tx.Ping() + return nil + }) + return err }) - Expect(err).To(MatchError("redis: client is closed")) + Expect(err).NotTo(HaveOccurred()) Expect(client.Ping().Err()).NotTo(HaveOccurred()) }) @@ -96,13 +95,6 @@ var _ = Describe("Client", func() { Expect(pubsub.Close()).NotTo(HaveOccurred()) }) - It("should close multi when client is closed", func() { - tx, err := client.Watch() - Expect(err).NotTo(HaveOccurred()) - Expect(client.Close()).NotTo(HaveOccurred()) - Expect(tx.Close()).NotTo(HaveOccurred()) - }) - It("should close pipeline when client is closed", func() { pipeline := client.Pipeline() Expect(client.Close()).NotTo(HaveOccurred()) diff --git a/tx.go b/tx.go index 5274af4..34e828b 100644 --- a/tx.go +++ b/tx.go @@ -34,17 +34,19 @@ func (c *Client) newTx() *Tx { return tx } -// Watch creates new transaction and marks the keys to be watched -// for conditional execution of a transaction. -func (c *Client) Watch(keys ...string) (*Tx, error) { +func (c *Client) Watch(fn func(*Tx) error, keys ...string) error { tx := c.newTx() if len(keys) > 0 { if err := tx.Watch(keys...).Err(); err != nil { - tx.Close() - return nil, err + tx.close() + return err } } - return tx, nil + retErr := fn(tx) + if err := tx.close(); err != nil && retErr == nil { + retErr = err + } + return retErr } func (tx *Tx) process(cmd Cmder) { @@ -55,8 +57,11 @@ func (tx *Tx) process(cmd Cmder) { } } -// Close closes the transaction, releasing any open resources. -func (tx *Tx) Close() error { +// close closes the transaction, releasing any open resources. +func (tx *Tx) close() error { + if tx.closed { + return nil + } tx.closed = true if err := tx.Unwatch().Err(); err != nil { internal.Logf("Unwatch failed: %s", err) @@ -98,7 +103,7 @@ func (tx *Tx) Discard() error { return nil } -// Exec executes all previously queued commands in a transaction +// MultiExec executes all previously queued commands in a transaction // and restores the connection state to normal. // // When using WATCH, EXEC will execute commands only if the watched keys @@ -107,13 +112,13 @@ func (tx *Tx) Discard() error { // Exec always returns list of commands. If transaction fails // TxFailedErr is returned. Otherwise Exec returns error of the first // failed command or nil. -func (tx *Tx) Exec(f func() error) ([]Cmder, error) { +func (tx *Tx) MultiExec(fn func() error) ([]Cmder, error) { if tx.closed { return nil, pool.ErrClosed } tx.cmds = []Cmder{NewStatusCmd("MULTI")} - if err := f(); err != nil { + if err := fn(); err != nil { return nil, err } tx.cmds = append(tx.cmds, NewSliceCmd("EXEC")) diff --git a/tx_test.go b/tx_test.go index 5df4b6a..7ff84dd 100644 --- a/tx_test.go +++ b/tx_test.go @@ -27,21 +27,18 @@ var _ = Describe("Tx", func() { // Transactionally increments key using GET and SET commands. incr = func(key string) error { - tx, err := client.Watch(key) - if err != nil { - return err - } - defer tx.Close() + err := client.Watch(func(tx *redis.Tx) error { + n, err := tx.Get(key).Int64() + if err != nil && err != redis.Nil { + return err + } - n, err := tx.Get(key).Int64() - if err != nil && err != redis.Nil { + _, err = tx.MultiExec(func() error { + tx.Set(key, strconv.FormatInt(n+1, 10), 0) + return nil + }) return err - } - - _, err = tx.Exec(func() error { - tx.Set(key, strconv.FormatInt(n+1, 10), 0) - return nil - }) + }, key) if err == redis.TxFailedErr { return incr(key) } @@ -67,20 +64,18 @@ var _ = Describe("Tx", func() { }) It("should discard", func() { - tx, err := client.Watch("key1", "key2") + err := client.Watch(func(tx *redis.Tx) error { + cmds, err := tx.MultiExec(func() error { + tx.Set("key1", "hello1", 0) + tx.Discard() + tx.Set("key2", "hello2", 0) + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(cmds).To(HaveLen(1)) + return err + }, "key1", "key2") Expect(err).NotTo(HaveOccurred()) - defer func() { - Expect(tx.Close()).NotTo(HaveOccurred()) - }() - - cmds, err := tx.Exec(func() error { - tx.Set("key1", "hello1", 0) - tx.Discard() - tx.Set("key2", "hello2", 0) - return nil - }) - Expect(err).NotTo(HaveOccurred()) - Expect(cmds).To(HaveLen(1)) get := client.Get("key1") Expect(get.Err()).To(Equal(redis.Nil)) @@ -92,43 +87,41 @@ var _ = Describe("Tx", func() { }) It("should exec empty", func() { - tx, err := client.Watch() + err := client.Watch(func(tx *redis.Tx) error { + cmds, err := tx.MultiExec(func() error { return nil }) + Expect(err).NotTo(HaveOccurred()) + Expect(cmds).To(HaveLen(0)) + return err + }) Expect(err).NotTo(HaveOccurred()) - defer func() { - Expect(tx.Close()).NotTo(HaveOccurred()) - }() - cmds, err := tx.Exec(func() error { return nil }) + v, err := client.Ping().Result() Expect(err).NotTo(HaveOccurred()) - Expect(cmds).To(HaveLen(0)) - - ping := tx.Ping() - Expect(ping.Err()).NotTo(HaveOccurred()) - Expect(ping.Val()).To(Equal("PONG")) + Expect(v).To(Equal("PONG")) }) It("should exec bulks", func() { - tx, err := client.Watch() - Expect(err).NotTo(HaveOccurred()) - defer func() { - Expect(tx.Close()).NotTo(HaveOccurred()) - }() + const N = 20000 - cmds, err := tx.Exec(func() error { - for i := int64(0); i < 20000; i++ { - tx.Incr("key") + err := client.Watch(func(tx *redis.Tx) error { + cmds, err := tx.MultiExec(func() error { + for i := 0; i < N; i++ { + tx.Incr("key") + } + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(len(cmds)).To(Equal(N)) + for _, cmd := range cmds { + Expect(cmd.Err()).NotTo(HaveOccurred()) } - return nil + return err }) Expect(err).NotTo(HaveOccurred()) - Expect(len(cmds)).To(Equal(20000)) - for _, cmd := range cmds { - Expect(cmd.Err()).NotTo(HaveOccurred()) - } - get := client.Get("key") - Expect(get.Err()).NotTo(HaveOccurred()) - Expect(get.Val()).To(Equal("20000")) + num, err := client.Get("key").Int64() + Expect(err).NotTo(HaveOccurred()) + Expect(num).To(Equal(int64(N))) }) It("should recover from bad connection", func() { @@ -140,22 +133,21 @@ var _ = Describe("Tx", func() { err = client.Pool().Put(cn) Expect(err).NotTo(HaveOccurred()) - tx, err := client.Watch() - Expect(err).NotTo(HaveOccurred()) - defer func() { - Expect(tx.Close()).NotTo(HaveOccurred()) - }() + do := func() error { + err := client.Watch(func(tx *redis.Tx) error { + _, err := tx.MultiExec(func() error { + tx.Ping() + return nil + }) + return err + }) + return err + } - _, err = tx.Exec(func() error { - tx.Ping() - return nil - }) + err = do() Expect(err).To(MatchError("bad connection")) - _, err = tx.Exec(func() error { - tx.Ping() - return nil - }) + err = do() Expect(err).NotTo(HaveOccurred()) }) @@ -168,21 +160,20 @@ var _ = Describe("Tx", func() { err = client.Pool().Put(cn) Expect(err).NotTo(HaveOccurred()) - { - tx, err := client.Watch("key") - Expect(err).To(MatchError("bad connection")) - Expect(tx).To(BeNil()) + do := func() error { + err := client.Watch(func(tx *redis.Tx) error { + _, err := tx.MultiExec(func() error { + return nil + }) + return err + }, "key") + return err } - { - tx, err := client.Watch("key") - Expect(err).NotTo(HaveOccurred()) + err = do() + Expect(err).To(MatchError("bad connection")) - err = tx.Ping().Err() - Expect(err).NotTo(HaveOccurred()) - - err = tx.Close() - Expect(err).NotTo(HaveOccurred()) - } + err = do() + Expect(err).NotTo(HaveOccurred()) }) })