diff --git a/server/snapshot.go b/server/snapshot.go index 4068f48..1643160 100644 --- a/server/snapshot.go +++ b/server/snapshot.go @@ -28,12 +28,12 @@ type snapshotStore struct { } func snapshotName(t time.Time) string { - return fmt.Sprintf("snap-%s.dump", t.Format(snapshotTimeFormat)) + return fmt.Sprintf("snap-%s.dmp", t.Format(snapshotTimeFormat)) } func parseSnapshotName(name string) (time.Time, error) { var timeString string - if _, err := fmt.Sscanf(name, "snap-%s.dump", &timeString); err != nil { + if _, err := fmt.Sscanf(name, "snap-%s.dmp", &timeString); err != nil { return time.Time{}, err } when, err := time.Parse(snapshotTimeFormat, timeString) @@ -58,23 +58,10 @@ func newSnapshotStore(cfg *config.Config) (*snapshotStore, error) { s.quit = make(chan struct{}) - snapshots, err := ioutil.ReadDir(cfg.Snapshot.Path) - if err != nil { + if err := s.checkSnapshots(); err != nil { return nil, err } - for _, info := range snapshots { - if _, err := parseSnapshotName(info.Name()); err != nil { - log.Error("invalid snapshot file name %s, err: %s", info.Name(), err.Error()) - continue - } - - s.names = append(s.names, info.Name()) - } - - //from old to new - sort.Strings(s.names) - go s.run() return s, nil @@ -84,15 +71,51 @@ func (s *snapshotStore) Close() { close(s.quit) } +func (s *snapshotStore) checkSnapshots() error { + cfg := s.cfg + snapshots, err := ioutil.ReadDir(cfg.Snapshot.Path) + if err != nil { + log.Error("read %s error: %s", cfg.Snapshot.Path, err.Error()) + return err + } + + names := []string{} + for _, info := range snapshots { + if path.Ext(info.Name()) == ".tmp" { + log.Error("temp snapshot file name %s, try remove", info.Name()) + os.Remove(path.Join(cfg.Snapshot.Path, info.Name())) + continue + } + + if _, err := parseSnapshotName(info.Name()); err != nil { + log.Error("invalid snapshot file name %s, err: %s, try remove", info.Name(), err.Error()) + continue + } + + names = append(names, info.Name()) + } + + //from old to new + sort.Strings(names) + + s.names = names + + s.purge(false) + + return nil +} + func (s *snapshotStore) run() { - t := time.NewTicker(1 * time.Minute) + t := time.NewTicker(60 * time.Minute) defer t.Stop() for { select { case <-t.C: s.Lock() - s.purge(false) + if err := s.checkSnapshots(); err != nil { + log.Error("check snapshots error %s", err.Error()) + } s.Unlock() case <-s.quit: return @@ -138,6 +161,8 @@ type snapshot struct { io.ReadCloser f *os.File + + temp bool } func (st *snapshot) Read(b []byte) (int, error) { @@ -145,7 +170,15 @@ func (st *snapshot) Read(b []byte) (int, error) { } func (st *snapshot) Close() error { - return st.f.Close() + if st.temp { + name := st.f.Name() + if err := st.f.Close(); err != nil { + return err + } + return os.Remove(name) + } else { + return st.f.Close() + } } func (st *snapshot) Size() int64 { @@ -153,37 +186,58 @@ func (st *snapshot) Size() int64 { return s.Size() } -func (s *snapshotStore) Create(d snapshotDumper) (*snapshot, time.Time, error) { +func (s *snapshotStore) Create(d snapshotDumper, temp bool) (*snapshot, time.Time, error) { s.Lock() defer s.Unlock() - s.purge(true) + if !temp { + s.purge(true) + } now := time.Now() name := snapshotName(now) - if len(s.names) > 0 && s.names[len(s.names)-1] >= name { - return nil, time.Time{}, fmt.Errorf("create snapshot file time %s is behind %s ", name, s.names[len(s.names)-1]) + tmpName := name + ".tmp" + + if len(s.names) > 0 && !temp { + lastTime, _ := parseSnapshotName(s.names[len(s.names)-1]) + if !now.After(lastTime) { + return nil, time.Time{}, fmt.Errorf("create snapshot file time %s is behind %s ", + now.Format(snapshotTimeFormat), lastTime.Format(snapshotTimeFormat)) + } } - f, err := os.OpenFile(s.snapshotPath(name), os.O_RDWR|os.O_CREATE, 0644) + f, err := os.OpenFile(s.snapshotPath(tmpName), os.O_RDWR|os.O_CREATE, 0644) if err != nil { return nil, time.Time{}, err } if err := d.Dump(f); err != nil { f.Close() - os.Remove(s.snapshotPath(name)) + os.Remove(s.snapshotPath(tmpName)) return nil, time.Time{}, err } - f.Sync() + if temp { + if err := f.Sync(); err != nil { + f.Close() + return nil, time.Time{}, err + } - s.names = append(s.names, name) + f.Seek(0, os.SEEK_SET) + } else { + f.Close() + if err := os.Rename(s.snapshotPath(tmpName), s.snapshotPath(name)); err != nil { + return nil, time.Time{}, err + } - f.Seek(0, os.SEEK_SET) + if f, err = os.Open(s.snapshotPath(name)); err != nil { + return nil, time.Time{}, err + } + s.names = append(s.names, name) + } - return &snapshot{f: f}, now, nil + return &snapshot{f: f, temp: temp}, now, nil } func (s *snapshotStore) OpenLatest() (*snapshot, time.Time, error) { @@ -202,5 +256,5 @@ func (s *snapshotStore) OpenLatest() (*snapshot, time.Time, error) { return nil, time.Time{}, err } - return &snapshot{f: f}, t, err + return &snapshot{f: f, temp: false}, t, err } diff --git a/server/snapshot_test.go b/server/snapshot_test.go index c55e435..11051a8 100644 --- a/server/snapshot_test.go +++ b/server/snapshot_test.go @@ -30,7 +30,7 @@ func TestSnapshot(t *testing.T) { t.Fatal(err) } - if f, _, err := s.Create(d); err != nil { + if f, _, err := s.Create(d, false); err != nil { t.Fatal(err) } else { defer f.Close() @@ -43,7 +43,7 @@ func TestSnapshot(t *testing.T) { } } - if f, _, err := s.Create(d); err != nil { + if f, _, err := s.Create(d, false); err != nil { t.Fatal(err) } else { defer f.Close() @@ -55,7 +55,7 @@ func TestSnapshot(t *testing.T) { } } - if f, _, err := s.Create(d); err != nil { + if f, _, err := s.Create(d, false); err != nil { t.Fatal(err) } else { defer f.Close() @@ -73,5 +73,29 @@ func TestSnapshot(t *testing.T) { t.Fatal("must 2 snapshot") } + if f, _, err := s.Create(d, true); err != nil { + t.Fatal(err) + } else { + if b, _ := ioutil.ReadAll(f); string(b) != "hello world" { + t.Fatal("invalid read snapshot") + } + + if len(s.names) != 2 { + t.Fatal("must 2 snapshot") + } + + fs, _ = ioutil.ReadDir(cfg.Snapshot.Path) + if len(fs) != 3 { + t.Fatal("must 3 snapshot") + } + + f.Close() + } + + fs, _ = ioutil.ReadDir(cfg.Snapshot.Path) + if len(fs) != 2 { + t.Fatal("must 2 snapshot") + } + s.Close() }