diff --git a/controller/aof.go b/controller/aof.go index 084a35d6..6cd25ef5 100644 --- a/controller/aof.go +++ b/controller/aof.go @@ -1,7 +1,6 @@ package controller import ( - "bufio" "errors" "fmt" "io" @@ -13,6 +12,7 @@ import ( "time" "github.com/tidwall/buntdb" + "github.com/tidwall/redcon" "github.com/tidwall/resp" "github.com/tidwall/tile38/controller/log" "github.com/tidwall/tile38/controller/server" @@ -53,86 +53,57 @@ func (c *Controller) loadAOF() error { log.Infof("AOF loaded %d commands: %.2fs, %.0f/s, %s", count, float64(d)/float64(time.Second), ps, byteSpeed) }() + var buf []byte + var args [][]byte + var packet [0xFFFF]byte var msg server.Message - rd := bufio.NewReader(c.aof) for { - var nn int - ch, err := rd.ReadByte() + n, err := c.aof.Read(packet[:]) if err != nil { if err == io.EOF { + if len(buf) > 0 { + return io.ErrUnexpectedEOF + } return nil } return err } - nn += 1 - if ch != '*' { - return errInvalidAOF + c.aofsz += n + data := packet[:n] + if len(buf) > 0 { + data = append(buf, data...) } - ns, err := rd.ReadString('\n') - if err != nil { - return err - } - nn += len(ns) - if len(ns) < 2 || ns[len(ns)-2] != '\r' { - return errInvalidAOF - } - n, err := strconv.ParseUint(ns[:len(ns)-2], 10, 64) - if err != nil { - return err - } - if int(n) == 0 { - continue - } - msg.Values = msg.Values[:0] - for i := 0; i < int(n); i++ { - ch, err := rd.ReadByte() + var complete bool + for { + complete, args, _, data, err = redcon.ReadNextCommand(data, args[:0]) if err != nil { return err } - if ch != '$' { - return errInvalidAOF + if !complete { + break } - ns, err := rd.ReadString('\n') - if err != nil { - return err - } - if len(ns) < 2 || ns[len(ns)-2] != '\r' { - return errInvalidAOF - } - n, err := strconv.ParseUint(ns[:len(ns)-2], 10, 64) - if err != nil { - return err - } - b := make([]byte, int(n)) - _, err = io.ReadFull(rd, b) - if err != nil { - return err - } - if ch, err := rd.ReadByte(); err != nil { - return err - } else if ch != '\r' { - return errInvalidAOF - } - if ch, err := rd.ReadByte(); err != nil { - return err - } else if ch != '\n' { - return errInvalidAOF - } - msg.Values = append(msg.Values, resp.BytesValue(b)) - if i == 0 { - msg.Command = qlower(b) - } - nn += 1 + len(ns) + int(n) + 2 - } - if _, _, err := c.command(&msg, nil, nil); err != nil { - if commandErrIsFatal(err) { - return err + if len(args) > 0 { + msg.Values = msg.Values[:0] + for _, arg := range args { + msg.Values = append(msg.Values, resp.BytesValue(arg)) + } + msg.Command = qlower(args[0]) + if _, _, err := c.command(&msg, nil, nil); err != nil { + if commandErrIsFatal(err) { + return err + } + } + count++ } } - c.aofsz += nn - count++ + if len(data) > 0 { + buf = append(buf[:0], data...) + } else if len(buf) > 0 { + buf = buf[:0] + } } } + func qlower(s []byte) string { if len(s) == 3 { if s[0] == 'S' && s[1] == 'E' && s[2] == 'T' {