Better CLIENT tests

This commit is contained in:
tidwall 2022-09-24 13:41:36 -07:00
parent e6cced4c4a
commit d8ecbba0be
4 changed files with 131 additions and 217 deletions

View File

@ -2,7 +2,6 @@ package server
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"sort" "sort"
@ -32,6 +31,8 @@ type Client struct {
name string // optional defined name name string // optional defined name
opened time.Time // when the client was created/opened, unix nano opened time.Time // when the client was created/opened, unix nano
last time.Time // last client request/response, unix nano last time.Time // last client request/response, unix nano
closer io.Closer // used to close the connection
} }
// Write ... // Write ...
@ -40,33 +41,19 @@ func (client *Client) Write(b []byte) (n int, err error) {
return len(b), nil return len(b), nil
} }
type byID []*Client
func (arr byID) Len() int {
return len(arr)
}
func (arr byID) Less(a, b int) bool {
return arr[a].id < arr[b].id
}
func (arr byID) Swap(a, b int) {
arr[a], arr[b] = arr[b], arr[a]
}
// CLIENT (LIST | KILL | GETNAME | SETNAME) // CLIENT (LIST | KILL | GETNAME | SETNAME)
func (s *Server) cmdCLIENT(msg *Message, client *Client) (resp.Value, error) { func (s *Server) cmdCLIENT(_msg *Message, client *Client) (resp.Value, error) {
start := time.Now() start := time.Now()
if len(msg.Args) == 1 { args := _msg.Args
return NOMessage, errInvalidNumberOfArguments if len(args) == 1 {
return retrerr(errInvalidNumberOfArguments)
} }
switch strings.ToLower(msg.Args[1]) {
default: switch strings.ToLower(args[1]) {
return NOMessage, clientErrorf(
"Syntax error, try CLIENT (LIST | KILL | GETNAME | SETNAME)",
)
case "list": case "list":
if len(msg.Args) != 2 { if len(args) != 2 {
return NOMessage, errInvalidNumberOfArguments return retrerr(errInvalidNumberOfArguments)
} }
var list []*Client var list []*Client
s.connsmu.RLock() s.connsmu.RLock()
@ -74,7 +61,9 @@ func (s *Server) cmdCLIENT(msg *Message, client *Client) (resp.Value, error) {
list = append(list, cc) list = append(list, cc)
} }
s.connsmu.RUnlock() s.connsmu.RUnlock()
sort.Sort(byID(list)) sort.Slice(list, func(i, j int) bool {
return list[i].id < list[j].id
})
now := time.Now() now := time.Now()
var buf []byte var buf []byte
for _, client := range list { for _, client := range list {
@ -90,7 +79,7 @@ func (s *Server) cmdCLIENT(msg *Message, client *Client) (resp.Value, error) {
) )
client.mu.Unlock() client.mu.Unlock()
} }
if msg.OutputType == JSON { if _msg.OutputType == JSON {
// Create a map of all key/value info fields // Create a map of all key/value info fields
var cmap []map[string]interface{} var cmap []map[string]interface{}
clients := strings.Split(string(buf), "\n") clients := strings.Split(string(buf), "\n")
@ -116,109 +105,107 @@ func (s *Server) cmdCLIENT(msg *Message, client *Client) (resp.Value, error) {
} }
return resp.BytesValue(buf), nil return resp.BytesValue(buf), nil
case "getname": case "getname":
if len(msg.Args) != 2 { if len(args) != 2 {
return NOMessage, errInvalidNumberOfArguments return retrerr(errInvalidNumberOfArguments)
} }
client.mu.Lock() client.mu.Lock()
name := client.name name := client.name
client.mu.Unlock() client.mu.Unlock()
if msg.OutputType == JSON { if _msg.OutputType == JSON {
return resp.StringValue(`{"ok":true,"name":` + jsonString(name) + return resp.StringValue(`{"ok":true,"name":` + jsonString(name) +
`,"elapsed":"` + time.Since(start).String() + "\"}"), nil `,"elapsed":"` + time.Since(start).String() + "\"}"), nil
} }
return resp.StringValue(name), nil return resp.StringValue(name), nil
case "setname": case "setname":
if len(msg.Args) != 3 { if len(args) != 3 {
return NOMessage, errInvalidNumberOfArguments return retrerr(errInvalidNumberOfArguments)
} }
name := msg.Args[2] name := _msg.Args[2]
for i := 0; i < len(name); i++ { for i := 0; i < len(name); i++ {
if name[i] < '!' || name[i] > '~' { if name[i] < '!' || name[i] > '~' {
return NOMessage, clientErrorf( return retrerr(clientErrorf(
"Client names cannot contain spaces, newlines or special characters.", "Client names cannot contain spaces, newlines or special characters.",
) ))
} }
} }
client.mu.Lock() client.mu.Lock()
client.name = name client.name = name
client.mu.Unlock() client.mu.Unlock()
switch msg.OutputType { if _msg.OutputType == JSON {
case JSON: return resp.StringValue(`{"ok":true,"elapsed":"` +
return resp.StringValue(`{"ok":true,"elapsed":"` + time.Since(start).String() + "\"}"), nil time.Since(start).String() + "\"}"), nil
case RESP:
return resp.SimpleStringValue("OK"), nil
} }
return resp.SimpleStringValue("OK"), nil
case "kill": case "kill":
if len(msg.Args) < 3 { if len(args) < 3 {
return NOMessage, errInvalidNumberOfArguments return retrerr(errInvalidNumberOfArguments)
} }
var useAddr bool var useAddr bool
var addr string var addr string
var useID bool var useID bool
var id string var id string
for i := 2; i < len(msg.Args); i++ { for i := 2; i < len(args); i++ {
arg := msg.Args[i] if useAddr || useID {
return retrerr(errInvalidNumberOfArguments)
}
arg := args[i]
if strings.Contains(arg, ":") { if strings.Contains(arg, ":") {
addr = arg addr = arg
useAddr = true useAddr = true
break } else {
} switch strings.ToLower(arg) {
switch strings.ToLower(arg) { case "addr":
default: i++
return NOMessage, clientErrorf("No such client") if i == len(args) {
case "addr": return retrerr(errInvalidNumberOfArguments)
i++ }
if i == len(msg.Args) { addr = args[i]
return NOMessage, errors.New("syntax error") useAddr = true
case "id":
i++
if i == len(args) {
return retrerr(errInvalidNumberOfArguments)
}
id = args[i]
useID = true
default:
return retrerr(clientErrorf("No such client"))
} }
addr = msg.Args[i]
useAddr = true
case "id":
i++
if i == len(msg.Args) {
return NOMessage, errors.New("syntax error")
}
id = msg.Args[i]
useID = true
} }
} }
var cclose *Client var closing []io.Closer
s.connsmu.RLock() s.connsmu.RLock()
for _, cc := range s.conns { for _, cc := range s.conns {
if useID && fmt.Sprintf("%d", cc.id) == id { if useID && fmt.Sprintf("%d", cc.id) == id {
cclose = cc if cc.closer != nil {
break closing = append(closing, cc.closer)
} else if useAddr && client.remoteAddr == addr { }
cclose = cc } else if useAddr {
break if cc.remoteAddr == addr {
if cc.closer != nil {
closing = append(closing, cc.closer)
}
}
} }
} }
s.connsmu.RUnlock() s.connsmu.RUnlock()
if cclose == nil { if len(closing) == 0 {
return NOMessage, clientErrorf("No such client") return retrerr(clientErrorf("No such client"))
} }
// go func() {
var res resp.Value // close the connections behind the scene
switch msg.OutputType { for _, closer := range closing {
case JSON: closer.Close()
res = resp.StringValue(`{"ok":true,"elapsed":"` + time.Since(start).String() + "\"}")
case RESP:
res = resp.SimpleStringValue("OK")
} }
// }()
client.conn.Close() if _msg.OutputType == JSON {
// closing self, return response now return resp.StringValue(`{"ok":true,"elapsed":"` +
// NOTE: This is the only exception where we do convert response to a string time.Since(start).String() + "\"}"), nil
var outBytes []byte
switch msg.OutputType {
case JSON:
outBytes = res.Bytes()
case RESP:
outBytes, _ = res.MarshalRESP()
} }
cclose.conn.Write(outBytes) return resp.SimpleStringValue("OK"), nil
cclose.conn.Close() default:
return res, nil return retrerr(clientErrorf(
"Syntax error, try CLIENT (LIST | KILL | GETNAME | SETNAME)",
))
} }
return NOMessage, errors.New("invalid output type")
} }

