diff --git a/brotli_test.go b/brotli_test.go
new file mode 100644
index 0000000..006fe6e
--- /dev/null
+++ b/brotli_test.go
@@ -0,0 +1,396 @@
+// Copyright 2016 Google Inc. All Rights Reserved.
+//
+// Distributed under MIT license.
+// See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
+
+package brotli
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "math"
+ "math/rand"
+ "testing"
+ "time"
+)
+
+func checkCompressedData(compressedData, wantOriginalData []byte) error {
+ uncompressed, err := Decode(compressedData)
+ if err != nil {
+ return fmt.Errorf("brotli decompress failed: %v", err)
+ }
+ if !bytes.Equal(uncompressed, wantOriginalData) {
+ if len(wantOriginalData) != len(uncompressed) {
+ return fmt.Errorf(""+
+ "Data doesn't uncompress to the original value.\n"+
+ "Length of original: %v\n"+
+ "Length of uncompressed: %v",
+ len(wantOriginalData), len(uncompressed))
+ }
+ for i := range wantOriginalData {
+ if wantOriginalData[i] != uncompressed[i] {
+ return fmt.Errorf(""+
+ "Data doesn't uncompress to the original value.\n"+
+ "Original at %v is %v\n"+
+ "Uncompressed at %v is %v",
+ i, wantOriginalData[i], i, uncompressed[i])
+ }
+ }
+ }
+ return nil
+}
+
+func TestEncoderNoWrite(t *testing.T) {
+ out := bytes.Buffer{}
+ e := NewWriter(&out, WriterOptions{Quality: 5})
+ if err := e.Close(); err != nil {
+ t.Errorf("Close()=%v, want nil", err)
+ }
+ // Check Write after close.
+ if _, err := e.Write([]byte("hi")); err == nil {
+ t.Errorf("No error after Close() + Write()")
+ }
+}
+
+func TestEncoderEmptyWrite(t *testing.T) {
+ out := bytes.Buffer{}
+ e := NewWriter(&out, WriterOptions{Quality: 5})
+ n, err := e.Write([]byte(""))
+ if n != 0 || err != nil {
+ t.Errorf("Write()=%v,%v, want 0, nil", n, err)
+ }
+ if err := e.Close(); err != nil {
+ t.Errorf("Close()=%v, want nil", err)
+ }
+}
+
+func TestWriter(t *testing.T) {
+ // Test basic encoder usage.
+ input := []byte("
Hello world
")
+ out := bytes.Buffer{}
+ e := NewWriter(&out, WriterOptions{Quality: 1})
+ in := bytes.NewReader([]byte(input))
+ n, err := io.Copy(e, in)
+ if err != nil {
+ t.Errorf("Copy Error: %v", err)
+ }
+ if int(n) != len(input) {
+ t.Errorf("Copy() n=%v, want %v", n, len(input))
+ }
+ if err := e.Close(); err != nil {
+ t.Errorf("Close Error after copied %d bytes: %v", n, err)
+ }
+ if err := checkCompressedData(out.Bytes(), input); err != nil {
+ t.Error(err)
+ }
+}
+
+func TestEncoderStreams(t *testing.T) {
+ // Test that output is streamed.
+ // Adjust window size to ensure the encoder outputs at least enough bytes
+ // to fill the window.
+ const lgWin = 16
+ windowSize := int(math.Pow(2, lgWin))
+ input := make([]byte, 8*windowSize)
+ rand.Read(input)
+ out := bytes.Buffer{}
+ e := NewWriter(&out, WriterOptions{Quality: 11, LGWin: lgWin})
+ halfInput := input[:len(input)/2]
+ in := bytes.NewReader(halfInput)
+
+ n, err := io.Copy(e, in)
+ if err != nil {
+ t.Errorf("Copy Error: %v", err)
+ }
+
+ // We've fed more data than the sliding window size. Check that some
+ // compressed data has been output.
+ if out.Len() == 0 {
+ t.Errorf("Output length is 0 after %d bytes written", n)
+ }
+ if err := e.Close(); err != nil {
+ t.Errorf("Close Error after copied %d bytes: %v", n, err)
+ }
+ if err := checkCompressedData(out.Bytes(), halfInput); err != nil {
+ t.Error(err)
+ }
+}
+
+func TestEncoderLargeInput(t *testing.T) {
+ input := make([]byte, 1000000)
+ rand.Read(input)
+ out := bytes.Buffer{}
+ e := NewWriter(&out, WriterOptions{Quality: 5})
+ in := bytes.NewReader(input)
+
+ n, err := io.Copy(e, in)
+ if err != nil {
+ t.Errorf("Copy Error: %v", err)
+ }
+ if int(n) != len(input) {
+ t.Errorf("Copy() n=%v, want %v", n, len(input))
+ }
+ if err := e.Close(); err != nil {
+ t.Errorf("Close Error after copied %d bytes: %v", n, err)
+ }
+ if err := checkCompressedData(out.Bytes(), input); err != nil {
+ t.Error(err)
+ }
+}
+
+func TestEncoderFlush(t *testing.T) {
+ input := make([]byte, 1000)
+ rand.Read(input)
+ out := bytes.Buffer{}
+ e := NewWriter(&out, WriterOptions{Quality: 5})
+ in := bytes.NewReader(input)
+ _, err := io.Copy(e, in)
+ if err != nil {
+ t.Fatalf("Copy Error: %v", err)
+ }
+ if err := e.Flush(); err != nil {
+ t.Fatalf("Flush(): %v", err)
+ }
+ if out.Len() == 0 {
+ t.Fatalf("0 bytes written after Flush()")
+ }
+ decompressed := make([]byte, 1000)
+ reader := NewReader(bytes.NewReader(out.Bytes()))
+ n, err := reader.Read(decompressed)
+ if n != len(decompressed) || err != nil {
+ t.Errorf("Expected <%v, nil>, but <%v, %v>", len(decompressed), n, err)
+ }
+ if !bytes.Equal(decompressed, input) {
+ t.Errorf(""+
+ "Decompress after flush: %v\n"+
+ "%q\n"+
+ "want:\n%q",
+ err, decompressed, input)
+ }
+ if err := e.Close(); err != nil {
+ t.Errorf("Close(): %v", err)
+ }
+}
+
+type readerWithTimeout struct {
+ io.Reader
+}
+
+func (r readerWithTimeout) Read(p []byte) (int, error) {
+ type result struct {
+ n int
+ err error
+ }
+ ch := make(chan result)
+ go func() {
+ n, err := r.Reader.Read(p)
+ ch <- result{n, err}
+ }()
+ select {
+ case result := <-ch:
+ return result.n, result.err
+ case <-time.After(5 * time.Second):
+ return 0, fmt.Errorf("read timed out")
+ }
+}
+
+func TestDecoderStreaming(t *testing.T) {
+ pr, pw := io.Pipe()
+ writer := NewWriter(pw, WriterOptions{Quality: 5, LGWin: 20})
+ reader := readerWithTimeout{NewReader(pr)}
+ defer func() {
+ go ioutil.ReadAll(pr) // swallow the "EOF" token from writer.Close
+ if err := writer.Close(); err != nil {
+ t.Errorf("writer.Close: %v", err)
+ }
+ }()
+
+ ch := make(chan []byte)
+ errch := make(chan error)
+ go func() {
+ for {
+ segment, ok := <-ch
+ if !ok {
+ return
+ }
+ if n, err := writer.Write(segment); err != nil || n != len(segment) {
+ errch <- fmt.Errorf("write=%v,%v, want %v,%v", n, err, len(segment), nil)
+ return
+ }
+ if err := writer.Flush(); err != nil {
+ errch <- fmt.Errorf("flush: %v", err)
+ return
+ }
+ }
+ }()
+ defer close(ch)
+
+ segments := [...][]byte{
+ []byte("first"),
+ []byte("second"),
+ []byte("third"),
+ }
+ for k, segment := range segments {
+ t.Run(fmt.Sprintf("Segment%d", k), func(t *testing.T) {
+ select {
+ case ch <- segment:
+ case err := <-errch:
+ t.Fatalf("write: %v", err)
+ case <-time.After(5 * time.Second):
+ t.Fatalf("timed out")
+ }
+ wantLen := len(segment)
+ got := make([]byte, wantLen)
+ if n, err := reader.Read(got); err != nil || n != wantLen || !bytes.Equal(got, segment) {
+ t.Fatalf("read[%d]=%q,%v,%v, want %q,%v,%v", k, got, n, err, segment, wantLen, nil)
+ }
+ })
+ }
+}
+
+func TestReader(t *testing.T) {
+ content := bytes.Repeat([]byte("hello world!"), 10000)
+ encoded, _ := Encode(content, WriterOptions{Quality: 5})
+ r := NewReader(bytes.NewReader(encoded))
+ var decodedOutput bytes.Buffer
+ n, err := io.Copy(&decodedOutput, r)
+ if err != nil {
+ t.Fatalf("Copy(): n=%v, err=%v", n, err)
+ }
+ if got := decodedOutput.Bytes(); !bytes.Equal(got, content) {
+ t.Errorf(""+
+ "Reader output:\n"+
+ "%q\n"+
+ "want:\n"+
+ "<%d bytes>",
+ got, len(content))
+ }
+}
+
+func TestDecode(t *testing.T) {
+ content := bytes.Repeat([]byte("hello world!"), 10000)
+ encoded, _ := Encode(content, WriterOptions{Quality: 5})
+ decoded, err := Decode(encoded)
+ if err != nil {
+ t.Errorf("Decode: %v", err)
+ }
+ if !bytes.Equal(decoded, content) {
+ t.Errorf(""+
+ "Decode content:\n"+
+ "%q\n"+
+ "want:\n"+
+ "<%d bytes>",
+ decoded, len(content))
+ }
+}
+
+func TestQuality(t *testing.T) {
+ content := bytes.Repeat([]byte("hello world!"), 10000)
+ for q := 0; q < 12; q++ {
+ encoded, _ := Encode(content, WriterOptions{Quality: q})
+ decoded, err := Decode(encoded)
+ if err != nil {
+ t.Errorf("Decode: %v", err)
+ }
+ if !bytes.Equal(decoded, content) {
+ t.Errorf(""+
+ "Decode content:\n"+
+ "%q\n"+
+ "want:\n"+
+ "<%d bytes>",
+ decoded, len(content))
+ }
+ }
+}
+
+func TestDecodeFuzz(t *testing.T) {
+ // Test that the decoder terminates with corrupted input.
+ content := bytes.Repeat([]byte("hello world!"), 100)
+ src := rand.NewSource(0)
+ encoded, err := Encode(content, WriterOptions{Quality: 5})
+ if err != nil {
+ t.Fatalf("Encode(<%d bytes>, _) = _, %s", len(content), err)
+ }
+ if len(encoded) == 0 {
+ t.Fatalf("Encode(<%d bytes>, _) produced empty output", len(content))
+ }
+ for i := 0; i < 100; i++ {
+ enc := append([]byte{}, encoded...)
+ for j := 0; j < 5; j++ {
+ enc[int(src.Int63())%len(enc)] = byte(src.Int63() % 256)
+ }
+ Decode(enc)
+ }
+}
+
+func TestDecodeTrailingData(t *testing.T) {
+ content := bytes.Repeat([]byte("hello world!"), 100)
+ encoded, _ := Encode(content, WriterOptions{Quality: 5})
+ _, err := Decode(append(encoded, 0))
+ if err == nil {
+ t.Errorf("Expected 'excessive input' error")
+ }
+}
+
+func TestEncodeDecode(t *testing.T) {
+ for _, test := range []struct {
+ data []byte
+ repeats int
+ }{
+ {nil, 0},
+ {[]byte("A"), 1},
+ {[]byte("Hello world
"), 10},
+ {[]byte("Hello world
"), 1000},
+ } {
+ t.Logf("case %q x %d", test.data, test.repeats)
+ input := bytes.Repeat(test.data, test.repeats)
+ encoded, err := Encode(input, WriterOptions{Quality: 5})
+ if err != nil {
+ t.Errorf("Encode: %v", err)
+ }
+ // Inputs are compressible, but may be too small to compress.
+ if maxSize := len(input)/2 + 20; len(encoded) >= maxSize {
+ t.Errorf(""+
+ "Encode returned %d bytes, want <%d\n"+
+ "Encoded=%q",
+ len(encoded), maxSize, encoded)
+ }
+ decoded, err := Decode(encoded)
+ if err != nil {
+ t.Errorf("Decode: %v", err)
+ }
+ if !bytes.Equal(decoded, input) {
+ var want string
+ if len(input) > 320 {
+ want = fmt.Sprintf("<%d bytes>", len(input))
+ } else {
+ want = fmt.Sprintf("%q", input)
+ }
+ t.Errorf(""+
+ "Decode content:\n"+
+ "%q\n"+
+ "want:\n"+
+ "%s",
+ decoded, want)
+ }
+ }
+}
+
+// Encode returns content encoded with Brotli.
+func Encode(content []byte, options WriterOptions) ([]byte, error) {
+ var buf bytes.Buffer
+ writer := NewWriter(&buf, options)
+ _, err := writer.Write(content)
+ if closeErr := writer.Close(); err == nil {
+ err = closeErr
+ }
+ return buf.Bytes(), err
+}
+
+// Decode decodes Brotli encoded data.
+func Decode(encodedData []byte) ([]byte, error) {
+ r := NewReader(bytes.NewReader(encodedData))
+ return ioutil.ReadAll(r)
+}
diff --git a/decode.go b/decode.go
index 38b7153..3ef969d 100644
--- a/decode.go
+++ b/decode.go
@@ -114,7 +114,7 @@ var kCodeLengthPrefixLength = [16]byte{2, 2, 2, 3, 2, 2, 2, 4, 2, 2, 2, 3, 2, 2,
var kCodeLengthPrefixValue = [16]byte{0, 4, 3, 2, 0, 4, 3, 1, 0, 4, 3, 2, 0, 4, 3, 5}
-func BrotliDecoderSetParameter(state *BrotliDecoderState, p int, value uint32) bool {
+func BrotliDecoderSetParameter(state *Reader, p int, value uint32) bool {
if state.state != BROTLI_STATE_UNINITED {
return false
}
@@ -136,9 +136,9 @@ func BrotliDecoderSetParameter(state *BrotliDecoderState, p int, value uint32) b
}
}
-func BrotliDecoderCreateInstance() *BrotliDecoderState {
- var state *BrotliDecoderState
- state = new(BrotliDecoderState)
+func BrotliDecoderCreateInstance() *Reader {
+ var state *Reader
+ state = new(Reader)
if state == nil {
return nil
}
@@ -151,7 +151,7 @@ func BrotliDecoderCreateInstance() *BrotliDecoderState {
}
/* Deinitializes and frees BrotliDecoderState instance. */
-func BrotliDecoderDestroyInstance(state *BrotliDecoderState) {
+func BrotliDecoderDestroyInstance(state *Reader) {
if state == nil {
return
} else {
@@ -160,7 +160,7 @@ func BrotliDecoderDestroyInstance(state *BrotliDecoderState) {
}
/* Saves error code and converts it to BrotliDecoderResult. */
-func SaveErrorCode(s *BrotliDecoderState, e int) int {
+func SaveErrorCode(s *Reader, e int) int {
s.error_code = int(e)
switch e {
case BROTLI_DECODER_SUCCESS:
@@ -179,7 +179,7 @@ func SaveErrorCode(s *BrotliDecoderState, e int) int {
/* Decodes WBITS by reading 1 - 7 bits, or 0x11 for "Large Window Brotli".
Precondition: bit-reader accumulator has at least 8 bits. */
-func DecodeWindowBits(s *BrotliDecoderState, br *BrotliBitReader) int {
+func DecodeWindowBits(s *Reader, br *BrotliBitReader) int {
var n uint32
var large_window bool = s.large_window
s.large_window = false
@@ -220,7 +220,7 @@ func DecodeWindowBits(s *BrotliDecoderState, br *BrotliBitReader) int {
}
/* Decodes a number in the range [0..255], by reading 1 - 11 bits. */
-func DecodeVarLenUint8(s *BrotliDecoderState, br *BrotliBitReader, value *uint32) int {
+func DecodeVarLenUint8(s *Reader, br *BrotliBitReader, value *uint32) int {
var bits uint32
switch s.substate_decode_uint8 {
case BROTLI_STATE_DECODE_UINT8_NONE:
@@ -268,7 +268,7 @@ func DecodeVarLenUint8(s *BrotliDecoderState, br *BrotliBitReader, value *uint32
}
/* Decodes a metablock length and flags by reading 2 - 31 bits. */
-func DecodeMetaBlockLength(s *BrotliDecoderState, br *BrotliBitReader) int {
+func DecodeMetaBlockLength(s *Reader, br *BrotliBitReader) int {
var bits uint32
var i int
for {
@@ -538,7 +538,7 @@ func Log2Floor(x uint32) uint32 {
/* Reads (s->symbol + 1) symbols.
Totally 1..4 symbols are read, 1..11 bits each.
The list of symbols MUST NOT contain duplicates. */
-func ReadSimpleHuffmanSymbols(alphabet_size uint32, max_symbol uint32, s *BrotliDecoderState) int {
+func ReadSimpleHuffmanSymbols(alphabet_size uint32, max_symbol uint32, s *Reader) int {
var br *BrotliBitReader = &s.br
var max_bits uint32 = Log2Floor(alphabet_size - 1)
var i uint32 = s.sub_loop_counter
@@ -651,7 +651,7 @@ func ProcessRepeatedCodeLength(code_len uint32, repeat_delta uint32, alphabet_si
}
/* Reads and decodes symbol codelengths. */
-func ReadSymbolCodeLengths(alphabet_size uint32, s *BrotliDecoderState) int {
+func ReadSymbolCodeLengths(alphabet_size uint32, s *Reader) int {
var br *BrotliBitReader = &s.br
var symbol uint32 = s.symbol
var repeat uint32 = s.repeat
@@ -700,7 +700,7 @@ func ReadSymbolCodeLengths(alphabet_size uint32, s *BrotliDecoderState) int {
return BROTLI_DECODER_SUCCESS
}
-func SafeReadSymbolCodeLengths(alphabet_size uint32, s *BrotliDecoderState) int {
+func SafeReadSymbolCodeLengths(alphabet_size uint32, s *Reader) int {
var br *BrotliBitReader = &s.br
var get_byte bool = false
var p []HuffmanCode
@@ -746,7 +746,7 @@ func SafeReadSymbolCodeLengths(alphabet_size uint32, s *BrotliDecoderState) int
/* Reads and decodes 15..18 codes using static prefix code.
Each code is 2..4 bits long. In total 30..72 bits are used. */
-func ReadCodeLengthCodeLengths(s *BrotliDecoderState) int {
+func ReadCodeLengthCodeLengths(s *Reader) int {
var br *BrotliBitReader = &s.br
var num_codes uint32 = s.repeat
var space uint32 = s.space
@@ -804,7 +804,7 @@ func ReadCodeLengthCodeLengths(s *BrotliDecoderState) int {
encoded with predefined entropy code. 32 - 74 bits are used.
B.2) Decoded table is used to decode code lengths of symbols in resulting
Huffman table. In worst case 3520 bits are read. */
-func ReadHuffmanCode(alphabet_size uint32, max_symbol uint32, table []HuffmanCode, opt_table_size *uint32, s *BrotliDecoderState) int {
+func ReadHuffmanCode(alphabet_size uint32, max_symbol uint32, table []HuffmanCode, opt_table_size *uint32, s *Reader) int {
var br *BrotliBitReader = &s.br
/* Unnecessary masking, but might be good for safety. */
@@ -954,7 +954,7 @@ func ReadBlockLength(table []HuffmanCode, br *BrotliBitReader) uint32 {
/* WARNING: if state is not BROTLI_STATE_READ_BLOCK_LENGTH_NONE, then
reading can't be continued with ReadBlockLength. */
-func SafeReadBlockLength(s *BrotliDecoderState, result *uint32, table []HuffmanCode, br *BrotliBitReader) bool {
+func SafeReadBlockLength(s *Reader, result *uint32, table []HuffmanCode, br *BrotliBitReader) bool {
var index uint32
if s.substate_read_block_length == BROTLI_STATE_READ_BLOCK_LENGTH_NONE {
if !SafeReadSymbol(table, br, &index) {
@@ -992,7 +992,7 @@ func SafeReadBlockLength(s *BrotliDecoderState, result *uint32, table []HuffmanC
Most of input values are 0 and 1. To reduce number of branches, we replace
inner for loop with do-while. */
-func InverseMoveToFrontTransform(v []byte, v_len uint32, state *BrotliDecoderState) {
+func InverseMoveToFrontTransform(v []byte, v_len uint32, state *Reader) {
var mtf [256]byte
var i int
for i = 1; i < 256; i++ {
@@ -1016,7 +1016,7 @@ func InverseMoveToFrontTransform(v []byte, v_len uint32, state *BrotliDecoderSta
}
/* Decodes a series of Huffman table using ReadHuffmanCode function. */
-func HuffmanTreeGroupDecode(group *HuffmanTreeGroup, s *BrotliDecoderState) int {
+func HuffmanTreeGroupDecode(group *HuffmanTreeGroup, s *Reader) int {
if s.substate_tree_group != BROTLI_STATE_TREE_GROUP_LOOP {
s.next = group.codes
s.htree_index = 0
@@ -1046,7 +1046,7 @@ func HuffmanTreeGroupDecode(group *HuffmanTreeGroup, s *BrotliDecoderState) int
This table will be used for reading context map items.
3) Read context map items; "0" values could be run-length encoded.
4) Optionally, apply InverseMoveToFront transform to the resulting map. */
-func DecodeContextMap(context_map_size uint32, num_htrees *uint32, context_map_arg *[]byte, s *BrotliDecoderState) int {
+func DecodeContextMap(context_map_size uint32, num_htrees *uint32, context_map_arg *[]byte, s *Reader) int {
var br *BrotliBitReader = &s.br
var result int = BROTLI_DECODER_SUCCESS
@@ -1192,7 +1192,7 @@ func DecodeContextMap(context_map_size uint32, num_htrees *uint32, context_map_a
/* Decodes a command or literal and updates block type ring-buffer.
Reads 3..54 bits. */
-func DecodeBlockTypeAndLength(safe int, s *BrotliDecoderState, tree_type int) bool {
+func DecodeBlockTypeAndLength(safe int, s *Reader, tree_type int) bool {
var max_block_type uint32 = s.num_block_types[tree_type]
var type_tree []HuffmanCode
type_tree = s.block_type_trees[tree_type*BROTLI_HUFFMAN_MAX_SIZE_258:]
@@ -1239,7 +1239,7 @@ func DecodeBlockTypeAndLength(safe int, s *BrotliDecoderState, tree_type int) bo
return true
}
-func DetectTrivialLiteralBlockTypes(s *BrotliDecoderState) {
+func DetectTrivialLiteralBlockTypes(s *Reader) {
var i uint
for i = 0; i < 8; i++ {
s.trivial_literal_contexts[i] = 0
@@ -1263,7 +1263,7 @@ func DetectTrivialLiteralBlockTypes(s *BrotliDecoderState) {
}
}
-func PrepareLiteralDecoding(s *BrotliDecoderState) {
+func PrepareLiteralDecoding(s *Reader) {
var context_mode byte
var trivial uint
var block_type uint32 = s.block_type_rb[1]
@@ -1278,7 +1278,7 @@ func PrepareLiteralDecoding(s *BrotliDecoderState) {
/* Decodes the block type and updates the state for literal context.
Reads 3..54 bits. */
-func DecodeLiteralBlockSwitchInternal(safe int, s *BrotliDecoderState) bool {
+func DecodeLiteralBlockSwitchInternal(safe int, s *Reader) bool {
if !DecodeBlockTypeAndLength(safe, s, 0) {
return false
}
@@ -1287,17 +1287,17 @@ func DecodeLiteralBlockSwitchInternal(safe int, s *BrotliDecoderState) bool {
return true
}
-func DecodeLiteralBlockSwitch(s *BrotliDecoderState) {
+func DecodeLiteralBlockSwitch(s *Reader) {
DecodeLiteralBlockSwitchInternal(0, s)
}
-func SafeDecodeLiteralBlockSwitch(s *BrotliDecoderState) bool {
+func SafeDecodeLiteralBlockSwitch(s *Reader) bool {
return DecodeLiteralBlockSwitchInternal(1, s)
}
/* Block switch for insert/copy length.
Reads 3..54 bits. */
-func DecodeCommandBlockSwitchInternal(safe int, s *BrotliDecoderState) bool {
+func DecodeCommandBlockSwitchInternal(safe int, s *Reader) bool {
if !DecodeBlockTypeAndLength(safe, s, 1) {
return false
}
@@ -1306,17 +1306,17 @@ func DecodeCommandBlockSwitchInternal(safe int, s *BrotliDecoderState) bool {
return true
}
-func DecodeCommandBlockSwitch(s *BrotliDecoderState) {
+func DecodeCommandBlockSwitch(s *Reader) {
DecodeCommandBlockSwitchInternal(0, s)
}
-func SafeDecodeCommandBlockSwitch(s *BrotliDecoderState) bool {
+func SafeDecodeCommandBlockSwitch(s *Reader) bool {
return DecodeCommandBlockSwitchInternal(1, s)
}
/* Block switch for distance codes.
Reads 3..54 bits. */
-func DecodeDistanceBlockSwitchInternal(safe int, s *BrotliDecoderState) bool {
+func DecodeDistanceBlockSwitchInternal(safe int, s *Reader) bool {
if !DecodeBlockTypeAndLength(safe, s, 2) {
return false
}
@@ -1326,15 +1326,15 @@ func DecodeDistanceBlockSwitchInternal(safe int, s *BrotliDecoderState) bool {
return true
}
-func DecodeDistanceBlockSwitch(s *BrotliDecoderState) {
+func DecodeDistanceBlockSwitch(s *Reader) {
DecodeDistanceBlockSwitchInternal(0, s)
}
-func SafeDecodeDistanceBlockSwitch(s *BrotliDecoderState) bool {
+func SafeDecodeDistanceBlockSwitch(s *Reader) bool {
return DecodeDistanceBlockSwitchInternal(1, s)
}
-func UnwrittenBytes(s *BrotliDecoderState, wrap bool) uint {
+func UnwrittenBytes(s *Reader, wrap bool) uint {
var pos uint
if wrap && s.pos > s.ringbuffer_size {
pos = uint(s.ringbuffer_size)
@@ -1348,7 +1348,7 @@ func UnwrittenBytes(s *BrotliDecoderState, wrap bool) uint {
/* Dumps output.
Returns BROTLI_DECODER_NEEDS_MORE_OUTPUT only if there is more output to push
and either ring-buffer is as big as window size, or |force| is true. */
-func WriteRingBuffer(s *BrotliDecoderState, available_out *uint, next_out *[]byte, total_out *uint, force bool) int {
+func WriteRingBuffer(s *Reader, available_out *uint, next_out *[]byte, total_out *uint, force bool) int {
var start []byte
start = s.ringbuffer[s.partial_pos_out&uint(s.ringbuffer_mask):]
var to_write uint = UnwrittenBytes(s, true)
@@ -1398,7 +1398,7 @@ func WriteRingBuffer(s *BrotliDecoderState, available_out *uint, next_out *[]byt
return BROTLI_DECODER_SUCCESS
}
-func WrapRingBuffer(s *BrotliDecoderState) {
+func WrapRingBuffer(s *Reader) {
if s.should_wrap_ringbuffer != 0 {
copy(s.ringbuffer, s.ringbuffer_end[:uint(s.pos)])
s.should_wrap_ringbuffer = 0
@@ -1412,7 +1412,7 @@ func WrapRingBuffer(s *BrotliDecoderState) {
Last two bytes of ring-buffer are initialized to 0, so context calculation
could be done uniformly for the first two and all other positions. */
-func BrotliEnsureRingBuffer(s *BrotliDecoderState) bool {
+func BrotliEnsureRingBuffer(s *Reader) bool {
var old_ringbuffer []byte = s.ringbuffer
if s.ringbuffer_size == s.new_ringbuffer_size {
return true
@@ -1442,7 +1442,7 @@ func BrotliEnsureRingBuffer(s *BrotliDecoderState) bool {
return true
}
-func CopyUncompressedBlockToOutput(available_out *uint, next_out *[]byte, total_out *uint, s *BrotliDecoderState) int {
+func CopyUncompressedBlockToOutput(available_out *uint, next_out *[]byte, total_out *uint, s *Reader) int {
/* TODO: avoid allocation for single uncompressed block. */
if !BrotliEnsureRingBuffer(s) {
return BROTLI_DECODER_ERROR_ALLOC_RING_BUFFER_1
@@ -1508,7 +1508,7 @@ func CopyUncompressedBlockToOutput(available_out *uint, next_out *[]byte, total_
size than needed to reduce memory usage.
When this method is called, metablock size and flags MUST be decoded. */
-func BrotliCalculateRingBufferSize(s *BrotliDecoderState) {
+func BrotliCalculateRingBufferSize(s *Reader) {
var window_size int = 1 << s.window_bits
var new_ringbuffer_size int = window_size
var min_size int
@@ -1557,7 +1557,7 @@ func BrotliCalculateRingBufferSize(s *BrotliDecoderState) {
}
/* Reads 1..256 2-bit context modes. */
-func ReadContextModes(s *BrotliDecoderState) int {
+func ReadContextModes(s *Reader) int {
var br *BrotliBitReader = &s.br
var i int = s.loop_counter
@@ -1575,7 +1575,7 @@ func ReadContextModes(s *BrotliDecoderState) int {
return BROTLI_DECODER_SUCCESS
}
-func TakeDistanceFromRingBuffer(s *BrotliDecoderState) {
+func TakeDistanceFromRingBuffer(s *Reader) {
if s.distance_code == 0 {
s.dist_rb_idx--
s.distance_code = s.dist_rb[s.dist_rb_idx&3]
@@ -1618,7 +1618,7 @@ func SafeReadBits(br *BrotliBitReader, n_bits uint32, val *uint32) bool {
}
/* Precondition: s->distance_code < 0. */
-func ReadDistanceInternal(safe int, s *BrotliDecoderState, br *BrotliBitReader) bool {
+func ReadDistanceInternal(safe int, s *Reader, br *BrotliBitReader) bool {
var distval int
var memento BrotliBitReaderState
var distance_tree []HuffmanCode = []HuffmanCode(s.distance_hgroup.htrees[s.dist_htree_index])
@@ -1679,15 +1679,15 @@ func ReadDistanceInternal(safe int, s *BrotliDecoderState, br *BrotliBitReader)
return true
}
-func ReadDistance(s *BrotliDecoderState, br *BrotliBitReader) {
+func ReadDistance(s *Reader, br *BrotliBitReader) {
ReadDistanceInternal(0, s, br)
}
-func SafeReadDistance(s *BrotliDecoderState, br *BrotliBitReader) bool {
+func SafeReadDistance(s *Reader, br *BrotliBitReader) bool {
return ReadDistanceInternal(1, s, br)
}
-func ReadCommandInternal(safe int, s *BrotliDecoderState, br *BrotliBitReader, insert_length *int) bool {
+func ReadCommandInternal(safe int, s *Reader, br *BrotliBitReader, insert_length *int) bool {
var cmd_code uint32
var insert_len_extra uint32 = 0
var copy_length uint32
@@ -1726,11 +1726,11 @@ func ReadCommandInternal(safe int, s *BrotliDecoderState, br *BrotliBitReader, i
return true
}
-func ReadCommand(s *BrotliDecoderState, br *BrotliBitReader, insert_length *int) {
+func ReadCommand(s *Reader, br *BrotliBitReader, insert_length *int) {
ReadCommandInternal(0, s, br, insert_length)
}
-func SafeReadCommand(s *BrotliDecoderState, br *BrotliBitReader, insert_length *int) bool {
+func SafeReadCommand(s *Reader, br *BrotliBitReader, insert_length *int) bool {
return ReadCommandInternal(1, s, br, insert_length)
}
@@ -1742,7 +1742,7 @@ func CheckInputAmount(safe int, br *BrotliBitReader, num uint) bool {
return BrotliCheckInputAmount(br, num)
}
-func ProcessCommandsInternal(safe int, s *BrotliDecoderState) int {
+func ProcessCommandsInternal(safe int, s *Reader) int {
var pos int = s.pos
var i int = s.loop_counter
var result int = BROTLI_DECODER_SUCCESS
@@ -2110,11 +2110,11 @@ saveStateAndReturn:
return result
}
-func ProcessCommands(s *BrotliDecoderState) int {
+func ProcessCommands(s *Reader) int {
return ProcessCommandsInternal(0, s)
}
-func SafeProcessCommands(s *BrotliDecoderState) int {
+func SafeProcessCommands(s *Reader) int {
return ProcessCommandsInternal(1, s)
}
@@ -2136,7 +2136,7 @@ func BrotliMaxDistanceSymbol(ndirect uint32, npostfix uint32) uint32 {
}
func BrotliDecoderDecompress(encoded_size uint, encoded_buffer []byte, decoded_size *uint, decoded_buffer []byte) int {
- var s BrotliDecoderState
+ var s Reader
var result int
var total_out uint = 0
var available_in uint = encoded_size
@@ -2168,7 +2168,7 @@ func BrotliDecoderDecompress(encoded_size uint, encoded_buffer []byte, decoded_s
buffer ahead of time
- when result is "success" decoder MUST return all unused data back to input
buffer; this is possible because the invariant is held on enter */
-func BrotliDecoderDecompressStream(s *BrotliDecoderState, available_in *uint, next_in *[]byte, available_out *uint, next_out *[]byte, total_out *uint) int {
+func BrotliDecoderDecompressStream(s *Reader, available_in *uint, next_in *[]byte, available_out *uint, next_out *[]byte, total_out *uint) int {
var result int = BROTLI_DECODER_SUCCESS
var br *BrotliBitReader = &s.br
@@ -2687,7 +2687,7 @@ func BrotliDecoderDecompressStream(s *BrotliDecoderState, available_in *uint, ne
return SaveErrorCode(s, result)
}
-func BrotliDecoderHasMoreOutput(s *BrotliDecoderState) bool {
+func BrotliDecoderHasMoreOutput(s *Reader) bool {
/* After unrecoverable error remaining output is considered nonsensical. */
if int(s.error_code) < 0 {
return false
@@ -2696,7 +2696,7 @@ func BrotliDecoderHasMoreOutput(s *BrotliDecoderState) bool {
return s.ringbuffer != nil && UnwrittenBytes(s, false) != 0
}
-func BrotliDecoderTakeOutput(s *BrotliDecoderState, size *uint) []byte {
+func BrotliDecoderTakeOutput(s *Reader, size *uint) []byte {
var result []byte = nil
var available_out uint
if *size != 0 {
@@ -2730,15 +2730,15 @@ func BrotliDecoderTakeOutput(s *BrotliDecoderState, size *uint) []byte {
return result
}
-func BrotliDecoderIsUsed(s *BrotliDecoderState) bool {
+func BrotliDecoderIsUsed(s *Reader) bool {
return s.state != BROTLI_STATE_UNINITED || BrotliGetAvailableBits(&s.br) != 0
}
-func BrotliDecoderIsFinished(s *BrotliDecoderState) bool {
+func BrotliDecoderIsFinished(s *Reader) bool {
return (s.state == BROTLI_STATE_DONE) && !BrotliDecoderHasMoreOutput(s)
}
-func BrotliDecoderGetErrorCode(s *BrotliDecoderState) int {
+func BrotliDecoderGetErrorCode(s *Reader) int {
return int(s.error_code)
}
diff --git a/reader.go b/reader.go
new file mode 100644
index 0000000..d97cde6
--- /dev/null
+++ b/reader.go
@@ -0,0 +1,94 @@
+package brotli
+
+import (
+ "errors"
+ "io"
+)
+
+type decodeError int
+
+func (err decodeError) Error() string {
+ return "brotli: " + string(BrotliDecoderErrorString(int(err)))
+}
+
+var errExcessiveInput = errors.New("brotli: excessive input")
+var errInvalidState = errors.New("brotli: invalid state")
+var errReaderClosed = errors.New("brotli: Reader is closed")
+
+// readBufSize is a "good" buffer size that avoids excessive round-trips
+// between C and Go but doesn't waste too much memory on buffering.
+// It is arbitrarily chosen to be equal to the constant used in io.Copy.
+const readBufSize = 32 * 1024
+
+// NewReader initializes new Reader instance.
+func NewReader(src io.Reader) *Reader {
+ r := new(Reader)
+ BrotliDecoderStateInit(r)
+ r.src = src
+ r.buf = make([]byte, readBufSize)
+ return r
+}
+
+func (r *Reader) Read(p []byte) (n int, err error) {
+ if !BrotliDecoderHasMoreOutput(r) && len(r.in) == 0 {
+ m, readErr := r.src.Read(r.buf)
+ if m == 0 {
+ // If readErr is `nil`, we just proxy underlying stream behavior.
+ return 0, readErr
+ }
+ r.in = r.buf[:m]
+ }
+
+ if len(p) == 0 {
+ return 0, nil
+ }
+
+ for {
+ var written uint
+ in_len := uint(len(r.in))
+ out_len := uint(len(p))
+ in_remaining := in_len
+ out_remaining := out_len
+ result := BrotliDecoderDecompressStream(r, &in_remaining, &r.in, &out_remaining, &p, nil)
+ written = out_len - out_remaining
+ n = int(written)
+
+ switch result {
+ case BROTLI_DECODER_RESULT_SUCCESS:
+ if len(r.in) > 0 {
+ return n, errExcessiveInput
+ }
+ return n, nil
+ case BROTLI_DECODER_RESULT_ERROR:
+ return n, decodeError(BrotliDecoderGetErrorCode(r))
+ case BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT:
+ if n == 0 {
+ return 0, io.ErrShortBuffer
+ }
+ return n, nil
+ case BROTLI_DECODER_NEEDS_MORE_INPUT:
+ }
+
+ if len(r.in) != 0 {
+ return 0, errInvalidState
+ }
+
+ // Calling r.src.Read may block. Don't block if we have data to return.
+ if n > 0 {
+ return n, nil
+ }
+
+ // Top off the buffer.
+ encN, err := r.src.Read(r.buf)
+ if encN == 0 {
+ // Not enough data to complete decoding.
+ if err == io.EOF {
+ return 0, io.ErrUnexpectedEOF
+ }
+ return 0, err
+ }
+ r.in = r.buf[:encN]
+ }
+
+ return n, nil
+}
diff --git a/state.go b/state.go
index 9fbe66e..4b06601 100644
--- a/state.go
+++ b/state.go
@@ -1,5 +1,7 @@
package brotli
+import "io"
+
/* Copyright 2015 Google Inc. All Rights Reserved.
Distributed under MIT license.
@@ -89,7 +91,11 @@ const (
BROTLI_STATE_READ_BLOCK_LENGTH_SUFFIX
)
-type BrotliDecoderState struct {
+type Reader struct {
+ src io.Reader
+ buf []byte // scratch space for reading from src
+ in []byte // current chunk to decode; usually aliases buf
+
state int
loop_counter int
br BrotliBitReader
@@ -177,7 +183,7 @@ type BrotliDecoderState struct {
trivial_literal_contexts [8]uint32
}
-func BrotliDecoderStateInit(s *BrotliDecoderState) bool {
+func BrotliDecoderStateInit(s *Reader) bool {
s.error_code = 0 /* BROTLI_DECODER_NO_ERROR */
BrotliInitBitReader(&s.br)
@@ -244,7 +250,7 @@ func BrotliDecoderStateInit(s *BrotliDecoderState) bool {
return true
}
-func BrotliDecoderStateMetablockBegin(s *BrotliDecoderState) {
+func BrotliDecoderStateMetablockBegin(s *Reader) {
s.meta_block_remaining_len = 0
s.block_length[0] = 1 << 24
s.block_length[1] = 1 << 24
@@ -274,7 +280,7 @@ func BrotliDecoderStateMetablockBegin(s *BrotliDecoderState) {
s.distance_hgroup.htrees = nil
}
-func BrotliDecoderStateCleanupAfterMetablock(s *BrotliDecoderState) {
+func BrotliDecoderStateCleanupAfterMetablock(s *Reader) {
s.context_modes = nil
s.context_map = nil
s.dist_context_map = nil
@@ -283,14 +289,14 @@ func BrotliDecoderStateCleanupAfterMetablock(s *BrotliDecoderState) {
s.distance_hgroup.htrees = nil
}
-func BrotliDecoderStateCleanup(s *BrotliDecoderState) {
+func BrotliDecoderStateCleanup(s *Reader) {
BrotliDecoderStateCleanupAfterMetablock(s)
s.ringbuffer = nil
s.block_type_trees = nil
}
-func BrotliDecoderHuffmanTreeGroupInit(s *BrotliDecoderState, group *HuffmanTreeGroup, alphabet_size uint32, max_symbol uint32, ntrees uint32) bool {
+func BrotliDecoderHuffmanTreeGroupInit(s *Reader, group *HuffmanTreeGroup, alphabet_size uint32, max_symbol uint32, ntrees uint32) bool {
var max_table_size uint = uint(kMaxHuffmanTableSize[(alphabet_size+31)>>5])
group.alphabet_size = uint16(alphabet_size)
group.max_symbol = uint16(max_symbol)