forked from mirror/ledisdb
361 lines
6.5 KiB
Go
361 lines
6.5 KiB
Go
package server
|
|
|
|
import (
|
|
"bufio"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"runtime"
|
|
"strconv"
|
|
"syscall"
|
|
"time"
|
|
|
|
"git.internal/re/ledisdb/ledis"
|
|
"github.com/siddontang/go/hack"
|
|
"github.com/siddontang/go/log"
|
|
"github.com/siddontang/go/num"
|
|
"github.com/siddontang/goredis"
|
|
)
|
|
|
|
var errReadRequest = errors.New("invalid request protocol")
|
|
var errClientQuit = errors.New("remote client quit")
|
|
|
|
type respClient struct {
|
|
*client
|
|
|
|
conn net.Conn
|
|
|
|
respReader *goredis.RespReader
|
|
|
|
activeQuit bool
|
|
}
|
|
|
|
type respWriter struct {
|
|
buff *bufio.Writer
|
|
}
|
|
|
|
func (app *App) addRespClient(c *respClient) {
|
|
app.rcm.Lock()
|
|
app.rcs[c] = struct{}{}
|
|
app.rcm.Unlock()
|
|
}
|
|
|
|
func (app *App) delRespClient(c *respClient) {
|
|
app.rcm.Lock()
|
|
delete(app.rcs, c)
|
|
app.rcm.Unlock()
|
|
}
|
|
|
|
func (app *App) closeAllRespClients() {
|
|
app.rcm.Lock()
|
|
|
|
for c := range app.rcs {
|
|
c.conn.Close()
|
|
}
|
|
|
|
app.rcm.Unlock()
|
|
}
|
|
|
|
func (app *App) respClientNum() int {
|
|
app.rcm.Lock()
|
|
n := len(app.rcs)
|
|
app.rcm.Unlock()
|
|
return n
|
|
}
|
|
|
|
func newClientRESP(conn net.Conn, app *App) {
|
|
c := new(respClient)
|
|
|
|
c.client = newClient(app)
|
|
c.conn = conn
|
|
|
|
c.activeQuit = false
|
|
|
|
if tcpConn, ok := conn.(*net.TCPConn); ok {
|
|
tcpConn.SetReadBuffer(app.cfg.ConnReadBufferSize)
|
|
tcpConn.SetWriteBuffer(app.cfg.ConnWriteBufferSize)
|
|
}
|
|
|
|
br := bufio.NewReaderSize(conn, app.cfg.ConnReadBufferSize)
|
|
c.respReader = goredis.NewRespReader(br)
|
|
|
|
c.resp = newWriterRESP(conn, app.cfg.ConnWriteBufferSize)
|
|
c.remoteAddr = conn.RemoteAddr().String()
|
|
|
|
app.connWait.Add(1)
|
|
|
|
app.addRespClient(c)
|
|
|
|
go c.run()
|
|
}
|
|
|
|
func (c *respClient) run() {
|
|
defer func() {
|
|
if e := recover(); e != nil {
|
|
buf := make([]byte, 4096)
|
|
n := runtime.Stack(buf, false)
|
|
buf = buf[0:n]
|
|
|
|
log.Fatalf("client run panic %s:%v", buf, e)
|
|
}
|
|
|
|
c.client.close()
|
|
|
|
c.conn.Close()
|
|
|
|
// if c.tx != nil {
|
|
// c.tx.Rollback()
|
|
// c.tx = nil
|
|
// }
|
|
|
|
c.app.removeSlave(c.client, c.activeQuit)
|
|
|
|
c.app.delRespClient(c)
|
|
|
|
c.app.connWait.Done()
|
|
}()
|
|
|
|
select {
|
|
case <-c.app.quit:
|
|
//check app closed
|
|
return
|
|
default:
|
|
break
|
|
}
|
|
|
|
kc := time.Duration(c.app.cfg.ConnKeepaliveInterval) * time.Second
|
|
for {
|
|
if kc > 0 {
|
|
c.conn.SetReadDeadline(time.Now().Add(kc))
|
|
}
|
|
|
|
c.cmd = ""
|
|
c.args = nil
|
|
|
|
reqData, err := c.respReader.ParseRequest()
|
|
if err == nil {
|
|
err = c.handleRequest(reqData)
|
|
}
|
|
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *respClient) handleRequest(reqData [][]byte) error {
|
|
if len(reqData) == 0 {
|
|
c.cmd = ""
|
|
c.args = reqData[0:0]
|
|
} else {
|
|
c.cmd = hack.String(lowerSlice(reqData[0]))
|
|
c.args = reqData[1:]
|
|
}
|
|
|
|
if c.cmd == "xselect" {
|
|
err := c.handleXSelectCmd()
|
|
if err != nil {
|
|
c.resp.writeError(err)
|
|
c.resp.flush()
|
|
return nil
|
|
}
|
|
}
|
|
|
|
if c.cmd == "quit" {
|
|
c.activeQuit = true
|
|
c.resp.writeStatus(OK)
|
|
c.resp.flush()
|
|
c.conn.Close()
|
|
return errClientQuit
|
|
} else if c.cmd == "shutdown" {
|
|
c.conn.Close()
|
|
|
|
// send kill signal
|
|
|
|
p, _ := os.FindProcess(os.Getpid())
|
|
|
|
p.Signal(syscall.SIGTERM)
|
|
p.Signal(os.Interrupt)
|
|
|
|
return errClientQuit
|
|
}
|
|
|
|
c.perform()
|
|
|
|
return nil
|
|
}
|
|
|
|
// XSELECT db THEN command
|
|
func (c *respClient) handleXSelectCmd() error {
|
|
if len(c.args) <= 2 {
|
|
// invalid command format
|
|
return fmt.Errorf("invalid format for XSELECT, must XSELECT db THEN your command")
|
|
}
|
|
|
|
if hack.String(upperSlice(c.args[1])) != "THEN" {
|
|
// invalid command format, just resturn here
|
|
return fmt.Errorf("invalid format for XSELECT, must XSELECT db THEN your command")
|
|
}
|
|
|
|
index, err := strconv.Atoi(hack.String(c.args[0]))
|
|
if err != nil {
|
|
return fmt.Errorf("invalid db for XSELECT, err %v", err)
|
|
}
|
|
|
|
db, err := c.app.ldb.Select(index)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid db for XSELECT, err %v", err)
|
|
}
|
|
|
|
c.db = db
|
|
|
|
c.cmd = hack.String(lowerSlice(c.args[2]))
|
|
c.args = c.args[3:]
|
|
|
|
return nil
|
|
}
|
|
|
|
// response writer
|
|
|
|
func newWriterRESP(conn net.Conn, size int) *respWriter {
|
|
w := new(respWriter)
|
|
w.buff = bufio.NewWriterSize(conn, size)
|
|
return w
|
|
}
|
|
|
|
func (w *respWriter) writeError(err error) {
|
|
w.buff.Write(hack.Slice("-"))
|
|
if err != nil {
|
|
w.buff.Write(hack.Slice(err.Error()))
|
|
}
|
|
w.buff.Write(Delims)
|
|
}
|
|
|
|
func (w *respWriter) writeStatus(status string) {
|
|
w.buff.WriteByte('+')
|
|
w.buff.Write(hack.Slice(status))
|
|
w.buff.Write(Delims)
|
|
}
|
|
|
|
func (w *respWriter) writeInteger(n int64) {
|
|
w.buff.WriteByte(':')
|
|
w.buff.Write(num.FormatInt64ToSlice(n))
|
|
w.buff.Write(Delims)
|
|
}
|
|
|
|
func (w *respWriter) writeBulk(b []byte) {
|
|
w.buff.WriteByte('$')
|
|
if b == nil {
|
|
w.buff.Write(NullBulk)
|
|
} else {
|
|
w.buff.Write(hack.Slice(strconv.Itoa(len(b))))
|
|
w.buff.Write(Delims)
|
|
w.buff.Write(b)
|
|
}
|
|
|
|
w.buff.Write(Delims)
|
|
}
|
|
|
|
func (w *respWriter) writeArray(lst []interface{}) {
|
|
w.buff.WriteByte('*')
|
|
if lst == nil {
|
|
w.buff.Write(NullArray)
|
|
w.buff.Write(Delims)
|
|
} else {
|
|
w.buff.Write(hack.Slice(strconv.Itoa(len(lst))))
|
|
w.buff.Write(Delims)
|
|
|
|
for i := 0; i < len(lst); i++ {
|
|
switch v := lst[i].(type) {
|
|
case []interface{}:
|
|
w.writeArray(v)
|
|
case [][]byte:
|
|
w.writeSliceArray(v)
|
|
case []byte:
|
|
w.writeBulk(v)
|
|
case nil:
|
|
w.writeBulk(nil)
|
|
case int64:
|
|
w.writeInteger(v)
|
|
case string:
|
|
w.writeStatus(v)
|
|
case error:
|
|
w.writeError(v)
|
|
default:
|
|
panic(fmt.Sprintf("invalid array type %T %v", lst[i], v))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (w *respWriter) writeSliceArray(lst [][]byte) {
|
|
w.buff.WriteByte('*')
|
|
if lst == nil {
|
|
w.buff.Write(NullArray)
|
|
w.buff.Write(Delims)
|
|
} else {
|
|
w.buff.Write(hack.Slice(strconv.Itoa(len(lst))))
|
|
w.buff.Write(Delims)
|
|
|
|
for i := 0; i < len(lst); i++ {
|
|
w.writeBulk(lst[i])
|
|
}
|
|
}
|
|
}
|
|
|
|
func (w *respWriter) writeFVPairArray(lst []ledis.FVPair) {
|
|
w.buff.WriteByte('*')
|
|
if lst == nil {
|
|
w.buff.Write(NullArray)
|
|
w.buff.Write(Delims)
|
|
} else {
|
|
w.buff.Write(hack.Slice(strconv.Itoa(len(lst) * 2)))
|
|
w.buff.Write(Delims)
|
|
|
|
for i := 0; i < len(lst); i++ {
|
|
w.writeBulk(lst[i].Field)
|
|
w.writeBulk(lst[i].Value)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (w *respWriter) writeScorePairArray(lst []ledis.ScorePair, withScores bool) {
|
|
w.buff.WriteByte('*')
|
|
if lst == nil {
|
|
w.buff.Write(NullArray)
|
|
w.buff.Write(Delims)
|
|
} else {
|
|
if withScores {
|
|
w.buff.Write(hack.Slice(strconv.Itoa(len(lst) * 2)))
|
|
w.buff.Write(Delims)
|
|
} else {
|
|
w.buff.Write(hack.Slice(strconv.Itoa(len(lst))))
|
|
w.buff.Write(Delims)
|
|
|
|
}
|
|
|
|
for i := 0; i < len(lst); i++ {
|
|
w.writeBulk(lst[i].Member)
|
|
|
|
if withScores {
|
|
w.writeBulk(num.FormatInt64ToSlice(lst[i].Score))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (w *respWriter) writeBulkFrom(n int64, rb io.Reader) {
|
|
w.buff.WriteByte('$')
|
|
w.buff.Write(hack.Slice(strconv.FormatInt(n, 10)))
|
|
w.buff.Write(Delims)
|
|
|
|
io.Copy(w.buff, rb)
|
|
w.buff.Write(Delims)
|
|
}
|
|
|
|
func (w *respWriter) flush() {
|
|
w.buff.Flush()
|
|
}
|