diff --git a/ledis/ledis.go b/ledis/ledis.go index e33ad74..50e9f16 100644 --- a/ledis/ledis.go +++ b/ledis/ledis.go @@ -58,7 +58,7 @@ func Open2(cfg *config.Config, flags int) (*Ledis, error) { return nil, err } - l.rc = make(chan struct{}) + l.rc = make(chan struct{}, 1) l.rbatch = l.ldb.NewWriteBatch() go l.onReplication() diff --git a/ledis/replication.go b/ledis/replication.go index 3473e6b..8763574 100644 --- a/ledis/replication.go +++ b/ledis/replication.go @@ -19,8 +19,11 @@ var ( ) func (l *Ledis) handleReplication() { + l.commitLock.Lock() + defer l.commitLock.Unlock() + l.rwg.Add(1) - var rl *rpl.Log + rl := &rpl.Log{} for { if err := l.r.NextCommitLog(rl); err != nil { if err != rpl.ErrNoBehindLog { @@ -59,33 +62,37 @@ func (l *Ledis) onReplication() { } func (l *Ledis) WaitReplication() error { - l.rwg.Wait() + b, err := l.r.CommitIDBehind() + if err != nil { + return err + } else if b { + l.rc <- struct{}{} + l.rwg.Wait() + } return nil } -func (l *Ledis) StoreLogsFromReader(rb io.Reader) (uint64, error) { +func (l *Ledis) StoreLogsFromReader(rb io.Reader) error { if l.r == nil { - return 0, fmt.Errorf("replication not enable") + return fmt.Errorf("replication not enable") } - var log *rpl.Log - var n uint64 + log := &rpl.Log{} for { if err := log.Decode(rb); err != nil { if err == io.EOF { break } else { - return 0, err + return err } } if err := l.r.StoreLog(log); err != nil { - return 0, err + return err } - n = log.ID } select { @@ -94,10 +101,10 @@ func (l *Ledis) StoreLogsFromReader(rb io.Reader) (uint64, error) { break } - return n, nil + return nil } -func (l *Ledis) StoreLogsFromData(data []byte) (uint64, error) { +func (l *Ledis) StoreLogsFromData(data []byte) error { rb := bytes.NewReader(data) return l.StoreLogsFromReader(rb) @@ -127,7 +134,7 @@ func (l *Ledis) ReadLogsTo(startLogID uint64, w io.Writer) (n int, nextLogID uin return } - var log *rpl.Log + log := &rpl.Log{} for i := startLogID; i <= lastID; i++ { if err = l.r.GetLog(i, log); err != nil { return diff --git a/ledis/replication_test.go b/ledis/replication_test.go index cc3a392..4cef10d 100644 --- a/ledis/replication_test.go +++ b/ledis/replication_test.go @@ -81,17 +81,14 @@ func TestReplication(t *testing.T) { var buf bytes.Buffer var n int var id uint64 = 1 - var nid uint64 for { buf.Reset() n, id, err = master.ReadLogsTo(id, &buf) if err != nil { t.Fatal(err) } else if n != 0 { - if nid, err = slave.StoreLogsFromReader(&buf); err != nil { + if err = slave.StoreLogsFromReader(&buf); err != nil { t.Fatal(err) - } else if nid != id { - t.Fatal(nid, id) } } else if n == 0 { break diff --git a/rpl/log.go b/rpl/log.go index ad637ca..775ea5d 100644 --- a/rpl/log.go +++ b/rpl/log.go @@ -90,7 +90,13 @@ func (l *Log) Decode(r io.Reader) error { length := binary.BigEndian.Uint32(buf[pos:]) - l.Data = make([]byte, length) + l.Data = l.Data[0:0] + + if cap(l.Data) >= int(length) { + l.Data = l.Data[0:length] + } else { + l.Data = make([]byte, length) + } if _, err := io.ReadFull(r, l.Data); err != nil { return err }