add label value validation to GetMetricWith and friends

This commit is contained in:
Marco Jantke 2017-08-19 22:12:52 +02:00
parent 94ff84a9a6
commit 957bba6f68
3 changed files with 97 additions and 4 deletions

View File

@ -14,6 +14,7 @@
package prometheus package prometheus
import ( import (
"fmt"
"math" "math"
"testing" "testing"
@ -56,3 +57,63 @@ func decreaseCounter(c *counter) (err error) {
c.Add(-1) c.Add(-1)
return nil return nil
} }
func TestCounterVecGetMetricWithInvalidLabelValues(t *testing.T) {
testCases := []struct {
desc string
labels Labels
}{
{
desc: "non utf8 label value",
labels: Labels{"a": "\xFF"},
},
{
desc: "not enough label values",
labels: Labels{},
},
{
desc: "too many label values",
labels: Labels{"a": "1", "b": "2"},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
counterVec := NewCounterVec(CounterOpts{
Name: "test",
}, []string{"a"})
labelValues := make([]string, len(test.labels))
for _, val := range test.labels {
labelValues = append(labelValues, val)
}
expectPanic(t, func() {
counterVec.WithLabelValues(labelValues...)
}, fmt.Sprintf("WithLabelValues: expected panic because: %s", test.desc))
expectPanic(t, func() {
counterVec.With(test.labels)
}, fmt.Sprintf("WithLabelValues: expected panic because: %s", test.desc))
if _, err := counterVec.GetMetricWithLabelValues(labelValues...); err == nil {
t.Errorf("GetMetricWithLabelValues: expected error because: %s", test.desc)
}
if _, err := counterVec.GetMetricWith(test.labels); err == nil {
t.Errorf("GetMetricWith: expected error because: %s", test.desc)
}
})
}
}
func expectPanic(t *testing.T, op func(), errorMsg string) {
defer func() {
if err := recover(); err == nil {
t.Error(errorMsg)
}
}()
op()
}

View File

@ -0,0 +1 @@
package prometheus

View File

@ -14,8 +14,10 @@
package prometheus package prometheus
import ( import (
"errors"
"fmt" "fmt"
"sync" "sync"
"unicode/utf8"
"github.com/prometheus/common/model" "github.com/prometheus/common/model"
) )
@ -206,10 +208,24 @@ func (m *metricVec) Reset() {
} }
} }
func (m *metricVec) hashLabelValues(vals []string) (uint64, error) { func (m *metricVec) validateLabelValues(vals []string) error {
if len(vals) != len(m.desc.variableLabels) { if len(vals) != len(m.desc.variableLabels) {
return 0, errInconsistentCardinality return errInconsistentCardinality
} }
for _, val := range vals {
if !utf8.ValidString(val) {
return errors.New(fmt.Sprintf("label %q is not valid utf8", val))
}
}
return nil
}
func (m *metricVec) hashLabelValues(vals []string) (uint64, error) {
if err := m.validateLabelValues(vals); err != nil {
return 0, err
}
h := hashNew() h := hashNew()
for _, val := range vals { for _, val := range vals {
h = m.hashAdd(h, val) h = m.hashAdd(h, val)
@ -218,10 +234,25 @@ func (m *metricVec) hashLabelValues(vals []string) (uint64, error) {
return h, nil return h, nil
} }
func (m *metricVec) hashLabels(labels Labels) (uint64, error) { func (m *metricVec) validateLabels(labels Labels) error {
if len(labels) != len(m.desc.variableLabels) { if len(labels) != len(m.desc.variableLabels) {
return 0, errInconsistentCardinality return errInconsistentCardinality
} }
for name, val := range labels {
if !utf8.ValidString(val) {
return errors.New(fmt.Sprintf("label %s: %q is not valid utf8", name, val))
}
}
return nil
}
func (m *metricVec) hashLabels(labels Labels) (uint64, error) {
if err := m.validateLabels(labels); err != nil {
return 0, err
}
h := hashNew() h := hashNew()
for _, label := range m.desc.variableLabels { for _, label := range m.desc.variableLabels {
val, ok := labels[label] val, ok := labels[label]