WIP mocking proof of concept

This commit is contained in:
Vladimir Mihailenco 2017-09-05 09:14:16 +03:00
parent ddbd81ea6c
commit dcd1da8247
7 changed files with 120 additions and 25 deletions

View File

@ -621,13 +621,13 @@ func (c *ClusterClient) Close() error {
func (c *ClusterClient) Process(cmd Cmder) error { func (c *ClusterClient) Process(cmd Cmder) error {
state, err := c.state() state, err := c.state()
if err != nil { if err != nil {
cmd.setErr(err) cmd.SetErr(err)
return err return err
} }
_, node, err := c.cmdSlotAndNode(state, cmd) _, node, err := c.cmdSlotAndNode(state, cmd)
if err != nil { if err != nil {
cmd.setErr(err) cmd.SetErr(err)
return err return err
} }

View File

@ -36,18 +36,18 @@ type Cmder interface {
Name() string Name() string
readReply(*pool.Conn) error readReply(*pool.Conn) error
setErr(error)
readTimeout() *time.Duration readTimeout() *time.Duration
Err() error Err() error
SetErr(error)
fmt.Stringer fmt.Stringer
} }
func setCmdsErr(cmds []Cmder, e error) { func setCmdsErr(cmds []Cmder, e error) {
for _, cmd := range cmds { for _, cmd := range cmds {
if cmd.Err() == nil { if cmd.Err() == nil {
cmd.setErr(e) cmd.SetErr(e)
} }
} }
} }
@ -154,7 +154,7 @@ func (cmd *baseCmd) setReadTimeout(d time.Duration) {
cmd._readTimeout = &d cmd._readTimeout = &d
} }
func (cmd *baseCmd) setErr(e error) { func (cmd *baseCmd) SetErr(e error) {
cmd.err = e cmd.err = e
} }
@ -437,6 +437,10 @@ func NewStringCmd(args ...interface{}) *StringCmd {
} }
} }
func (cmd *StringCmd) SetVal(val string) {
cmd.val = []byte(val)
}
func (cmd *StringCmd) Val() string { func (cmd *StringCmd) Val() string {
return internal.BytesToString(cmd.val) return internal.BytesToString(cmd.val)
} }

View File

@ -412,3 +412,26 @@ func ExampleNewUniversalClient_cluster() {
client.Ping() client.Ping()
} }
func ExampleReplay() {
replay := redis.NewReplay()
cmd := redis.NewStringCmd("get", "foo")
cmd.SetVal("bar")
replay.Add(cmd)
cmd = redis.NewStringCmd("get", "hello")
cmd.SetErr(redis.Nil)
replay.Add(cmd)
client := redis.NewClient(&redis.Options{})
replay.WrapClient(client)
foo := client.Get("foo")
fmt.Println(foo)
hello := client.Get("hello")
fmt.Println(hello)
// Output: get foo: bar
// get hello: redis: nil
}

View File

