tile38/internal/controller/client.go

251 lines
5.9 KiB
Go

package controller
import (
"errors"
"fmt"
"net"
"sort"
"strings"
"time"
"github.com/tidwall/resp"
"github.com/tidwall/tile38/internal/server"
)
// Conn represents a simple resp connection.
type Conn struct {
conn net.Conn
rd *resp.Reader
wr *resp.Writer
}
type clientConn struct {
id int
name astring
opened atime
last atime
conn *server.Conn
}
// DialTimeout dials a resp server.
func DialTimeout(address string, timeout time.Duration) (*Conn, error) {
tcpconn, err := net.DialTimeout("tcp", address, timeout)
if err != nil {
return nil, err
}
conn := &Conn{
conn: tcpconn,
rd: resp.NewReader(tcpconn),
wr: resp.NewWriter(tcpconn),
}
return conn, nil
}
// Close closes the connection.
func (conn *Conn) Close() error {
conn.wr.WriteMultiBulk("quit")
return conn.conn.Close()
}
// Do performs a command and returns a resp value.
func (conn *Conn) Do(commandName string, args ...interface{}) (val resp.Value, err error) {
if err := conn.wr.WriteMultiBulk(commandName, args...); err != nil {
return val, err
}
val, _, err = conn.rd.ReadValue()
return val, err
}
type byID []*clientConn
func (arr byID) Len() int {
return len(arr)
}
func (arr byID) Less(a, b int) bool {
return arr[a].id < arr[b].id
}
func (arr byID) Swap(a, b int) {
arr[a], arr[b] = arr[b], arr[a]
}
func (c *Controller) cmdClient(msg *server.Message, conn *server.Conn) (resp.Value, error) {
start := time.Now()
if len(msg.Values) == 1 {
return server.NOMessage, errInvalidNumberOfArguments
}
switch strings.ToLower(msg.Values[1].String()) {
default:
return server.NOMessage, errors.New("Syntax error, try CLIENT " +
"(LIST | KILL | GETNAME | SETNAME)")
case "list":
if len(msg.Values) != 2 {
return server.NOMessage, errInvalidNumberOfArguments
}
var list []*clientConn
c.connsmu.RLock()
for _, cc := range c.conns {
list = append(list, cc)
}
c.connsmu.RUnlock()
sort.Sort(byID(list))
now := time.Now()
var buf []byte
for _, cc := range list {
buf = append(buf,
fmt.Sprintf("id=%d addr=%s name=%s age=%d idle=%d\n",
cc.id, cc.conn.RemoteAddr().String(), cc.name.get(),
now.Sub(cc.opened.get())/time.Second,
now.Sub(cc.last.get())/time.Second,
)...,
)
}
switch msg.OutputType {
case server.JSON:
return resp.StringValue(`{"ok":true,"list":` + jsonString(string(buf)) + `,"elapsed":"` + time.Now().Sub(start).String() + "\"}"), nil
case server.RESP:
return resp.BytesValue(buf), nil
}
return server.NOMessage, nil
case "getname":
if len(msg.Values) != 2 {
return server.NOMessage, errInvalidNumberOfArguments
}
name := ""
c.connsmu.RLock()
if cc, ok := c.conns[conn]; ok {
name = cc.name.get()
}
c.connsmu.RUnlock()
switch msg.OutputType {
case server.JSON:
return resp.StringValue(`{"ok":true,"name":` + jsonString(name) + `,"elapsed":"` + time.Now().Sub(start).String() + "\"}"), nil
case server.RESP:
return resp.StringValue(name), nil
}
case "setname":
if len(msg.Values) != 3 {
return server.NOMessage, errInvalidNumberOfArguments
}
name := msg.Values[2].String()
for i := 0; i < len(name); i++ {
if name[i] < '!' || name[i] > '~' {
errstr := "Client names cannot contain spaces, newlines or special characters."
return server.NOMessage, errors.New(errstr)
}
}
c.connsmu.RLock()
if cc, ok := c.conns[conn]; ok {
cc.name.set(name)
}
c.connsmu.RUnlock()
switch msg.OutputType {
case server.JSON:
return resp.StringValue(`{"ok":true,"elapsed":"` + time.Now().Sub(start).String() + "\"}"), nil
case server.RESP:
return resp.SimpleStringValue("OK"), nil
}
case "kill":
if len(msg.Values) < 3 {
return server.NOMessage, errInvalidNumberOfArguments
}
var useAddr bool
var addr string
var useID bool
var id string
for i := 2; i < len(msg.Values); i++ {
arg := msg.Values[i].String()
if strings.Contains(arg, ":") {
addr = arg
useAddr = true
break
}
switch strings.ToLower(arg) {
default:
return server.NOMessage, errors.New("No such client")
case "addr":
i++
if i == len(msg.Values) {
return server.NOMessage, errors.New("syntax error")
}
addr = msg.Values[i].String()
useAddr = true
case "id":
i++
if i == len(msg.Values) {
return server.NOMessage, errors.New("syntax error")
}
id = msg.Values[i].String()
useID = true
}
}
var cclose *clientConn
c.connsmu.RLock()
for _, cc := range c.conns {
if useID && fmt.Sprintf("%d", cc.id) == id {
cclose = cc
break
} else if useAddr && cc.conn.RemoteAddr().String() == addr {
cclose = cc
break
}
}
c.connsmu.RUnlock()
if cclose == nil {
return server.NOMessage, errors.New("No such client")
}
var res resp.Value
switch msg.OutputType {
case server.JSON:
res = resp.StringValue(`{"ok":true,"elapsed":"` + time.Now().Sub(start).String() + "\"}")
case server.RESP:
res = resp.SimpleStringValue("OK")
}
if cclose.conn == conn {
// closing self, return response now
// NOTE: This is the only exception where we do convert response to a string
var outBytes []byte
switch msg.OutputType {
case server.JSON:
outBytes = res.Bytes()
case server.RESP:
outBytes, _ = res.MarshalRESP()
}
cclose.conn.Write(outBytes)
}
cclose.conn.Close()
return res, nil
}
return server.NOMessage, errors.New("invalid output type")
}
/*
func (c *Controller) cmdClientList(msg *server.Message) (string, error) {
var ok bool
var key string
if vs, key, ok = tokenval(vs); !ok || key == "" {
return "", errInvalidNumberOfArguments
}
col := c.getCol(key)
if col == nil {
if msg.OutputType == server.RESP {
return "+none\r\n", nil
}
return "", errKeyNotFound
}
typ := "hash"
switch msg.OutputType {
case server.JSON:
return `{"ok":true,"type":` + string(typ) + `,"elapsed":"` + time.Now().Sub(start).String() + "\"}", nil
case server.RESP:
return "+" + typ + "\r\n", nil
}
return "", nil
}
*/