diff --git a/cmd/ledis-server/main.go b/cmd/ledis-server/main.go index 8a76984..8513559 100644 --- a/cmd/ledis-server/main.go +++ b/cmd/ledis-server/main.go @@ -9,7 +9,7 @@ import ( "syscall" ) -var configFile = flag.String("config", "", "ledisdb config file") +var configFile = flag.String("config", "/etc/ledis.json", "ledisdb config file") func main() { runtime.GOMAXPROCS(runtime.NumCPU()) diff --git a/etc/ledis.json b/etc/ledis.json index 92d9f93..8f230c3 100644 --- a/etc/ledis.json +++ b/etc/ledis.json @@ -1,12 +1,14 @@ { "addr": "127.0.0.1:6380", + "data_dir": "/tmp/ledis_server", "db": { "data_db" : { - "path": "/tmp/ledisdb", "compression": false, "block_size": 32768, "write_buffer_size": 67108864, "cache_size": 524288000 } - } + }, + + "access_log" : "access.log" } \ No newline at end of file diff --git a/ledis/binlog.go b/ledis/binlog.go new file mode 100644 index 0000000..d6e99f0 --- /dev/null +++ b/ledis/binlog.go @@ -0,0 +1,329 @@ +package ledis + +import ( + "bufio" + "encoding/binary" + "encoding/json" + "fmt" + "github.com/siddontang/go-log/log" + "io/ioutil" + "os" + "path" + "strconv" + "strings" + "time" +) + +const ( + MaxBinLogFileSize int = 1024 * 1024 * 1024 + MaxBinLogFileNum int = 10000 + + DefaultBinLogFileSize int = MaxBinLogFileSize + DefaultBinLogFileNum int = 10 +) + +/* +index file format: +ledis-bin.00001 +ledis-bin.00002 +ledis-bin.00003 + +log file format + +timestamp(bigendian uint32, seconds)|PayloadLen(bigendian uint32)|PayloadData + +*/ + +type BinLogConfig struct { + Path string `json:"path"` + MaxFileSize int `json:"max_file_size"` + MaxFileNum int `json:"max_file_num"` +} + +func (cfg *BinLogConfig) adjust() { + if cfg.MaxFileSize <= 0 { + cfg.MaxFileSize = DefaultBinLogFileSize + } else if cfg.MaxFileSize > MaxBinLogFileSize { + cfg.MaxFileSize = MaxBinLogFileSize + } + + if cfg.MaxFileNum <= 0 { + cfg.MaxFileNum = DefaultBinLogFileNum + } else if cfg.MaxFileNum > MaxBinLogFileNum { + cfg.MaxFileNum = MaxBinLogFileNum + } +} + +type BinLog struct { + cfg *BinLogConfig + + logFile *os.File + + logWb *bufio.Writer + + indexName string + logNames []string + lastLogIndex int64 +} + +func NewBinLog(data json.RawMessage) (*BinLog, error) { + var cfg BinLogConfig + + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, err + } + + return NewBinLogWithConfig(&cfg) +} + +func NewBinLogWithConfig(cfg *BinLogConfig) (*BinLog, error) { + cfg.adjust() + + l := new(BinLog) + + l.cfg = cfg + + if err := os.MkdirAll(cfg.Path, os.ModePerm); err != nil { + return nil, err + } + + l.logNames = make([]string, 0, 16) + + if err := l.loadIndex(); err != nil { + return nil, err + } + + return l, nil +} + +func (l *BinLog) flushIndex() error { + data := strings.Join(l.logNames, "\n") + + bakName := fmt.Sprintf("%s.bak", l.indexName) + f, err := os.OpenFile(bakName, os.O_WRONLY|os.O_CREATE, 0666) + if err != nil { + log.Error("create binlog bak index error %s", err.Error()) + return err + } + + if _, err := f.WriteString(data); err != nil { + log.Error("write binlog index error %s", err.Error()) + f.Close() + return err + } + + f.Close() + + if err := os.Rename(bakName, l.indexName); err != nil { + log.Error("rename binlog bak index error %s", err.Error()) + return err + } + + return nil +} + +func (l *BinLog) loadIndex() error { + l.indexName = path.Join(l.cfg.Path, fmt.Sprintf("ledis-bin.index")) + if _, err := os.Stat(l.indexName); os.IsNotExist(err) { + //no index file, nothing to do + } else { + indexData, err := ioutil.ReadFile(l.indexName) + if err != nil { + return err + } + + lines := strings.Split(string(indexData), "\n") + for _, line := range lines { + line = strings.Trim(line, "\r\n ") + if len(line) == 0 { + continue + } + + if _, err := os.Stat(path.Join(l.cfg.Path, line)); err != nil { + log.Error("load index line %s error %s", line, err.Error()) + return err + } else { + l.logNames = append(l.logNames, line) + } + } + } + if l.cfg.MaxFileNum > 0 && len(l.logNames) > l.cfg.MaxFileNum { + //remove oldest logfile + if err := l.Purge(len(l.logNames) - l.cfg.MaxFileNum); err != nil { + return err + } + } + + var err error + if len(l.logNames) == 0 { + l.lastLogIndex = 1 + } else { + lastName := l.logNames[len(l.logNames)-1] + + if l.lastLogIndex, err = strconv.ParseInt(path.Ext(lastName)[1:], 10, 64); err != nil { + log.Error("invalid logfile name %s", err.Error()) + return err + } + + //like mysql, if server restart, a new binlog will create + l.lastLogIndex++ + } + + return nil +} + +func (l *BinLog) getLogFile() string { + return l.FormatLogFileName(l.lastLogIndex) +} + +func (l *BinLog) openNewLogFile() error { + var err error + lastName := l.getLogFile() + + logPath := path.Join(l.cfg.Path, lastName) + if l.logFile, err = os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY, 0666); err != nil { + log.Error("open new logfile error %s", err.Error()) + return err + } + + if l.cfg.MaxFileNum > 0 && len(l.logNames) == l.cfg.MaxFileNum { + l.purge(1) + } + + l.logNames = append(l.logNames, lastName) + + if l.logWb == nil { + l.logWb = bufio.NewWriterSize(l.logFile, 1024) + } else { + l.logWb.Reset(l.logFile) + } + + if err = l.flushIndex(); err != nil { + return err + } + + return nil +} + +func (l *BinLog) checkLogFileSize() bool { + if l.logFile == nil { + return false + } + + st, _ := l.logFile.Stat() + if st.Size() >= int64(l.cfg.MaxFileSize) { + l.lastLogIndex++ + + l.logFile.Close() + l.logFile = nil + return true + } + + return false +} + +func (l *BinLog) purge(n int) { + for i := 0; i < n; i++ { + logPath := path.Join(l.cfg.Path, l.logNames[i]) + os.Remove(logPath) + } + + copy(l.logNames[0:], l.logNames[n:]) + l.logNames = l.logNames[0 : len(l.logNames)-n] +} + +func (l *BinLog) Close() { + if l.logFile != nil { + l.logFile.Close() + l.logFile = nil + } +} + +func (l *BinLog) LogNames() []string { + return l.logNames +} + +func (l *BinLog) LogFileName() string { + return l.getLogFile() +} + +func (l *BinLog) LogFilePos() int64 { + if l.logFile == nil { + return 0 + } else { + st, _ := l.logFile.Stat() + return st.Size() + } +} + +func (l *BinLog) LogFileIndex() int64 { + return l.lastLogIndex +} + +func (l *BinLog) FormatLogFileName(index int64) string { + return fmt.Sprintf("ledis-bin.%07d", index) +} + +func (l *BinLog) FormatLogFilePath(index int64) string { + return path.Join(l.cfg.Path, l.FormatLogFileName(index)) +} + +func (l *BinLog) LogPath() string { + return l.cfg.Path +} + +func (l *BinLog) Purge(n int) error { + if len(l.logNames) == 0 { + return nil + } + + if n >= len(l.logNames) { + n = len(l.logNames) + //can not purge current log file + if l.logNames[n-1] == l.getLogFile() { + n = n - 1 + } + } + + l.purge(n) + + return l.flushIndex() +} + +func (l *BinLog) Log(args ...[]byte) error { + var err error + + if l.logFile == nil { + if err = l.openNewLogFile(); err != nil { + return err + } + } + + //we treat log many args as a batch, so use same createTime + createTime := uint32(time.Now().Unix()) + + for _, data := range args { + payLoadLen := uint32(len(data)) + + if err := binary.Write(l.logWb, binary.BigEndian, createTime); err != nil { + return err + } + + if err := binary.Write(l.logWb, binary.BigEndian, payLoadLen); err != nil { + return err + } + + if _, err := l.logWb.Write(data); err != nil { + return err + } + } + + if err = l.logWb.Flush(); err != nil { + log.Error("write log error %s", err.Error()) + return err + } + + l.checkLogFileSize() + + return nil +} diff --git a/ledis/binlog_test.go b/ledis/binlog_test.go new file mode 100644 index 0000000..7fc89b4 --- /dev/null +++ b/ledis/binlog_test.go @@ -0,0 +1,36 @@ +package ledis + +import ( + "io/ioutil" + "os" + "testing" +) + +func TestBinLog(t *testing.T) { + cfg := new(BinLogConfig) + + cfg.MaxFileNum = 1 + cfg.MaxFileSize = 1024 + cfg.Path = "/tmp/ledis_binlog" + + os.RemoveAll(cfg.Path) + + b, err := NewBinLogWithConfig(cfg) + if err != nil { + t.Fatal(err) + } + + if err := b.Log(make([]byte, 1024)); err != nil { + t.Fatal(err) + } + + if err := b.Log(make([]byte, 1024)); err != nil { + t.Fatal(err) + } + + if fs, err := ioutil.ReadDir(cfg.Path); err != nil { + t.Fatal(err) + } else if len(fs) != 2 { + t.Fatal(len(fs)) + } +} diff --git a/ledis/binlog_util.go b/ledis/binlog_util.go new file mode 100644 index 0000000..bc1cd63 --- /dev/null +++ b/ledis/binlog_util.go @@ -0,0 +1,62 @@ +package ledis + +import ( + "encoding/binary" + "errors" +) + +var ( + errBinLogDeleteType = errors.New("invalid bin log delete type") + errBinLogPutType = errors.New("invalid bin log put type") + errBinLogCommandType = errors.New("invalid bin log command type") +) + +func encodeBinLogDelete(key []byte) []byte { + buf := make([]byte, 1+len(key)) + buf[0] = BinLogTypeDeletion + copy(buf[1:], key) + return buf +} + +func decodeBinLogDelete(sz []byte) ([]byte, error) { + if len(sz) < 1 || sz[0] != BinLogTypeDeletion { + return nil, errBinLogDeleteType + } + + return sz[1:], nil +} + +func encodeBinLogPut(key []byte, value []byte) []byte { + buf := make([]byte, 3+len(key)+len(value)) + buf[0] = BinLogTypePut + pos := 1 + binary.BigEndian.PutUint16(buf[pos:], uint16(len(key))) + pos += 2 + copy(buf[pos:], key) + pos += len(key) + copy(buf[pos:], value) + + return buf +} + +func decodeBinLogPut(sz []byte) ([]byte, []byte, error) { + if len(sz) < 3 || sz[0] != BinLogTypePut { + return nil, nil, errBinLogPutType + } + + keyLen := int(binary.BigEndian.Uint16(sz[1:])) + if 3+keyLen > len(sz) { + return nil, nil, errBinLogPutType + } + + return sz[3 : 3+keyLen], sz[3+keyLen:], nil +} + +func encodeBinLogCommand(commandType uint8, args ...[]byte) []byte { + //to do + return nil +} + +func decodeBinLogCommand(sz []byte) (uint8, [][]byte, error) { + return 0, nil, errBinLogCommandType +} diff --git a/ledis/const.go b/ledis/const.go index ba0cf05..e6d3b6d 100644 --- a/ledis/const.go +++ b/ledis/const.go @@ -13,24 +13,44 @@ const ( zsetType zSizeType zScoreType + + kvExpType + kvExpMetaType + lExpType + lExpMetaType + hExpType + hExpMetaType + zExpType + zExpMetaType ) const ( defaultScanCount int = 10 ) +var ( + errKeySize = errors.New("invalid key size") + errValueSize = errors.New("invalid value size") + errHashFieldSize = errors.New("invalid hash field size") + errZSetMemberSize = errors.New("invalid zset member size") + errExpireValue = errors.New("invalid expire value") +) + const ( //we don't support too many databases MaxDBNumber uint8 = 16 //max key size - MaxKeySize int = 1<<16 - 1 + MaxKeySize int = 1024 //max hash field size - MaxHashFieldSize int = 1<<16 - 1 + MaxHashFieldSize int = 1024 //max zset member size - MaxZSetMemberSize int = 1<<16 - 1 + MaxZSetMemberSize int = 1024 + + //max value size + MaxValueSize int = 10 * 1024 * 1024 ) var ( @@ -39,3 +59,9 @@ var ( ErrZSetMemberSize = errors.New("invalid zset member size") ErrScoreMiss = errors.New("zset score miss") ) + +const ( + BinLogTypeDeletion uint8 = 0x0 + BinLogTypePut uint8 = 0x1 + BinLogTypeCommand uint8 = 0x2 +) diff --git a/ledis/dump.go b/ledis/dump.go new file mode 100644 index 0000000..47bca19 --- /dev/null +++ b/ledis/dump.go @@ -0,0 +1,167 @@ +package ledis + +import ( + "bufio" + "bytes" + "encoding/binary" + "github.com/siddontang/go-leveldb/leveldb" + "io" + "os" +) + +//dump format +// fileIndex(bigendian int64)|filePos(bigendian int64) +// |keylen(bigendian int32)|key|valuelen(bigendian int32)|value...... + +type MasterInfo struct { + LogFileIndex int64 + LogPos int64 +} + +func (m *MasterInfo) WriteTo(w io.Writer) error { + if err := binary.Write(w, binary.BigEndian, m.LogFileIndex); err != nil { + return err + } + + if err := binary.Write(w, binary.BigEndian, m.LogPos); err != nil { + return err + } + return nil +} + +func (m *MasterInfo) ReadFrom(r io.Reader) error { + err := binary.Read(r, binary.BigEndian, &m.LogFileIndex) + if err != nil { + return err + } + + err = binary.Read(r, binary.BigEndian, &m.LogPos) + if err != nil { + return err + } + + return nil +} + +func (l *Ledis) DumpFile(path string) error { + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + + return l.Dump(f) +} + +func (l *Ledis) Dump(w io.Writer) error { + var sp *leveldb.Snapshot + var m *MasterInfo = new(MasterInfo) + if l.binlog == nil { + sp = l.ldb.NewSnapshot() + } else { + l.Lock() + sp = l.ldb.NewSnapshot() + m.LogFileIndex = l.binlog.LogFileIndex() + m.LogPos = l.binlog.LogFilePos() + l.Unlock() + } + + var err error + + wb := bufio.NewWriterSize(w, 4096) + if err = m.WriteTo(wb); err != nil { + return err + } + + it := sp.Iterator(nil, nil, leveldb.RangeClose, 0, -1) + var key []byte + var value []byte + for ; it.Valid(); it.Next() { + key = it.Key() + value = it.Value() + + if err = binary.Write(wb, binary.BigEndian, uint16(len(key))); err != nil { + return err + } + + if _, err = wb.Write(key); err != nil { + return err + } + + if err = binary.Write(wb, binary.BigEndian, uint32(len(value))); err != nil { + return err + } + + if _, err = wb.Write(value); err != nil { + return err + } + } + + if err = wb.Flush(); err != nil { + return err + } + + return nil +} + +func (l *Ledis) LoadDumpFile(path string) (*MasterInfo, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + return l.LoadDump(f) +} + +func (l *Ledis) LoadDump(r io.Reader) (*MasterInfo, error) { + l.Lock() + defer l.Unlock() + + info := new(MasterInfo) + + rb := bufio.NewReaderSize(r, 4096) + + err := info.ReadFrom(rb) + if err != nil { + return nil, err + } + + var keyLen uint16 + var valueLen uint32 + + var keyBuf bytes.Buffer + var valueBuf bytes.Buffer + for { + if err = binary.Read(rb, binary.BigEndian, &keyLen); err != nil && err != io.EOF { + return nil, err + } else if err == io.EOF { + break + } + + if _, err = io.CopyN(&keyBuf, rb, int64(keyLen)); err != nil { + return nil, err + } + + if err = binary.Read(rb, binary.BigEndian, &valueLen); err != nil { + return nil, err + } + + if _, err = io.CopyN(&valueBuf, rb, int64(valueLen)); err != nil { + return nil, err + } + + if err = l.ldb.Put(keyBuf.Bytes(), valueBuf.Bytes()); err != nil { + return nil, err + } + + if l.binlog != nil { + err = l.binlog.Log(encodeBinLogPut(keyBuf.Bytes(), valueBuf.Bytes())) + } + + keyBuf.Reset() + valueBuf.Reset() + } + + return info, nil +} diff --git a/ledis/dump_test.go b/ledis/dump_test.go new file mode 100644 index 0000000..f15f8f4 --- /dev/null +++ b/ledis/dump_test.go @@ -0,0 +1,73 @@ +package ledis + +import ( + "bytes" + "github.com/siddontang/go-leveldb/leveldb" + "os" + "testing" +) + +func TestDump(t *testing.T) { + os.RemoveAll("/tmp/test_ledis_master") + os.RemoveAll("/tmp/test_ledis_slave") + + var masterConfig = []byte(` + { + "data_dir" : "/tmp/test_ledis_master", + "data_db" : { + "compression":true, + "block_size" : 32768, + "write_buffer_size" : 2097152, + "cache_size" : 20971520 + } + } + `) + + master, err := Open(masterConfig) + if err != nil { + t.Fatal(err) + } + + var slaveConfig = []byte(` + { + "data_dir" : "/tmp/test_ledis_slave", + "data_db" : { + "compression":true, + "block_size" : 32768, + "write_buffer_size" : 2097152, + "cache_size" : 20971520 + } + } + `) + + var slave *Ledis + if slave, err = Open(slaveConfig); err != nil { + t.Fatal(err) + } + + db, _ := master.Select(0) + + db.Set([]byte("a"), []byte("1")) + db.Set([]byte("b"), []byte("2")) + db.Set([]byte("c"), []byte("3")) + + if err := master.DumpFile("/tmp/testdb.dump"); err != nil { + t.Fatal(err) + } + + if _, err := slave.LoadDumpFile("/tmp/testdb.dump"); err != nil { + t.Fatal(err) + } + + it := master.ldb.Iterator(nil, nil, leveldb.RangeClose, 0, -1) + for ; it.Valid(); it.Next() { + key := it.Key() + value := it.Value() + + if v, err := slave.ldb.Get(key); err != nil { + t.Fatal(err) + } else if !bytes.Equal(v, value) { + t.Fatal("load dump error") + } + } +} diff --git a/ledis/ledis.go b/ledis/ledis.go index 330f24c..baf152b 100644 --- a/ledis/ledis.go +++ b/ledis/ledis.go @@ -4,13 +4,27 @@ import ( "encoding/json" "fmt" "github.com/siddontang/go-leveldb/leveldb" + "github.com/siddontang/go-log/log" + "path" + "sync" + "time" ) type Config struct { + DataDir string `json:"data_dir"` + + //if you not set leveldb path, use data_dir/data DataDB leveldb.Config `json:"data_db"` + + UseBinLog bool `json:"use_bin_log"` + + //if you not set bin log path, use data_dir/bin_log + BinLog BinLogConfig `json:"bin_log"` } type DB struct { + l *Ledis + db *leveldb.DB index uint8 @@ -22,10 +36,16 @@ type DB struct { } type Ledis struct { + sync.Mutex + cfg *Config ldb *leveldb.DB dbs [MaxDBNumber]*DB + + binlog *BinLog + + quit chan struct{} } func Open(configJson json.RawMessage) (*Ledis, error) { @@ -39,38 +59,72 @@ func Open(configJson json.RawMessage) (*Ledis, error) { } func OpenWithConfig(cfg *Config) (*Ledis, error) { + if len(cfg.DataDir) == 0 { + return nil, fmt.Errorf("must set correct data_dir") + } + + if len(cfg.DataDB.Path) == 0 { + cfg.DataDB.Path = path.Join(cfg.DataDir, "data") + } + ldb, err := leveldb.OpenWithConfig(&cfg.DataDB) if err != nil { return nil, err } l := new(Ledis) + + l.quit = make(chan struct{}) + l.ldb = ldb + if cfg.UseBinLog { + if len(cfg.BinLog.Path) == 0 { + cfg.BinLog.Path = path.Join(cfg.DataDir, "bin_log") + } + l.binlog, err = NewBinLogWithConfig(&cfg.BinLog) + if err != nil { + return nil, err + } + } else { + l.binlog = nil + } + for i := uint8(0); i < MaxDBNumber; i++ { l.dbs[i] = newDB(l, i) } + l.activeExpireCycle() + return l, nil } func newDB(l *Ledis, index uint8) *DB { d := new(DB) + d.l = l + d.db = l.ldb d.index = index - d.kvTx = &tx{wb: d.db.NewWriteBatch()} - d.listTx = &tx{wb: d.db.NewWriteBatch()} - d.hashTx = &tx{wb: d.db.NewWriteBatch()} - d.zsetTx = &tx{wb: d.db.NewWriteBatch()} + d.kvTx = newTx(l) + d.listTx = newTx(l) + d.hashTx = newTx(l) + d.zsetTx = newTx(l) return d } func (l *Ledis) Close() { + close(l.quit) + l.ldb.Close() + + if l.binlog != nil { + l.binlog.Close() + l.binlog = nil + } } func (l *Ledis) Select(index int) (*DB, error) { @@ -80,3 +134,41 @@ func (l *Ledis) Select(index int) (*DB, error) { return l.dbs[index], nil } + +func (l *Ledis) FlushAll() error { + for index, db := range l.dbs { + if _, err := db.FlushAll(); err != nil { + log.Error("flush db %d error %s", index, err.Error()) + } + } + + return nil +} + +//very dangerous to use +func (l *Ledis) DataDB() *leveldb.DB { + return l.ldb +} + +func (l *Ledis) activeExpireCycle() { + var executors []*elimination = make([]*elimination, len(l.dbs)) + for i, db := range l.dbs { + executors[i] = db.newEliminator() + } + + go func() { + tick := time.NewTicker(1 * time.Second) + for { + select { + case <-tick.C: + for _, eli := range executors { + eli.active() + } + case <-l.quit: + break + } + } + + tick.Stop() + }() +} diff --git a/ledis/ledis_db.go b/ledis/ledis_db.go index af583c1..0e76d17 100644 --- a/ledis/ledis_db.go +++ b/ledis/ledis_db.go @@ -1,11 +1,11 @@ package ledis -func (db *DB) Flush() (drop int64, err error) { +func (db *DB) FlushAll() (drop int64, err error) { all := [...](func() (int64, error)){ - db.KvFlush, - db.LFlush, - db.HFlush, - db.ZFlush} + db.flush, + db.lFlush, + db.hFlush, + db.zFlush} for _, flush := range all { if n, e := flush(); e != nil { @@ -15,5 +15,16 @@ func (db *DB) Flush() (drop int64, err error) { drop += n } } + return } + +func (db *DB) newEliminator() *elimination { + eliminator := newEliminator(db) + eliminator.regRetireContext(kvExpType, db.kvTx, db.delete) + eliminator.regRetireContext(lExpType, db.listTx, db.lDelete) + eliminator.regRetireContext(hExpType, db.hashTx, db.hDelete) + eliminator.regRetireContext(zExpType, db.zsetTx, db.zDelete) + + return eliminator +} diff --git a/ledis/ledis_test.go b/ledis/ledis_test.go index 2512632..dc28b5c 100644 --- a/ledis/ledis_test.go +++ b/ledis/ledis_test.go @@ -1,6 +1,7 @@ package ledis import ( + "os" "sync" "testing" ) @@ -12,23 +13,29 @@ func getTestDB() *DB { f := func() { var d = []byte(` { + "data_dir" : "/tmp/test_ledis", "data_db" : { - "path" : "/tmp/testdb", "compression":true, "block_size" : 32768, "write_buffer_size" : 2097152, "cache_size" : 20971520 - } + }, + + "binlog" : { + "max_file_size" : 1073741824, + "max_file_num" : 3 + } } `) + + os.RemoveAll("/tmp/test_ledis") + var err error testLedis, err = Open(d) if err != nil { println(err.Error()) panic(err) } - - testLedis.ldb.Clear() } testLedisOnce.Do(f) @@ -75,7 +82,7 @@ func TestFlush(t *testing.T) { db1.LPush([]byte("lst"), []byte("a1"), []byte("b2")) db1.ZAdd([]byte("zset_0"), ScorePair{int64(3), []byte("mc")}) - db1.Flush() + db1.FlushAll() // 0 - existing if exists, _ := db0.Exists([]byte("a")); exists <= 0 { diff --git a/ledis/replication.go b/ledis/replication.go new file mode 100644 index 0000000..e19da6a --- /dev/null +++ b/ledis/replication.go @@ -0,0 +1,238 @@ +package ledis + +import ( + "bufio" + "bytes" + "encoding/binary" + "errors" + "github.com/siddontang/go-log/log" + "io" + "os" +) + +var ( + errInvalidBinLogEvent = errors.New("invalid binglog event") + errInvalidBinLogFile = errors.New("invalid binlog file") +) + +func (l *Ledis) ReplicateEvent(event []byte) error { + if len(event) == 0 { + return errInvalidBinLogEvent + } + + logType := uint8(event[0]) + switch logType { + case BinLogTypePut: + return l.replicatePutEvent(event) + case BinLogTypeDeletion: + return l.replicateDeleteEvent(event) + case BinLogTypeCommand: + return l.replicateCommandEvent(event) + default: + return errInvalidBinLogEvent + } +} + +func (l *Ledis) replicatePutEvent(event []byte) error { + key, value, err := decodeBinLogPut(event) + if err != nil { + return err + } + + if err = l.ldb.Put(key, value); err != nil { + return err + } + + if l.binlog != nil { + err = l.binlog.Log(event) + } + + return err +} + +func (l *Ledis) replicateDeleteEvent(event []byte) error { + key, err := decodeBinLogDelete(event) + if err != nil { + return err + } + + if err = l.ldb.Delete(key); err != nil { + return err + } + + if l.binlog != nil { + err = l.binlog.Log(event) + } + + return err +} + +func (l *Ledis) replicateCommandEvent(event []byte) error { + return errors.New("command event not supported now") +} + +func (l *Ledis) ReplicateFromReader(rb io.Reader) error { + var createTime uint32 + var dataLen uint32 + var dataBuf bytes.Buffer + var err error + + for { + if err = binary.Read(rb, binary.BigEndian, &createTime); err != nil { + if err == io.EOF { + break + } else { + return err + } + } + + if err = binary.Read(rb, binary.BigEndian, &dataLen); err != nil { + return err + } + + if _, err = io.CopyN(&dataBuf, rb, int64(dataLen)); err != nil { + return err + } + + err = l.ReplicateEvent(dataBuf.Bytes()) + if err != nil { + log.Fatal("replication error %s, skip to next", err.Error()) + } + + dataBuf.Reset() + } + + return nil +} + +func (l *Ledis) ReplicateFromData(data []byte) error { + rb := bytes.NewReader(data) + + l.Lock() + err := l.ReplicateFromReader(rb) + l.Unlock() + + return err +} + +func (l *Ledis) ReplicateFromBinLog(filePath string) error { + f, err := os.Open(filePath) + if err != nil { + return err + } + + rb := bufio.NewReaderSize(f, 4096) + + l.Lock() + err = l.ReplicateFromReader(rb) + l.Unlock() + + f.Close() + + return err +} + +const maxSyncEvents = 64 + +func (l *Ledis) ReadEventsTo(info *MasterInfo, w io.Writer) (n int, err error) { + n = 0 + if l.binlog == nil { + //binlog not supported + info.LogFileIndex = 0 + info.LogPos = 0 + return + } + + index := info.LogFileIndex + offset := info.LogPos + + filePath := l.binlog.FormatLogFilePath(index) + + var f *os.File + f, err = os.Open(filePath) + if os.IsNotExist(err) { + lastIndex := l.binlog.LogFileIndex() + + if index == lastIndex { + //no binlog at all + info.LogPos = 0 + } else { + //slave binlog info had lost + info.LogFileIndex = -1 + } + } + + if err != nil { + if os.IsNotExist(err) { + err = nil + } + return + } + + defer f.Close() + + var fileSize int64 + st, _ := f.Stat() + fileSize = st.Size() + + if fileSize == info.LogPos { + return + } + + if _, err = f.Seek(offset, os.SEEK_SET); err != nil { + //may be invliad seek offset + return + } + + var lastCreateTime uint32 = 0 + var createTime uint32 + var dataLen uint32 + + var eventsNum int = 0 + + for { + if err = binary.Read(f, binary.BigEndian, &createTime); err != nil { + if err == io.EOF { + //we will try to use next binlog + if index < l.binlog.LogFileIndex() { + info.LogFileIndex += 1 + info.LogPos = 0 + } + err = nil + return + } else { + return + } + } + + eventsNum++ + if lastCreateTime == 0 { + lastCreateTime = createTime + } else if lastCreateTime != createTime { + return + } else if eventsNum > maxSyncEvents { + return + } + + if err = binary.Read(f, binary.BigEndian, &dataLen); err != nil { + return + } + + if err = binary.Write(w, binary.BigEndian, createTime); err != nil { + return + } + + if err = binary.Write(w, binary.BigEndian, dataLen); err != nil { + return + } + + if _, err = io.CopyN(w, f, int64(dataLen)); err != nil { + return + } + + n += (8 + int(dataLen)) + info.LogPos = info.LogPos + 8 + int64(dataLen) + } + + return +} diff --git a/ledis/replication_test.go b/ledis/replication_test.go new file mode 100644 index 0000000..21d4dbc --- /dev/null +++ b/ledis/replication_test.go @@ -0,0 +1,117 @@ +package ledis + +import ( + "bytes" + "fmt" + "github.com/siddontang/go-leveldb/leveldb" + "os" + "path" + "testing" +) + +func checkLedisEqual(master *Ledis, slave *Ledis) error { + it := master.ldb.Iterator(nil, nil, leveldb.RangeClose, 0, -1) + for ; it.Valid(); it.Next() { + key := it.Key() + value := it.Value() + + if v, err := slave.ldb.Get(key); err != nil { + return err + } else if !bytes.Equal(v, value) { + return fmt.Errorf("replication error %d != %d", len(v), len(value)) + } + } + + return nil +} + +func TestReplication(t *testing.T) { + var master *Ledis + var slave *Ledis + var err error + + os.RemoveAll("/tmp/test_repl") + + master, err = Open([]byte(` + { + "data_dir" : "/tmp/test_repl/master", + "use_bin_log" : true, + "bin_log" : { + "max_file_size" : 50 + } + } + `)) + if err != nil { + t.Fatal(err) + } + + slave, err = Open([]byte(` + { + "data_dir" : "/tmp/test_repl/slave" + } + `)) + if err != nil { + t.Fatal(err) + } + + db, _ := master.Select(0) + db.Set([]byte("a"), []byte("value")) + db.Set([]byte("b"), []byte("value")) + db.Set([]byte("c"), []byte("value")) + + db.HSet([]byte("a"), []byte("1"), []byte("value")) + db.HSet([]byte("b"), []byte("2"), []byte("value")) + db.HSet([]byte("c"), []byte("3"), []byte("value")) + + for _, name := range master.binlog.LogNames() { + p := path.Join(master.binlog.cfg.Path, name) + + err = slave.ReplicateFromBinLog(p) + if err != nil { + t.Fatal(err) + } + } + + if err = checkLedisEqual(master, slave); err != nil { + t.Fatal(err) + } + + slave.FlushAll() + + db.Set([]byte("a1"), []byte("1")) + db.Set([]byte("b1"), []byte("2")) + db.Set([]byte("c1"), []byte("3")) + + db.HSet([]byte("a1"), []byte("1"), []byte("value")) + db.HSet([]byte("b1"), []byte("2"), []byte("value")) + db.HSet([]byte("c1"), []byte("3"), []byte("value")) + + info := new(MasterInfo) + info.LogFileIndex = 1 + info.LogPos = 0 + var buf bytes.Buffer + var n int + + for { + buf.Reset() + n, err = master.ReadEventsTo(info, &buf) + if err != nil { + t.Fatal(err) + } else if info.LogFileIndex == -1 { + t.Fatal("invalid log file index -1") + } else if info.LogFileIndex == 0 { + t.Fatal("invalid log file index 0") + } else { + if err = slave.ReplicateFromReader(&buf); err != nil { + t.Fatal(err) + } + if n == 0 { + break + } + } + } + + if err = checkLedisEqual(master, slave); err != nil { + t.Fatal(err) + } +} diff --git a/ledis/t_hash.go b/ledis/t_hash.go index 2d6550f..d64a1f6 100644 --- a/ledis/t_hash.go +++ b/ledis/t_hash.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "errors" "github.com/siddontang/go-leveldb/leveldb" + "time" ) type FVPair struct { @@ -21,9 +22,9 @@ const ( func checkHashKFSize(key []byte, field []byte) error { if len(key) > MaxKeySize || len(key) == 0 { - return ErrKeySize + return errKeySize } else if len(field) > MaxHashFieldSize || len(field) == 0 { - return ErrHashFieldSize + return errHashFieldSize } return nil } @@ -105,10 +106,6 @@ func (db *DB) hEncodeStopKey(key []byte) []byte { return k } -func (db *DB) HLen(key []byte) (int64, error) { - return Int64(db.db.Get(db.hEncodeSizeKey(key))) -} - func (db *DB) hSetItem(key []byte, field []byte, value []byte) (int64, error) { t := db.hashTx @@ -127,9 +124,54 @@ func (db *DB) hSetItem(key []byte, field []byte, value []byte) (int64, error) { return n, nil } +// ps : here just focus on deleting the hash data, +// any other likes expire is ignore. +func (db *DB) hDelete(t *tx, key []byte) int64 { + sk := db.hEncodeSizeKey(key) + start := db.hEncodeStartKey(key) + stop := db.hEncodeStopKey(key) + + var num int64 = 0 + it := db.db.Iterator(start, stop, leveldb.RangeROpen, 0, -1) + for ; it.Valid(); it.Next() { + t.Delete(it.Key()) + num++ + } + it.Close() + + t.Delete(sk) + return num +} + +func (db *DB) hExpireAt(key []byte, when int64) (int64, error) { + t := db.hashTx + t.Lock() + defer t.Unlock() + + if hlen, err := db.HLen(key); err != nil || hlen == 0 { + return 0, err + } else { + db.expireAt(t, hExpType, key, when) + if err := t.Commit(); err != nil { + return 0, err + } + } + return 1, nil +} + +func (db *DB) HLen(key []byte) (int64, error) { + if err := checkKeySize(key); err != nil { + return 0, err + } + + return Int64(db.db.Get(db.hEncodeSizeKey(key))) +} + func (db *DB) HSet(key []byte, field []byte, value []byte) (int64, error) { if err := checkHashKFSize(key, field); err != nil { return 0, err + } else if err := checkValueSize(value); err != nil { + return 0, err } t := db.hashTx @@ -166,6 +208,8 @@ func (db *DB) HMset(key []byte, args ...FVPair) error { for i := 0; i < len(args); i++ { if err := checkHashKFSize(key, args[i].Field); err != nil { return err + } else if err := checkValueSize(args[i].Value); err != nil { + return err } ek = db.hEncodeHashKey(key, args[i].Field) @@ -261,6 +305,7 @@ func (db *DB) hIncrSize(key []byte, delta int64) (int64, error) { if size <= 0 { size = 0 t.Delete(sk) + db.rmExpire(t, hExpType, key) } else { t.Put(sk, PutInt64(size)) } @@ -374,30 +419,18 @@ func (db *DB) HClear(key []byte) (int64, error) { return 0, err } - sk := db.hEncodeSizeKey(key) - start := db.hEncodeStartKey(key) - stop := db.hEncodeStopKey(key) - t := db.hashTx t.Lock() defer t.Unlock() - var num int64 = 0 - it := db.db.Iterator(start, stop, leveldb.RangeROpen, 0, -1) - for ; it.Valid(); it.Next() { - t.Delete(it.Key()) - num++ - } - - it.Close() - - t.Delete(sk) + num := db.hDelete(t, key) + db.rmExpire(t, hExpType, key) err := t.Commit() return num, err } -func (db *DB) HFlush() (drop int64, err error) { +func (db *DB) hFlush() (drop int64, err error) { t := db.kvTx t.Lock() defer t.Unlock() @@ -414,7 +447,15 @@ func (db *DB) HFlush() (drop int64, err error) { for ; it.Valid(); it.Next() { t.Delete(it.Key()) drop++ + if drop&1023 == 0 { + if err = t.Commit(); err != nil { + return + } + } } + it.Close() + + db.expFlush(t, hExpType) err = t.Commit() return @@ -452,6 +493,31 @@ func (db *DB) HScan(key []byte, field []byte, count int, inclusive bool) ([]FVPa v = append(v, FVPair{Field: f, Value: it.Value()}) } } + it.Close() return v, nil } + +func (db *DB) HExpire(key []byte, duration int64) (int64, error) { + if duration <= 0 { + return 0, errExpireValue + } + + return db.hExpireAt(key, time.Now().Unix()+duration) +} + +func (db *DB) HExpireAt(key []byte, when int64) (int64, error) { + if when <= time.Now().Unix() { + return 0, errExpireValue + } + + return db.hExpireAt(key, when) +} + +func (db *DB) HTTL(key []byte) (int64, error) { + if err := checkKeySize(key); err != nil { + return -1, err + } + + return db.ttl(hExpType, key) +} diff --git a/ledis/t_hash_test.go b/ledis/t_hash_test.go index e753361..4648c09 100644 --- a/ledis/t_hash_test.go +++ b/ledis/t_hash_test.go @@ -42,7 +42,7 @@ func TestDBHash(t *testing.T) { func TestDBHScan(t *testing.T) { db := getTestDB() - db.HFlush() + db.hFlush() key := []byte("a") db.HSet(key, []byte("1"), []byte{}) diff --git a/ledis/t_kv.go b/ledis/t_kv.go index c7e228e..2ae2a63 100644 --- a/ledis/t_kv.go +++ b/ledis/t_kv.go @@ -3,6 +3,7 @@ package ledis import ( "errors" "github.com/siddontang/go-leveldb/leveldb" + "time" ) type KVPair struct { @@ -14,11 +15,19 @@ var errKVKey = errors.New("invalid encode kv key") func checkKeySize(key []byte) error { if len(key) > MaxKeySize || len(key) == 0 { - return ErrKeySize + return errKeySize } return nil } +func checkValueSize(value []byte) error { + if len(value) > MaxValueSize { + return errValueSize + } + + return nil +} + func (db *DB) encodeKVKey(key []byte) []byte { ek := make([]byte, len(key)+2) ek[0] = db.index @@ -75,6 +84,30 @@ func (db *DB) incr(key []byte, delta int64) (int64, error) { return n, err } +// ps : here just focus on deleting the key-value data, +// any other likes expire is ignore. +func (db *DB) delete(t *tx, key []byte) int64 { + key = db.encodeKVKey(key) + t.Delete(key) + return 1 +} + +func (db *DB) setExpireAt(key []byte, when int64) (int64, error) { + t := db.kvTx + t.Lock() + defer t.Unlock() + + if exist, err := db.Exists(key); err != nil || exist == 0 { + return 0, err + } else { + db.expireAt(t, kvExpType, key, when) + if err := t.Commit(); err != nil { + return 0, err + } + } + return 1, nil +} + func (db *DB) Decr(key []byte) (int64, error) { return db.incr(key, -1) } @@ -88,22 +121,21 @@ func (db *DB) Del(keys ...[]byte) (int64, error) { return 0, nil } - var err error - for i := range keys { - keys[i] = db.encodeKVKey(keys[i]) + codedKeys := make([][]byte, len(keys)) + for i, k := range keys { + codedKeys[i] = db.encodeKVKey(k) } t := db.kvTx - t.Lock() defer t.Unlock() - for i := range keys { - t.Delete(keys[i]) - //todo binlog + for i, k := range keys { + t.Delete(codedKeys[i]) + db.rmExpire(t, kvExpType, k) } - err = t.Commit() + err := t.Commit() return int64(len(keys)), err } @@ -137,6 +169,8 @@ func (db *DB) Get(key []byte) ([]byte, error) { func (db *DB) GetSet(key []byte, value []byte) ([]byte, error) { if err := checkKeySize(key); err != nil { return nil, err + } else if err := checkValueSize(value); err != nil { + return nil, err } key = db.encodeKVKey(key) @@ -204,6 +238,8 @@ func (db *DB) MSet(args ...KVPair) error { for i := 0; i < len(args); i++ { if err := checkKeySize(args[i].Key); err != nil { return err + } else if err := checkValueSize(args[i].Value); err != nil { + return err } key = db.encodeKVKey(args[i].Key) @@ -222,6 +258,8 @@ func (db *DB) MSet(args ...KVPair) error { func (db *DB) Set(key []byte, value []byte) error { if err := checkKeySize(key); err != nil { return err + } else if err := checkValueSize(value); err != nil { + return err } var err error @@ -244,6 +282,8 @@ func (db *DB) Set(key []byte, value []byte) error { func (db *DB) SetNX(key []byte, value []byte) (int64, error) { if err := checkKeySize(key); err != nil { return 0, err + } else if err := checkValueSize(value); err != nil { + return 0, err } var err error @@ -271,7 +311,7 @@ func (db *DB) SetNX(key []byte, value []byte) (int64, error) { return n, err } -func (db *DB) KvFlush() (drop int64, err error) { +func (db *DB) flush() (drop int64, err error) { t := db.kvTx t.Lock() defer t.Unlock() @@ -283,9 +323,17 @@ func (db *DB) KvFlush() (drop int64, err error) { for ; it.Valid(); it.Next() { t.Delete(it.Key()) drop++ + + if drop&1023 == 0 { + if err = t.Commit(); err != nil { + return + } + } } + it.Close() err = t.Commit() + err = db.expFlush(t, kvExpType) return } @@ -322,6 +370,31 @@ func (db *DB) Scan(key []byte, count int, inclusive bool) ([]KVPair, error) { v = append(v, KVPair{Key: key, Value: it.Value()}) } } + it.Close() return v, nil } + +func (db *DB) Expire(key []byte, duration int64) (int64, error) { + if duration <= 0 { + return 0, errExpireValue + } + + return db.setExpireAt(key, time.Now().Unix()+duration) +} + +func (db *DB) ExpireAt(key []byte, when int64) (int64, error) { + if when <= time.Now().Unix() { + return 0, errExpireValue + } + + return db.setExpireAt(key, when) +} + +func (db *DB) TTL(key []byte) (int64, error) { + if err := checkKeySize(key); err != nil { + return -1, err + } + + return db.ttl(kvExpType, key) +} diff --git a/ledis/t_kv_test.go b/ledis/t_kv_test.go index 967eab6..0252421 100644 --- a/ledis/t_kv_test.go +++ b/ledis/t_kv_test.go @@ -29,7 +29,7 @@ func TestDBKV(t *testing.T) { func TestDBScan(t *testing.T) { db := getTestDB() - db.Flush() + db.FlushAll() if v, err := db.Scan(nil, 10, true); err != nil { t.Fatal(err) diff --git a/ledis/t_list.go b/ledis/t_list.go index 34cbb66..df5f576 100644 --- a/ledis/t_list.go +++ b/ledis/t_list.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "errors" "github.com/siddontang/go-leveldb/leveldb" + "time" ) const ( @@ -84,62 +85,53 @@ func (db *DB) lpush(key []byte, whereSeq int32, args ...[]byte) (int64, error) { var err error metaKey := db.lEncodeMetaKey(key) + headSeq, tailSeq, size, err = db.lGetMeta(metaKey) + if err != nil { + return 0, err + } - if len(args) == 0 { - _, _, size, err := db.lGetMeta(metaKey) - return int64(size), err + var pushCnt int = len(args) + if pushCnt == 0 { + return int64(size), nil + } + + var seq int32 = headSeq + var delta int32 = -1 + if whereSeq == listTailSeq { + seq = tailSeq + delta = 1 } t := db.listTx t.Lock() defer t.Unlock() - if headSeq, tailSeq, size, err = db.lGetMeta(metaKey); err != nil { - return 0, err - } - - var delta int32 = 1 - var seq int32 = 0 - if whereSeq == listHeadSeq { - delta = -1 - seq = headSeq - } else { - seq = tailSeq - } - - if size == 0 { - headSeq = listInitialSeq - tailSeq = listInitialSeq - seq = headSeq - } else { + // append elements + if size > 0 { seq += delta } - for i := 0; i < len(args); i++ { + for i := 0; i < pushCnt; i++ { ek := db.lEncodeListKey(key, seq+int32(i)*delta) t.Put(ek, args[i]) - //to do add binlog } - seq += int32(len(args)-1) * delta - + seq += int32(pushCnt-1) * delta if seq <= listMinSeq || seq >= listMaxSeq { return 0, errListSeq } - size += int32(len(args)) - + // set meta info if whereSeq == listHeadSeq { headSeq = seq } else { tailSeq = seq } - db.lSetMeta(metaKey, headSeq, tailSeq, size) + db.lSetMeta(metaKey, headSeq, tailSeq) err = t.Commit() - - return int64(size), err + return int64(size) + int64(pushCnt), err } func (db *DB) lpop(key []byte, whereSeq int32) ([]byte, error) { @@ -153,54 +145,73 @@ func (db *DB) lpop(key []byte, whereSeq int32) ([]byte, error) { var headSeq int32 var tailSeq int32 - var size int32 var err error metaKey := db.lEncodeMetaKey(key) - - headSeq, tailSeq, size, err = db.lGetMeta(metaKey) - + headSeq, tailSeq, _, err = db.lGetMeta(metaKey) if err != nil { return nil, err } - var seq int32 = 0 - var delta int32 = 1 - if whereSeq == listHeadSeq { - seq = headSeq - } else { - delta = -1 + var value []byte + + var seq int32 = headSeq + if whereSeq == listTailSeq { seq = tailSeq } itemKey := db.lEncodeListKey(key, seq) - var value []byte value, err = db.db.Get(itemKey) if err != nil { return nil, err } - t.Delete(itemKey) - seq += delta - - size-- - if size <= 0 { - t.Delete(metaKey) + if whereSeq == listHeadSeq { + headSeq += 1 } else { - if whereSeq == listHeadSeq { - headSeq = seq - } else { - tailSeq = seq - } - - db.lSetMeta(metaKey, headSeq, tailSeq, size) + tailSeq -= 1 + } + + t.Delete(itemKey) + size := db.lSetMeta(metaKey, headSeq, tailSeq) + if size == 0 { + db.rmExpire(t, hExpType, key) } - //todo add binlog err = t.Commit() return value, err } +// ps : here just focus on deleting the list data, +// any other likes expire is ignore. +func (db *DB) lDelete(t *tx, key []byte) int64 { + mk := db.lEncodeMetaKey(key) + + var headSeq int32 + var tailSeq int32 + var err error + + headSeq, tailSeq, _, err = db.lGetMeta(mk) + if err != nil { + return 0 + } + + var num int64 = 0 + startKey := db.lEncodeListKey(key, headSeq) + stopKey := db.lEncodeListKey(key, tailSeq) + + it := db.db.Iterator(startKey, stopKey, leveldb.RangeClose, 0, -1) + for ; it.Valid(); it.Next() { + t.Delete(it.Key()) + num++ + } + it.Close() + + t.Delete(mk) + + return num +} + func (db *DB) lGetSeq(key []byte, whereSeq int32) (int64, error) { ek := db.lEncodeListKey(key, whereSeq) @@ -213,26 +224,52 @@ func (db *DB) lGetMeta(ek []byte) (headSeq int32, tailSeq int32, size int32, err if err != nil { return } else if v == nil { + headSeq = listInitialSeq + tailSeq = listInitialSeq size = 0 return } else { headSeq = int32(binary.LittleEndian.Uint32(v[0:4])) tailSeq = int32(binary.LittleEndian.Uint32(v[4:8])) - size = int32(binary.LittleEndian.Uint32(v[8:])) + size = tailSeq - headSeq + 1 } return } -func (db *DB) lSetMeta(ek []byte, headSeq int32, tailSeq int32, size int32) { +func (db *DB) lSetMeta(ek []byte, headSeq int32, tailSeq int32) int32 { t := db.listTx - buf := make([]byte, 12) + var size int32 = tailSeq - headSeq + 1 + if size < 0 { + // todo : log error + panic + } else if size == 0 { + t.Delete(ek) + } else { + buf := make([]byte, 8) - binary.LittleEndian.PutUint32(buf[0:4], uint32(headSeq)) - binary.LittleEndian.PutUint32(buf[4:8], uint32(tailSeq)) - binary.LittleEndian.PutUint32(buf[8:], uint32(size)) + binary.LittleEndian.PutUint32(buf[0:4], uint32(headSeq)) + binary.LittleEndian.PutUint32(buf[4:8], uint32(tailSeq)) - t.Put(ek, buf) + t.Put(ek, buf) + } + + return size +} + +func (db *DB) lExpireAt(key []byte, when int64) (int64, error) { + t := db.listTx + t.Lock() + defer t.Unlock() + + if llen, err := db.LLen(key); err != nil || llen == 0 { + return 0, err + } else { + db.expireAt(t, lExpType, key, when) + if err := t.Commit(); err != nil { + return 0, err + } + } + return 1, nil } func (db *DB) LIndex(key []byte, index int32) ([]byte, error) { @@ -347,41 +384,18 @@ func (db *DB) LClear(key []byte) (int64, error) { return 0, err } - mk := db.lEncodeMetaKey(key) - t := db.listTx t.Lock() defer t.Unlock() - var headSeq int32 - var tailSeq int32 - var err error + num := db.lDelete(t, key) + db.rmExpire(t, lExpType, key) - headSeq, tailSeq, _, err = db.lGetMeta(mk) - - if err != nil { - return 0, err - } - - var num int64 = 0 - startKey := db.lEncodeListKey(key, headSeq) - stopKey := db.lEncodeListKey(key, tailSeq) - - it := db.db.Iterator(startKey, stopKey, leveldb.RangeClose, 0, -1) - for ; it.Valid(); it.Next() { - t.Delete(it.Key()) - num++ - } - - it.Close() - - t.Delete(mk) - - err = t.Commit() + err := t.Commit() return num, err } -func (db *DB) LFlush() (drop int64, err error) { +func (db *DB) lFlush() (drop int64, err error) { t := db.listTx t.Lock() defer t.Unlock() @@ -398,8 +412,40 @@ func (db *DB) LFlush() (drop int64, err error) { for ; it.Valid(); it.Next() { t.Delete(it.Key()) drop++ + if drop&1023 == 0 { + if err = t.Commit(); err != nil { + return + } + } } + it.Close() + + db.expFlush(t, lExpType) err = t.Commit() return } + +func (db *DB) LExpire(key []byte, duration int64) (int64, error) { + if duration <= 0 { + return 0, errExpireValue + } + + return db.lExpireAt(key, time.Now().Unix()+duration) +} + +func (db *DB) LExpireAt(key []byte, when int64) (int64, error) { + if when <= time.Now().Unix() { + return 0, errExpireValue + } + + return db.lExpireAt(key, when) +} + +func (db *DB) LTTL(key []byte) (int64, error) { + if err := checkKeySize(key); err != nil { + return -1, err + } + + return db.ttl(lExpType, key) +} diff --git a/ledis/t_list_test.go b/ledis/t_list_test.go index fe5d24d..b7cb660 100644 --- a/ledis/t_list_test.go +++ b/ledis/t_list_test.go @@ -31,9 +31,37 @@ func TestDBList(t *testing.T) { key := []byte("testdb_list_a") - if n, err := db.RPush(key, []byte("1"), []byte("2")); err != nil { + if n, err := db.RPush(key, []byte("1"), []byte("2"), []byte("3")); err != nil { t.Fatal(err) - } else if n != 2 { + } else if n != 3 { t.Fatal(n) } + + if k, err := db.RPop(key); err != nil { + t.Fatal(err) + } else if string(k) != "3" { + t.Fatal(string(k)) + } + + if k, err := db.LPop(key); err != nil { + t.Fatal(err) + } else if string(k) != "1" { + t.Fatal(string(k)) + } + + if llen, err := db.LLen(key); err != nil { + t.Fatal(err) + } else if llen != 1 { + t.Fatal(llen) + } + + if num, err := db.LClear(key); err != nil { + t.Fatal(err) + } else if num != 1 { + t.Fatal(num) + } + + if llen, _ := db.LLen(key); llen != 0 { + t.Fatal(llen) + } } diff --git a/ledis/t_ttl.go b/ledis/t_ttl.go new file mode 100644 index 0000000..b353ce7 --- /dev/null +++ b/ledis/t_ttl.go @@ -0,0 +1,212 @@ +package ledis + +import ( + "encoding/binary" + "errors" + "github.com/siddontang/go-leveldb/leveldb" + "time" +) + +var mapExpMetaType = map[byte]byte{ + kvExpType: kvExpMetaType, + lExpType: lExpMetaType, + hExpType: hExpMetaType, + zExpType: zExpMetaType} + +type retireCallback func(*tx, []byte) int64 + +type elimination struct { + db *DB + exp2Tx map[byte]*tx + exp2Retire map[byte]retireCallback +} + +var errExpType = errors.New("invalid expire type") + +func (db *DB) expEncodeTimeKey(expType byte, key []byte, when int64) []byte { + // format : db[8] / expType[8] / when[64] / key[...] + buf := make([]byte, len(key)+10) + + buf[0] = db.index + buf[1] = expType + pos := 2 + + binary.BigEndian.PutUint64(buf[pos:], uint64(when)) + pos += 8 + + copy(buf[pos:], key) + + return buf +} + +func (db *DB) expEncodeMetaKey(expType byte, key []byte) []byte { + // format : db[8] / expType[8] / key[...] + buf := make([]byte, len(key)+2) + + buf[0] = db.index + buf[1] = expType + pos := 2 + + copy(buf[pos:], key) + + return buf +} + +// usage : separate out the original key +func (db *DB) expDecodeMetaKey(mk []byte) []byte { + if len(mk) <= 2 { + // check db ? check type ? + return nil + } + + return mk[2:] +} + +func (db *DB) expire(t *tx, expType byte, key []byte, duration int64) { + db.expireAt(t, expType, key, time.Now().Unix()+duration) +} + +func (db *DB) expireAt(t *tx, expType byte, key []byte, when int64) { + mk := db.expEncodeMetaKey(expType+1, key) + tk := db.expEncodeTimeKey(expType, key, when) + + t.Put(tk, mk) + t.Put(mk, PutInt64(when)) +} + +func (db *DB) ttl(expType byte, key []byte) (t int64, err error) { + mk := db.expEncodeMetaKey(expType+1, key) + + if t, err = Int64(db.db.Get(mk)); err != nil || t == 0 { + t = -1 + } else { + t -= time.Now().Unix() + if t <= 0 { + t = -1 + } + // if t == -1 : to remove ???? + } + + return t, err +} + +func (db *DB) rmExpire(t *tx, expType byte, key []byte) { + mk := db.expEncodeMetaKey(expType+1, key) + if v, err := db.db.Get(mk); err != nil || v == nil { + return + } else if when, err2 := Int64(v, nil); err2 != nil { + return + } else { + tk := db.expEncodeTimeKey(expType, key, when) + t.Delete(mk) + t.Delete(tk) + } +} + +func (db *DB) expFlush(t *tx, expType byte) (err error) { + expMetaType, ok := mapExpMetaType[expType] + if !ok { + return errExpType + } + + drop := 0 + + minKey := make([]byte, 2) + minKey[0] = db.index + minKey[1] = expType + + maxKey := make([]byte, 2) + maxKey[0] = db.index + maxKey[1] = expMetaType + 1 + + it := db.db.Iterator(minKey, maxKey, leveldb.RangeROpen, 0, -1) + for ; it.Valid(); it.Next() { + t.Delete(it.Key()) + drop++ + if drop&1023 == 0 { + if err = t.Commit(); err != nil { + return + } + } + } + it.Close() + + err = t.Commit() + return +} + +////////////////////////////////////////////////////////// +// +////////////////////////////////////////////////////////// + +func newEliminator(db *DB) *elimination { + eli := new(elimination) + eli.db = db + eli.exp2Tx = make(map[byte]*tx) + eli.exp2Retire = make(map[byte]retireCallback) + return eli +} + +func (eli *elimination) regRetireContext(expType byte, t *tx, onRetire retireCallback) { + eli.exp2Tx[expType] = t + eli.exp2Retire[expType] = onRetire +} + +// call by outside ... (from *db to another *db) +func (eli *elimination) active() { + now := time.Now().Unix() + db := eli.db + dbGet := db.db.Get + expKeys := make([][]byte, 0, 1024) + expTypes := [...]byte{kvExpType, lExpType, hExpType, zExpType} + + for _, et := range expTypes { + // search those keys' which expire till the moment + minKey := db.expEncodeTimeKey(et, nil, 0) + maxKey := db.expEncodeTimeKey(et, nil, now+1) + expKeys = expKeys[0:0] + + t, _ := eli.exp2Tx[et] + onRetire, _ := eli.exp2Retire[et] + if t == nil || onRetire == nil { + // todo : log error + continue + } + + it := db.db.Iterator(minKey, maxKey, leveldb.RangeROpen, 0, -1) + for it.Valid() { + for i := 1; i < 512 && it.Valid(); i++ { + expKeys = append(expKeys, it.Key(), it.Value()) + it.Next() + } + + var cnt int = len(expKeys) + if cnt == 0 { + continue + } + + t.Lock() + var mk, ek, k []byte + for i := 0; i < cnt; i += 2 { + ek, mk = expKeys[i], expKeys[i+1] + if exp, err := Int64(dbGet(mk)); err == nil { + // check expire again + if exp > now { + continue + } + + // delete keys + k = db.expDecodeMetaKey(mk) + onRetire(t, k) + t.Delete(ek) + t.Delete(mk) + } + } + t.Commit() + t.Unlock() + } // end : it + it.Close() + } // end : expType + + return +} diff --git a/ledis/t_ttl_test.go b/ledis/t_ttl_test.go new file mode 100644 index 0000000..e7ec920 --- /dev/null +++ b/ledis/t_ttl_test.go @@ -0,0 +1,362 @@ +package ledis + +import ( + "fmt" + "sync" + "testing" + "time" +) + +var m sync.Mutex + +type adaptor struct { + set func([]byte, []byte) (int64, error) + del func([]byte) (int64, error) + exists func([]byte) (int64, error) + + expire func([]byte, int64) (int64, error) + expireAt func([]byte, int64) (int64, error) + ttl func([]byte) (int64, error) + + showIdent func() string +} + +func kvAdaptor(db *DB) *adaptor { + adp := new(adaptor) + adp.showIdent = func() string { + return "kv-adptor" + } + + adp.set = db.SetNX + adp.exists = db.Exists + adp.del = func(k []byte) (int64, error) { + return db.Del(k) + } + + adp.expire = db.Expire + adp.expireAt = db.ExpireAt + adp.ttl = db.TTL + + return adp +} + +func listAdaptor(db *DB) *adaptor { + adp := new(adaptor) + adp.showIdent = func() string { + return "list-adptor" + } + + adp.set = func(k []byte, v []byte) (int64, error) { + eles := make([][]byte, 0) + for i := 0; i < 3; i++ { + e := []byte(String(v) + fmt.Sprintf("_%d", i)) + eles = append(eles, e) + } + + if n, err := db.LPush(k, eles...); err != nil { + return 0, err + } else { + return n, nil + } + } + + adp.exists = func(k []byte) (int64, error) { + if llen, err := db.LLen(k); err != nil || llen <= 0 { + return 0, err + } else { + return 1, nil + } + } + + adp.del = db.LClear + adp.expire = db.LExpire + adp.expireAt = db.LExpireAt + adp.ttl = db.LTTL + + return adp +} + +func hashAdaptor(db *DB) *adaptor { + adp := new(adaptor) + adp.showIdent = func() string { + return "hash-adptor" + } + + adp.set = func(k []byte, v []byte) (int64, error) { + datas := make([]FVPair, 0) + for i := 0; i < 3; i++ { + suffix := fmt.Sprintf("_%d", i) + pair := FVPair{ + Field: []byte(String(k) + suffix), + Value: []byte(String(v) + suffix)} + + datas = append(datas, pair) + } + + if err := db.HMset(k, datas...); err != nil { + return 0, err + } else { + return int64(len(datas)), nil + } + } + + adp.exists = func(k []byte) (int64, error) { + if hlen, err := db.HLen(k); err != nil || hlen <= 0 { + return 0, err + } else { + return 1, nil + } + } + + adp.del = db.HClear + adp.expire = db.HExpire + adp.expireAt = db.HExpireAt + adp.ttl = db.HTTL + + return adp +} + +func zsetAdaptor(db *DB) *adaptor { + adp := new(adaptor) + adp.showIdent = func() string { + return "zset-adptor" + } + + adp.set = func(k []byte, v []byte) (int64, error) { + datas := make([]ScorePair, 0) + for i := 0; i < 3; i++ { + memb := []byte(String(k) + fmt.Sprintf("_%d", i)) + pair := ScorePair{ + Score: int64(i), + Member: memb} + + datas = append(datas, pair) + } + + if n, err := db.ZAdd(k, datas...); err != nil { + return 0, err + } else { + return n, nil + } + } + + adp.exists = func(k []byte) (int64, error) { + if cnt, err := db.ZCard(k); err != nil || cnt <= 0 { + return 0, err + } else { + return 1, nil + } + } + + adp.del = db.ZClear + adp.expire = db.ZExpire + adp.expireAt = db.ZExpireAt + adp.ttl = db.ZTTL + + return adp +} + +func allAdaptors(db *DB) []*adaptor { + adps := make([]*adaptor, 4) + adps[0] = kvAdaptor(db) + adps[1] = listAdaptor(db) + adps[2] = hashAdaptor(db) + adps[3] = zsetAdaptor(db) + return adps +} + +/////////////////////////////////////////////////////// + +func TestExpire(t *testing.T) { + db := getTestDB() + m.Lock() + defer m.Unlock() + + k := []byte("ttl_a") + ek := []byte("ttl_b") + + dbEntrys := allAdaptors(db) + for _, entry := range dbEntrys { + ident := entry.showIdent() + + entry.set(k, []byte("1")) + + if ok, _ := entry.expire(k, 10); ok != 1 { + t.Fatal(ident, ok) + } + + // err - expire on an inexisting key + if ok, _ := entry.expire(ek, 10); ok != 0 { + t.Fatal(ident, ok) + } + + // err - duration is zero + if ok, err := entry.expire(k, 0); err == nil || ok != 0 { + t.Fatal(ident, fmt.Sprintf("res = %d, err = %s", ok, err)) + } + + // err - duration is negative + if ok, err := entry.expire(k, -10); err == nil || ok != 0 { + t.Fatal(ident, fmt.Sprintf("res = %d, err = %s", ok, err)) + } + } +} + +func TestExpireAt(t *testing.T) { + db := getTestDB() + m.Lock() + defer m.Unlock() + + k := []byte("ttl_a") + ek := []byte("ttl_b") + + dbEntrys := allAdaptors(db) + for _, entry := range dbEntrys { + ident := entry.showIdent() + now := time.Now().Unix() + + entry.set(k, []byte("1")) + + if ok, _ := entry.expireAt(k, now+5); ok != 1 { + t.Fatal(ident, ok) + } + + // err - expire on an inexisting key + if ok, _ := entry.expireAt(ek, now+5); ok != 0 { + t.Fatal(ident, ok) + } + + // err - expire with the current time + if ok, err := entry.expireAt(k, now); err == nil || ok != 0 { + t.Fatal(ident, fmt.Sprintf("res = %d, err = %s", ok, err)) + } + + // err - expire with the time before + if ok, err := entry.expireAt(k, now-5); err == nil || ok != 0 { + t.Fatal(ident, fmt.Sprintf("res = %d, err = %s", ok, err)) + } + } +} + +func TestTTL(t *testing.T) { + db := getTestDB() + m.Lock() + defer m.Unlock() + + k := []byte("ttl_a") + ek := []byte("ttl_b") + + dbEntrys := allAdaptors(db) + for _, entry := range dbEntrys { + ident := entry.showIdent() + + entry.set(k, []byte("1")) + entry.expire(k, 2) + + if tRemain, _ := entry.ttl(k); tRemain != 2 { + t.Fatal(ident, tRemain) + } + + // err - check ttl on an inexisting key + if tRemain, _ := entry.ttl(ek); tRemain != -1 { + t.Fatal(ident, tRemain) + } + + entry.del(k) + if tRemain, _ := entry.ttl(k); tRemain != -1 { + t.Fatal(ident, tRemain) + } + } +} + +func TestExpCompose(t *testing.T) { + db := getTestDB() + m.Lock() + defer m.Unlock() + + k0 := []byte("ttl_a") + k1 := []byte("ttl_b") + k2 := []byte("ttl_c") + + dbEntrys := allAdaptors(db) + + for _, entry := range dbEntrys { + ident := entry.showIdent() + + entry.set(k0, k0) + entry.set(k1, k1) + entry.set(k2, k2) + + entry.expire(k0, 5) + entry.expire(k1, 2) + entry.expire(k2, 60) + + if tRemain, _ := entry.ttl(k0); tRemain != 5 { + t.Fatal(ident, tRemain) + } + if tRemain, _ := entry.ttl(k1); tRemain != 2 { + t.Fatal(ident, tRemain) + } + if tRemain, _ := entry.ttl(k2); tRemain != 60 { + t.Fatal(ident, tRemain) + } + } + + // after 1 sec + time.Sleep(1 * time.Second) + + for _, entry := range dbEntrys { + ident := entry.showIdent() + + if tRemain, _ := entry.ttl(k0); tRemain != 4 { + t.Fatal(ident, tRemain) + } + if tRemain, _ := entry.ttl(k1); tRemain != 1 { + t.Fatal(ident, tRemain) + } + } + + // after 2 sec + time.Sleep(2 * time.Second) + + for _, entry := range dbEntrys { + ident := entry.showIdent() + + if tRemain, _ := entry.ttl(k1); tRemain != -1 { + t.Fatal(ident, tRemain) + } + if exist, _ := entry.exists(k1); exist > 0 { + t.Fatal(ident, false) + } + + if tRemain, _ := entry.ttl(k0); tRemain != 2 { + t.Fatal(ident, tRemain) + } + if exist, _ := entry.exists(k0); exist <= 0 { + t.Fatal(ident, false) + } + + // refresh the expiration of key + if tRemain, _ := entry.ttl(k2); !(0 < tRemain && tRemain < 60) { + t.Fatal(ident, tRemain) + } + + if ok, _ := entry.expire(k2, 100); ok != 1 { + t.Fatal(ident, false) + } + + if tRemain, _ := entry.ttl(k2); tRemain != 100 { + t.Fatal(ident, tRemain) + } + + // expire an inexisting key + if ok, _ := entry.expire(k1, 10); ok == 1 { + t.Fatal(ident, false) + } + if tRemain, _ := entry.ttl(k1); tRemain != -1 { + t.Fatal(ident, tRemain) + } + } + + return +} diff --git a/ledis/t_zset.go b/ledis/t_zset.go index 8be54b7..45f8cfa 100644 --- a/ledis/t_zset.go +++ b/ledis/t_zset.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "errors" "github.com/siddontang/go-leveldb/leveldb" + "time" ) const ( @@ -34,9 +35,9 @@ const ( func checkZSetKMSize(key []byte, member []byte) error { if len(key) > MaxKeySize || len(key) == 0 { - return ErrKeySize + return errKeySize } else if len(member) > MaxZSetMemberSize || len(member) == 0 { - return ErrZSetMemberSize + return errZSetMemberSize } return nil } @@ -191,13 +192,11 @@ func (db *DB) zDecodeScoreKey(ek []byte) (key []byte, member []byte, score int64 return } -func (db *DB) zSetItem(key []byte, score int64, member []byte) (int64, error) { +func (db *DB) zSetItem(t *tx, key []byte, score int64, member []byte) (int64, error) { if score <= MinScore || score >= MaxScore { return 0, errScoreOverflow } - t := db.zsetTx - var exists int64 = 0 ek := db.zEncodeSetKey(key, member) @@ -222,9 +221,7 @@ func (db *DB) zSetItem(key []byte, score int64, member []byte) (int64, error) { return exists, nil } -func (db *DB) zDelItem(key []byte, member []byte, skipDelScore bool) (int64, error) { - t := db.zsetTx - +func (db *DB) zDelItem(t *tx, key []byte, member []byte, skipDelScore bool) (int64, error) { ek := db.zEncodeSetKey(key, member) if v, err := db.db.Get(ek); err != nil { return 0, err @@ -245,6 +242,29 @@ func (db *DB) zDelItem(key []byte, member []byte, skipDelScore bool) (int64, err } t.Delete(ek) + + return 1, nil +} + +func (db *DB) zDelete(t *tx, key []byte) int64 { + delMembCnt, _ := db.zRemRange(t, key, MinScore, MaxScore, 0, -1) + // todo : log err + return delMembCnt +} + +func (db *DB) zExpireAt(key []byte, when int64) (int64, error) { + t := db.zsetTx + t.Lock() + defer t.Unlock() + + if zcnt, err := db.ZCard(key); err != nil || zcnt == 0 { + return 0, err + } else { + db.expireAt(t, zExpType, key, when) + if err := t.Commit(); err != nil { + return 0, err + } + } return 1, nil } @@ -266,7 +286,7 @@ func (db *DB) ZAdd(key []byte, args ...ScorePair) (int64, error) { return 0, err } - if n, err := db.zSetItem(key, score, member); err != nil { + if n, err := db.zSetItem(t, key, score, member); err != nil { return 0, err } else if n == 0 { //add new @@ -274,7 +294,7 @@ func (db *DB) ZAdd(key []byte, args ...ScorePair) (int64, error) { } } - if _, err := db.zIncrSize(key, num); err != nil { + if _, err := db.zIncrSize(t, key, num); err != nil { return 0, err } @@ -283,8 +303,7 @@ func (db *DB) ZAdd(key []byte, args ...ScorePair) (int64, error) { return num, err } -func (db *DB) zIncrSize(key []byte, delta int64) (int64, error) { - t := db.zsetTx +func (db *DB) zIncrSize(t *tx, key []byte, delta int64) (int64, error) { sk := db.zEncodeSizeKey(key) size, err := Int64(db.db.Get(sk)) @@ -295,6 +314,7 @@ func (db *DB) zIncrSize(key []byte, delta int64) (int64, error) { if size <= 0 { size = 0 t.Delete(sk) + db.rmExpire(t, zExpType, key) } else { t.Put(sk, PutInt64(size)) } @@ -348,14 +368,14 @@ func (db *DB) ZRem(key []byte, members ...[]byte) (int64, error) { return 0, err } - if n, err := db.zDelItem(key, members[i], false); err != nil { + if n, err := db.zDelItem(t, key, members[i], false); err != nil { return 0, err } else if n == 1 { num++ } } - if _, err := db.zIncrSize(key, -num); err != nil { + if _, err := db.zIncrSize(t, key, -num); err != nil { return 0, err } @@ -374,34 +394,35 @@ func (db *DB) ZIncrBy(key []byte, delta int64, member []byte) (int64, error) { ek := db.zEncodeSetKey(key, member) - var score int64 = delta - + var oldScore int64 = 0 v, err := db.db.Get(ek) if err != nil { return InvalidScore, err - } else if v != nil { - if s, err := Int64(v, err); err != nil { - return InvalidScore, err - } else { - sk := db.zEncodeScoreKey(key, member, s) - t.Delete(sk) - - score = s + delta - if score >= MaxScore || score <= MinScore { - return InvalidScore, errScoreOverflow - } - } + } else if v == nil { + db.zIncrSize(t, key, 1) } else { - db.zIncrSize(key, 1) + if oldScore, err = Int64(v, err); err != nil { + return InvalidScore, err + } } - t.Put(ek, PutInt64(score)) + newScore := oldScore + delta + if newScore >= MaxScore || newScore <= MinScore { + return InvalidScore, errScoreOverflow + } - sk := db.zEncodeScoreKey(key, member, score) + sk := db.zEncodeScoreKey(key, member, newScore) t.Put(sk, []byte{}) + t.Put(ek, PutInt64(newScore)) + + if v != nil { + // so as to update score, we must delete the old one + oldSk := db.zEncodeScoreKey(key, member, oldScore) + t.Delete(oldSk) + } err = t.Commit() - return score, err + return newScore, err } func (db *DB) ZCount(key []byte, min int64, max int64) (int64, error) { @@ -482,41 +503,35 @@ func (db *DB) zIterator(key []byte, min int64, max int64, offset int, limit int, } } -func (db *DB) zRemRange(key []byte, min int64, max int64, offset int, limit int) (int64, error) { +func (db *DB) zRemRange(t *tx, key []byte, min int64, max int64, offset int, limit int) (int64, error) { if len(key) > MaxKeySize { - return 0, ErrKeySize + return 0, errKeySize } - t := db.zsetTx - t.Lock() - defer t.Unlock() - it := db.zIterator(key, min, max, offset, limit, false) var num int64 = 0 for ; it.Valid(); it.Next() { - k := it.Key() - _, m, _, err := db.zDecodeScoreKey(k) + sk := it.Key() + _, m, _, err := db.zDecodeScoreKey(sk) if err != nil { continue } - if n, err := db.zDelItem(key, m, true); err != nil { + if n, err := db.zDelItem(t, key, m, true); err != nil { return 0, err } else if n == 1 { num++ } - t.Delete(k) + t.Delete(sk) } + it.Close() - if _, err := db.zIncrSize(key, -num); err != nil { + if _, err := db.zIncrSize(t, key, -num); err != nil { return 0, err } - //todo add binlog - - err := t.Commit() - return num, err + return num, nil } func (db *DB) zReverse(s []interface{}, withScores bool) []interface{} { @@ -536,7 +551,7 @@ func (db *DB) zReverse(s []interface{}, withScores bool) []interface{} { func (db *DB) zRange(key []byte, min int64, max int64, withScores bool, offset int, limit int, reverse bool) ([]interface{}, error) { if len(key) > MaxKeySize { - return nil, ErrKeySize + return nil, errKeySize } if offset < 0 { @@ -575,6 +590,7 @@ func (db *DB) zRange(key []byte, min int64, max int64, withScores bool, offset i v = append(v, m) } } + it.Close() if reverse && (offset == 0 && limit < 0) { v = db.zReverse(v, withScores) @@ -622,7 +638,16 @@ func (db *DB) zParseLimit(key []byte, start int, stop int) (offset int, limit in } func (db *DB) ZClear(key []byte) (int64, error) { - return db.zRemRange(key, MinScore, MaxScore, 0, -1) + t := db.zsetTx + t.Lock() + defer t.Unlock() + + rmCnt, err := db.zRemRange(t, key, MinScore, MaxScore, 0, -1) + if err == nil { + err = t.Commit() + } + + return rmCnt, err } func (db *DB) ZRange(key []byte, start int, stop int, withScores bool) ([]interface{}, error) { @@ -645,12 +670,33 @@ func (db *DB) ZRemRangeByRank(key []byte, start int, stop int) (int64, error) { if err != nil { return 0, err } - return db.zRemRange(key, MinScore, MaxScore, offset, limit) + + var rmCnt int64 + + t := db.zsetTx + t.Lock() + defer t.Unlock() + + rmCnt, err = db.zRemRange(t, key, MinScore, MaxScore, offset, limit) + if err == nil { + err = t.Commit() + } + + return rmCnt, err } //min and max must be inclusive func (db *DB) ZRemRangeByScore(key []byte, min int64, max int64) (int64, error) { - return db.zRemRange(key, min, max, 0, -1) + t := db.zsetTx + t.Lock() + defer t.Unlock() + + rmCnt, err := db.zRemRange(t, key, min, max, 0, -1) + if err == nil { + err = t.Commit() + } + + return rmCnt, err } func (db *DB) ZRevRange(key []byte, start int, stop int, withScores bool) ([]interface{}, error) { @@ -686,7 +732,7 @@ func (db *DB) ZRangeByScoreGeneric(key []byte, min int64, max int64, return db.zRange(key, min, max, withScores, offset, count, reverse) } -func (db *DB) ZFlush() (drop int64, err error) { +func (db *DB) zFlush() (drop int64, err error) { t := db.zsetTx t.Lock() defer t.Unlock() @@ -703,10 +749,17 @@ func (db *DB) ZFlush() (drop int64, err error) { for ; it.Valid(); it.Next() { t.Delete(it.Key()) drop++ + if drop&1023 == 0 { + if err = t.Commit(); err != nil { + return + } + } } + it.Close() + + db.expFlush(t, zExpType) err = t.Commit() - // to do : binlog return } @@ -744,6 +797,31 @@ func (db *DB) ZScan(key []byte, member []byte, count int, inclusive bool) ([]Sco v = append(v, ScorePair{Member: m, Score: score}) } } + it.Close() return v, nil } + +func (db *DB) ZExpire(key []byte, duration int64) (int64, error) { + if duration <= 0 { + return 0, errExpireValue + } + + return db.zExpireAt(key, time.Now().Unix()+duration) +} + +func (db *DB) ZExpireAt(key []byte, when int64) (int64, error) { + if when <= time.Now().Unix() { + return 0, errExpireValue + } + + return db.zExpireAt(key, when) +} + +func (db *DB) ZTTL(key []byte) (int64, error) { + if err := checkKeySize(key); err != nil { + return -1, err + } + + return db.ttl(zExpType, key) +} diff --git a/ledis/t_zset_test.go b/ledis/t_zset_test.go index 202da70..2b637e6 100644 --- a/ledis/t_zset_test.go +++ b/ledis/t_zset_test.go @@ -220,7 +220,7 @@ func TestZSetOrder(t *testing.T) { func TestDBZScan(t *testing.T) { db := getTestDB() - db.ZFlush() + db.zFlush() key := []byte("key") db.ZAdd(key, pair("a", 0), pair("b", 1), pair("c", 2)) diff --git a/ledis/tx.go b/ledis/tx.go index 38b6e01..fa7379b 100644 --- a/ledis/tx.go +++ b/ledis/tx.go @@ -8,7 +8,22 @@ import ( type tx struct { m sync.Mutex + l *Ledis wb *leveldb.WriteBatch + + binlog *BinLog + batch [][]byte +} + +func newTx(l *Ledis) *tx { + t := new(tx) + + t.l = l + t.wb = l.ldb.NewWriteBatch() + + t.batch = make([][]byte, 0, 4) + t.binlog = l.binlog + return t } func (t *tx) Close() { @@ -17,10 +32,20 @@ func (t *tx) Close() { func (t *tx) Put(key []byte, value []byte) { t.wb.Put(key, value) + + if t.binlog != nil { + buf := encodeBinLogPut(key, value) + t.batch = append(t.batch, buf) + } } func (t *tx) Delete(key []byte) { t.wb.Delete(key) + + if t.binlog != nil { + buf := encodeBinLogDelete(key) + t.batch = append(t.batch, buf) + } } func (t *tx) Lock() { @@ -28,12 +53,29 @@ func (t *tx) Lock() { } func (t *tx) Unlock() { + t.batch = t.batch[0:0] t.wb.Rollback() t.m.Unlock() } func (t *tx) Commit() error { - err := t.wb.Commit() + var err error + if t.binlog != nil { + t.l.Lock() + err = t.wb.Commit() + if err != nil { + t.l.Unlock() + return err + } + + err = t.binlog.Log(t.batch...) + + t.l.Unlock() + } else { + t.l.Lock() + err = t.wb.Commit() + t.l.Unlock() + } return err } diff --git a/server/accesslog.go b/server/accesslog.go new file mode 100644 index 0000000..58b34b7 --- /dev/null +++ b/server/accesslog.go @@ -0,0 +1,51 @@ +package server + +import ( + "fmt" + "github.com/siddontang/go-log/log" + "strings" +) + +const ( + accessTimeFormat = "2006/01/02 15:04:05" +) + +type accessLog struct { + l *log.Logger +} + +func newAcessLog(baseName string) (*accessLog, error) { + l := new(accessLog) + + h, err := log.NewTimeRotatingFileHandler(baseName, log.WhenDay, 1) + if err != nil { + return nil, err + } + + l.l = log.New(h, log.Ltime) + + return l, nil +} + +func (l *accessLog) Close() { + l.l.Close() +} + +func (l *accessLog) Log(remoteAddr string, usedTime int64, cmd string, args [][]byte, err error) { + argsFormat := make([]string, len(args)) + argsI := make([]interface{}, len(args)) + for i := range args { + argsFormat[i] = " %.24q" + argsI[i] = args[i] + } + + argsStr := fmt.Sprintf(strings.Join(argsFormat, ""), argsI...) + + format := `%s [%s%s] %d [%s]` + + if err == nil { + l.l.Info(format, remoteAddr, cmd, argsStr, usedTime, "OK") + } else { + l.l.Info(format, remoteAddr, cmd, argsStr, usedTime, err.Error()) + } +} diff --git a/server/app.go b/server/app.go index 9e7e029..5a9bb87 100644 --- a/server/app.go +++ b/server/app.go @@ -1,8 +1,10 @@ package server import ( + "fmt" "github.com/siddontang/ledisdb/ledis" "net" + "path" "strings" ) @@ -14,11 +16,28 @@ type App struct { ldb *ledis.Ledis closed bool + + quit chan struct{} + + access *accessLog + + //for slave replication + m *master } func NewApp(cfg *Config) (*App, error) { + if len(cfg.DataDir) == 0 { + return nil, fmt.Errorf("must set data_dir first") + } + + if len(cfg.DB.DataDir) == 0 { + cfg.DB.DataDir = cfg.DataDir + } + app := new(App) + app.quit = make(chan struct{}) + app.closed = false app.cfg = cfg @@ -35,11 +54,24 @@ func NewApp(cfg *Config) (*App, error) { return nil, err } - app.ldb, err = ledis.OpenWithConfig(&cfg.DB) - if err != nil { + if len(cfg.AccessLog) > 0 { + if path.Dir(cfg.AccessLog) == "." { + app.access, err = newAcessLog(path.Join(cfg.DataDir, cfg.AccessLog)) + } else { + app.access, err = newAcessLog(cfg.AccessLog) + } + + if err != nil { + return nil, err + } + } + + if app.ldb, err = ledis.OpenWithConfig(&cfg.DB); err != nil { return nil, err } + app.m = newMaster(app) + return app, nil } @@ -48,20 +80,36 @@ func (app *App) Close() { return } + app.closed = true + + close(app.quit) + app.listener.Close() - app.ldb.Close() + app.m.Close() - app.closed = true + if app.access != nil { + app.access.Close() + } + + app.ldb.Close() } func (app *App) Run() { + if len(app.cfg.SlaveOf) > 0 { + app.slaveof(app.cfg.SlaveOf) + } + for !app.closed { conn, err := app.listener.Accept() if err != nil { continue } - newClient(conn, app.ldb) + newClient(conn, app) } } + +func (app *App) Ledis() *ledis.Ledis { + return app.ldb +} diff --git a/server/app_test.go b/server/app_test.go index 2293ce2..3357ddd 100644 --- a/server/app_test.go +++ b/server/app_test.go @@ -38,10 +38,10 @@ func startTestApp() { var d = []byte(` { + "data_dir" : "/tmp/testdb", "addr" : "127.0.0.1:16380", "db" : { "data_db" : { - "path" : "/tmp/testdb", "compression":true, "block_size" : 32768, "write_buffer_size" : 2097152, diff --git a/server/client.go b/server/client.go index bd558fd..ca3d72d 100644 --- a/server/client.go +++ b/server/client.go @@ -2,6 +2,7 @@ package server import ( "bufio" + "bytes" "errors" "github.com/siddontang/go-log/log" "github.com/siddontang/ledisdb/ledis" @@ -10,11 +11,13 @@ import ( "runtime" "strconv" "strings" + "time" ) var errReadRequest = errors.New("invalid request protocol") type client struct { + app *App ldb *ledis.Ledis db *ledis.DB @@ -27,13 +30,17 @@ type client struct { args [][]byte reqC chan error + + syncBuf bytes.Buffer } -func newClient(c net.Conn, ldb *ledis.Ledis) { +func newClient(c net.Conn, app *App) { co := new(client) - co.ldb = ldb + + co.app = app + co.ldb = app.ldb //use default db - co.db, _ = ldb.Select(0) + co.db, _ = app.ldb.Select(0) co.c = c co.rb = bufio.NewReaderSize(c, 256) @@ -68,22 +75,7 @@ func (c *client) run() { } func (c *client) readLine() ([]byte, error) { - var line []byte - for { - l, more, err := c.rb.ReadLine() - if err != nil { - return nil, err - } - - if line == nil && !more { - return l, nil - } - line = append(line, l...) - if !more { - break - } - } - return line, nil + return readLine(c.rb) } //A client sends to the Redis server a RESP Array consisting of just Bulk Strings. @@ -118,15 +110,19 @@ func (c *client) readRequest() ([][]byte, error) { } else if n == -1 { req = append(req, nil) } else { - buf := make([]byte, n+2) + buf := make([]byte, n) if _, err = io.ReadFull(c.rb, buf); err != nil { return nil, err - } else if buf[len(buf)-2] != '\r' || buf[len(buf)-1] != '\n' { - return nil, errReadRequest - - } else { - req = append(req, buf[0:len(buf)-2]) } + + if l, err = c.readLine(); err != nil { + return nil, err + } else if len(l) != 0 { + return nil, errors.New("bad bulk string format") + } + + req = append(req, buf) + } } else { @@ -140,6 +136,8 @@ func (c *client) readRequest() ([][]byte, error) { func (c *client) handleRequest(req [][]byte) { var err error + start := time.Now() + if len(req) == 0 { err = ErrEmptyCommand } else { @@ -157,6 +155,12 @@ func (c *client) handleRequest(req [][]byte) { } } + duration := time.Since(start) + + if c.app.access != nil { + c.app.access.Log(c.c.RemoteAddr().String(), duration.Nanoseconds()/1000000, c.cmd, c.args, err) + } + if err != nil { c.writeError(err) } @@ -223,3 +227,12 @@ func (c *client) writeArray(ay []interface{}) { } } } + +func (c *client) writeBulkFrom(n int64, rb io.Reader) { + c.wb.WriteByte('$') + c.wb.Write(ledis.Slice(strconv.FormatInt(n, 10))) + c.wb.Write(Delims) + + io.Copy(c.wb, rb) + c.wb.Write(Delims) +} diff --git a/server/cmd_kv.go b/server/cmd_kv.go index a6bdcbd..1c42cec 100644 --- a/server/cmd_kv.go +++ b/server/cmd_kv.go @@ -20,7 +20,7 @@ func getCommand(c *client) error { func setCommand(c *client) error { args := c.args - if len(args) < 2 { + if len(args) != 2 { return ErrCmdParams } diff --git a/server/cmd_replication.go b/server/cmd_replication.go new file mode 100644 index 0000000..82803dd --- /dev/null +++ b/server/cmd_replication.go @@ -0,0 +1,116 @@ +package server + +import ( + "encoding/binary" + "fmt" + "github.com/siddontang/ledisdb/ledis" + "io/ioutil" + "os" + "strconv" + "strings" +) + +func slaveofCommand(c *client) error { + args := c.args + + if len(args) != 2 { + return ErrCmdParams + } + + masterAddr := "" + + if strings.ToLower(ledis.String(args[0])) == "no" && + strings.ToLower(ledis.String(args[1])) == "one" { + //stop replication, use master = "" + } else { + if _, err := strconv.ParseInt(ledis.String(args[1]), 10, 16); err != nil { + return err + } + + masterAddr = fmt.Sprintf("%s:%s", args[0], args[1]) + } + + if err := c.app.slaveof(masterAddr); err != nil { + return err + } + + c.writeStatus(OK) + + return nil +} + +func fullsyncCommand(c *client) error { + //todo, multi fullsync may use same dump file + dumpFile, err := ioutil.TempFile(c.app.cfg.DataDir, "dump_") + if err != nil { + return err + } + + if err = c.app.ldb.Dump(dumpFile); err != nil { + return err + } + + st, _ := dumpFile.Stat() + n := st.Size() + + dumpFile.Seek(0, os.SEEK_SET) + + c.writeBulkFrom(n, dumpFile) + + name := dumpFile.Name() + dumpFile.Close() + + os.Remove(name) + + return nil +} + +var reserveInfoSpace = make([]byte, 16) + +func syncCommand(c *client) error { + args := c.args + if len(args) != 2 { + return ErrCmdParams + } + + var logIndex int64 + var logPos int64 + var err error + logIndex, err = ledis.StrInt64(args[0], nil) + if err != nil { + return ErrCmdParams + } + + logPos, err = ledis.StrInt64(args[1], nil) + if err != nil { + return ErrCmdParams + } + + c.syncBuf.Reset() + + //reserve space to write master info + if _, err := c.syncBuf.Write(reserveInfoSpace); err != nil { + return err + } + + m := &ledis.MasterInfo{logIndex, logPos} + + if _, err := c.app.ldb.ReadEventsTo(m, &c.syncBuf); err != nil { + return err + } else { + buf := c.syncBuf.Bytes() + + binary.BigEndian.PutUint64(buf[0:], uint64(m.LogFileIndex)) + binary.BigEndian.PutUint64(buf[8:], uint64(m.LogPos)) + + c.writeBulk(buf) + } + + return nil +} + +func init() { + register("slaveof", slaveofCommand) + register("fullsync", fullsyncCommand) + register("sync", syncCommand) +} diff --git a/server/cmd_replication_test.go b/server/cmd_replication_test.go new file mode 100644 index 0000000..645f4f7 --- /dev/null +++ b/server/cmd_replication_test.go @@ -0,0 +1,107 @@ +package server + +import ( + "bytes" + "fmt" + "github.com/siddontang/go-leveldb/leveldb" + "os" + "testing" + "time" +) + +func checkDataEqual(master *App, slave *App) error { + it := master.ldb.DataDB().Iterator(nil, nil, leveldb.RangeClose, 0, -1) + for ; it.Valid(); it.Next() { + key := it.Key() + value := it.Value() + + if v, err := slave.ldb.DataDB().Get(key); err != nil { + return err + } else if !bytes.Equal(v, value) { + return fmt.Errorf("replication error %d != %d", len(v), len(value)) + } + } + + return nil +} + +func TestReplication(t *testing.T) { + data_dir := "/tmp/test_replication" + os.RemoveAll(data_dir) + + masterCfg := new(Config) + masterCfg.DataDir = fmt.Sprintf("%s/master", data_dir) + masterCfg.Addr = "127.0.0.1:11182" + masterCfg.DB.UseBinLog = true + + var master *App + var slave *App + var err error + master, err = NewApp(masterCfg) + if err != nil { + t.Fatal(err) + } + + slaveCfg := new(Config) + slaveCfg.DataDir = fmt.Sprintf("%s/slave", data_dir) + slaveCfg.Addr = "127.0.0.1:11183" + slaveCfg.SlaveOf = masterCfg.Addr + + slave, err = NewApp(slaveCfg) + if err != nil { + t.Fatal(err) + } + + go master.Run() + + db, _ := master.ldb.Select(0) + + value := make([]byte, 10) + + db.Set([]byte("a"), value) + db.Set([]byte("b"), value) + db.HSet([]byte("a"), []byte("1"), value) + db.HSet([]byte("b"), []byte("2"), value) + + go slave.Run() + + time.Sleep(1 * time.Second) + + if err = checkDataEqual(master, slave); err != nil { + t.Fatal(err) + } + + db.Set([]byte("a1"), value) + db.Set([]byte("b1"), value) + db.HSet([]byte("a1"), []byte("1"), value) + db.HSet([]byte("b1"), []byte("2"), value) + + time.Sleep(1 * time.Second) + if err = checkDataEqual(master, slave); err != nil { + t.Fatal(err) + } + + slave.slaveof("") + + db.Set([]byte("a2"), value) + db.Set([]byte("b2"), value) + db.HSet([]byte("a2"), []byte("1"), value) + db.HSet([]byte("b2"), []byte("2"), value) + + db.Set([]byte("a3"), value) + db.Set([]byte("b3"), value) + db.HSet([]byte("a3"), []byte("1"), value) + db.HSet([]byte("b3"), []byte("2"), value) + + if err = checkDataEqual(master, slave); err == nil { + t.Fatal("must error") + } + + slave.slaveof(masterCfg.Addr) + time.Sleep(1 * time.Second) + + if err = checkDataEqual(master, slave); err != nil { + t.Fatal(err) + } + +} diff --git a/server/config.go b/server/config.go index 5e98694..89aba2c 100644 --- a/server/config.go +++ b/server/config.go @@ -9,7 +9,16 @@ import ( type Config struct { Addr string `json:"addr"` + DataDir string `json:"data_dir"` + + //if you not set db path, use data_dir DB ledis.Config `json:"db"` + + //set slaveof to enable replication from master + //empty, no replication + SlaveOf string `json:"slaveof"` + + AccessLog string `json:"access_log"` } func NewConfig(data json.RawMessage) (*Config, error) { diff --git a/server/replication.go b/server/replication.go new file mode 100644 index 0000000..2ee0af4 --- /dev/null +++ b/server/replication.go @@ -0,0 +1,325 @@ +package server + +import ( + "bufio" + "bytes" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "github.com/siddontang/go-log/log" + "github.com/siddontang/ledisdb/ledis" + "io/ioutil" + "net" + "os" + "path" + "strconv" + "sync" + "time" +) + +var ( + errConnectMaster = errors.New("connect master error") +) + +type master struct { + sync.Mutex + + addr string `json:"addr"` + logFileIndex int64 `json:"log_file_index"` + logPos int64 `json:"log_pos"` + + c net.Conn + rb *bufio.Reader + + app *App + + quit chan struct{} + + infoName string + infoNameBak string + + wg sync.WaitGroup + + syncBuf bytes.Buffer +} + +func newMaster(app *App) *master { + m := new(master) + m.app = app + + m.infoName = path.Join(m.app.cfg.DataDir, "master.info") + m.infoNameBak = fmt.Sprintf("%s.bak", m.infoName) + + m.quit = make(chan struct{}, 1) + + //if load error, we will start a fullsync later + m.loadInfo() + + return m +} + +func (m *master) Close() { + select { + case m.quit <- struct{}{}: + default: + } + + if m.c != nil { + m.c.Close() + m.c = nil + } + + m.wg.Wait() +} + +func (m *master) loadInfo() error { + data, err := ioutil.ReadFile(m.infoName) + if err != nil { + if os.IsNotExist(err) { + return nil + } else { + return err + } + } + + if err = json.Unmarshal(data, m); err != nil { + return err + } + + return nil +} + +func (m *master) saveInfo() error { + data, err := json.Marshal(m) + if err != nil { + return err + } + + var fd *os.File + fd, err = os.OpenFile(m.infoNameBak, os.O_CREATE|os.O_WRONLY, os.ModePerm) + if err != nil { + return err + } + + if _, err = fd.Write(data); err != nil { + fd.Close() + return err + } + + fd.Close() + return os.Rename(m.infoNameBak, m.infoName) +} + +func (m *master) connect() error { + if len(m.addr) == 0 { + return fmt.Errorf("no assign master addr") + } + + if m.c != nil { + m.c.Close() + m.c = nil + } + + if c, err := net.Dial("tcp", m.addr); err != nil { + return err + } else { + m.c = c + + m.rb = bufio.NewReaderSize(m.c, 4096) + } + return nil +} + +func (m *master) resetInfo(addr string) { + m.addr = addr + m.logFileIndex = 0 + m.logPos = 0 +} + +func (m *master) stopReplication() error { + m.Close() + + if err := m.saveInfo(); err != nil { + log.Error("save master info error %s", err.Error()) + return err + } + + return nil +} + +func (m *master) startReplication(masterAddr string) error { + //stop last replcation, if avaliable + m.Close() + + if masterAddr != m.addr { + m.resetInfo(masterAddr) + if err := m.saveInfo(); err != nil { + log.Error("save master info error %s", err.Error()) + return err + } + } + + m.quit = make(chan struct{}, 1) + + go m.runReplication() + return nil +} + +func (m *master) runReplication() { + m.wg.Add(1) + defer m.wg.Done() + + for { + select { + case <-m.quit: + return + default: + if err := m.connect(); err != nil { + log.Error("connect master %s error %s, try 2s later", m.addr, err.Error()) + time.Sleep(2 * time.Second) + continue + } + } + + if m.logFileIndex == 0 { + //try a fullsync + if err := m.fullSync(); err != nil { + log.Warn("full sync error %s", err.Error()) + return + } + + if m.logFileIndex == 0 { + //master not support binlog, we cannot sync, so stop replication + m.stopReplication() + return + } + } + + for { + for { + lastIndex := m.logFileIndex + lastPos := m.logPos + if err := m.sync(); err != nil { + log.Warn("sync error %s", err.Error()) + return + } + + if m.logFileIndex == lastIndex && m.logPos == lastPos { + //sync no data, wait 1s and retry + break + } + } + + select { + case <-m.quit: + return + case <-time.After(1 * time.Second): + break + } + } + } + + return +} + +var ( + fullSyncCmd = []byte("*1\r\n$8\r\nfullsync\r\n") //fullsync + syncCmdFormat = "*3\r\n$4\r\nsync\r\n$%d\r\n%s\r\n$%d\r\n%s\r\n" //sync index pos +) + +func (m *master) fullSync() error { + if _, err := m.c.Write(fullSyncCmd); err != nil { + return err + } + + dumpPath := path.Join(m.app.cfg.DataDir, "master.dump") + f, err := os.OpenFile(dumpPath, os.O_CREATE|os.O_WRONLY, os.ModePerm) + if err != nil { + return err + } + + defer os.Remove(dumpPath) + + err = readBulkTo(m.rb, f) + f.Close() + if err != nil { + log.Error("read dump data error %s", err.Error()) + return err + } + + if err = m.app.ldb.FlushAll(); err != nil { + return err + } + + var head *ledis.MasterInfo + head, err = m.app.ldb.LoadDumpFile(dumpPath) + + if err != nil { + log.Error("load dump file error %s", err.Error()) + return err + } + + m.logFileIndex = head.LogFileIndex + m.logPos = head.LogPos + + return nil +} + +func (m *master) sync() error { + logIndexStr := strconv.FormatInt(m.logFileIndex, 10) + logPosStr := strconv.FormatInt(m.logPos, 10) + + cmd := ledis.Slice(fmt.Sprintf(syncCmdFormat, len(logIndexStr), + logIndexStr, len(logPosStr), logPosStr)) + if _, err := m.c.Write(cmd); err != nil { + return err + } + + m.syncBuf.Reset() + + err := readBulkTo(m.rb, &m.syncBuf) + if err != nil { + return err + } + + err = binary.Read(&m.syncBuf, binary.BigEndian, &m.logFileIndex) + if err != nil { + return err + } + + err = binary.Read(&m.syncBuf, binary.BigEndian, &m.logPos) + if err != nil { + return err + } + + if m.logFileIndex == 0 { + //master now not support binlog, stop replication + m.stopReplication() + return nil + } else if m.logFileIndex == -1 { + //-1 means than binlog index and pos are lost, we must start a full sync instead + return m.fullSync() + } + + err = m.app.ldb.ReplicateFromReader(&m.syncBuf) + if err != nil { + return err + } + + return nil + +} + +func (app *App) slaveof(masterAddr string) error { + app.m.Lock() + defer app.m.Unlock() + + if len(masterAddr) == 0 { + return app.m.stopReplication() + } else { + return app.m.startReplication(masterAddr) + } + + return nil +} diff --git a/server/util.go b/server/util.go new file mode 100644 index 0000000..9afee4e --- /dev/null +++ b/server/util.go @@ -0,0 +1,57 @@ +package server + +import ( + "bufio" + "errors" + "github.com/siddontang/ledisdb/ledis" + "io" + "strconv" +) + +var ( + errArrayFormat = errors.New("bad array format") + errBulkFormat = errors.New("bad bulk string format") + errLineFormat = errors.New("bad response line format") +) + +func readLine(rb *bufio.Reader) ([]byte, error) { + p, err := rb.ReadSlice('\n') + + if err != nil { + return nil, err + } + i := len(p) - 2 + if i < 0 || p[i] != '\r' { + return nil, errLineFormat + } + return p[:i], nil +} + +func readBulkTo(rb *bufio.Reader, w io.Writer) error { + l, err := readLine(rb) + if len(l) == 0 { + return errBulkFormat + } else if l[0] == '$' { + var n int + //handle resp string + if n, err = strconv.Atoi(ledis.String(l[1:])); err != nil { + return err + } else if n == -1 { + return nil + } else { + if _, err = io.CopyN(w, rb, int64(n)); err != nil { + return err + } + + if l, err = readLine(rb); err != nil { + return err + } else if len(l) != 0 { + return errBulkFormat + } + } + } else { + return errBulkFormat + } + + return nil +}