@ -130,7 +130,7 @@ func (c *baseClient) defaultProcess(cmd Cmder) error {
cn, _, err := c.getConn() cn, _, err := c.getConn()
if err != nil { if err != nil {
cmd.setErr(err) cmd.SetErr(err)
if internal.IsRetryableError(err) { if internal.IsRetryableError(err) {
continue continue
} }
@ -140,7 +140,7 @@ func (c *baseClient) defaultProcess(cmd Cmder) error {
cn.SetWriteTimeout(c.opt.WriteTimeout) cn.SetWriteTimeout(c.opt.WriteTimeout)
if err := writeCmd(cn, cmd); err != nil { if err := writeCmd(cn, cmd); err != nil {
c.releaseConn(cn, err) c.releaseConn(cn, err)
cmd.setErr(err) cmd.SetErr(err)
if internal.IsRetryableError(err) { if internal.IsRetryableError(err) {
continue continue
} }

68
replay.go Normal file
View File

@ -0,0 +1,68 @@
package redis
import (
"fmt"
"reflect"
)
type replayAction struct {
cmd Cmder
}
type Replay struct {
actions []*replayAction
}
func NewReplay() *Replay {
return &Replay{}
}
func (r *Replay) Add(cmd Cmder) *Replay {
action := &replayAction{
cmd: cmd,
}
r.actions = append(r.actions, action)
return r
}
func (r *Replay) WrapClient(c *Client) {
c.WrapProcess(func(oldProcess func(cmd Cmder) error) func(cmd Cmder) error {
return r.process
})
}
func (r *Replay) process(cmd Cmder) error {
for _, a := range r.actions {
if argsEqual(cmd.args(), a.cmd.args()) {
if err := setCmd(cmd, a.cmd); err != nil {
return err
}
return cmd.Err()
}
}
cmd.SetErr(fmt.Errorf("unexpected cmd: %s", cmd))
return cmd.Err()
}
func argsEqual(a []interface{}, b []interface{}) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
func setCmd(dst, src interface{}) error {
dstv := reflect.ValueOf(dst).Elem()
srcv := reflect.ValueOf(src).Elem()
if dstv.Type() != srcv.Type() {
return fmt.Errorf("dst and src commands have different types: %T and %T", dst, src)
}
dstv.Set(srcv)
return nil
}

View File

@ -6,7 +6,7 @@ import "time"
func NewCmdResult(val interface{}, err error) *Cmd { func NewCmdResult(val interface{}, err error) *Cmd {
var cmd Cmd var cmd Cmd
cmd.val = val cmd.val = val
cmd.setErr(err) cmd.SetErr(err)
return &cmd return &cmd
} }
@ -14,7 +14,7 @@ func NewCmdResult(val interface{}, err error) *Cmd {
func NewSliceResult(val []interface{}, err error) *SliceCmd { func NewSliceResult(val []interface{}, err error) *SliceCmd {
var cmd SliceCmd var cmd SliceCmd
cmd.val = val cmd.val = val
cmd.setErr(err) cmd.SetErr(err)
return &cmd return &cmd
} }
@ -22,7 +22,7 @@ func NewSliceResult(val []interface{}, err error) *SliceCmd {
func NewStatusResult(val string, err error) *StatusCmd { func NewStatusResult(val string, err error) *StatusCmd {
var cmd StatusCmd var cmd StatusCmd
cmd.val = val cmd.val = val
cmd.setErr(err) cmd.SetErr(err)
return &cmd return &cmd
} }
@ -30,7 +30,7 @@ func NewStatusResult(val string, err error) *StatusCmd {
func NewIntResult(val int64, err error) *IntCmd { func NewIntResult(val int64, err error) *IntCmd {
var cmd IntCmd var cmd IntCmd
cmd.val = val cmd.val = val
cmd.setErr(err) cmd.SetErr(err)
return &cmd return &cmd
} }
@ -38,7 +38,7 @@ func NewIntResult(val int64, err error) *IntCmd {
func NewDurationResult(val time.Duration, err error) *DurationCmd { func NewDurationResult(val time.Duration, err error) *DurationCmd {
var cmd DurationCmd var cmd DurationCmd
cmd.val = val cmd.val = val
cmd.setErr(err) cmd.SetErr(err)
return &cmd return &cmd
} }
@ -46,7 +46,7 @@ func NewDurationResult(val time.Duration, err error) *DurationCmd {
func NewBoolResult(val bool, err error) *BoolCmd { func NewBoolResult(val bool, err error) *BoolCmd {
var cmd BoolCmd var cmd BoolCmd
cmd.val = val cmd.val = val
cmd.setErr(err) cmd.SetErr(err)
return &cmd return &cmd
} }
@ -54,7 +54,7 @@ func NewBoolResult(val bool, err error) *BoolCmd {
func NewStringResult(val string, err error) *StringCmd { func NewStringResult(val string, err error) *StringCmd {
var cmd StringCmd var cmd StringCmd
cmd.val = []byte(val) cmd.val = []byte(val)
cmd.setErr(err) cmd.SetErr(err)
return &cmd return &cmd
} }
@ -62,7 +62,7 @@ func NewStringResult(val string, err error) *StringCmd {
func NewFloatResult(val float64, err error) *FloatCmd { func NewFloatResult(val float64, err error) *FloatCmd {
var cmd FloatCmd var cmd FloatCmd
cmd.val = val cmd.val = val
cmd.setErr(err) cmd.SetErr(err)
return &cmd return &cmd
} }
@ -70,7 +70,7 @@ func NewFloatResult(val float64, err error) *FloatCmd {
func NewStringSliceResult(val []string, err error) *StringSliceCmd { func NewStringSliceResult(val []string, err error) *StringSliceCmd {
var cmd StringSliceCmd var cmd StringSliceCmd
cmd.val = val cmd.val = val
cmd.setErr(err) cmd.SetErr(err)
return &cmd return &cmd
} }
@ -78,7 +78,7 @@ func NewStringSliceResult(val []string, err error) *StringSliceCmd {
func NewBoolSliceResult(val []bool, err error) *BoolSliceCmd { func NewBoolSliceResult(val []bool, err error) *BoolSliceCmd {
var cmd BoolSliceCmd var cmd BoolSliceCmd
cmd.val = val cmd.val = val
cmd.setErr(err) cmd.SetErr(err)
return &cmd return &cmd
} }
@ -86,7 +86,7 @@ func NewBoolSliceResult(val []bool, err error) *BoolSliceCmd {
func NewStringStringMapResult(val map[string]string, err error) *StringStringMapCmd { func NewStringStringMapResult(val map[string]string, err error) *StringStringMapCmd {
var cmd StringStringMapCmd var cmd StringStringMapCmd
cmd.val = val cmd.val = val
cmd.setErr(err) cmd.SetErr(err)
return &cmd return &cmd
} }
@ -94,7 +94,7 @@ func NewStringStringMapResult(val map[string]string, err error) *StringStringMap
func NewStringIntMapCmdResult(val map[string]int64, err error) *StringIntMapCmd { func NewStringIntMapCmdResult(val map[string]int64, err error) *StringIntMapCmd {
var cmd StringIntMapCmd var cmd StringIntMapCmd
cmd.val = val cmd.val = val
cmd.setErr(err) cmd.SetErr(err)
return &cmd return &cmd
} }
@ -102,7 +102,7 @@ func NewStringIntMapCmdResult(val map[string]int64, err error) *StringIntMapCmd
func NewZSliceCmdResult(val []Z, err error) *ZSliceCmd { func NewZSliceCmdResult(val []Z, err error) *ZSliceCmd {
var cmd ZSliceCmd var cmd ZSliceCmd
cmd.val = val cmd.val = val
cmd.setErr(err) cmd.SetErr(err)
return &cmd return &cmd
} }
@ -111,7 +111,7 @@ func NewScanCmdResult(keys []string, cursor uint64, err error) *ScanCmd {
var cmd ScanCmd var cmd ScanCmd
cmd.page = keys cmd.page = keys
cmd.cursor = cursor cmd.cursor = cursor
cmd.setErr(err) cmd.SetErr(err)
return &cmd return &cmd
} }
@ -119,7 +119,7 @@ func NewScanCmdResult(keys []string, cursor uint64, err error) *ScanCmd {
func NewClusterSlotsCmdResult(val []ClusterSlot, err error) *ClusterSlotsCmd { func NewClusterSlotsCmdResult(val []ClusterSlot, err error) *ClusterSlotsCmd {
var cmd ClusterSlotsCmd var cmd ClusterSlotsCmd
cmd.val = val cmd.val = val
cmd.setErr(err) cmd.SetErr(err)
return &cmd return &cmd
} }
@ -127,7 +127,7 @@ func NewClusterSlotsCmdResult(val []ClusterSlot, err error) *ClusterSlotsCmd {
func NewGeoLocationCmdResult(val []GeoLocation, err error) *GeoLocationCmd { func NewGeoLocationCmdResult(val []GeoLocation, err error) *GeoLocationCmd {
var cmd GeoLocationCmd var cmd GeoLocationCmd
cmd.locations = val cmd.locations = val
cmd.setErr(err) cmd.SetErr(err)
return &cmd return &cmd
} }
@ -135,6 +135,6 @@ func NewGeoLocationCmdResult(val []GeoLocation, err error) *GeoLocationCmd {
func NewCommandsInfoCmdResult(val map[string]*CommandInfo, err error) *CommandsInfoCmd { func NewCommandsInfoCmdResult(val map[string]*CommandInfo, err error) *CommandsInfoCmd {
var cmd CommandsInfoCmd var cmd CommandsInfoCmd
cmd.val = val cmd.val = val
cmd.setErr(err) cmd.SetErr(err)
return &cmd return &cmd
} }

View File

@ -338,7 +338,7 @@ func (c *Ring) cmdShard(cmd Cmder) (*ringShard, error) {
func (c *Ring) Process(cmd Cmder) error { func (c *Ring) Process(cmd Cmder) error {
shard, err := c.cmdShard(cmd) shard, err := c.cmdShard(cmd)
if err != nil { if err != nil {
cmd.setErr(err) cmd.SetErr(err)
return err return err
} }
return shard.Client.Process(cmd) return shard.Client.Process(cmd)