From e54aac2c6a0e639c31cbfe95944a1e9c01c05606 Mon Sep 17 00:00:00 2001 From: Jamie Wilkinson Date: Sun, 22 Mar 2015 11:24:08 +1100 Subject: [PATCH] Apply locking in InMemoryFile --- README.md | 1 + fs_test.go | 21 +++++++++++++++++++++ memfile.go | 14 +++++++++++++- 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index e7d93b2..256dffe 100644 --- a/README.md +++ b/README.md @@ -166,6 +166,7 @@ Any Afero FileSystem can be used as an httpFs. Names in no particular order: * [spf13](https://github.com/spf13) +* [jaqx0r](https://github.com/jaqx0r) ## License diff --git a/fs_test.go b/fs_test.go index 4fe6afe..225100e 100644 --- a/fs_test.go +++ b/fs_test.go @@ -69,6 +69,27 @@ func TestRead0(t *testing.T) { } } +func TestMemFileRead(t *testing.T) { + f := MemFileCreate("testfile") + f.WriteString("abcd") + f.Seek(0, 0) + b := make([]byte, 8) + n, err := f.Read(b) + if n != 4 { + t.Errorf("didn't read all bytes: %v %v %v", n, err, b) + } + if err != nil { + t.Errorf("err is not nil: %v %v %v", n, err, b) + } + n, err = f.Read(b) + if n != 0 { + t.Errorf("read more bytes: %v %v %v", n, err, b) + } + if err != io.EOF { + t.Errorf("error is not EOF: %v %v %v", n, err, b) + } +} + func TestRename(t *testing.T) { for _, fs := range Fss { from, to := testDir+"/renamefrom", testDir+"/renameto" diff --git a/memfile.go b/memfile.go index 4815c6b..2e6751f 100644 --- a/memfile.go +++ b/memfile.go @@ -18,6 +18,7 @@ import ( "bytes" "io" "os" + "sync" "sync/atomic" ) @@ -32,6 +33,7 @@ type MemDir interface { } type InMemoryFile struct { + sync.Mutex at int64 name string data []byte @@ -48,13 +50,17 @@ func MemFileCreate(name string) *InMemoryFile { func (f *InMemoryFile) Open() error { atomic.StoreInt64(&f.at, 0) + f.Lock() f.closed = false + f.Unlock() return nil } func (f *InMemoryFile) Close() error { atomic.StoreInt64(&f.at, 0) + f.Lock() f.closed = true + f.Unlock() return nil } @@ -102,14 +108,18 @@ func (f *InMemoryFile) Readdirnames(n int) (names []string, err error) { } func (f *InMemoryFile) Read(b []byte) (n int, err error) { + f.Lock() + defer f.Unlock() if f.closed == true { return 0, ErrFileClosed } + if len(b) > 0 && int(f.at) == len(f.data) { + return 0, io.EOF + } if len(f.data)-int(f.at) >= len(b) { n = len(b) } else { n = len(f.data) - int(f.at) - err = io.EOF } copy(b, f.data[f.at:f.at+int64(n)]) atomic.AddInt64(&f.at, int64(n)) @@ -155,6 +165,8 @@ func (f *InMemoryFile) Seek(offset int64, whence int) (int64, error) { func (f *InMemoryFile) Write(b []byte) (n int, err error) { n = len(b) cur := atomic.LoadInt64(&f.at) + f.Lock() + defer f.Unlock() diff := cur - int64(len(f.data)) var tail []byte if n+int(cur) < len(f.data) {