From 9765fa1da7b09653a45bce52cef29fd755273dcd Mon Sep 17 00:00:00 2001 From: luhuaei Date: Mon, 11 Apr 2022 20:49:28 +0800 Subject: [PATCH] sftpfs: implement Symlinker and RemoveAll --- sftpfs/sftp.go | 51 ++++++++++++++++++++++++++++++++++++++--- sftpfs/sftp_test.go | 56 +++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 100 insertions(+), 7 deletions(-) diff --git a/sftpfs/sftp.go b/sftpfs/sftp.go index eadf1e0..c2eeb65 100644 --- a/sftpfs/sftp.go +++ b/sftpfs/sftp.go @@ -15,12 +15,15 @@ package sftpfs import ( "os" + "path/filepath" "time" "github.com/pkg/sftp" "github.com/spf13/afero" ) +var _ afero.Symlinker = (*Fs)(nil) + // Fs is a afero.Fs implementation that uses functions provided by the sftp package. // // For details in any method, check the documentation of the sftp package @@ -110,9 +113,34 @@ func (s Fs) Remove(name string) error { } func (s Fs) RemoveAll(path string) error { - // TODO have a look at os.RemoveAll - // https://github.com/golang/go/blob/master/src/os/path.go#L66 - return nil + if path == "" { + return nil + } + + info, _, err := s.LstatIfPossible(path) + if err != nil { + return err + } + if !info.IsDir() { + return s.client.Remove(path) + } + + files, err := s.client.ReadDir(path) + if err != nil { + return err + } + for _, file := range files { + fp := filepath.Join(path, file.Name()) + if file.IsDir() { + err = s.RemoveAll(fp) + } else { + err = s.client.Remove(fp) + } + if err != nil { + return err + } + } + return s.client.RemoveDirectory(path) } func (s Fs) Rename(oldname, newname string) error { @@ -138,3 +166,20 @@ func (s Fs) Chown(name string, uid, gid int) error { func (s Fs) Chtimes(name string, atime time.Time, mtime time.Time) error { return s.client.Chtimes(name, atime, mtime) } + +func (s Fs) LstatIfPossible(name string) (os.FileInfo, bool, error) { + fi, err := s.client.Lstat(name) + if err == nil { + return fi, true, err + } + fi, err = s.client.Stat(name) + return fi, false, err +} + +func (s Fs) SymlinkIfPossible(oldname, newname string) error { + return s.client.Symlink(oldname, newname) +} + +func (s Fs) ReadlinkIfPossible(name string) (string, error) { + return s.client.ReadLink(name) +} diff --git a/sftpfs/sftp_test.go b/sftpfs/sftp_test.go index 4dba7fc..576f35b 100644 --- a/sftpfs/sftp_test.go +++ b/sftpfs/sftp_test.go @@ -28,6 +28,7 @@ import ( "time" "github.com/pkg/sftp" + "github.com/spf13/afero" "golang.org/x/crypto/ssh" ) @@ -234,9 +235,10 @@ func MakeSSHKeyPair(bits int, pubKeyPath, privateKeyPath string) error { return ioutil.WriteFile(pubKeyPath, ssh.MarshalAuthorizedKey(pub), 0655) } -func TestSftpCreate(t *testing.T) { +func TestSftp(t *testing.T) { os.Mkdir("./test", 0777) MakeSSHKeyPair(1024, "./test/id_rsa.pub", "./test/id_rsa") + defer os.RemoveAll("./test") go RunSftpServer("./test/") time.Sleep(5 * time.Second) @@ -254,16 +256,16 @@ func TestSftpCreate(t *testing.T) { fs.Chmod("test/foo", os.FileMode(0700)) fs.Mkdir("test/bar", os.FileMode(0777)) - file, err := fs.Create("file1") + file, err := fs.Create("./test/file1") if err != nil { t.Error(err) } - defer file.Close() file.Write([]byte("hello ")) file.WriteString("world!\n") + file.Close() - f1, err := fs.Open("file1") + f1, err := fs.Open("./test/file1") if err != nil { log.Fatalf("open: %v", err) } @@ -274,6 +276,52 @@ func TestSftpCreate(t *testing.T) { _, _ = f1.Read(b) fmt.Println(string(b)) + fs.MkdirAll("test/testdir1/testdir2", os.FileMode(0755)) + linker, ok := fs.(afero.Symlinker) + if !ok { + t.Fatal("not implement symlinker") + } + + err = linker.SymlinkIfPossible("./test/file1", "test/testdir1/testdir2/file1") + if err != nil { + t.Fatal(err) + } + _, success, err := linker.LstatIfPossible("test/testdir1/testdir2/file1") + if !success { + t.Fatal("link stat failed") + } + if err != nil { + t.Fatal(err) + } + + linkPath, err := linker.ReadlinkIfPossible("test/testdir1/testdir2/file1") + if err != nil { + t.Fatal(err) + } + if linkPath != "./test/file1" { + t.Fatal("linkpath error") + } + + err = fs.RemoveAll("test/testdir1/testdir2/file1") + if err != nil { + t.Fatal(err) + } + + _, err = fs.Stat("test/testdir1/testdir2/file1") + if !os.IsNotExist(err) { + t.Fatal("remove all failed") + } + + err = fs.RemoveAll("test/testdir1") + if err != nil { + t.Fatal(err) + } + + _, err = fs.Stat("test/testdir1") + if !os.IsNotExist(err) { + t.Fatal("remove all failed") + } + fmt.Println("done") // TODO check here if "hello\tworld\n" is in buffer b }