diff --git a/predicatefs.go b/predicatefs.go index 12f28e7..29332c9 100644 --- a/predicatefs.go +++ b/predicatefs.go @@ -8,10 +8,12 @@ import ( "time" ) +type predFn func(bool, string) bool + // FilePredicateFs filters files (not directories) by predicate, // which takes file path as an arg. type FilePredicateFs struct { - pred func(string) bool + pred predFn source Fs } @@ -19,49 +21,43 @@ var ( _ fs.ReadDirFile = (*PredicateFile)(nil) ) -func NewFilePredicateFs(source Fs, pred func(string) bool) Fs { +func NewFilePredicateFs(source Fs, pred predFn) Fs { return &FilePredicateFs{source: source, pred: pred} } type PredicateFile struct { f File - pred func(string) bool + pred predFn } func (p *FilePredicateFs) validate(path string) error { - if p.pred(path) { + dir, err := IsDir(p.source, path) + if err != nil { + return err + } + + if p.pred(dir, path) { return nil } return syscall.ENOENT } -func (p *FilePredicateFs) dirOrValidPath(path string) error { - dir, err := IsDir(p.source, path) - if err != nil { - return err - } - if dir { - return nil - } - return p.validate(path) -} - func (p *FilePredicateFs) Chtimes(path string, a, m time.Time) error { - if err := p.dirOrValidPath(path); err != nil { + if err := p.validate(path); err != nil { return err } return p.source.Chtimes(path, a, m) } func (p *FilePredicateFs) Chmod(path string, mode os.FileMode) error { - if err := p.dirOrValidPath(path); err != nil { + if err := p.validate(path); err != nil { return err } return p.source.Chmod(path, mode) } func (p *FilePredicateFs) Chown(path string, uid, gid int) error { - if err := p.dirOrValidPath(path); err != nil { + if err := p.validate(path); err != nil { return err } return p.source.Chown(path, uid, gid) @@ -72,7 +68,7 @@ func (p *FilePredicateFs) Name() string { } func (p *FilePredicateFs) Stat(path string) (os.FileInfo, error) { - if err := p.dirOrValidPath(path); err != nil { + if err := p.validate(path); err != nil { return nil, err } return p.source.Stat(path) @@ -96,42 +92,30 @@ func (p *FilePredicateFs) Rename(oldname, newname string) error { } func (p *FilePredicateFs) RemoveAll(path string) error { - dir, err := IsDir(p.source, path) - if err != nil { + if err := p.validate(path); err != nil { return err } - if !dir { - if err := p.validate(path); err != nil { - return err - } - } return p.source.RemoveAll(path) } func (p *FilePredicateFs) Remove(path string) error { - if err := p.dirOrValidPath(path); err != nil { + if err := p.validate(path); err != nil { return err } return p.source.Remove(path) } func (p *FilePredicateFs) OpenFile(path string, flag int, perm os.FileMode) (File, error) { - if err := p.dirOrValidPath(path); err != nil { + if err := p.validate(path); err != nil { return nil, err } return p.source.OpenFile(path, flag, perm) } func (p *FilePredicateFs) Open(path string) (File, error) { - dir, err := IsDir(p.source, path) - if err != nil { + if err := p.validate(path); err != nil { return nil, err } - if !dir { - if err := p.validate(path); err != nil { - return nil, err - } - } f, err := p.source.Open(path) if err != nil { return nil, err @@ -182,18 +166,18 @@ func (f *PredicateFile) Name() string { return f.f.Name() } -func (f *PredicateFile) Readdir(c int) (fi []os.FileInfo, err error) { - var pfi []os.FileInfo - pfi, err = f.f.Readdir(c) +func (f *PredicateFile) Readdir(c int) (filtered []os.FileInfo, err error) { + var infos []os.FileInfo + infos, err = f.f.Readdir(c) if err != nil { return nil, err } - for _, i := range pfi { - if i.IsDir() || f.pred(filepath.Join(f.f.Name(), i.Name())) { - fi = append(fi, i) + for _, i := range infos { + if f.pred(i.IsDir(), filepath.Join(f.f.Name(), i.Name())) { + filtered = append(filtered, i) } } - return fi, nil + return filtered, nil } func (f *PredicateFile) ReadDir(n int) (filtered []fs.DirEntry, err error) { @@ -207,12 +191,11 @@ func (f *PredicateFile) ReadDir(n int) (filtered []fs.DirEntry, err error) { return nil, err } for _, e := range entreis { - if e.IsDir() || f.pred(filepath.Join(f.f.Name(), e.Name())) { + if f.pred(e.IsDir(), filepath.Join(f.f.Name(), e.Name())) { filtered = append(filtered, e) } } return filtered, nil - } func (f *PredicateFile) Readdirnames(c int) (n []string, err error) { diff --git a/predicatefs_test.go b/predicatefs_test.go index 74040cd..03cad18 100644 --- a/predicatefs_test.go +++ b/predicatefs_test.go @@ -26,8 +26,8 @@ func TestFilePredicateFs(t *testing.T) { return strings.HasSuffix(filepath.Dir(path), ".hidden") } - pred := func(path string) bool { - return nonEmpty(path) && txtExts(path) && !inHiddenDir(path) + pred := func(isDir bool, path string) bool { + return isDir || (nonEmpty(path) && txtExts(path) && !inHiddenDir(path)) } fs := &FilePredicateFs{pred: pred, source: mfs}