diff --git a/internal/server/client.go b/internal/server/client.go index 71881326..67fedf75 100644 --- a/internal/server/client.go +++ b/internal/server/client.go @@ -42,10 +42,10 @@ func (client *Client) Write(b []byte) (n int, err error) { } // CLIENT (LIST | KILL | GETNAME | SETNAME) -func (s *Server) cmdCLIENT(_msg *Message, client *Client) (resp.Value, error) { +func (s *Server) cmdCLIENT(msg *Message, client *Client) (resp.Value, error) { start := time.Now() - args := _msg.Args + args := msg.Args if len(args) == 1 { return retrerr(errInvalidNumberOfArguments) } @@ -79,7 +79,7 @@ func (s *Server) cmdCLIENT(_msg *Message, client *Client) (resp.Value, error) { ) client.mu.Unlock() } - if _msg.OutputType == JSON { + if msg.OutputType == JSON { // Create a map of all key/value info fields var cmap []map[string]interface{} clients := strings.Split(string(buf), "\n") @@ -111,7 +111,7 @@ func (s *Server) cmdCLIENT(_msg *Message, client *Client) (resp.Value, error) { client.mu.Lock() name := client.name client.mu.Unlock() - if _msg.OutputType == JSON { + if msg.OutputType == JSON { return resp.StringValue(`{"ok":true,"name":` + jsonString(name) + `,"elapsed":"` + time.Since(start).String() + "\"}"), nil } @@ -120,7 +120,7 @@ func (s *Server) cmdCLIENT(_msg *Message, client *Client) (resp.Value, error) { if len(args) != 3 { return retrerr(errInvalidNumberOfArguments) } - name := _msg.Args[2] + name := msg.Args[2] for i := 0; i < len(name); i++ { if name[i] < '!' || name[i] > '~' { return retrerr(clientErrorf( @@ -131,7 +131,7 @@ func (s *Server) cmdCLIENT(_msg *Message, client *Client) (resp.Value, error) { client.mu.Lock() client.name = name client.mu.Unlock() - if _msg.OutputType == JSON { + if msg.OutputType == JSON { return resp.StringValue(`{"ok":true,"elapsed":"` + time.Since(start).String() + "\"}"), nil } @@ -198,7 +198,7 @@ func (s *Server) cmdCLIENT(_msg *Message, client *Client) (resp.Value, error) { closer.Close() } // }() - if _msg.OutputType == JSON { + if msg.OutputType == JSON { return resp.StringValue(`{"ok":true,"elapsed":"` + time.Since(start).String() + "\"}"), nil } diff --git a/internal/server/server.go b/internal/server/server.go index b51ded5e..f3093320 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1081,7 +1081,7 @@ func (s *Server) command(msg *Message, client *Client) ( case "healthz": res, err = s.cmdHEALTHZ(msg) case "info": - res, err = s.cmdInfo(msg) + res, err = s.cmdINFO(msg) case "scan": res, err = s.cmdScan(msg) case "nearby": diff --git a/internal/server/stats.go b/internal/server/stats.go index 12187016..c0ab0bc4 100644 --- a/internal/server/stats.go +++ b/internal/server/stats.go @@ -457,27 +457,52 @@ func (s *Server) writeInfoCluster(w *bytes.Buffer) { fmt.Fprintf(w, "cluster_enabled:0\r\n") } -func (s *Server) cmdInfo(msg *Message) (res resp.Value, err error) { +// INFO [section ...] +func (s *Server) cmdINFO(msg *Message) (res resp.Value, err error) { start := time.Now() - sections := []string{"server", "clients", "memory", "persistence", "stats", "replication", "cpu", "cluster", "keyspace"} - switch len(msg.Args) { - default: - return NOMessage, errInvalidNumberOfArguments - case 1: - case 2: - section := strings.ToLower(msg.Args[1]) + // >> Args + + args := msg.Args + + msects := make(map[string]bool) + allsects := []string{ + "server", "clients", "memory", "persistence", "stats", + "replication", "cpu", "cluster", "keyspace", + } + + if len(args) == 1 { + for _, s := range allsects { + msects[s] = true + } + } + for i := 1; i < len(args); i++ { + section := strings.ToLower(args[i]) switch section { + case "all", "default": + for _, s := range allsects { + msects[s] = true + } default: - sections = []string{section} - case "all": - sections = []string{"server", "clients", "memory", "persistence", "stats", "replication", "cpu", "commandstats", "cluster", "keyspace"} - case "default": + for _, s := range allsects { + if s == section { + msects[section] = true + } + } + } + } + + // >> Operation + + var sects []string + for _, s := range allsects { + if msects[s] { + sects = append(sects, s) } } w := &bytes.Buffer{} - for i, section := range sections { + for i, section := range sects { if i > 0 { w.WriteString("\r\n") } @@ -511,8 +536,9 @@ func (s *Server) cmdInfo(msg *Message) (res resp.Value, err error) { } } - switch msg.OutputType { - case JSON: + // >> Response + + if msg.OutputType == JSON { // Create a map of all key/value info fields m := make(map[string]interface{}) for _, kv := range strings.Split(w.String(), "\r\n") { @@ -525,15 +551,11 @@ func (s *Server) cmdInfo(msg *Message) (res resp.Value, err error) { } // Marshal the map and use the output in the JSON response - data, err := json.Marshal(m) - if err != nil { - return NOMessage, err - } - res = resp.StringValue(`{"ok":true,"info":` + string(data) + `,"elapsed":"` + time.Since(start).String() + "\"}") - case RESP: - res = resp.BytesValue(w.Bytes()) + data, _ := json.Marshal(m) + return resp.StringValue(`{"ok":true,"info":` + string(data) + + `,"elapsed":"` + time.Since(start).String() + "\"}"), nil } - return res, nil + return resp.BytesValue(w.Bytes()), nil } // tryParseType attempts to parse the passed string as an integer, float64 and diff --git a/tests/keys_test.go b/tests/keys_test.go index 2bfec979..5b2de5bd 100644 --- a/tests/keys_test.go +++ b/tests/keys_test.go @@ -35,7 +35,7 @@ func subTestKeys(t *testing.T, mc *mockServer) { runStep(t, mc, "FLUSHDB", keys_FLUSHDB_test) runStep(t, mc, "HEALTHZ", keys_HEALTHZ_test) runStep(t, mc, "SERVER", keys_SERVER_test) - + runStep(t, mc, "INFO", keys_INFO_test) } func keys_BOUNDS_test(mc *mockServer) error { @@ -545,32 +545,9 @@ func keys_FLUSHDB_test(mc *mockServer) error { } func keys_HEALTHZ_test(mc *mockServer) error { - - // // follow and wait - // str, err := redis.String(mc.Do("FOLLOW", "localhost", mc.alt.port)) - // if err != nil { - // return err - // } - // if str != "OK" { - // return errors.New("not ok") - // } - // start := time.Now() - // for time.Since(start) < time.Second*5 { - // str, err = redis.String(mc.Do("HEALTHZ")) - // if str == "OK" { - // err = nil - // break - // } - // time.Sleep(time.Second / 4) - // } - // if err != nil { - // return err - // } - return mc.DoBatch( Do("HEALTHZ").OK(), Do("HEALTHZ").JSON().OK(), - // Do("FOLLOW", "no", "one").OK(), Do("HEALTHZ", "arg").Err(`wrong number of arguments for 'healthz' command`), ) } @@ -621,3 +598,51 @@ func keys_SERVER_test(mc *mockServer) error { Do("SERVER", "ett").JSON().Err(`invalid argument 'ett'`), ) } + +func keys_INFO_test(mc *mockServer) error { + return mc.DoBatch( + Do("INFO").Func(func(s string) error { + if !strings.Contains(s, "# Clients") || + !strings.Contains(s, "# Stats") { + return errors.New("looks invalid") + } + return nil + }), + Do("INFO", "all").Func(func(s string) error { + if !strings.Contains(s, "# Clients") || + !strings.Contains(s, "# Stats") { + return errors.New("looks invalid") + } + return nil + }), + Do("INFO", "default").Func(func(s string) error { + if !strings.Contains(s, "# Clients") || + !strings.Contains(s, "# Stats") { + return errors.New("looks invalid") + } + return nil + }), + Do("INFO", "cpu").Func(func(s string) error { + if !strings.Contains(s, "# CPU") || + strings.Contains(s, "# Clients") || + strings.Contains(s, "# Stats") { + return errors.New("looks invalid") + } + return nil + }), + Do("INFO", "cpu", "clients").Func(func(s string) error { + if !strings.Contains(s, "# CPU") || + !strings.Contains(s, "# Clients") || + strings.Contains(s, "# Stats") { + return errors.New("looks invalid") + } + return nil + }), + Do("INFO").JSON().Func(func(s string) error { + if gjson.Get(s, "info.tile38_version").String() == "" { + return errors.New("looks invalid") + } + return nil + }), + ) +}