From dcd1da8247afb1f785251f70a6f2bec7e34413cd Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Tue, 5 Sep 2017 09:14:16 +0300 Subject: [PATCH] WIP mocking proof of concept --- cluster.go | 4 +-- command.go | 10 +++++--- example_test.go | 23 +++++++++++++++++ redis.go | 4 +-- replay.go | 68 +++++++++++++++++++++++++++++++++++++++++++++++++ result.go | 34 ++++++++++++------------- ring.go | 2 +- 7 files changed, 120 insertions(+), 25 deletions(-) create mode 100644 replay.go diff --git a/cluster.go b/cluster.go index 72bace7..cc3124f 100644 --- a/cluster.go +++ b/cluster.go @@ -621,13 +621,13 @@ func (c *ClusterClient) Close() error { func (c *ClusterClient) Process(cmd Cmder) error { state, err := c.state() if err != nil { - cmd.setErr(err) + cmd.SetErr(err) return err } _, node, err := c.cmdSlotAndNode(state, cmd) if err != nil { - cmd.setErr(err) + cmd.SetErr(err) return err } diff --git a/command.go b/command.go index a796a93..d982a5c 100644 --- a/command.go +++ b/command.go @@ -36,18 +36,18 @@ type Cmder interface { Name() string readReply(*pool.Conn) error - setErr(error) readTimeout() *time.Duration Err() error + SetErr(error) fmt.Stringer } func setCmdsErr(cmds []Cmder, e error) { for _, cmd := range cmds { if cmd.Err() == nil { - cmd.setErr(e) + cmd.SetErr(e) } } } @@ -154,7 +154,7 @@ func (cmd *baseCmd) setReadTimeout(d time.Duration) { cmd._readTimeout = &d } -func (cmd *baseCmd) setErr(e error) { +func (cmd *baseCmd) SetErr(e error) { cmd.err = e } @@ -437,6 +437,10 @@ func NewStringCmd(args ...interface{}) *StringCmd { } } +func (cmd *StringCmd) SetVal(val string) { + cmd.val = []byte(val) +} + func (cmd *StringCmd) Val() string { return internal.BytesToString(cmd.val) } diff --git a/example_test.go b/example_test.go index 7e04cd4..548fd64 100644 --- a/example_test.go +++ b/example_test.go @@ -412,3 +412,26 @@ func ExampleNewUniversalClient_cluster() { client.Ping() } + +func ExampleReplay() { + replay := redis.NewReplay() + + cmd := redis.NewStringCmd("get", "foo") + cmd.SetVal("bar") + replay.Add(cmd) + + cmd = redis.NewStringCmd("get", "hello") + cmd.SetErr(redis.Nil) + replay.Add(cmd) + + client := redis.NewClient(&redis.Options{}) + replay.WrapClient(client) + + foo := client.Get("foo") + fmt.Println(foo) + + hello := client.Get("hello") + fmt.Println(hello) + // Output: get foo: bar + // get hello: redis: nil +} diff --git a/redis.go b/redis.go index db1f39c..e58f01b 100644 --- a/redis.go +++ b/redis.go @@ -130,7 +130,7 @@ func (c *baseClient) defaultProcess(cmd Cmder) error { cn, _, err := c.getConn() if err != nil { - cmd.setErr(err) + cmd.SetErr(err) if internal.IsRetryableError(err) { continue } @@ -140,7 +140,7 @@ func (c *baseClient) defaultProcess(cmd Cmder) error { cn.SetWriteTimeout(c.opt.WriteTimeout) if err := writeCmd(cn, cmd); err != nil { c.releaseConn(cn, err) - cmd.setErr(err) + cmd.SetErr(err) if internal.IsRetryableError(err) { continue } diff --git a/replay.go b/replay.go new file mode 100644 index 0000000..dfc8584 --- /dev/null +++ b/replay.go @@ -0,0 +1,68 @@ +package redis + +import ( + "fmt" + "reflect" +) + +type replayAction struct { + cmd Cmder +} + +type Replay struct { + actions []*replayAction +} + +func NewReplay() *Replay { + return &Replay{} +} + +func (r *Replay) Add(cmd Cmder) *Replay { + action := &replayAction{ + cmd: cmd, + } + r.actions = append(r.actions, action) + return r +} + +func (r *Replay) WrapClient(c *Client) { + c.WrapProcess(func(oldProcess func(cmd Cmder) error) func(cmd Cmder) error { + return r.process + }) +} + +func (r *Replay) process(cmd Cmder) error { + for _, a := range r.actions { + if argsEqual(cmd.args(), a.cmd.args()) { + if err := setCmd(cmd, a.cmd); err != nil { + return err + } + return cmd.Err() + } + } + + cmd.SetErr(fmt.Errorf("unexpected cmd: %s", cmd)) + return cmd.Err() +} + +func argsEqual(a []interface{}, b []interface{}) bool { + if len(a) != len(b) { + return false + } + for i, v := range a { + if v != b[i] { + return false + } + } + return true +} + +func setCmd(dst, src interface{}) error { + dstv := reflect.ValueOf(dst).Elem() + srcv := reflect.ValueOf(src).Elem() + if dstv.Type() != srcv.Type() { + return fmt.Errorf("dst and src commands have different types: %T and %T", dst, src) + } + dstv.Set(srcv) + return nil +} diff --git a/result.go b/result.go index 28cea5c..c414e4c 100644 --- a/result.go +++ b/result.go @@ -6,7 +6,7 @@ import "time" func NewCmdResult(val interface{}, err error) *Cmd { var cmd Cmd cmd.val = val - cmd.setErr(err) + cmd.SetErr(err) return &cmd } @@ -14,7 +14,7 @@ func NewCmdResult(val interface{}, err error) *Cmd { func NewSliceResult(val []interface{}, err error) *SliceCmd { var cmd SliceCmd cmd.val = val - cmd.setErr(err) + cmd.SetErr(err) return &cmd } @@ -22,7 +22,7 @@ func NewSliceResult(val []interface{}, err error) *SliceCmd { func NewStatusResult(val string, err error) *StatusCmd { var cmd StatusCmd cmd.val = val - cmd.setErr(err) + cmd.SetErr(err) return &cmd } @@ -30,7 +30,7 @@ func NewStatusResult(val string, err error) *StatusCmd { func NewIntResult(val int64, err error) *IntCmd { var cmd IntCmd cmd.val = val - cmd.setErr(err) + cmd.SetErr(err) return &cmd } @@ -38,7 +38,7 @@ func NewIntResult(val int64, err error) *IntCmd { func NewDurationResult(val time.Duration, err error) *DurationCmd { var cmd DurationCmd cmd.val = val - cmd.setErr(err) + cmd.SetErr(err) return &cmd } @@ -46,7 +46,7 @@ func NewDurationResult(val time.Duration, err error) *DurationCmd { func NewBoolResult(val bool, err error) *BoolCmd { var cmd BoolCmd cmd.val = val - cmd.setErr(err) + cmd.SetErr(err) return &cmd } @@ -54,7 +54,7 @@ func NewBoolResult(val bool, err error) *BoolCmd { func NewStringResult(val string, err error) *StringCmd { var cmd StringCmd cmd.val = []byte(val) - cmd.setErr(err) + cmd.SetErr(err) return &cmd } @@ -62,7 +62,7 @@ func NewStringResult(val string, err error) *StringCmd { func NewFloatResult(val float64, err error) *FloatCmd { var cmd FloatCmd cmd.val = val - cmd.setErr(err) + cmd.SetErr(err) return &cmd } @@ -70,7 +70,7 @@ func NewFloatResult(val float64, err error) *FloatCmd { func NewStringSliceResult(val []string, err error) *StringSliceCmd { var cmd StringSliceCmd cmd.val = val - cmd.setErr(err) + cmd.SetErr(err) return &cmd } @@ -78,7 +78,7 @@ func NewStringSliceResult(val []string, err error) *StringSliceCmd { func NewBoolSliceResult(val []bool, err error) *BoolSliceCmd { var cmd BoolSliceCmd cmd.val = val - cmd.setErr(err) + cmd.SetErr(err) return &cmd } @@ -86,7 +86,7 @@ func NewBoolSliceResult(val []bool, err error) *BoolSliceCmd { func NewStringStringMapResult(val map[string]string, err error) *StringStringMapCmd { var cmd StringStringMapCmd cmd.val = val - cmd.setErr(err) + cmd.SetErr(err) return &cmd } @@ -94,7 +94,7 @@ func NewStringStringMapResult(val map[string]string, err error) *StringStringMap func NewStringIntMapCmdResult(val map[string]int64, err error) *StringIntMapCmd { var cmd StringIntMapCmd cmd.val = val - cmd.setErr(err) + cmd.SetErr(err) return &cmd } @@ -102,7 +102,7 @@ func NewStringIntMapCmdResult(val map[string]int64, err error) *StringIntMapCmd func NewZSliceCmdResult(val []Z, err error) *ZSliceCmd { var cmd ZSliceCmd cmd.val = val - cmd.setErr(err) + cmd.SetErr(err) return &cmd } @@ -111,7 +111,7 @@ func NewScanCmdResult(keys []string, cursor uint64, err error) *ScanCmd { var cmd ScanCmd cmd.page = keys cmd.cursor = cursor - cmd.setErr(err) + cmd.SetErr(err) return &cmd } @@ -119,7 +119,7 @@ func NewScanCmdResult(keys []string, cursor uint64, err error) *ScanCmd { func NewClusterSlotsCmdResult(val []ClusterSlot, err error) *ClusterSlotsCmd { var cmd ClusterSlotsCmd cmd.val = val - cmd.setErr(err) + cmd.SetErr(err) return &cmd } @@ -127,7 +127,7 @@ func NewClusterSlotsCmdResult(val []ClusterSlot, err error) *ClusterSlotsCmd { func NewGeoLocationCmdResult(val []GeoLocation, err error) *GeoLocationCmd { var cmd GeoLocationCmd cmd.locations = val - cmd.setErr(err) + cmd.SetErr(err) return &cmd } @@ -135,6 +135,6 @@ func NewGeoLocationCmdResult(val []GeoLocation, err error) *GeoLocationCmd { func NewCommandsInfoCmdResult(val map[string]*CommandInfo, err error) *CommandsInfoCmd { var cmd CommandsInfoCmd cmd.val = val - cmd.setErr(err) + cmd.SetErr(err) return &cmd } diff --git a/ring.go b/ring.go index a9314fb..1a7415c 100644 --- a/ring.go +++ b/ring.go @@ -338,7 +338,7 @@ func (c *Ring) cmdShard(cmd Cmder) (*ringShard, error) { func (c *Ring) Process(cmd Cmder) error { shard, err := c.cmdShard(cmd) if err != nil { - cmd.setErr(err) + cmd.SetErr(err) return err } return shard.Client.Process(cmd)