diff --git a/basepath.go b/basepath.go index 4f98328..70a1d91 100644 --- a/basepath.go +++ b/basepath.go @@ -1,6 +1,7 @@ package afero import ( + "io/fs" "os" "path/filepath" "runtime" @@ -8,7 +9,10 @@ import ( "time" ) -var _ Lstater = (*BasePathFs)(nil) +var ( + _ Lstater = (*BasePathFs)(nil) + _ fs.ReadDirFile = (*BasePathFile)(nil) +) // The BasePathFs restricts all operations to a given path within an Fs. // The given file name to the operations on this Fs will be prepended with @@ -33,6 +37,14 @@ func (f *BasePathFile) Name() string { return strings.TrimPrefix(sourcename, filepath.Clean(f.path)) } +func (f *BasePathFile) ReadDir(n int) ([]fs.DirEntry, error) { + if rdf, ok := f.File.(fs.ReadDirFile); ok { + return rdf.ReadDir(n) + + } + return readDirFile{f.File}.ReadDir(n) +} + func NewBasePathFs(source Fs, path string) Fs { return &BasePathFs{source: source, path: path} } diff --git a/iofs.go b/iofs.go index 0135703..8bc9735 100644 --- a/iofs.go +++ b/iofs.go @@ -8,6 +8,7 @@ import ( "io/fs" "os" "path" + "sort" "time" ) @@ -67,11 +68,23 @@ func (iofs IOFS) Glob(pattern string) ([]string, error) { } func (iofs IOFS) ReadDir(name string) ([]fs.DirEntry, error) { - items, err := ReadDir(iofs.Fs, name) + f, err := iofs.Fs.Open(name) if err != nil { return nil, iofs.wrapError("readdir", name, err) } + defer f.Close() + + if rdf, ok := f.(fs.ReadDirFile); ok { + return rdf.ReadDir(-1) + } + + items, err := f.Readdir(-1) + if err != nil { + return nil, iofs.wrapError("readdir", name, err) + } + sort.Sort(byName(items)) + ret := make([]fs.DirEntry, len(items)) for i := range items { ret[i] = dirEntry{items[i]} diff --git a/iofs_test.go b/iofs_test.go index cb86eb4..32f0d9d 100644 --- a/iofs_test.go +++ b/iofs_test.go @@ -6,9 +6,11 @@ package afero import ( "bytes" "errors" + "fmt" "io" "io/fs" "os" + "path/filepath" "runtime" "testing" "testing/fstest" @@ -61,6 +63,77 @@ func TestIOFS(t *testing.T) { t.Error(err) } }) + +} + +func TestIOFSNativeDirEntryWhenPossible(t *testing.T) { + t.Parallel() + + osfs := NewBasePathFs(NewOsFs(), t.TempDir()) + + err := osfs.MkdirAll("dir1/dir2", os.ModePerm) + if err != nil { + t.Fatal(err) + } + + for i := 1; i <= 2; i++ { + f, err := osfs.Create(fmt.Sprintf("dir1/dir2/test%d.txt", i)) + if err != nil { + t.Fatal(err) + } + f.Close() + } + + dir2, err := osfs.Open("dir1/dir2") + if err != nil { + t.Fatal(err) + } + + assertDirEntries := func(entries []fs.DirEntry) { + if len(entries) != 2 { + t.Fatalf("expected 2, got %d", len(entries)) + } + for _, entry := range entries { + if _, ok := entry.(dirEntry); ok { + t.Fatal("DirEntry not native") + } + } + } + + dirEntries, err := dir2.(fs.ReadDirFile).ReadDir(-1) + if err != nil { + t.Fatal(err) + } + assertDirEntries(dirEntries) + + iofs := NewIOFS(osfs) + + fileCount := 0 + err = fs.WalkDir(iofs, "", func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + if !d.IsDir() { + fileCount++ + } + + if _, ok := d.(dirEntry); ok { + t.Fatal("DirEntry not native") + } + + return nil + + }) + + if err != nil { + t.Fatal(err) + } + + if fileCount != 2 { + t.Fatalf("expected 2, got %d", fileCount) + } + } func TestFromIOFS(t *testing.T) { @@ -416,3 +489,48 @@ func assertPermissionError(t *testing.T, err error) { t.Errorf("Expected (*fs.PathError).Err == fs.ErrPermisson, got %[1]T (%[1]v)", err) } } + +func BenchmarkWalkDir(b *testing.B) { + osfs := NewBasePathFs(NewOsFs(), b.TempDir()) + + createSomeFiles := func(dirname string) { + for i := 0; i < 10; i++ { + f, err := osfs.Create(filepath.Join(dirname, fmt.Sprintf("test%d.txt", i))) + if err != nil { + b.Fatal(err) + } + f.Close() + } + } + + depth := 10 + for level := depth; level > 0; level-- { + dirname := "" + for i := 0; i < level; i++ { + dirname = filepath.Join(dirname, fmt.Sprintf("dir%d", i)) + err := osfs.MkdirAll(dirname, 0755) + if err != nil && !os.IsExist(err) { + b.Fatal(err) + } + } + createSomeFiles(dirname) + } + + iofs := NewIOFS(osfs) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := fs.WalkDir(iofs, "", func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + return nil + + }) + + if err != nil { + b.Fatal(err) + } + } + +}