View File

@ -376,6 +376,7 @@ func (s *Server) netServe() error {
client.id = int(atomic.AddInt64(&clientID, 1)) client.id = int(atomic.AddInt64(&clientID, 1))
client.opened = time.Now() client.opened = time.Now()
client.remoteAddr = conn.RemoteAddr().String() client.remoteAddr = conn.RemoteAddr().String()
client.closer = conn
// add client to server map // add client to server map
s.connsmu.Lock() s.connsmu.Lock()

View File

@ -8,6 +8,7 @@ import (
"github.com/gomodule/redigo/redis" "github.com/gomodule/redigo/redis"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/pretty"
) )
func subTestClient(t *testing.T, mc *mockServer) { func subTestClient(t *testing.T, mc *mockServer) {
@ -62,6 +63,16 @@ func client_CLIENT_test(mc *mockServer) error {
} }
conns = append(conns, conn) conns = append(conns, conn)
} }
_, err := conns[1].Do("CLIENT", "setname", "cl1")
if err != nil {
return err
}
_, err = conns[2].Do("CLIENT", "setname", "cl2")
if err != nil {
return err
}
if _, err := mc.Do("OUTPUT", "JSON"); err != nil { if _, err := mc.Do("OUTPUT", "JSON"); err != nil {
return err return err
} }
@ -73,11 +84,15 @@ func client_CLIENT_test(mc *mockServer) error {
if !ok { if !ok {
return errors.New("Failed to type assert CLIENT response") return errors.New("Failed to type assert CLIENT response")
} }
sres := string(bres) sres := string(pretty.Pretty(bres))
if len(gjson.Get(sres, "list").Array()) < numConns { if int(gjson.Get(sres, "list.#").Int()) < numConns {
return errors.New("Invalid number of connections") return errors.New("Invalid number of connections")
} }
client13ID := gjson.Get(sres, "list.13.id").String()
client14Addr := gjson.Get(sres, "list.14.addr").String()
client15Addr := gjson.Get(sres, "list.15.addr").String()
return mc.DoBatch( return mc.DoBatch(
Do("CLIENT", "list").JSON().Func(func(s string) error { Do("CLIENT", "list").JSON().Func(func(s string) error {
if int(gjson.Get(s, "list.#").Int()) < numConns { if int(gjson.Get(s, "list.#").Int()) < numConns {
@ -97,6 +112,25 @@ func client_CLIENT_test(mc *mockServer) error {
Do("CLIENT", "getname", "arg3").Err(`wrong number of arguments for 'client' command`), Do("CLIENT", "getname", "arg3").Err(`wrong number of arguments for 'client' command`),
Do("CLIENT", "getname").JSON().Str(`{"ok":true,"name":""}`), Do("CLIENT", "getname").JSON().Str(`{"ok":true,"name":""}`),
Do("CLIENT", "getname").Str(``), Do("CLIENT", "getname").Str(``),
Do("CLIENT", "setname", "abc").OK(),
Do("CLIENT", "getname").Str(`abc`),
Do("CLIENT", "getname").JSON().Str(`{"ok":true,"name":"abc"}`),
Do("CLIENT", "setname", "abc", "efg").Err(`wrong number of arguments for 'client' command`),
Do("CLIENT", "setname", " abc ").Err(`Client names cannot contain spaces, newlines or special characters.`),
Do("CLIENT", "setname", "abcd").JSON().OK(),
Do("CLIENT", "kill", "name", "abcd").Err("No such client"),
Do("CLIENT", "getname").Str(`abcd`),
Do("CLIENT", "kill").Err(`wrong number of arguments for 'client' command`),
Do("CLIENT", "kill", "").Err(`No such client`),
Do("CLIENT", "kill", "abcd").Err(`No such client`),
Do("CLIENT", "kill", "id", client13ID).OK(),
Do("CLIENT", "kill", "id").Err("wrong number of arguments for 'client' command"),
Do("CLIENT", "kill", client14Addr).OK(),
Do("CLIENT", "kill", client14Addr, "yikes").Err("wrong number of arguments for 'client' command"),
Do("CLIENT", "kill", "addr").Err("wrong number of arguments for 'client' command"),
Do("CLIENT", "kill", "addr", client15Addr).JSON().OK(),
Do("CLIENT", "kill", "addr", client14Addr, "yikes").Err("wrong number of arguments for 'client' command"),
Do("CLIENT", "kill", "id", "1000").Err("No such client"),
) )
} }

View File

@ -3,9 +3,6 @@ package tests
import ( import (
"errors" "errors"
"fmt" "fmt"
"go/ast"
"go/parser"
"go/token"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
@ -21,10 +18,13 @@ type IO struct {
out any out any
sleep bool sleep bool
dur time.Duration dur time.Duration
cfile string
cln int
} }
func Do(args ...any) *IO { func Do(args ...any) *IO {
return &IO{args: args} _, cfile, cln, _ := runtime.Caller(1)
return &IO{args: args, cfile: cfile, cln: cln}
} }
func (cmd *IO) JSON() *IO { func (cmd *IO) JSON() *IO {
cmd.json = true cmd.json = true
@ -83,136 +83,29 @@ func Sleep(duration time.Duration) *IO {
return &IO{sleep: true, dur: duration} return &IO{sleep: true, dur: duration}
} }
type ioVisitor struct {
fset *token.FileSet
ln int
pos int
got bool
data string
end int
done bool
index int
nidx int
frag string
fpos int
}
func (v *ioVisitor) Visit(n ast.Node) ast.Visitor {
if n == nil || v.done {
return nil
}
if v.got {
if int(n.Pos()) > v.end {
v.done = true
return v
}
if n, ok := n.(*ast.CallExpr); ok {
frag := strings.TrimSpace(v.data[int(n.Pos())-1 : int(n.End())])
if _, ok := n.Fun.(*ast.Ident); ok {
if v.index == v.nidx {
frag = strings.TrimSpace(strings.TrimSuffix(frag, "."))
idx := strings.IndexByte(frag, '(')
if idx != -1 {
frag = frag[idx:]
}
v.frag = frag
v.done = true
v.fpos = int(n.Pos())
return v
}
v.nidx++
}
}
return v
}
if int(n.Pos()) == v.pos {
if n, ok := n.(*ast.CallExpr); ok {
v.end = int(n.Rparen)
v.got = true
return v
}
}
return v
}
func (cmd *IO) deepError(index int, err error) error { func (cmd *IO) deepError(index int, err error) error {
oerr := err frag := "(?)"
werr := func(err error) error { bdata, _ := os.ReadFile(cmd.cfile)
return fmt.Errorf("batch[%d]: %v: %v", index, oerr, err)
}
// analyse stack
_, file, ln, ok := runtime.Caller(3)
if !ok {
return werr(errors.New("runtime.Caller failed"))
}
// get the character position from line
bdata, err := os.ReadFile(file)
if err != nil {
return werr(err)
}
data := string(bdata) data := string(bdata)
ln := 1
var pos int
var iln int
var pln int
for i := 0; i < len(data); i++ { for i := 0; i < len(data); i++ {
if data[i] == '\n' { if data[i] == '\n' {
j := pln ln++
line := data[pln:i] if ln == cmd.cln {
pln = i + 1 data = data[i+1:]
iln++ i = strings.IndexByte(data, '(')
if iln == ln { if i != -1 {
line = strings.TrimSpace(line) j := strings.IndexByte(data[i:], ')')
if !strings.HasPrefix(line, "return mc.DoBatch(") { if j != -1 {
return oerr frag = string(data[i : j+i+1])
}
for ; j < len(data); j++ {
if data[j] == 'm' {
break
} }
} }
pos = j + 1
break break
} }
} }
} }
if pos == 0 { fsig := fmt.Sprintf("%s:%d", filepath.Base(cmd.cfile), cmd.cln)
return oerr emsg := err.Error()
}
fset := token.NewFileSet()
pfile, err := parser.ParseFile(fset, file, nil, 0)
if err != nil {
return werr(err)
}
v := &ioVisitor{
fset: fset,
ln: ln,
pos: pos,
data: string(data),
index: index,
}
ast.Walk(v, pfile)
if v.fpos == 0 {
return oerr
}
pln = 1
for i := 0; i < len(data); i++ {
if data[i] == '\n' {
if i > v.fpos {
break
}
pln++
}
}
fsig := fmt.Sprintf("%s:%d", filepath.Base(file), pln)
emsg := oerr.Error()
if strings.HasPrefix(emsg, "expected ") && if strings.HasPrefix(emsg, "expected ") &&
strings.Contains(emsg, ", got ") { strings.Contains(emsg, ", got ") {
emsg = "" + emsg = "" +
@ -223,9 +116,8 @@ func (cmd *IO) deepError(index int, err error) error {
emsg = "" + emsg = "" +
" ERROR: " + emsg " ERROR: " + emsg
} }
return fmt.Errorf("\n%s: entry[%d]\n COMMAND: %s\n%s", return fmt.Errorf("\n%s: entry[%d]\n COMMAND: %s\n%s",
fsig, index+1, v.frag, emsg) fsig, index+1, frag, emsg)
} }
func (mc *mockServer) doIOTest(index int, cmd *IO) error { func (mc *mockServer) doIOTest(index int, cmd *IO) error {