ledisdb/cmd/vendor/github.com/cupcake/rdb/decoder.go

782 lines
18 KiB
Go

// Package rdb implements parsing and encoding of the Redis RDB file format.
package rdb
import (
"bufio"
"bytes"
"encoding/binary"
"fmt"
"io"
"math"
"strconv"
"github.com/cupcake/rdb/crc64"
)
// A Decoder must be implemented to parse a RDB file.
type Decoder interface {
// StartRDB is called when parsing of a valid RDB file starts.
StartRDB()
// StartDatabase is called when database n starts.
// Once a database starts, another database will not start until EndDatabase is called.
StartDatabase(n int)
// Set is called once for each string key.
Set(key, value []byte, expiry int64)
// StartHash is called at the beginning of a hash.
// Hset will be called exactly length times before EndHash.
StartHash(key []byte, length, expiry int64)
// Hset is called once for each field=value pair in a hash.
Hset(key, field, value []byte)
// EndHash is called when there are no more fields in a hash.
EndHash(key []byte)
// StartSet is called at the beginning of a set.
// Sadd will be called exactly cardinality times before EndSet.
StartSet(key []byte, cardinality, expiry int64)
// Sadd is called once for each member of a set.
Sadd(key, member []byte)
// EndSet is called when there are no more fields in a set.
EndSet(key []byte)
// StartList is called at the beginning of a list.
// Rpush will be called exactly length times before EndList.
StartList(key []byte, length, expiry int64)
// Rpush is called once for each value in a list.
Rpush(key, value []byte)
// EndList is called when there are no more values in a list.
EndList(key []byte)
// StartZSet is called at the beginning of a sorted set.
// Zadd will be called exactly cardinality times before EndZSet.
StartZSet(key []byte, cardinality, expiry int64)
// Zadd is called once for each member of a sorted set.
Zadd(key []byte, score float64, member []byte)
// EndZSet is called when there are no more members in a sorted set.
EndZSet(key []byte)
// EndDatabase is called at the end of a database.
EndDatabase(n int)
// EndRDB is called when parsing of the RDB file is complete.
EndRDB()
}
// Decode parses a RDB file from r and calls the decode hooks on d.
func Decode(r io.Reader, d Decoder) error {
decoder := &decode{d, make([]byte, 8), bufio.NewReader(r)}
return decoder.decode()
}
// Decode a byte slice from the Redis DUMP command. The dump does not contain the
// database, key or expiry, so they must be included in the function call (but
// can be zero values).
func DecodeDump(dump []byte, db int, key []byte, expiry int64, d Decoder) error {
err := verifyDump(dump)
if err != nil {
return err
}
decoder := &decode{d, make([]byte, 8), bytes.NewReader(dump[1:])}
decoder.event.StartRDB()
decoder.event.StartDatabase(db)
err = decoder.readObject(key, ValueType(dump[0]), expiry)
decoder.event.EndDatabase(db)
decoder.event.EndRDB()
return err
}
type byteReader interface {
io.Reader
io.ByteReader
}
type decode struct {
event Decoder
intBuf []byte
r byteReader
}
type ValueType byte
const (
TypeString ValueType = 0
TypeList ValueType = 1
TypeSet ValueType = 2
TypeZSet ValueType = 3
TypeHash ValueType = 4
TypeHashZipmap ValueType = 9
TypeListZiplist ValueType = 10
TypeSetIntset ValueType = 11
TypeZSetZiplist ValueType = 12
TypeHashZiplist ValueType = 13
)
const (
rdb6bitLen = 0
rdb14bitLen = 1
rdb32bitLen = 2
rdbEncVal = 3
rdbFlagExpiryMS = 0xfc
rdbFlagExpiry = 0xfd
rdbFlagSelectDB = 0xfe
rdbFlagEOF = 0xff
rdbEncInt8 = 0
rdbEncInt16 = 1
rdbEncInt32 = 2
rdbEncLZF = 3
rdbZiplist6bitlenString = 0
rdbZiplist14bitlenString = 1
rdbZiplist32bitlenString = 2
rdbZiplistInt16 = 0xc0
rdbZiplistInt32 = 0xd0
rdbZiplistInt64 = 0xe0
rdbZiplistInt24 = 0xf0
rdbZiplistInt8 = 0xfe
rdbZiplistInt4 = 15
)
func (d *decode) decode() error {
err := d.checkHeader()
if err != nil {
return err
}
d.event.StartRDB()
var db uint32
var expiry int64
firstDB := true
for {
objType, err := d.r.ReadByte()
if err != nil {
return err
}
switch objType {
case rdbFlagExpiryMS:
_, err := io.ReadFull(d.r, d.intBuf)
if err != nil {
return err
}
expiry = int64(binary.LittleEndian.Uint64(d.intBuf))
case rdbFlagExpiry:
_, err := io.ReadFull(d.r, d.intBuf[:4])
if err != nil {
return err
}
expiry = int64(binary.LittleEndian.Uint32(d.intBuf)) * 1000
case rdbFlagSelectDB:
if !firstDB {
d.event.EndDatabase(int(db))
}
db, _, err = d.readLength()
if err != nil {
return err
}
d.event.StartDatabase(int(db))
case rdbFlagEOF:
d.event.EndDatabase(int(db))
d.event.EndRDB()
return nil
default:
key, err := d.readString()
if err != nil {
return err
}
err = d.readObject(key, ValueType(objType), expiry)
if err != nil {
return err
}
}
}
panic("not reached")
}
func (d *decode) readObject(key []byte, typ ValueType, expiry int64) error {
switch typ {
case TypeString:
value, err := d.readString()
if err != nil {
return err
}
d.event.Set(key, value, expiry)
case TypeList:
length, _, err := d.readLength()
if err != nil {
return err
}
d.event.StartList(key, int64(length), expiry)
for i := uint32(0); i < length; i++ {
value, err := d.readString()
if err != nil {
return err
}
d.event.Rpush(key, value)
}
d.event.EndList(key)
case TypeSet:
cardinality, _, err := d.readLength()
if err != nil {
return err
}
d.event.StartSet(key, int64(cardinality), expiry)
for i := uint32(0); i < cardinality; i++ {
member, err := d.readString()
if err != nil {
return err
}
d.event.Sadd(key, member)
}
d.event.EndSet(key)
case TypeZSet:
cardinality, _, err := d.readLength()
if err != nil {
return err
}
d.event.StartZSet(key, int64(cardinality), expiry)
for i := uint32(0); i < cardinality; i++ {
member, err := d.readString()
if err != nil {
return err
}
score, err := d.readFloat64()
if err != nil {
return err
}
d.event.Zadd(key, score, member)
}
d.event.EndZSet(key)
case TypeHash:
length, _, err := d.readLength()
if err != nil {
return err
}
d.event.StartHash(key, int64(length), expiry)
for i := uint32(0); i < length; i++ {
field, err := d.readString()
if err != nil {
return err
}
value, err := d.readString()
if err != nil {
return err
}
d.event.Hset(key, field, value)
}
d.event.EndHash(key)
case TypeHashZipmap:
return d.readZipmap(key, expiry)
case TypeListZiplist:
return d.readZiplist(key, expiry)
case TypeSetIntset:
return d.readIntset(key, expiry)
case TypeZSetZiplist:
return d.readZiplistZset(key, expiry)
case TypeHashZiplist:
return d.readZiplistHash(key, expiry)
default:
return fmt.Errorf("rdb: unknown object type %d for key %s", typ, key)
}
return nil
}
func (d *decode) readZipmap(key []byte, expiry int64) error {
var length int
zipmap, err := d.readString()
if err != nil {
return err
}
buf := newSliceBuffer(zipmap)
lenByte, err := buf.ReadByte()
if err != nil {
return err
}
if lenByte >= 254 { // we need to count the items manually
length, err = countZipmapItems(buf)
length /= 2
if err != nil {
return err
}
} else {
length = int(lenByte)
}
d.event.StartHash(key, int64(length), expiry)
for i := 0; i < length; i++ {
field, err := readZipmapItem(buf, false)
if err != nil {
return err
}
value, err := readZipmapItem(buf, true)
if err != nil {
return err
}
d.event.Hset(key, field, value)
}
d.event.EndHash(key)
return nil
}
func readZipmapItem(buf *sliceBuffer, readFree bool) ([]byte, error) {
length, free, err := readZipmapItemLength(buf, readFree)
if err != nil {
return nil, err
}
if length == -1 {
return nil, nil
}
value, err := buf.Slice(length)
if err != nil {
return nil, err
}
_, err = buf.Seek(int64(free), 1)
return value, err
}
func countZipmapItems(buf *sliceBuffer) (int, error) {
n := 0
for {
strLen, free, err := readZipmapItemLength(buf, n%2 != 0)
if err != nil {
return 0, err
}
if strLen == -1 {
break
}
_, err = buf.Seek(int64(strLen)+int64(free), 1)
if err != nil {
return 0, err
}
n++
}
_, err := buf.Seek(0, 0)
return n, err
}
func readZipmapItemLength(buf *sliceBuffer, readFree bool) (int, int, error) {
b, err := buf.ReadByte()
if err != nil {
return 0, 0, err
}
switch b {
case 253:
s, err := buf.Slice(5)
if err != nil {
return 0, 0, err
}
return int(binary.BigEndian.Uint32(s)), int(s[4]), nil
case 254:
return 0, 0, fmt.Errorf("rdb: invalid zipmap item length")
case 255:
return -1, 0, nil
}
var free byte
if readFree {
free, err = buf.ReadByte()
}
return int(b), int(free), err
}
func (d *decode) readZiplist(key []byte, expiry int64) error {
ziplist, err := d.readString()
if err != nil {
return err
}
buf := newSliceBuffer(ziplist)
length, err := readZiplistLength(buf)
if err != nil {
return err
}
d.event.StartList(key, length, expiry)
for i := int64(0); i < length; i++ {
entry, err := readZiplistEntry(buf)
if err != nil {
return err
}
d.event.Rpush(key, entry)
}
d.event.EndList(key)
return nil
}
func (d *decode) readZiplistZset(key []byte, expiry int64) error {
ziplist, err := d.readString()
if err != nil {
return err
}
buf := newSliceBuffer(ziplist)
cardinality, err := readZiplistLength(buf)
if err != nil {
return err
}
cardinality /= 2
d.event.StartZSet(key, cardinality, expiry)
for i := int64(0); i < cardinality; i++ {
member, err := readZiplistEntry(buf)
if err != nil {
return err
}
scoreBytes, err := readZiplistEntry(buf)
if err != nil {
return err
}
score, err := strconv.ParseFloat(string(scoreBytes), 64)
if err != nil {
return err
}
d.event.Zadd(key, score, member)
}
d.event.EndZSet(key)
return nil
}
func (d *decode) readZiplistHash(key []byte, expiry int64) error {
ziplist, err := d.readString()
if err != nil {
return err
}
buf := newSliceBuffer(ziplist)
length, err := readZiplistLength(buf)
if err != nil {
return err
}
length /= 2
d.event.StartHash(key, length, expiry)
for i := int64(0); i < length; i++ {
field, err := readZiplistEntry(buf)
if err != nil {
return err
}
value, err := readZiplistEntry(buf)
if err != nil {
return err
}
d.event.Hset(key, field, value)
}
d.event.EndHash(key)
return nil
}
func readZiplistLength(buf *sliceBuffer) (int64, error) {
buf.Seek(8, 0) // skip the zlbytes and zltail
lenBytes, err := buf.Slice(2)
if err != nil {
return 0, err
}
return int64(binary.LittleEndian.Uint16(lenBytes)), nil
}
func readZiplistEntry(buf *sliceBuffer) ([]byte, error) {
prevLen, err := buf.ReadByte()
if err != nil {
return nil, err
}
if prevLen == 254 {
buf.Seek(4, 1) // skip the 4-byte prevlen
}
header, err := buf.ReadByte()
if err != nil {
return nil, err
}
switch {
case header>>6 == rdbZiplist6bitlenString:
return buf.Slice(int(header & 0x3f))
case header>>6 == rdbZiplist14bitlenString:
b, err := buf.ReadByte()
if err != nil {
return nil, err
}
return buf.Slice((int(header&0x3f) << 8) | int(b))
case header>>6 == rdbZiplist32bitlenString:
lenBytes, err := buf.Slice(4)
if err != nil {
return nil, err
}
return buf.Slice(int(binary.BigEndian.Uint32(lenBytes)))
case header == rdbZiplistInt16:
intBytes, err := buf.Slice(2)
if err != nil {
return nil, err
}
return []byte(strconv.FormatInt(int64(int16(binary.LittleEndian.Uint16(intBytes))), 10)), nil
case header == rdbZiplistInt32:
intBytes, err := buf.Slice(4)
if err != nil {
return nil, err
}
return []byte(strconv.FormatInt(int64(int32(binary.LittleEndian.Uint32(intBytes))), 10)), nil
case header == rdbZiplistInt64:
intBytes, err := buf.Slice(8)
if err != nil {
return nil, err
}
return []byte(strconv.FormatInt(int64(binary.LittleEndian.Uint64(intBytes)), 10)), nil
case header == rdbZiplistInt24:
intBytes := make([]byte, 4)
_, err := buf.Read(intBytes[1:])
if err != nil {
return nil, err
}
return []byte(strconv.FormatInt(int64(int32(binary.LittleEndian.Uint32(intBytes))>>8), 10)), nil
case header == rdbZiplistInt8:
b, err := buf.ReadByte()
return []byte(strconv.FormatInt(int64(int8(b)), 10)), err
case header>>4 == rdbZiplistInt4:
return []byte(strconv.FormatInt(int64(header&0x0f)-1, 10)), nil
}
return nil, fmt.Errorf("rdb: unknown ziplist header byte: %d", header)
}
func (d *decode) readIntset(key []byte, expiry int64) error {
intset, err := d.readString()
if err != nil {
return err
}
buf := newSliceBuffer(intset)
intSizeBytes, err := buf.Slice(4)
if err != nil {
return err
}
intSize := binary.LittleEndian.Uint32(intSizeBytes)
if intSize != 2 && intSize != 4 && intSize != 8 {
return fmt.Errorf("rdb: unknown intset encoding: %d", intSize)
}
lenBytes, err := buf.Slice(4)
if err != nil {
return err
}
cardinality := binary.LittleEndian.Uint32(lenBytes)
d.event.StartSet(key, int64(cardinality), expiry)
for i := uint32(0); i < cardinality; i++ {
intBytes, err := buf.Slice(int(intSize))
if err != nil {
return err
}
var intString string
switch intSize {
case 2:
intString = strconv.FormatInt(int64(int16(binary.LittleEndian.Uint16(intBytes))), 10)
case 4:
intString = strconv.FormatInt(int64(int32(binary.LittleEndian.Uint32(intBytes))), 10)
case 8:
intString = strconv.FormatInt(int64(int64(binary.LittleEndian.Uint64(intBytes))), 10)
}
d.event.Sadd(key, []byte(intString))
}
d.event.EndSet(key)
return nil
}
func (d *decode) checkHeader() error {
header := make([]byte, 9)
_, err := io.ReadFull(d.r, header)
if err != nil {
return err
}
if !bytes.Equal(header[:5], []byte("REDIS")) {
return fmt.Errorf("rdb: invalid file format")
}
version, _ := strconv.ParseInt(string(header[5:]), 10, 64)
if version < 1 || version > 6 {
return fmt.Errorf("rdb: invalid RDB version number %d", version)
}
return nil
}
func (d *decode) readString() ([]byte, error) {
length, encoded, err := d.readLength()
if err != nil {
return nil, err
}
if encoded {
switch length {
case rdbEncInt8:
i, err := d.readUint8()
return []byte(strconv.FormatInt(int64(int8(i)), 10)), err
case rdbEncInt16:
i, err := d.readUint16()
return []byte(strconv.FormatInt(int64(int16(i)), 10)), err
case rdbEncInt32:
i, err := d.readUint32()
return []byte(strconv.FormatInt(int64(int32(i)), 10)), err
case rdbEncLZF:
clen, _, err := d.readLength()
if err != nil {
return nil, err
}
ulen, _, err := d.readLength()
if err != nil {
return nil, err
}
compressed := make([]byte, clen)
_, err = io.ReadFull(d.r, compressed)
if err != nil {
return nil, err
}
decompressed := lzfDecompress(compressed, int(ulen))
if len(decompressed) != int(ulen) {
return nil, fmt.Errorf("decompressed string length %d didn't match expected length %d", len(decompressed), ulen)
}
return decompressed, nil
}
}
str := make([]byte, length)
_, err = io.ReadFull(d.r, str)
return str, err
}
func (d *decode) readUint8() (uint8, error) {
b, err := d.r.ReadByte()
return uint8(b), err
}
func (d *decode) readUint16() (uint16, error) {
_, err := io.ReadFull(d.r, d.intBuf[:2])
if err != nil {
return 0, err
}
return binary.LittleEndian.Uint16(d.intBuf), nil
}
func (d *decode) readUint32() (uint32, error) {
_, err := io.ReadFull(d.r, d.intBuf[:4])
if err != nil {
return 0, err
}
return binary.LittleEndian.Uint32(d.intBuf), nil
}
func (d *decode) readUint64() (uint64, error) {
_, err := io.ReadFull(d.r, d.intBuf)
if err != nil {
return 0, err
}
return binary.LittleEndian.Uint64(d.intBuf), nil
}
func (d *decode) readUint32Big() (uint32, error) {
_, err := io.ReadFull(d.r, d.intBuf[:4])
if err != nil {
return 0, err
}
return binary.BigEndian.Uint32(d.intBuf), nil
}
// Doubles are saved as strings prefixed by an unsigned
// 8 bit integer specifying the length of the representation.
// This 8 bit integer has special values in order to specify the following
// conditions:
// 253: not a number
// 254: + inf
// 255: - inf
func (d *decode) readFloat64() (float64, error) {
length, err := d.readUint8()
if err != nil {
return 0, err
}
switch length {
case 253:
return math.NaN(), nil
case 254:
return math.Inf(0), nil
case 255:
return math.Inf(-1), nil
default:
floatBytes := make([]byte, length)
_, err := io.ReadFull(d.r, floatBytes)
if err != nil {
return 0, err
}
f, err := strconv.ParseFloat(string(floatBytes), 64)
return f, err
}
panic("not reached")
}
func (d *decode) readLength() (uint32, bool, error) {
b, err := d.r.ReadByte()
if err != nil {
return 0, false, err
}
// The first two bits of the first byte are used to indicate the length encoding type
switch (b & 0xc0) >> 6 {
case rdb6bitLen:
// When the first two bits are 00, the next 6 bits are the length.
return uint32(b & 0x3f), false, nil
case rdb14bitLen:
// When the first two bits are 01, the next 14 bits are the length.
bb, err := d.r.ReadByte()
if err != nil {
return 0, false, err
}
return (uint32(b&0x3f) << 8) | uint32(bb), false, nil
case rdbEncVal:
// When the first two bits are 11, the next object is encoded.
// The next 6 bits indicate the encoding type.
return uint32(b & 0x3f), true, nil
default:
// When the first two bits are 10, the next 6 bits are discarded.
// The next 4 bytes are the length.
length, err := d.readUint32Big()
return length, false, err
}
panic("not reached")
}
func verifyDump(d []byte) error {
if len(d) < 10 {
return fmt.Errorf("rdb: invalid dump length")
}
version := binary.LittleEndian.Uint16(d[len(d)-10:])
if version != uint16(Version) {
return fmt.Errorf("rdb: invalid version %d, expecting %d", version, Version)
}
if binary.LittleEndian.Uint64(d[len(d)-8:]) != crc64.Digest(d[:len(d)-8]) {
return fmt.Errorf("rdb: invalid CRC checksum")
}
return nil
}
func lzfDecompress(in []byte, outlen int) []byte {
out := make([]byte, outlen)
for i, o := 0, 0; i < len(in); {
ctrl := int(in[i])
i++
if ctrl < 32 {
for x := 0; x <= ctrl; x++ {
out[o] = in[i]
i++
o++
}
} else {
length := ctrl >> 5
if length == 7 {
length = length + int(in[i])
i++
}
ref := o - ((ctrl & 0x1f) << 8) - int(in[i]) - 1
i++
for x := 0; x <= length+1; x++ {
out[o] = out[ref]
ref++
o++
}
}
}
return out
}