diff --git a/server/client_resp.go b/server/client_resp.go index 8d8378f..2713609 100644 --- a/server/client_resp.go +++ b/server/client_resp.go @@ -37,6 +37,7 @@ func newClientRESP(conn net.Conn, app *App) { tcpConn.SetReadBuffer(app.cfg.ConnReadBufferSize) tcpConn.SetWriteBuffer(app.cfg.ConnWriteBufferSize) } + c.rb = bufio.NewReaderSize(conn, app.cfg.ConnReadBufferSize) c.resp = newWriterRESP(conn, app.cfg.ConnWriteBufferSize) @@ -85,57 +86,93 @@ func (c *respClient) readLine() ([]byte, error) { } //A client sends to the Redis server a RESP Array consisting of just Bulk Strings. +// func (c *respClient) readRequest() ([][]byte, error) { +// l, err := c.readLine() +// if err != nil { +// return nil, err +// } else if len(l) == 0 || l[0] != '*' { +// return nil, errReadRequest +// } + +// var nparams int +// if nparams, err = strconv.Atoi(hack.String(l[1:])); err != nil { +// return nil, err +// } else if nparams <= 0 { +// return nil, errReadRequest +// } + +// req := make([][]byte, 0, nparams) +// var n int +// for i := 0; i < nparams; i++ { +// if l, err = c.readLine(); err != nil { +// return nil, err +// } + +// if len(l) == 0 { +// return nil, errReadRequest +// } else if l[0] == '$' { +// //handle resp string +// if n, err = strconv.Atoi(hack.String(l[1:])); err != nil { +// return nil, err +// } else if n == -1 { +// req = append(req, nil) +// } else { +// buf := make([]byte, n+2) +// if _, err = io.ReadFull(c.rb, buf); err != nil { +// return nil, err +// } + +// if buf[len(buf)-2] != '\r' && buf[len(buf)-1] != '\n' { +// return nil, errors.New("bad bulk string format") +// } + +// // if l, err = c.readLine(); err != nil { +// // return nil, err +// // } else if len(l) != 0 { +// // return nil, errors.New("bad bulk string format") +// // } + +// req = append(req, buf[0:len(buf)-2]) + +// } + +// } else { +// return nil, errReadRequest +// } +// } + +// return req, nil +// } + func (c *respClient) readRequest() ([][]byte, error) { - l, err := c.readLine() + code, err := c.rb.ReadByte() if err != nil { return nil, err - } else if len(l) == 0 || l[0] != '*' { + } + + if code != '*' { return nil, errReadRequest } - var nparams int - if nparams, err = strconv.Atoi(hack.String(l[1:])); err != nil { + var nparams int64 + if nparams, err = readLong(c.rb); err != nil { return nil, err } else if nparams <= 0 { return nil, errReadRequest } - req := make([][]byte, 0, nparams) - var n int - for i := 0; i < nparams; i++ { - if l, err = c.readLine(); err != nil { + req := make([][]byte, nparams) + for i := range req { + if code, err = c.rb.ReadByte(); err != nil { + return nil, err + } else if code != '$' { + return nil, errReadRequest + } + + if req[i], err = readBytes(c.rb); err != nil { return nil, err } - - if len(l) == 0 { - return nil, errReadRequest - } else if l[0] == '$' { - //handle resp string - if n, err = strconv.Atoi(hack.String(l[1:])); err != nil { - return nil, err - } else if n == -1 { - req = append(req, nil) - } else { - buf := make([]byte, n) - if _, err = io.ReadFull(c.rb, buf); err != nil { - return nil, err - } - - if l, err = c.readLine(); err != nil { - return nil, err - } else if len(l) != 0 { - return nil, errors.New("bad bulk string format") - } - - req = append(req, buf) - - } - - } else { - return nil, errReadRequest - } } - return req, nil } diff --git a/server/util.go b/server/util.go index 44b289c..504b051 100644 --- a/server/util.go +++ b/server/util.go @@ -3,6 +3,8 @@ package server import ( "bufio" "errors" + "fmt" + "io" ) var ( @@ -19,5 +21,71 @@ func ReadLine(rb *bufio.Reader) ([]byte, error) { if i < 0 || p[i] != '\r' { return nil, errLineFormat } + return p[:i], nil } + +func readBytes(br *bufio.Reader) (bytes []byte, err error) { + size, err := readLong(br) + if err != nil { + return nil, err + } + if size == -1 { + return nil, nil + } + if size < 0 { + return nil, errors.New("Invalid size: " + fmt.Sprint("%d", size)) + } + + buf := make([]byte, size+2) + if _, err = io.ReadFull(br, buf); err != nil { + return nil, err + } + + if buf[len(buf)-2] != '\r' && buf[len(buf)-1] != '\n' { + return nil, errors.New("bad bulk string format") + } + + bytes = buf[0 : len(buf)-2] + + return +} + +func readLong(in *bufio.Reader) (result int64, err error) { + read, err := in.ReadByte() + if err != nil { + return -1, err + } + var sign int + if read == '-' { + read, err = in.ReadByte() + if err != nil { + return -1, err + } + sign = -1 + } else { + sign = 1 + } + var number int64 + for number = 0; err == nil; read, err = in.ReadByte() { + if read == '\r' { + read, err = in.ReadByte() + if err != nil { + return -1, err + } + if read == '\n' { + return number * int64(sign), nil + } else { + return -1, errors.New("Bad line ending") + } + } + value := read - '0' + if value >= 0 && value < 10 { + number *= 10 + number += int64(value) + } else { + return -1, errors.New("Invalid digit") + } + } + return -1, err +}