diff --git a/brotli_test.go b/brotli_test.go index 8f761bb..45b989e 100644 --- a/brotli_test.go +++ b/brotli_test.go @@ -472,6 +472,48 @@ func TestEncodeDecode(t *testing.T) { } } +func TestErrorReset(t *testing.T) { + compress := func(input []byte) []byte { + var buf bytes.Buffer + writer := new(Writer) + writer.Reset(&buf) + writer.Write(input) + writer.Close() + + return buf.Bytes() + } + + corruptReader := func(reader *Reader) { + buf := bytes.NewBuffer([]byte("trash")) + reader.Reset(buf) + _, err := io.ReadAll(reader) + if err == nil { + t.Fatalf("successively decompressed invalid input") + } + } + + decompress := func(input []byte, reader *Reader) []byte { + buf := bytes.NewBuffer(input) + reader.Reset(buf) + output, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("failed to decompress data %s", err.Error()) + } + + return output + } + + source := []byte("text") + + compressed := compress(source) + reader := new(Reader) + corruptReader(reader) + decompressed := decompress(compressed, reader) + if string(source) != string(decompressed) { + t.Fatalf("decompressed data does not match original state") + } +} + // Encode returns content encoded with Brotli. func Encode(content []byte, options WriterOptions) ([]byte, error) { var buf bytes.Buffer diff --git a/reader.go b/reader.go index b392a2f..9419c79 100644 --- a/reader.go +++ b/reader.go @@ -31,6 +31,12 @@ func NewReader(src io.Reader) *Reader { // This permits reusing a Reader rather than allocating a new one. // Error is always nil func (r *Reader) Reset(src io.Reader) error { + if r.error_code < 0 { + // There was an unrecoverable error, leaving the Reader's state + // undefined. Clear out everything but the buffer. + *r = Reader{buf: r.buf} + } + decoderStateInit(r) r.src = src if r.buf == nil {