diff --git a/util.go b/util.go index 84856c4..e757867 100644 --- a/util.go +++ b/util.go @@ -173,7 +173,7 @@ func (a Afero) FileContainsBytes(filename string, subslice []byte) (bool, error) return FileContainsBytes(a.Fs, filename, subslice) } -// Check if a file contains a specified string. +// Check if a file contains a specified byte slice. func FileContainsBytes(fs Fs, filename string, subslice []byte) (bool, error) { f, err := fs.Open(filename) if err != nil { @@ -181,17 +181,44 @@ func FileContainsBytes(fs Fs, filename string, subslice []byte) (bool, error) { } defer f.Close() - return readerContains(f, subslice), nil + return readerContainsAny(f, subslice), nil } -// readerContains reports whether subslice is within r. -func readerContains(r io.Reader, subslice []byte) bool { +func (a Afero) FileContainsAnyBytes(filename string, subslices [][]byte) (bool, error) { + return FileContainsAnyBytes(a.Fs, filename, subslices) +} - if r == nil || len(subslice) == 0 { +// Check if a file contains any of the specified byte slices. +func FileContainsAnyBytes(fs Fs, filename string, subslices [][]byte) (bool, error) { + f, err := fs.Open(filename) + if err != nil { + return false, err + } + defer f.Close() + + return readerContainsAny(f, subslices...), nil +} + +// readerContains reports whether any of the subslices is within r. +func readerContainsAny(r io.Reader, subslices ...[]byte) bool { + + if r == nil || len(subslices) == 0 { return false } - bufflen := len(subslice) * 4 + largestSlice := 0 + + for _, sl := range subslices { + if len(sl) > largestSlice { + largestSlice = len(sl) + } + } + + if largestSlice == 0 { + return false + } + + bufflen := largestSlice * 4 halflen := bufflen / 2 buff := make([]byte, bufflen) var err error @@ -209,8 +236,12 @@ func readerContains(r io.Reader, subslice []byte) bool { n, err = io.ReadAtLeast(r, buff[halflen:], halflen) } - if n > 0 && bytes.Contains(buff, subslice) { - return true + if n > 0 { + for _, sl := range subslices { + if bytes.Contains(buff, sl) { + return true + } + } } if err != nil { diff --git a/util_test.go b/util_test.go index d3dcdd1..50763df 100644 --- a/util_test.go +++ b/util_test.go @@ -142,28 +142,31 @@ func TestIsEmpty(t *testing.T) { func TestReaderContains(t *testing.T) { for i, this := range []struct { v1 string - v2 []byte + v2 [][]byte expect bool }{ - {"abc", []byte("a"), true}, - {"abc", []byte("b"), true}, - {"abcdefg", []byte("efg"), true}, - {"abc", []byte("d"), false}, + {"abc", [][]byte{[]byte("a")}, true}, + {"abc", [][]byte{[]byte("b")}, true}, + {"abcdefg", [][]byte{[]byte("efg")}, true}, + {"abc", [][]byte{[]byte("d")}, false}, + {"abc", [][]byte{[]byte("d"), []byte("e")}, false}, + {"abc", [][]byte{[]byte("d"), []byte("a")}, true}, + {"abc", [][]byte{[]byte("b"), []byte("e")}, true}, {"", nil, false}, - {"", []byte("a"), false}, - {"a", []byte(""), false}, - {"", []byte(""), false}} { - result := readerContains(strings.NewReader(this.v1), this.v2) + {"", [][]byte{[]byte("a")}, false}, + {"a", [][]byte{[]byte("")}, false}, + {"", [][]byte{[]byte("")}, false}} { + result := readerContainsAny(strings.NewReader(this.v1), this.v2...) if result != this.expect { t.Errorf("[%d] readerContains: got %t but expected %t", i, result, this.expect) } } - if readerContains(nil, []byte("a")) { + if readerContainsAny(nil, []byte("a")) { t.Error("readerContains with nil reader") } - if readerContains(nil, nil) { + if readerContainsAny(nil, nil) { t.Error("readerContains with nil arguments") } }