From ad25785b77f6b06a83bbe34ad8d8fa60b3c9aaf5 Mon Sep 17 00:00:00 2001 From: saxon Date: Mon, 4 Feb 2019 22:48:51 +1030 Subject: [PATCH] stream/mts/meta/meta_test.go: improved ReadFrom by checking for valid header --- stream/mts/meta/meta.go | 11 ++++++++++- stream/mts/meta/meta_test.go | 6 +++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/stream/mts/meta/meta.go b/stream/mts/meta/meta.go index 1e4c72b8..7a863d0c 100644 --- a/stream/mts/meta/meta.go +++ b/stream/mts/meta/meta.go @@ -28,6 +28,7 @@ LICENSE package meta import ( + "encoding/binary" "errors" "strings" "sync" @@ -49,7 +50,9 @@ const ( ) var ( - errKeyAbsent = errors.New("Key does not exist in map") + errKeyAbsent = errors.New("Key does not exist in map") + errNoHeader = errors.New("Metadata string does not contain header") + errInvalidHeader = errors.New("Metadata string does not contain valid header") ) // Metadata provides functionality for the storage and encoding of metadata @@ -143,6 +146,12 @@ func (m *Metadata) Encode() []byte { // key is not present in the metadata string, an error is returned. If the // metadata header is not present in the string, an error is returned. func ReadFrom(d []byte, key string) (string, error) { + if d[0] != 0 { + return "", errNoHeader + } else if d[0] == 0 && binary.BigEndian.Uint16(d[2:headSize]) != uint16(len(d[headSize:])) { + return "", errInvalidHeader + } + d = d[headSize:] entries := strings.Split(string(d), "\t") for _, entry := range entries { kv := strings.Split(entry, "=") diff --git a/stream/mts/meta/meta_test.go b/stream/mts/meta/meta_test.go index 0fcda13b..15789ebf 100644 --- a/stream/mts/meta/meta_test.go +++ b/stream/mts/meta/meta_test.go @@ -156,8 +156,8 @@ func TestEncode(t *testing.T) { } func TestReadFrom(t *testing.T) { - tstStr := "loc=a,b,c\tts=12345" - got, err := ReadFrom([]byte(tstStr), "loc") + tstMeta := append([]byte{0x00, 0x10, 0x00, 0x12}, "loc=a,b,c\tts=12345"...) + got, err := ReadFrom([]byte(tstMeta), "loc") if err != nil { t.Errorf(errUnexpectedErr, err.Error()) } @@ -166,7 +166,7 @@ func TestReadFrom(t *testing.T) { t.Errorf(errNotExpectedOut, got, want) } - if got, err = ReadFrom([]byte(tstStr), "ts"); err != nil { + if got, err = ReadFrom([]byte(tstMeta), "ts"); err != nil { t.Errorf(errUnexpectedErr, err.Error()) } want = "12345"