Better CLIENT tests

This commit is contained in:
tidwall 2022-09-24 13:41:36 -07:00
parent e6cced4c4a
commit d8ecbba0be
4 changed files with 131 additions and 217 deletions

View File

@ -2,7 +2,6 @@ package server
import (
"encoding/json"
"errors"
"fmt"
"io"
"sort"
@ -32,6 +31,8 @@ type Client struct {
name string // optional defined name
opened time.Time // when the client was created/opened, unix nano
last time.Time // last client request/response, unix nano
closer io.Closer // used to close the connection
}
// Write ...
@ -40,33 +41,19 @@ func (client *Client) Write(b []byte) (n int, err error) {
return len(b), nil
}
type byID []*Client
func (arr byID) Len() int {
return len(arr)
}
func (arr byID) Less(a, b int) bool {
return arr[a].id < arr[b].id
}
func (arr byID) Swap(a, b int) {
arr[a], arr[b] = arr[b], arr[a]
}
// CLIENT (LIST | KILL | GETNAME | SETNAME)
func (s *Server) cmdCLIENT(msg *Message, client *Client) (resp.Value, error) {
func (s *Server) cmdCLIENT(_msg *Message, client *Client) (resp.Value, error) {
start := time.Now()
if len(msg.Args) == 1 {
return NOMessage, errInvalidNumberOfArguments
args := _msg.Args
if len(args) == 1 {
return retrerr(errInvalidNumberOfArguments)
}
switch strings.ToLower(msg.Args[1]) {
default:
return NOMessage, clientErrorf(
"Syntax error, try CLIENT (LIST | KILL | GETNAME | SETNAME)",
)
switch strings.ToLower(args[1]) {
case "list":
if len(msg.Args) != 2 {
return NOMessage, errInvalidNumberOfArguments
if len(args) != 2 {
return retrerr(errInvalidNumberOfArguments)
}
var list []*Client
s.connsmu.RLock()
@ -74,7 +61,9 @@ func (s *Server) cmdCLIENT(msg *Message, client *Client) (resp.Value, error) {
list = append(list, cc)
}
s.connsmu.RUnlock()
sort.Sort(byID(list))
sort.Slice(list, func(i, j int) bool {
return list[i].id < list[j].id
})
now := time.Now()
var buf []byte
for _, client := range list {
@ -90,7 +79,7 @@ func (s *Server) cmdCLIENT(msg *Message, client *Client) (resp.Value, error) {
)
client.mu.Unlock()
}
if msg.OutputType == JSON {
if _msg.OutputType == JSON {
// Create a map of all key/value info fields
var cmap []map[string]interface{}
clients := strings.Split(string(buf), "\n")
@ -116,109 +105,107 @@ func (s *Server) cmdCLIENT(msg *Message, client *Client) (resp.Value, error) {
}
return resp.BytesValue(buf), nil
case "getname":
if len(msg.Args) != 2 {
return NOMessage, errInvalidNumberOfArguments
if len(args) != 2 {
return retrerr(errInvalidNumberOfArguments)
}
client.mu.Lock()
name := client.name
client.mu.Unlock()
if msg.OutputType == JSON {
if _msg.OutputType == JSON {
return resp.StringValue(`{"ok":true,"name":` + jsonString(name) +
`,"elapsed":"` + time.Since(start).String() + "\"}"), nil
}
return resp.StringValue(name), nil
case "setname":
if len(msg.Args) != 3 {
return NOMessage, errInvalidNumberOfArguments
if len(args) != 3 {
return retrerr(errInvalidNumberOfArguments)
}
name := msg.Args[2]
name := _msg.Args[2]
for i := 0; i < len(name); i++ {
if name[i] < '!' || name[i] > '~' {
return NOMessage, clientErrorf(
return retrerr(clientErrorf(
"Client names cannot contain spaces, newlines or special characters.",
)
))
}
}
client.mu.Lock()
client.name = name
client.mu.Unlock()
switch msg.OutputType {
case JSON:
return resp.StringValue(`{"ok":true,"elapsed":"` + time.Since(start).String() + "\"}"), nil
case RESP:
return resp.SimpleStringValue("OK"), nil
if _msg.OutputType == JSON {
return resp.StringValue(`{"ok":true,"elapsed":"` +
time.Since(start).String() + "\"}"), nil
}
return resp.SimpleStringValue("OK"), nil
case "kill":
if len(msg.Args) < 3 {
return NOMessage, errInvalidNumberOfArguments
if len(args) < 3 {
return retrerr(errInvalidNumberOfArguments)
}
var useAddr bool
var addr string
var useID bool
var id string
for i := 2; i < len(msg.Args); i++ {
arg := msg.Args[i]
for i := 2; i < len(args); i++ {
if useAddr || useID {
return retrerr(errInvalidNumberOfArguments)
}
arg := args[i]
if strings.Contains(arg, ":") {
addr = arg
useAddr = true
break
}
} else {
switch strings.ToLower(arg) {
default:
return NOMessage, clientErrorf("No such client")
case "addr":
i++
if i == len(msg.Args) {
return NOMessage, errors.New("syntax error")
if i == len(args) {
return retrerr(errInvalidNumberOfArguments)
}
addr = msg.Args[i]
addr = args[i]
useAddr = true
case "id":
i++
if i == len(msg.Args) {
return NOMessage, errors.New("syntax error")
if i == len(args) {
return retrerr(errInvalidNumberOfArguments)
}
id = msg.Args[i]
id = args[i]
useID = true
default:
return retrerr(clientErrorf("No such client"))
}
}
var cclose *Client
}
var closing []io.Closer
s.connsmu.RLock()
for _, cc := range s.conns {
if useID && fmt.Sprintf("%d", cc.id) == id {
cclose = cc
break
} else if useAddr && client.remoteAddr == addr {
cclose = cc
break
if cc.closer != nil {
closing = append(closing, cc.closer)
}
} else if useAddr {
if cc.remoteAddr == addr {
if cc.closer != nil {
closing = append(closing, cc.closer)
}
}
}
}
s.connsmu.RUnlock()
if cclose == nil {
return NOMessage, clientErrorf("No such client")
if len(closing) == 0 {
return retrerr(clientErrorf("No such client"))
}
var res resp.Value
switch msg.OutputType {
case JSON:
res = resp.StringValue(`{"ok":true,"elapsed":"` + time.Since(start).String() + "\"}")
case RESP:
res = resp.SimpleStringValue("OK")
// go func() {
// close the connections behind the scene
for _, closer := range closing {
closer.Close()
}
client.conn.Close()
// closing self, return response now
// NOTE: This is the only exception where we do convert response to a string
var outBytes []byte
switch msg.OutputType {
case JSON:
outBytes = res.Bytes()
case RESP:
outBytes, _ = res.MarshalRESP()
// }()
if _msg.OutputType == JSON {
return resp.StringValue(`{"ok":true,"elapsed":"` +
time.Since(start).String() + "\"}"), nil
}
cclose.conn.Write(outBytes)
cclose.conn.Close()
return res, nil
return resp.SimpleStringValue("OK"), nil
default:
return retrerr(clientErrorf(
"Syntax error, try CLIENT (LIST | KILL | GETNAME | SETNAME)",
))
}
return NOMessage, errors.New("invalid output type")
}

View File

@ -376,6 +376,7 @@ func (s *Server) netServe() error {
client.id = int(atomic.AddInt64(&clientID, 1))
client.opened = time.Now()
client.remoteAddr = conn.RemoteAddr().String()
client.closer = conn
// add client to server map
s.connsmu.Lock()

View File

@ -8,6 +8,7 @@ import (
"github.com/gomodule/redigo/redis"
"github.com/tidwall/gjson"
"github.com/tidwall/pretty"
)
func subTestClient(t *testing.T, mc *mockServer) {
@ -62,6 +63,16 @@ func client_CLIENT_test(mc *mockServer) error {
}
conns = append(conns, conn)
}
_, err := conns[1].Do("CLIENT", "setname", "cl1")
if err != nil {
return err
}
_, err = conns[2].Do("CLIENT", "setname", "cl2")
if err != nil {
return err
}
if _, err := mc.Do("OUTPUT", "JSON"); err != nil {
return err
}
@ -73,11 +84,15 @@ func client_CLIENT_test(mc *mockServer) error {
if !ok {
return errors.New("Failed to type assert CLIENT response")
}
sres := string(bres)
if len(gjson.Get(sres, "list").Array()) < numConns {
sres := string(pretty.Pretty(bres))
if int(gjson.Get(sres, "list.#").Int()) < numConns {
return errors.New("Invalid number of connections")
}
client13ID := gjson.Get(sres, "list.13.id").String()
client14Addr := gjson.Get(sres, "list.14.addr").String()
client15Addr := gjson.Get(sres, "list.15.addr").String()
return mc.DoBatch(
Do("CLIENT", "list").JSON().Func(func(s string) error {
if int(gjson.Get(s, "list.#").Int()) < numConns {
@ -97,6 +112,25 @@ func client_CLIENT_test(mc *mockServer) error {
Do("CLIENT", "getname", "arg3").Err(`wrong number of arguments for 'client' command`),
Do("CLIENT", "getname").JSON().Str(`{"ok":true,"name":""}`),
Do("CLIENT", "getname").Str(``),
Do("CLIENT", "setname", "abc").OK(),
Do("CLIENT", "getname").Str(`abc`),
Do("CLIENT", "getname").JSON().Str(`{"ok":true,"name":"abc"}`),
Do("CLIENT", "setname", "abc", "efg").Err(`wrong number of arguments for 'client' command`),
Do("CLIENT", "setname", " abc ").Err(`Client names cannot contain spaces, newlines or special characters.`),
Do("CLIENT", "setname", "abcd").JSON().OK(),
Do("CLIENT", "kill", "name", "abcd").Err("No such client"),
Do("CLIENT", "getname").Str(`abcd`),
Do("CLIENT", "kill").Err(`wrong number of arguments for 'client' command`),
Do("CLIENT", "kill", "").Err(`No such client`),
Do("CLIENT", "kill", "abcd").Err(`No such client`),
Do("CLIENT", "kill", "id", client13ID).OK(),
Do("CLIENT", "kill", "id").Err("wrong number of arguments for 'client' command"),
Do("CLIENT", "kill", client14Addr).OK(),
Do("CLIENT", "kill", client14Addr, "yikes").Err("wrong number of arguments for 'client' command"),
Do("CLIENT", "kill", "addr").Err("wrong number of arguments for 'client' command"),
Do("CLIENT", "kill", "addr", client15Addr).JSON().OK(),
Do("CLIENT", "kill", "addr", client14Addr, "yikes").Err("wrong number of arguments for 'client' command"),
Do("CLIENT", "kill", "id", "1000").Err("No such client"),
)
}

View File

@ -3,9 +3,6 @@ package tests
import (
"errors"
"fmt"
"go/ast"
"go/parser"
"go/token"
"os"
"path/filepath"
"runtime"
@ -21,10 +18,13 @@ type IO struct {
out any
sleep bool
dur time.Duration
cfile string
cln int
}
func Do(args ...any) *IO {
return &IO{args: args}
_, cfile, cln, _ := runtime.Caller(1)
return &IO{args: args, cfile: cfile, cln: cln}
}
func (cmd *IO) JSON() *IO {
cmd.json = true
@ -83,136 +83,29 @@ func Sleep(duration time.Duration) *IO {
return &IO{sleep: true, dur: duration}
}
type ioVisitor struct {
fset *token.FileSet
ln int
pos int
got bool
data string
end int
done bool
index int
nidx int
frag string
fpos int
}
func (v *ioVisitor) Visit(n ast.Node) ast.Visitor {
if n == nil || v.done {
return nil
}
if v.got {
if int(n.Pos()) > v.end {
v.done = true
return v
}
if n, ok := n.(*ast.CallExpr); ok {
frag := strings.TrimSpace(v.data[int(n.Pos())-1 : int(n.End())])
if _, ok := n.Fun.(*ast.Ident); ok {
if v.index == v.nidx {
frag = strings.TrimSpace(strings.TrimSuffix(frag, "."))
idx := strings.IndexByte(frag, '(')
if idx != -1 {
frag = frag[idx:]
}
v.frag = frag
v.done = true
v.fpos = int(n.Pos())
return v
}
v.nidx++
}
}
return v
}
if int(n.Pos()) == v.pos {
if n, ok := n.(*ast.CallExpr); ok {
v.end = int(n.Rparen)
v.got = true
return v
}
}
return v
}
func (cmd *IO) deepError(index int, err error) error {
oerr := err
werr := func(err error) error {
return fmt.Errorf("batch[%d]: %v: %v", index, oerr, err)
}
// analyse stack
_, file, ln, ok := runtime.Caller(3)
if !ok {
return werr(errors.New("runtime.Caller failed"))
}
// get the character position from line
bdata, err := os.ReadFile(file)
if err != nil {
return werr(err)
}
frag := "(?)"
bdata, _ := os.ReadFile(cmd.cfile)
data := string(bdata)
var pos int
var iln int
var pln int
ln := 1
for i := 0; i < len(data); i++ {
if data[i] == '\n' {
j := pln
line := data[pln:i]
pln = i + 1
iln++
if iln == ln {
line = strings.TrimSpace(line)
if !strings.HasPrefix(line, "return mc.DoBatch(") {
return oerr
}
for ; j < len(data); j++ {
if data[j] == 'm' {
break
ln++
if ln == cmd.cln {
data = data[i+1:]
i = strings.IndexByte(data, '(')
if i != -1 {
j := strings.IndexByte(data[i:], ')')
if j != -1 {
frag = string(data[i : j+i+1])
}
}
pos = j + 1
break
}
}
}
if pos == 0 {
return oerr
}
fset := token.NewFileSet()
pfile, err := parser.ParseFile(fset, file, nil, 0)
if err != nil {
return werr(err)
}
v := &ioVisitor{
fset: fset,
ln: ln,
pos: pos,
data: string(data),
index: index,
}
ast.Walk(v, pfile)
if v.fpos == 0 {
return oerr
}
pln = 1
for i := 0; i < len(data); i++ {
if data[i] == '\n' {
if i > v.fpos {
break
}
pln++
}
}
fsig := fmt.Sprintf("%s:%d", filepath.Base(file), pln)
emsg := oerr.Error()
fsig := fmt.Sprintf("%s:%d", filepath.Base(cmd.cfile), cmd.cln)
emsg := err.Error()
if strings.HasPrefix(emsg, "expected ") &&
strings.Contains(emsg, ", got ") {
emsg = "" +
@ -223,9 +116,8 @@ func (cmd *IO) deepError(index int, err error) error {
emsg = "" +
" ERROR: " + emsg
}
return fmt.Errorf("\n%s: entry[%d]\n COMMAND: %s\n%s",
fsig, index+1, v.frag, emsg)
fsig, index+1, frag, emsg)
}
func (mc *mockServer) doIOTest(index int, cmd *IO) error {