stream/mts/meta/meta_test.go: improved ReadFrom by checking for valid header

This commit is contained in:
saxon 2019-02-04 22:48:51 +10:30
parent 953d363b3a
commit ad25785b77
2 changed files with 13 additions and 4 deletions

View File

@ -28,6 +28,7 @@ LICENSE
package meta package meta
import ( import (
"encoding/binary"
"errors" "errors"
"strings" "strings"
"sync" "sync"
@ -49,7 +50,9 @@ const (
) )
var ( var (
errKeyAbsent = errors.New("Key does not exist in map") errKeyAbsent = errors.New("Key does not exist in map")
errNoHeader = errors.New("Metadata string does not contain header")
errInvalidHeader = errors.New("Metadata string does not contain valid header")
) )
// Metadata provides functionality for the storage and encoding of metadata // Metadata provides functionality for the storage and encoding of metadata
@ -143,6 +146,12 @@ func (m *Metadata) Encode() []byte {
// key is not present in the metadata string, an error is returned. If the // key is not present in the metadata string, an error is returned. If the
// metadata header is not present in the string, an error is returned. // metadata header is not present in the string, an error is returned.
func ReadFrom(d []byte, key string) (string, error) { func ReadFrom(d []byte, key string) (string, error) {
if d[0] != 0 {
return "", errNoHeader
} else if d[0] == 0 && binary.BigEndian.Uint16(d[2:headSize]) != uint16(len(d[headSize:])) {
return "", errInvalidHeader
}
d = d[headSize:]
entries := strings.Split(string(d), "\t") entries := strings.Split(string(d), "\t")
for _, entry := range entries { for _, entry := range entries {
kv := strings.Split(entry, "=") kv := strings.Split(entry, "=")

View File

@ -156,8 +156,8 @@ func TestEncode(t *testing.T) {
} }
func TestReadFrom(t *testing.T) { func TestReadFrom(t *testing.T) {
tstStr := "loc=a,b,c\tts=12345" tstMeta := append([]byte{0x00, 0x10, 0x00, 0x12}, "loc=a,b,c\tts=12345"...)
got, err := ReadFrom([]byte(tstStr), "loc") got, err := ReadFrom([]byte(tstMeta), "loc")
if err != nil { if err != nil {
t.Errorf(errUnexpectedErr, err.Error()) t.Errorf(errUnexpectedErr, err.Error())
} }
@ -166,7 +166,7 @@ func TestReadFrom(t *testing.T) {
t.Errorf(errNotExpectedOut, got, want) t.Errorf(errNotExpectedOut, got, want)
} }
if got, err = ReadFrom([]byte(tstStr), "ts"); err != nil { if got, err = ReadFrom([]byte(tstMeta), "ts"); err != nil {
t.Errorf(errUnexpectedErr, err.Error()) t.Errorf(errUnexpectedErr, err.Error())
} }
want = "12345" want = "12345"