diff --git a/server/app.go b/server/app.go index cc95575..781e097 100644 --- a/server/app.go +++ b/server/app.go @@ -123,7 +123,7 @@ func (app *App) Run() { continue } - newTcpClient(conn, app) + newClientRESP(conn, app) } } diff --git a/server/client.go b/server/client.go deleted file mode 100644 index 22200a1..0000000 --- a/server/client.go +++ /dev/null @@ -1,160 +0,0 @@ -package server - -import ( - "bytes" - "errors" - "github.com/siddontang/go-log/log" - "github.com/siddontang/ledisdb/ledis" - "io" - "net" - "runtime" - "strings" - "time" -) - -var errReadRequest = errors.New("invalid request protocol") - -type client struct { - app *App - ldb *ledis.Ledis - - db *ledis.DB - - ctx clientContext - resp responseWriter - req requestReader - - cmd string - args [][]byte - - reqC chan error - - compressBuf []byte - syncBuf bytes.Buffer - logBuf bytes.Buffer -} - -type clientContext interface { - addr() string - release() -} - -type requestReader interface { - // readLine func() ([]byte, error) - read() ([][]byte, error) -} - -type responseWriter interface { - writeError(error) - writeStatus(string) - writeInteger(int64) - writeBulk([]byte) - writeArray([]interface{}) - writeSliceArray([][]byte) - writeFVPairArray([]ledis.FVPair) - writeScorePairArray([]ledis.ScorePair, bool) - writeBulkFrom(int64, io.Reader) - flush() -} - -func newClient(app *App) *client { - c := new(client) - - c.app = app - c.ldb = app.ldb - c.db, _ = app.ldb.Select(0) - - c.reqC = make(chan error, 1) - - c.compressBuf = make([]byte, 256) - - return c -} - -func (c *client) run() { - defer func() { - if e := recover(); e != nil { - buf := make([]byte, 4096) - n := runtime.Stack(buf, false) - buf = buf[0:n] - - log.Fatal("client run panic %s:%v", buf, e) - } - - c.ctx.release() - }() - - for { - req, err := c.req.read() - if err != nil { - return - } - - c.handleRequest(req) - } -} - -func (c *client) handleRequest(req [][]byte) { - var err error - - start := time.Now() - - if len(req) == 0 { - err = ErrEmptyCommand - } else { - c.cmd = strings.ToLower(ledis.String(req[0])) - c.args = req[1:] - - f, ok := regCmds[c.cmd] - if !ok { - err = ErrNotFound - } else { - go func() { - c.reqC <- f(c) - }() - err = <-c.reqC - } - } - - duration := time.Since(start) - - if c.app.access != nil { - c.logBuf.Reset() - for i, r := range req { - left := 256 - c.logBuf.Len() - if left <= 0 { - break - } else if len(r) <= left { - c.logBuf.Write(r) - if i != len(req)-1 { - c.logBuf.WriteByte(' ') - } - } else { - c.logBuf.Write(r[0:left]) - } - } - - c.app.access.Log(c.ctx.addr(), duration.Nanoseconds()/1000000, c.logBuf.Bytes(), err) - } - - if err != nil { - c.resp.writeError(err) - } - - c.resp.flush() -} - -func newTcpClient(conn net.Conn, app *App) { - c := newClient(app) - - c.ctx = newTcpContext(conn) - c.req = newTcpReader(conn) - c.resp = newTcpWriter(conn) - - go c.run() -} - -// func newHttpClient(w http.ResponseWriter, r *http.Request, app *App) { -// c := newClient(app) -// go c.run() -// } diff --git a/server/tcp_io.go b/server/client_resp.go similarity index 58% rename from server/tcp_io.go rename to server/client_resp.go index 62095b9..4d0110a 100644 --- a/server/tcp_io.go +++ b/server/client_resp.go @@ -3,58 +3,81 @@ package server import ( "bufio" "errors" + "github.com/siddontang/go-log/log" "github.com/siddontang/ledisdb/ledis" "io" "net" + "runtime" "strconv" + "strings" ) -type tcpContext struct { +var errReadRequest = errors.New("invalid request protocol") + +type respClient struct { + app *App + ldb *ledis.Ledis + db *ledis.DB + conn net.Conn + rb *bufio.Reader + + req *requestContext + hdl *requestHandler } -type tcpWriter struct { +type respWriter struct { buff *bufio.Writer } -type tcpReader struct { - buff *bufio.Reader +func newClientRESP(conn net.Conn, app *App) { + c := new(respClient) + + c.app = app + c.conn = conn + c.ldb = app.ldb + c.db, _ = app.ldb.Select(0) + + c.rb = bufio.NewReaderSize(conn, 256) + + c.req = newRequestContext(app) + c.req.resp = newWriterRESP(conn) + + c.hdl = newRequestHandler(app) + + go c.run() } -// tcp context +func (c *respClient) run() { + defer func() { + if e := recover(); e != nil { + buf := make([]byte, 4096) + n := runtime.Stack(buf, false) + buf = buf[0:n] -func newTcpContext(conn net.Conn) *tcpContext { - ctx := new(tcpContext) - ctx.conn = conn - return ctx -} + log.Fatal("client run panic %s:%v", buf, e) + } -func (ctx *tcpContext) addr() string { - return ctx.conn.RemoteAddr().String() -} + c.conn.Close() + }() -func (ctx *tcpContext) release() { - if ctx.conn != nil { - ctx.conn.Close() - ctx.conn = nil + for { + reqData, err := c.readRequest() + if err != nil { + return + } + + c.handleRequest(reqData) } } -// tcp reader - -func newTcpReader(conn net.Conn) *tcpReader { - r := new(tcpReader) - r.buff = bufio.NewReaderSize(conn, 256) - return r -} - -func (r *tcpReader) readLine() ([]byte, error) { - return ReadLine(r.buff) +func (c *respClient) readLine() ([]byte, error) { + return ReadLine(c.rb) } //A client sends to the Redis server a RESP Array consisting of just Bulk Strings. -func (r *tcpReader) read() ([][]byte, error) { - l, err := r.readLine() +func (c *respClient) readRequest() ([][]byte, error) { + l, err := c.readLine() if err != nil { return nil, err } else if len(l) == 0 || l[0] != '*' { @@ -68,10 +91,10 @@ func (r *tcpReader) read() ([][]byte, error) { return nil, errReadRequest } - reqData := make([][]byte, 0, nparams) + req := make([][]byte, 0, nparams) var n int for i := 0; i < nparams; i++ { - if l, err = r.readLine(); err != nil { + if l, err = c.readLine(); err != nil { return nil, err } @@ -82,20 +105,20 @@ func (r *tcpReader) read() ([][]byte, error) { if n, err = strconv.Atoi(ledis.String(l[1:])); err != nil { return nil, err } else if n == -1 { - reqData = append(reqData, nil) + req = append(req, nil) } else { buf := make([]byte, n) - if _, err = io.ReadFull(r.buff, buf); err != nil { + if _, err = io.ReadFull(c.rb, buf); err != nil { return nil, err } - if l, err = r.readLine(); err != nil { + if l, err = c.readLine(); err != nil { return nil, err } else if len(l) != 0 { return nil, errors.New("bad bulk string format") } - reqData = append(reqData, buf) + req = append(req, buf) } @@ -104,18 +127,35 @@ func (r *tcpReader) read() ([][]byte, error) { } } - return reqData, nil + return req, nil } -// tcp writer +func (c *respClient) handleRequest(reqData [][]byte) { + req := c.req -func newTcpWriter(conn net.Conn) *tcpWriter { - w := new(tcpWriter) + req.db = c.db + req.remoteAddr = c.conn.RemoteAddr().String() + + if len(reqData) == 0 { + c.req.cmd = "" + c.req.args = reqData[0:0] + } else { + c.req.cmd = strings.ToLower(ledis.String(reqData[0])) + c.req.args = reqData[1:] + } + + c.hdl.postRequest(req) +} + +// response writer + +func newWriterRESP(conn net.Conn) *respWriter { + w := new(respWriter) w.buff = bufio.NewWriterSize(conn, 256) return w } -func (w *tcpWriter) writeError(err error) { +func (w *respWriter) writeError(err error) { w.buff.Write(ledis.Slice("-ERR")) if err != nil { w.buff.WriteByte(' ') @@ -124,19 +164,19 @@ func (w *tcpWriter) writeError(err error) { w.buff.Write(Delims) } -func (w *tcpWriter) writeStatus(status string) { +func (w *respWriter) writeStatus(status string) { w.buff.WriteByte('+') w.buff.Write(ledis.Slice(status)) w.buff.Write(Delims) } -func (w *tcpWriter) writeInteger(n int64) { +func (w *respWriter) writeInteger(n int64) { w.buff.WriteByte(':') w.buff.Write(ledis.StrPutInt64(n)) w.buff.Write(Delims) } -func (w *tcpWriter) writeBulk(b []byte) { +func (w *respWriter) writeBulk(b []byte) { w.buff.WriteByte('$') if b == nil { w.buff.Write(NullBulk) @@ -149,7 +189,7 @@ func (w *tcpWriter) writeBulk(b []byte) { w.buff.Write(Delims) } -func (w *tcpWriter) writeArray(lst []interface{}) { +func (w *respWriter) writeArray(lst []interface{}) { w.buff.WriteByte('*') if lst == nil { w.buff.Write(NullArray) @@ -175,7 +215,7 @@ func (w *tcpWriter) writeArray(lst []interface{}) { } } -func (w *tcpWriter) writeSliceArray(lst [][]byte) { +func (w *respWriter) writeSliceArray(lst [][]byte) { w.buff.WriteByte('*') if lst == nil { w.buff.Write(NullArray) @@ -190,7 +230,7 @@ func (w *tcpWriter) writeSliceArray(lst [][]byte) { } } -func (w *tcpWriter) writeFVPairArray(lst []ledis.FVPair) { +func (w *respWriter) writeFVPairArray(lst []ledis.FVPair) { w.buff.WriteByte('*') if lst == nil { w.buff.Write(NullArray) @@ -206,7 +246,7 @@ func (w *tcpWriter) writeFVPairArray(lst []ledis.FVPair) { } } -func (w *tcpWriter) writeScorePairArray(lst []ledis.ScorePair, withScores bool) { +func (w *respWriter) writeScorePairArray(lst []ledis.ScorePair, withScores bool) { w.buff.WriteByte('*') if lst == nil { w.buff.Write(NullArray) @@ -231,7 +271,7 @@ func (w *tcpWriter) writeScorePairArray(lst []ledis.ScorePair, withScores bool) } } -func (w *tcpWriter) writeBulkFrom(n int64, rb io.Reader) { +func (w *respWriter) writeBulkFrom(n int64, rb io.Reader) { w.buff.WriteByte('$') w.buff.Write(ledis.Slice(strconv.FormatInt(n, 10))) w.buff.Write(Delims) @@ -240,6 +280,6 @@ func (w *tcpWriter) writeBulkFrom(n int64, rb io.Reader) { w.buff.Write(Delims) } -func (w *tcpWriter) flush() { +func (w *respWriter) flush() { w.buff.Flush() } diff --git a/server/cmd_bit.go b/server/cmd_bit.go index ec887a2..d95e416 100644 --- a/server/cmd_bit.go +++ b/server/cmd_bit.go @@ -5,36 +5,36 @@ import ( "strings" ) -func bgetCommand(c *client) error { - args := c.args +func bgetCommand(req *requestContext) error { + args := req.args if len(args) != 1 { return ErrCmdParams } - if v, err := c.db.BGet(args[0]); err != nil { + if v, err := req.db.BGet(args[0]); err != nil { return err } else { - c.resp.writeBulk(v) + req.resp.writeBulk(v) } return nil } -func bdeleteCommand(c *client) error { - args := c.args +func bdeleteCommand(req *requestContext) error { + args := req.args if len(args) != 1 { return ErrCmdParams } - if n, err := c.db.BDelete(args[0]); err != nil { + if n, err := req.db.BDelete(args[0]); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func bsetbitCommand(c *client) error { - args := c.args +func bsetbitCommand(req *requestContext) error { + args := req.args if len(args) != 3 { return ErrCmdParams } @@ -53,16 +53,16 @@ func bsetbitCommand(c *client) error { return err } - if ori, err := c.db.BSetBit(args[0], offset, uint8(val)); err != nil { + if ori, err := req.db.BSetBit(args[0], offset, uint8(val)); err != nil { return err } else { - c.resp.writeInteger(int64(ori)) + req.resp.writeInteger(int64(ori)) } return nil } -func bgetbitCommand(c *client) error { - args := c.args +func bgetbitCommand(req *requestContext) error { + args := req.args if len(args) != 2 { return ErrCmdParams } @@ -72,16 +72,16 @@ func bgetbitCommand(c *client) error { return err } - if v, err := c.db.BGetBit(args[0], offset); err != nil { + if v, err := req.db.BGetBit(args[0], offset); err != nil { return err } else { - c.resp.writeInteger(int64(v)) + req.resp.writeInteger(int64(v)) } return nil } -func bmsetbitCommand(c *client) error { - args := c.args +func bmsetbitCommand(req *requestContext) error { + args := req.args if len(args) < 3 { return ErrCmdParams } @@ -113,16 +113,16 @@ func bmsetbitCommand(c *client) error { pairs[i].Val = uint8(val) } - if place, err := c.db.BMSetBit(key, pairs...); err != nil { + if place, err := req.db.BMSetBit(key, pairs...); err != nil { return err } else { - c.resp.writeInteger(place) + req.resp.writeInteger(place) } return nil } -func bcountCommand(c *client) error { - args := c.args +func bcountCommand(req *requestContext) error { + args := req.args argCnt := len(args) if !(argCnt > 0 && argCnt <= 3) { @@ -148,16 +148,16 @@ func bcountCommand(c *client) error { } } - if cnt, err := c.db.BCount(args[0], start, end); err != nil { + if cnt, err := req.db.BCount(args[0], start, end); err != nil { return err } else { - c.resp.writeInteger(int64(cnt)) + req.resp.writeInteger(int64(cnt)) } return nil } -func boptCommand(c *client) error { - args := c.args +func boptCommand(req *requestContext) error { + args := req.args if len(args) < 2 { return ErrCmdParams } @@ -180,16 +180,16 @@ func boptCommand(c *client) error { return ErrCmdParams } - if blen, err := c.db.BOperation(op, dstKey, srcKeys...); err != nil { + if blen, err := req.db.BOperation(op, dstKey, srcKeys...); err != nil { return err } else { - c.resp.writeInteger(int64(blen)) + req.resp.writeInteger(int64(blen)) } return nil } -func bexpireCommand(c *client) error { - args := c.args +func bexpireCommand(req *requestContext) error { + args := req.args if len(args) == 0 { return ErrCmdParams } @@ -199,17 +199,17 @@ func bexpireCommand(c *client) error { return err } - if v, err := c.db.BExpire(args[0], duration); err != nil { + if v, err := req.db.BExpire(args[0], duration); err != nil { return err } else { - c.resp.writeInteger(v) + req.resp.writeInteger(v) } return nil } -func bexpireatCommand(c *client) error { - args := c.args +func bexpireatCommand(req *requestContext) error { + args := req.args if len(args) == 0 { return ErrCmdParams } @@ -219,40 +219,40 @@ func bexpireatCommand(c *client) error { return err } - if v, err := c.db.BExpireAt(args[0], when); err != nil { + if v, err := req.db.BExpireAt(args[0], when); err != nil { return err } else { - c.resp.writeInteger(v) + req.resp.writeInteger(v) } return nil } -func bttlCommand(c *client) error { - args := c.args +func bttlCommand(req *requestContext) error { + args := req.args if len(args) == 0 { return ErrCmdParams } - if v, err := c.db.BTTL(args[0]); err != nil { + if v, err := req.db.BTTL(args[0]); err != nil { return err } else { - c.resp.writeInteger(v) + req.resp.writeInteger(v) } return nil } -func bpersistCommand(c *client) error { - args := c.args +func bpersistCommand(req *requestContext) error { + args := req.args if len(args) != 1 { return ErrCmdParams } - if n, err := c.db.BPersist(args[0]); err != nil { + if n, err := req.db.BPersist(args[0]); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil diff --git a/server/cmd_hash.go b/server/cmd_hash.go index 0a736fc..e6dc433 100644 --- a/server/cmd_hash.go +++ b/server/cmd_hash.go @@ -4,87 +4,87 @@ import ( "github.com/siddontang/ledisdb/ledis" ) -func hsetCommand(c *client) error { - args := c.args +func hsetCommand(req *requestContext) error { + args := req.args if len(args) != 3 { return ErrCmdParams } - if n, err := c.db.HSet(args[0], args[1], args[2]); err != nil { + if n, err := req.db.HSet(args[0], args[1], args[2]); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func hgetCommand(c *client) error { - args := c.args +func hgetCommand(req *requestContext) error { + args := req.args if len(args) != 2 { return ErrCmdParams } - if v, err := c.db.HGet(args[0], args[1]); err != nil { + if v, err := req.db.HGet(args[0], args[1]); err != nil { return err } else { - c.resp.writeBulk(v) + req.resp.writeBulk(v) } return nil } -func hexistsCommand(c *client) error { - args := c.args +func hexistsCommand(req *requestContext) error { + args := req.args if len(args) != 2 { return ErrCmdParams } var n int64 = 1 - if v, err := c.db.HGet(args[0], args[1]); err != nil { + if v, err := req.db.HGet(args[0], args[1]); err != nil { return err } else { if v == nil { n = 0 } - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func hdelCommand(c *client) error { - args := c.args +func hdelCommand(req *requestContext) error { + args := req.args if len(args) < 2 { return ErrCmdParams } - if n, err := c.db.HDel(args[0], args[1:]...); err != nil { + if n, err := req.db.HDel(args[0], args[1:]...); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func hlenCommand(c *client) error { - args := c.args +func hlenCommand(req *requestContext) error { + args := req.args if len(args) != 1 { return ErrCmdParams } - if n, err := c.db.HLen(args[0]); err != nil { + if n, err := req.db.HLen(args[0]); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func hincrbyCommand(c *client) error { - args := c.args +func hincrbyCommand(req *requestContext) error { + args := req.args if len(args) != 3 { return ErrCmdParams } @@ -95,16 +95,16 @@ func hincrbyCommand(c *client) error { } var n int64 - if n, err = c.db.HIncrBy(args[0], args[1], delta); err != nil { + if n, err = req.db.HIncrBy(args[0], args[1], delta); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func hmsetCommand(c *client) error { - args := c.args +func hmsetCommand(req *requestContext) error { + args := req.args if len(args) < 3 { return ErrCmdParams } @@ -123,107 +123,107 @@ func hmsetCommand(c *client) error { kvs[i].Value = args[2*i+1] } - if err := c.db.HMset(key, kvs...); err != nil { + if err := req.db.HMset(key, kvs...); err != nil { return err } else { - c.resp.writeStatus(OK) + req.resp.writeStatus(OK) } return nil } -func hmgetCommand(c *client) error { - args := c.args +func hmgetCommand(req *requestContext) error { + args := req.args if len(args) < 2 { return ErrCmdParams } - if v, err := c.db.HMget(args[0], args[1:]...); err != nil { + if v, err := req.db.HMget(args[0], args[1:]...); err != nil { return err } else { - c.resp.writeSliceArray(v) + req.resp.writeSliceArray(v) } return nil } -func hgetallCommand(c *client) error { - args := c.args +func hgetallCommand(req *requestContext) error { + args := req.args if len(args) != 1 { return ErrCmdParams } - if v, err := c.db.HGetAll(args[0]); err != nil { + if v, err := req.db.HGetAll(args[0]); err != nil { return err } else { - c.resp.writeFVPairArray(v) + req.resp.writeFVPairArray(v) } return nil } -func hkeysCommand(c *client) error { - args := c.args +func hkeysCommand(req *requestContext) error { + args := req.args if len(args) != 1 { return ErrCmdParams } - if v, err := c.db.HKeys(args[0]); err != nil { + if v, err := req.db.HKeys(args[0]); err != nil { return err } else { - c.resp.writeSliceArray(v) + req.resp.writeSliceArray(v) } return nil } -func hvalsCommand(c *client) error { - args := c.args +func hvalsCommand(req *requestContext) error { + args := req.args if len(args) != 1 { return ErrCmdParams } - if v, err := c.db.HValues(args[0]); err != nil { + if v, err := req.db.HValues(args[0]); err != nil { return err } else { - c.resp.writeSliceArray(v) + req.resp.writeSliceArray(v) } return nil } -func hclearCommand(c *client) error { - args := c.args +func hclearCommand(req *requestContext) error { + args := req.args if len(args) != 1 { return ErrCmdParams } - if n, err := c.db.HClear(args[0]); err != nil { + if n, err := req.db.HClear(args[0]); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func hmclearCommand(c *client) error { - args := c.args +func hmclearCommand(req *requestContext) error { + args := req.args if len(args) < 1 { return ErrCmdParams } - if n, err := c.db.HMclear(args...); err != nil { + if n, err := req.db.HMclear(args...); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func hexpireCommand(c *client) error { - args := c.args +func hexpireCommand(req *requestContext) error { + args := req.args if len(args) != 2 { return ErrCmdParams } @@ -233,17 +233,17 @@ func hexpireCommand(c *client) error { return err } - if v, err := c.db.HExpire(args[0], duration); err != nil { + if v, err := req.db.HExpire(args[0], duration); err != nil { return err } else { - c.resp.writeInteger(v) + req.resp.writeInteger(v) } return nil } -func hexpireAtCommand(c *client) error { - args := c.args +func hexpireAtCommand(req *requestContext) error { + args := req.args if len(args) != 2 { return ErrCmdParams } @@ -253,40 +253,40 @@ func hexpireAtCommand(c *client) error { return err } - if v, err := c.db.HExpireAt(args[0], when); err != nil { + if v, err := req.db.HExpireAt(args[0], when); err != nil { return err } else { - c.resp.writeInteger(v) + req.resp.writeInteger(v) } return nil } -func httlCommand(c *client) error { - args := c.args +func httlCommand(req *requestContext) error { + args := req.args if len(args) != 1 { return ErrCmdParams } - if v, err := c.db.HTTL(args[0]); err != nil { + if v, err := req.db.HTTL(args[0]); err != nil { return err } else { - c.resp.writeInteger(v) + req.resp.writeInteger(v) } return nil } -func hpersistCommand(c *client) error { - args := c.args +func hpersistCommand(req *requestContext) error { + args := req.args if len(args) != 1 { return ErrCmdParams } - if n, err := c.db.HPersist(args[0]); err != nil { + if n, err := req.db.HPersist(args[0]); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil diff --git a/server/cmd_kv.go b/server/cmd_kv.go index 4462c83..9ac922c 100644 --- a/server/cmd_kv.go +++ b/server/cmd_kv.go @@ -4,112 +4,112 @@ import ( "github.com/siddontang/ledisdb/ledis" ) -func getCommand(c *client) error { - args := c.args +func getCommand(req *requestContext) error { + args := req.args if len(args) != 1 { return ErrCmdParams } - if v, err := c.db.Get(args[0]); err != nil { + if v, err := req.db.Get(args[0]); err != nil { return err } else { - c.resp.writeBulk(v) + req.resp.writeBulk(v) } return nil } -func setCommand(c *client) error { - args := c.args +func setCommand(req *requestContext) error { + args := req.args if len(args) != 2 { return ErrCmdParams } - if err := c.db.Set(args[0], args[1]); err != nil { + if err := req.db.Set(args[0], args[1]); err != nil { return err } else { - c.resp.writeStatus(OK) + req.resp.writeStatus(OK) } return nil } -func getsetCommand(c *client) error { - args := c.args +func getsetCommand(req *requestContext) error { + args := req.args if len(args) != 2 { return ErrCmdParams } - if v, err := c.db.GetSet(args[0], args[1]); err != nil { + if v, err := req.db.GetSet(args[0], args[1]); err != nil { return err } else { - c.resp.writeBulk(v) + req.resp.writeBulk(v) } return nil } -func setnxCommand(c *client) error { - args := c.args +func setnxCommand(req *requestContext) error { + args := req.args if len(args) != 2 { return ErrCmdParams } - if n, err := c.db.SetNX(args[0], args[1]); err != nil { + if n, err := req.db.SetNX(args[0], args[1]); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func existsCommand(c *client) error { - args := c.args +func existsCommand(req *requestContext) error { + args := req.args if len(args) != 1 { return ErrCmdParams } - if n, err := c.db.Exists(args[0]); err != nil { + if n, err := req.db.Exists(args[0]); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func incrCommand(c *client) error { - args := c.args +func incrCommand(req *requestContext) error { + args := req.args if len(args) != 1 { return ErrCmdParams } - if n, err := c.db.Incr(c.args[0]); err != nil { + if n, err := req.db.Incr(req.args[0]); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func decrCommand(c *client) error { - args := c.args +func decrCommand(req *requestContext) error { + args := req.args if len(args) != 1 { return ErrCmdParams } - if n, err := c.db.Decr(c.args[0]); err != nil { + if n, err := req.db.Decr(req.args[0]); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func incrbyCommand(c *client) error { - args := c.args +func incrbyCommand(req *requestContext) error { + args := req.args if len(args) != 2 { return ErrCmdParams } @@ -119,17 +119,17 @@ func incrbyCommand(c *client) error { return err } - if n, err := c.db.IncryBy(c.args[0], delta); err != nil { + if n, err := req.db.IncryBy(req.args[0], delta); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func decrbyCommand(c *client) error { - args := c.args +func decrbyCommand(req *requestContext) error { + args := req.args if len(args) != 2 { return ErrCmdParams } @@ -139,32 +139,32 @@ func decrbyCommand(c *client) error { return err } - if n, err := c.db.DecrBy(c.args[0], delta); err != nil { + if n, err := req.db.DecrBy(req.args[0], delta); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func delCommand(c *client) error { - args := c.args +func delCommand(req *requestContext) error { + args := req.args if len(args) == 0 { return ErrCmdParams } - if n, err := c.db.Del(args...); err != nil { + if n, err := req.db.Del(args...); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func msetCommand(c *client) error { - args := c.args +func msetCommand(req *requestContext) error { + args := req.args if len(args) == 0 || len(args)%2 != 0 { return ErrCmdParams } @@ -175,36 +175,36 @@ func msetCommand(c *client) error { kvs[i].Value = args[2*i+1] } - if err := c.db.MSet(kvs...); err != nil { + if err := req.db.MSet(kvs...); err != nil { return err } else { - c.resp.writeStatus(OK) + req.resp.writeStatus(OK) } return nil } -// func setexCommand(c *client) error { +// func setexCommand(req *requestContext) error { // return nil // } -func mgetCommand(c *client) error { - args := c.args +func mgetCommand(req *requestContext) error { + args := req.args if len(args) == 0 { return ErrCmdParams } - if v, err := c.db.MGet(args...); err != nil { + if v, err := req.db.MGet(args...); err != nil { return err } else { - c.resp.writeSliceArray(v) + req.resp.writeSliceArray(v) } return nil } -func expireCommand(c *client) error { - args := c.args +func expireCommand(req *requestContext) error { + args := req.args if len(args) != 2 { return ErrCmdParams } @@ -214,17 +214,17 @@ func expireCommand(c *client) error { return err } - if v, err := c.db.Expire(args[0], duration); err != nil { + if v, err := req.db.Expire(args[0], duration); err != nil { return err } else { - c.resp.writeInteger(v) + req.resp.writeInteger(v) } return nil } -func expireAtCommand(c *client) error { - args := c.args +func expireAtCommand(req *requestContext) error { + args := req.args if len(args) != 2 { return ErrCmdParams } @@ -234,40 +234,40 @@ func expireAtCommand(c *client) error { return err } - if v, err := c.db.ExpireAt(args[0], when); err != nil { + if v, err := req.db.ExpireAt(args[0], when); err != nil { return err } else { - c.resp.writeInteger(v) + req.resp.writeInteger(v) } return nil } -func ttlCommand(c *client) error { - args := c.args +func ttlCommand(req *requestContext) error { + args := req.args if len(args) != 1 { return ErrCmdParams } - if v, err := c.db.TTL(args[0]); err != nil { + if v, err := req.db.TTL(args[0]); err != nil { return err } else { - c.resp.writeInteger(v) + req.resp.writeInteger(v) } return nil } -func persistCommand(c *client) error { - args := c.args +func persistCommand(req *requestContext) error { + args := req.args if len(args) != 1 { return ErrCmdParams } - if n, err := c.db.Persist(args[0]); err != nil { + if n, err := req.db.Persist(args[0]); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil diff --git a/server/cmd_list.go b/server/cmd_list.go index a4e971b..dd9f6ef 100644 --- a/server/cmd_list.go +++ b/server/cmd_list.go @@ -4,83 +4,83 @@ import ( "github.com/siddontang/ledisdb/ledis" ) -func lpushCommand(c *client) error { - args := c.args +func lpushCommand(req *requestContext) error { + args := req.args if len(args) < 2 { return ErrCmdParams } - if n, err := c.db.LPush(args[0], args[1:]...); err != nil { + if n, err := req.db.LPush(args[0], args[1:]...); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func rpushCommand(c *client) error { - args := c.args +func rpushCommand(req *requestContext) error { + args := req.args if len(args) < 2 { return ErrCmdParams } - if n, err := c.db.RPush(args[0], args[1:]...); err != nil { + if n, err := req.db.RPush(args[0], args[1:]...); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func lpopCommand(c *client) error { - args := c.args +func lpopCommand(req *requestContext) error { + args := req.args if len(args) != 1 { return ErrCmdParams } - if v, err := c.db.LPop(args[0]); err != nil { + if v, err := req.db.LPop(args[0]); err != nil { return err } else { - c.resp.writeBulk(v) + req.resp.writeBulk(v) } return nil } -func rpopCommand(c *client) error { - args := c.args +func rpopCommand(req *requestContext) error { + args := req.args if len(args) != 1 { return ErrCmdParams } - if v, err := c.db.RPop(args[0]); err != nil { + if v, err := req.db.RPop(args[0]); err != nil { return err } else { - c.resp.writeBulk(v) + req.resp.writeBulk(v) } return nil } -func llenCommand(c *client) error { - args := c.args +func llenCommand(req *requestContext) error { + args := req.args if len(args) != 1 { return ErrCmdParams } - if n, err := c.db.LLen(args[0]); err != nil { + if n, err := req.db.LLen(args[0]); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func lindexCommand(c *client) error { - args := c.args +func lindexCommand(req *requestContext) error { + args := req.args if len(args) != 2 { return ErrCmdParams } @@ -90,17 +90,17 @@ func lindexCommand(c *client) error { return err } - if v, err := c.db.LIndex(args[0], int32(index)); err != nil { + if v, err := req.db.LIndex(args[0], int32(index)); err != nil { return err } else { - c.resp.writeBulk(v) + req.resp.writeBulk(v) } return nil } -func lrangeCommand(c *client) error { - args := c.args +func lrangeCommand(req *requestContext) error { + args := req.args if len(args) != 3 { return ErrCmdParams } @@ -119,47 +119,47 @@ func lrangeCommand(c *client) error { return err } - if v, err := c.db.LRange(args[0], int32(start), int32(stop)); err != nil { + if v, err := req.db.LRange(args[0], int32(start), int32(stop)); err != nil { return err } else { - c.resp.writeSliceArray(v) + req.resp.writeSliceArray(v) } return nil } -func lclearCommand(c *client) error { - args := c.args +func lclearCommand(req *requestContext) error { + args := req.args if len(args) != 1 { return ErrCmdParams } - if n, err := c.db.LClear(args[0]); err != nil { + if n, err := req.db.LClear(args[0]); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func lmclearCommand(c *client) error { - args := c.args +func lmclearCommand(req *requestContext) error { + args := req.args if len(args) < 1 { return ErrCmdParams } - if n, err := c.db.LMclear(args...); err != nil { + if n, err := req.db.LMclear(args...); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func lexpireCommand(c *client) error { - args := c.args +func lexpireCommand(req *requestContext) error { + args := req.args if len(args) != 2 { return ErrCmdParams } @@ -169,17 +169,17 @@ func lexpireCommand(c *client) error { return err } - if v, err := c.db.LExpire(args[0], duration); err != nil { + if v, err := req.db.LExpire(args[0], duration); err != nil { return err } else { - c.resp.writeInteger(v) + req.resp.writeInteger(v) } return nil } -func lexpireAtCommand(c *client) error { - args := c.args +func lexpireAtCommand(req *requestContext) error { + args := req.args if len(args) != 2 { return ErrCmdParams } @@ -189,40 +189,40 @@ func lexpireAtCommand(c *client) error { return err } - if v, err := c.db.LExpireAt(args[0], when); err != nil { + if v, err := req.db.LExpireAt(args[0], when); err != nil { return err } else { - c.resp.writeInteger(v) + req.resp.writeInteger(v) } return nil } -func lttlCommand(c *client) error { - args := c.args +func lttlCommand(req *requestContext) error { + args := req.args if len(args) != 1 { return ErrCmdParams } - if v, err := c.db.LTTL(args[0]); err != nil { + if v, err := req.db.LTTL(args[0]); err != nil { return err } else { - c.resp.writeInteger(v) + req.resp.writeInteger(v) } return nil } -func lpersistCommand(c *client) error { - args := c.args +func lpersistCommand(req *requestContext) error { + args := req.args if len(args) != 1 { return ErrCmdParams } - if n, err := c.db.LPersist(args[0]); err != nil { + if n, err := req.db.LPersist(args[0]); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil diff --git a/server/cmd_replication.go b/server/cmd_replication.go index fe84191..85c0861 100644 --- a/server/cmd_replication.go +++ b/server/cmd_replication.go @@ -11,8 +11,8 @@ import ( "strings" ) -func slaveofCommand(c *client) error { - args := c.args +func slaveofCommand(req *requestContext) error { + args := req.args if len(args) != 2 { return ErrCmdParams @@ -31,23 +31,23 @@ func slaveofCommand(c *client) error { masterAddr = fmt.Sprintf("%s:%s", args[0], args[1]) } - if err := c.app.slaveof(masterAddr); err != nil { + if err := req.app.slaveof(masterAddr); err != nil { return err } - c.resp.writeStatus(OK) + req.resp.writeStatus(OK) return nil } -func fullsyncCommand(c *client) error { +func fullsyncCommand(req *requestContext) error { //todo, multi fullsync may use same dump file - dumpFile, err := ioutil.TempFile(c.app.cfg.DataDir, "dump_") + dumpFile, err := ioutil.TempFile(req.app.cfg.DataDir, "dump_") if err != nil { return err } - if err = c.app.ldb.Dump(dumpFile); err != nil { + if err = req.app.ldb.Dump(dumpFile); err != nil { return err } @@ -56,7 +56,7 @@ func fullsyncCommand(c *client) error { dumpFile.Seek(0, os.SEEK_SET) - c.resp.writeBulkFrom(n, dumpFile) + req.resp.writeBulkFrom(n, dumpFile) name := dumpFile.Name() dumpFile.Close() @@ -68,8 +68,8 @@ func fullsyncCommand(c *client) error { var reserveInfoSpace = make([]byte, 16) -func syncCommand(c *client) error { - args := c.args +func syncCommand(req *requestContext) error { + args := req.args if len(args) != 2 { return ErrCmdParams } @@ -87,32 +87,32 @@ func syncCommand(c *client) error { return ErrCmdParams } - c.syncBuf.Reset() + req.syncBuf.Reset() //reserve space to write master info - if _, err := c.syncBuf.Write(reserveInfoSpace); err != nil { + if _, err := req.syncBuf.Write(reserveInfoSpace); err != nil { return err } m := &ledis.MasterInfo{logIndex, logPos} - if _, err := c.app.ldb.ReadEventsTo(m, &c.syncBuf); err != nil { + if _, err := req.app.ldb.ReadEventsTo(m, &req.syncBuf); err != nil { return err } else { - buf := c.syncBuf.Bytes() + buf := req.syncBuf.Bytes() binary.BigEndian.PutUint64(buf[0:], uint64(m.LogFileIndex)) binary.BigEndian.PutUint64(buf[8:], uint64(m.LogPos)) - if len(c.compressBuf) < snappy.MaxEncodedLen(len(buf)) { - c.compressBuf = make([]byte, snappy.MaxEncodedLen(len(buf))) + if len(req.compressBuf) < snappy.MaxEncodedLen(len(buf)) { + req.compressBuf = make([]byte, snappy.MaxEncodedLen(len(buf))) } - if buf, err = snappy.Encode(c.compressBuf, buf); err != nil { + if buf, err = snappy.Encode(req.compressBuf, buf); err != nil { return err } - c.resp.writeBulk(buf) + req.resp.writeBulk(buf) } return nil diff --git a/server/cmd_zset.go b/server/cmd_zset.go index b697453..d2facd3 100644 --- a/server/cmd_zset.go +++ b/server/cmd_zset.go @@ -12,8 +12,8 @@ import ( var errScoreOverflow = errors.New("zset score overflow") -func zaddCommand(c *client) error { - args := c.args +func zaddCommand(req *requestContext) error { + args := req.args if len(args) < 3 { return ErrCmdParams } @@ -36,66 +36,66 @@ func zaddCommand(c *client) error { params[i].Member = args[2*i+1] } - if n, err := c.db.ZAdd(key, params...); err != nil { + if n, err := req.db.ZAdd(key, params...); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func zcardCommand(c *client) error { - args := c.args +func zcardCommand(req *requestContext) error { + args := req.args if len(args) != 1 { return ErrCmdParams } - if n, err := c.db.ZCard(args[0]); err != nil { + if n, err := req.db.ZCard(args[0]); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func zscoreCommand(c *client) error { - args := c.args +func zscoreCommand(req *requestContext) error { + args := req.args if len(args) != 2 { return ErrCmdParams } - if s, err := c.db.ZScore(args[0], args[1]); err != nil { + if s, err := req.db.ZScore(args[0], args[1]); err != nil { if err == ledis.ErrScoreMiss { - c.resp.writeBulk(nil) + req.resp.writeBulk(nil) } else { return err } } else { - c.resp.writeBulk(ledis.StrPutInt64(s)) + req.resp.writeBulk(ledis.StrPutInt64(s)) } return nil } -func zremCommand(c *client) error { - args := c.args +func zremCommand(req *requestContext) error { + args := req.args if len(args) < 2 { return ErrCmdParams } - if n, err := c.db.ZRem(args[0], args[1:]...); err != nil { + if n, err := req.db.ZRem(args[0], args[1:]...); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func zincrbyCommand(c *client) error { - args := c.args +func zincrbyCommand(req *requestContext) error { + args := req.args if len(args) != 3 { return ErrCmdParams } @@ -107,10 +107,10 @@ func zincrbyCommand(c *client) error { return err } - if v, err := c.db.ZIncrBy(key, delta, args[2]); err != nil { + if v, err := req.db.ZIncrBy(key, delta, args[2]); err != nil { return err } else { - c.resp.writeBulk(ledis.StrPutInt64(v)) + req.resp.writeBulk(ledis.StrPutInt64(v)) } return nil @@ -178,8 +178,8 @@ func zparseScoreRange(minBuf []byte, maxBuf []byte) (min int64, max int64, err e return } -func zcountCommand(c *client) error { - args := c.args +func zcountCommand(req *requestContext) error { + args := req.args if len(args) != 3 { return ErrCmdParams } @@ -190,77 +190,77 @@ func zcountCommand(c *client) error { } if min > max { - c.resp.writeInteger(0) + req.resp.writeInteger(0) return nil } - if n, err := c.db.ZCount(args[0], min, max); err != nil { + if n, err := req.db.ZCount(args[0], min, max); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func zrankCommand(c *client) error { - args := c.args +func zrankCommand(req *requestContext) error { + args := req.args if len(args) != 2 { return ErrCmdParams } - if n, err := c.db.ZRank(args[0], args[1]); err != nil { + if n, err := req.db.ZRank(args[0], args[1]); err != nil { return err } else if n == -1 { - c.resp.writeBulk(nil) + req.resp.writeBulk(nil) } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func zrevrankCommand(c *client) error { - args := c.args +func zrevrankCommand(req *requestContext) error { + args := req.args if len(args) != 2 { return ErrCmdParams } - if n, err := c.db.ZRevRank(args[0], args[1]); err != nil { + if n, err := req.db.ZRevRank(args[0], args[1]); err != nil { return err } else if n == -1 { - c.resp.writeBulk(nil) + req.resp.writeBulk(nil) } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func zremrangebyrankCommand(c *client) error { - args := c.args +func zremrangebyrankCommand(req *requestContext) error { + args := req.args if len(args) != 3 { return ErrCmdParams } key := args[0] - start, stop, err := zparseRange(c, args[1], args[2]) + start, stop, err := zparseRange(req, args[1], args[2]) if err != nil { return err } - if n, err := c.db.ZRemRangeByRank(key, start, stop); err != nil { + if n, err := req.db.ZRemRangeByRank(key, start, stop); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func zremrangebyscoreCommand(c *client) error { - args := c.args +func zremrangebyscoreCommand(req *requestContext) error { + args := req.args if len(args) != 3 { return ErrCmdParams } @@ -271,16 +271,16 @@ func zremrangebyscoreCommand(c *client) error { return err } - if n, err := c.db.ZRemRangeByScore(key, min, max); err != nil { + if n, err := req.db.ZRemRangeByScore(key, min, max); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func zparseRange(c *client, a1 []byte, a2 []byte) (start int, stop int, err error) { +func zparseRange(req *requestContext, a1 []byte, a2 []byte) (start int, stop int, err error) { if start, err = strconv.Atoi(ledis.String(a1)); err != nil { return } @@ -292,15 +292,15 @@ func zparseRange(c *client, a1 []byte, a2 []byte) (start int, stop int, err erro return } -func zrangeGeneric(c *client, reverse bool) error { - args := c.args +func zrangeGeneric(req *requestContext, reverse bool) error { + args := req.args if len(args) < 3 { return ErrCmdParams } key := args[0] - start, stop, err := zparseRange(c, args[1], args[2]) + start, stop, err := zparseRange(req, args[1], args[2]) if err != nil { return err } @@ -312,24 +312,24 @@ func zrangeGeneric(c *client, reverse bool) error { withScores = true } - if datas, err := c.db.ZRangeGeneric(key, start, stop, reverse); err != nil { + if datas, err := req.db.ZRangeGeneric(key, start, stop, reverse); err != nil { return err } else { - c.resp.writeScorePairArray(datas, withScores) + req.resp.writeScorePairArray(datas, withScores) } return nil } -func zrangeCommand(c *client) error { - return zrangeGeneric(c, false) +func zrangeCommand(req *requestContext) error { + return zrangeGeneric(req, false) } -func zrevrangeCommand(c *client) error { - return zrangeGeneric(c, true) +func zrevrangeCommand(req *requestContext) error { + return zrangeGeneric(req, true) } -func zrangebyscoreGeneric(c *client, reverse bool) error { - args := c.args +func zrangebyscoreGeneric(req *requestContext, reverse bool) error { + args := req.args if len(args) < 3 { return ErrCmdParams } @@ -383,59 +383,59 @@ func zrangebyscoreGeneric(c *client, reverse bool) error { if offset < 0 { //for ledis, if offset < 0, a empty will return //so here we directly return a empty array - c.resp.writeArray([]interface{}{}) + req.resp.writeArray([]interface{}{}) return nil } - if datas, err := c.db.ZRangeByScoreGeneric(key, min, max, offset, count, reverse); err != nil { + if datas, err := req.db.ZRangeByScoreGeneric(key, min, max, offset, count, reverse); err != nil { return err } else { - c.resp.writeScorePairArray(datas, withScores) + req.resp.writeScorePairArray(datas, withScores) } return nil } -func zrangebyscoreCommand(c *client) error { - return zrangebyscoreGeneric(c, false) +func zrangebyscoreCommand(req *requestContext) error { + return zrangebyscoreGeneric(req, false) } -func zrevrangebyscoreCommand(c *client) error { - return zrangebyscoreGeneric(c, true) +func zrevrangebyscoreCommand(req *requestContext) error { + return zrangebyscoreGeneric(req, true) } -func zclearCommand(c *client) error { - args := c.args +func zclearCommand(req *requestContext) error { + args := req.args if len(args) != 1 { return ErrCmdParams } - if n, err := c.db.ZClear(args[0]); err != nil { + if n, err := req.db.ZClear(args[0]); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func zmclearCommand(c *client) error { - args := c.args +func zmclearCommand(req *requestContext) error { + args := req.args if len(args) < 1 { return ErrCmdParams } - if n, err := c.db.ZMclear(args...); err != nil { + if n, err := req.db.ZMclear(args...); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil } -func zexpireCommand(c *client) error { - args := c.args +func zexpireCommand(req *requestContext) error { + args := req.args if len(args) != 2 { return ErrCmdParams } @@ -445,17 +445,17 @@ func zexpireCommand(c *client) error { return err } - if v, err := c.db.ZExpire(args[0], duration); err != nil { + if v, err := req.db.ZExpire(args[0], duration); err != nil { return err } else { - c.resp.writeInteger(v) + req.resp.writeInteger(v) } return nil } -func zexpireAtCommand(c *client) error { - args := c.args +func zexpireAtCommand(req *requestContext) error { + args := req.args if len(args) != 2 { return ErrCmdParams } @@ -465,40 +465,40 @@ func zexpireAtCommand(c *client) error { return err } - if v, err := c.db.ZExpireAt(args[0], when); err != nil { + if v, err := req.db.ZExpireAt(args[0], when); err != nil { return err } else { - c.resp.writeInteger(v) + req.resp.writeInteger(v) } return nil } -func zttlCommand(c *client) error { - args := c.args +func zttlCommand(req *requestContext) error { + args := req.args if len(args) != 1 { return ErrCmdParams } - if v, err := c.db.ZTTL(args[0]); err != nil { + if v, err := req.db.ZTTL(args[0]); err != nil { return err } else { - c.resp.writeInteger(v) + req.resp.writeInteger(v) } return nil } -func zpersistCommand(c *client) error { - args := c.args +func zpersistCommand(req *requestContext) error { + args := req.args if len(args) != 1 { return ErrCmdParams } - if n, err := c.db.ZPersist(args[0]); err != nil { + if n, err := req.db.ZPersist(args[0]); err != nil { return err } else { - c.resp.writeInteger(n) + req.resp.writeInteger(n) } return nil diff --git a/server/command.go b/server/command.go index 23ca7bd..440a177 100644 --- a/server/command.go +++ b/server/command.go @@ -8,7 +8,7 @@ import ( "strings" ) -type CommandFunc func(c *client) error +type CommandFunc func(req *requestContext) error var regCmds = map[string]CommandFunc{} @@ -20,33 +20,33 @@ func register(name string, f CommandFunc) { regCmds[name] = f } -func pingCommand(c *client) error { - c.resp.writeStatus(PONG) +func pingCommand(req *requestContext) error { + req.resp.writeStatus(PONG) return nil } -func echoCommand(c *client) error { - if len(c.args) != 1 { +func echoCommand(req *requestContext) error { + if len(req.args) != 1 { return ErrCmdParams } - c.resp.writeBulk(c.args[0]) + req.resp.writeBulk(req.args[0]) return nil } -func selectCommand(c *client) error { - if len(c.args) != 1 { +func selectCommand(req *requestContext) error { + if len(req.args) != 1 { return ErrCmdParams } - if index, err := strconv.Atoi(ledis.String(c.args[0])); err != nil { + if index, err := strconv.Atoi(ledis.String(req.args[0])); err != nil { return err } else { - if db, err := c.ldb.Select(index); err != nil { + if db, err := req.ldb.Select(index); err != nil { return err } else { - c.db = db - c.resp.writeStatus(OK) + req.db = db + req.resp.writeStatus(OK) } } return nil diff --git a/server/request.go b/server/request.go new file mode 100644 index 0000000..5d6dacf --- /dev/null +++ b/server/request.go @@ -0,0 +1,200 @@ +package server + +import ( + "bytes" + "github.com/siddontang/go-log/log" + "github.com/siddontang/ledisdb/ledis" + "io" + "runtime" + "sync" + "time" +) + +type responseWriter interface { + writeError(error) + writeStatus(string) + writeInteger(int64) + writeBulk([]byte) + writeArray([]interface{}) + writeSliceArray([][]byte) + writeFVPairArray([]ledis.FVPair) + writeScorePairArray([]ledis.ScorePair, bool) + writeBulkFrom(int64, io.Reader) + flush() +} + +type requestContext struct { + app *App + ldb *ledis.Ledis + db *ledis.DB + + remoteAddr string + cmd string + args [][]byte + + resp responseWriter + + syncBuf bytes.Buffer + compressBuf []byte + + finish chan interface{} +} + +type requestHandler struct { + app *App + + async bool + quit chan struct{} + jobs *sync.WaitGroup + + reqs chan *requestContext + reqErr chan error + + buf bytes.Buffer +} + +func newRequestContext(app *App) *requestContext { + req := new(requestContext) + + req.app = app + req.ldb = app.ldb + req.db, _ = app.ldb.Select(0) //use default db + + req.compressBuf = make([]byte, 256) + req.finish = make(chan interface{}, 1) + + return req +} + +func newRequestHandler(app *App) *requestHandler { + hdl := new(requestHandler) + + hdl.app = app + + hdl.async = false + hdl.jobs = new(sync.WaitGroup) + hdl.quit = make(chan struct{}) + + hdl.reqs = make(chan *requestContext) + hdl.reqErr = make(chan error) + + return hdl +} + +func (h *requestHandler) asyncRun() { + if !h.async { + // todo ... not safe + h.async = true + go h.run() + } +} + +func (h *requestHandler) close() { + if h.async { + close(h.quit) + h.jobs.Wait() + } +} + +func (h *requestHandler) run() { + defer func() { + if e := recover(); e != nil { + buf := make([]byte, 4096) + n := runtime.Stack(buf, false) + buf = buf[0:n] + + log.Fatal("request handler run panic %s:%v", buf, e) + } + }() + + h.jobs.Add(1) + + var req *requestContext + for !h.async { + select { + case req = <-h.reqs: + if req != nil { + h.performance(req) + } + case <-h.quit: + h.async = true + break + } + } + + h.jobs.Done() + return +} + +func (h *requestHandler) postRequest(req *requestContext) { + if h.async { + h.reqs <- req + } else { + h.performance(req) + } + + <-req.finish +} + +func (h *requestHandler) performance(req *requestContext) { + var err error + + start := time.Now() + + if len(req.cmd) == 0 { + err = ErrEmptyCommand + } else if exeCmd, ok := regCmds[req.cmd]; !ok { + err = ErrNotFound + } else { + go func() { + h.reqErr <- exeCmd(req) + }() + + err = <-h.reqErr + } + + duration := time.Since(start) + + if h.app.access != nil { + fullCmd := h.catGenericCommand(req) + cost := duration.Nanoseconds() / 1000000 + + h.app.access.Log(req.remoteAddr, cost, fullCmd[:256], err) + } + + if err != nil { + req.resp.writeError(err) + } + req.resp.flush() + + req.finish <- nil + return +} + +// func (h *requestHandler) catFullCommand(req *requestContext) []byte { +// +// // if strings.HasSuffix(cmd, "expire") { +// // catExpireCommand(c, buffer) +// // } else { +// // catGenericCommand(c, buffer) +// // } +// +// return h.catGenericCommand(req) +// } + +func (h *requestHandler) catGenericCommand(req *requestContext) []byte { + buffer := h.buf + buffer.Reset() + + buffer.Write([]byte(req.cmd)) + + nargs := len(req.args) + for i, arg := range req.args { + buffer.Write(arg) + if i != nargs-1 { + buffer.WriteByte(' ') + } + } + + return buffer.Bytes() +}