forked from mirror/client_golang
add label value validation to GetMetricWith and friends
This commit is contained in:
parent
94ff84a9a6
commit
957bba6f68
|
@ -14,6 +14,7 @@
|
||||||
package prometheus
|
package prometheus
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -56,3 +57,63 @@ func decreaseCounter(c *counter) (err error) {
|
||||||
c.Add(-1)
|
c.Add(-1)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCounterVecGetMetricWithInvalidLabelValues(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
desc string
|
||||||
|
labels Labels
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "non utf8 label value",
|
||||||
|
labels: Labels{"a": "\xFF"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "not enough label values",
|
||||||
|
labels: Labels{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "too many label values",
|
||||||
|
labels: Labels{"a": "1", "b": "2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range testCases {
|
||||||
|
test := test
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
counterVec := NewCounterVec(CounterOpts{
|
||||||
|
Name: "test",
|
||||||
|
}, []string{"a"})
|
||||||
|
|
||||||
|
labelValues := make([]string, len(test.labels))
|
||||||
|
for _, val := range test.labels {
|
||||||
|
labelValues = append(labelValues, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectPanic(t, func() {
|
||||||
|
counterVec.WithLabelValues(labelValues...)
|
||||||
|
}, fmt.Sprintf("WithLabelValues: expected panic because: %s", test.desc))
|
||||||
|
expectPanic(t, func() {
|
||||||
|
counterVec.With(test.labels)
|
||||||
|
}, fmt.Sprintf("WithLabelValues: expected panic because: %s", test.desc))
|
||||||
|
|
||||||
|
if _, err := counterVec.GetMetricWithLabelValues(labelValues...); err == nil {
|
||||||
|
t.Errorf("GetMetricWithLabelValues: expected error because: %s", test.desc)
|
||||||
|
}
|
||||||
|
if _, err := counterVec.GetMetricWith(test.labels); err == nil {
|
||||||
|
t.Errorf("GetMetricWith: expected error because: %s", test.desc)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func expectPanic(t *testing.T, op func(), errorMsg string) {
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err == nil {
|
||||||
|
t.Error(errorMsg)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
op()
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
package prometheus
|
|
@ -14,8 +14,10 @@
|
||||||
package prometheus
|
package prometheus
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
"github.com/prometheus/common/model"
|
"github.com/prometheus/common/model"
|
||||||
)
|
)
|
||||||
|
@ -206,10 +208,24 @@ func (m *metricVec) Reset() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *metricVec) hashLabelValues(vals []string) (uint64, error) {
|
func (m *metricVec) validateLabelValues(vals []string) error {
|
||||||
if len(vals) != len(m.desc.variableLabels) {
|
if len(vals) != len(m.desc.variableLabels) {
|
||||||
return 0, errInconsistentCardinality
|
return errInconsistentCardinality
|
||||||
}
|
}
|
||||||
|
for _, val := range vals {
|
||||||
|
if !utf8.ValidString(val) {
|
||||||
|
return errors.New(fmt.Sprintf("label %q is not valid utf8", val))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *metricVec) hashLabelValues(vals []string) (uint64, error) {
|
||||||
|
if err := m.validateLabelValues(vals); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
h := hashNew()
|
h := hashNew()
|
||||||
for _, val := range vals {
|
for _, val := range vals {
|
||||||
h = m.hashAdd(h, val)
|
h = m.hashAdd(h, val)
|
||||||
|
@ -218,10 +234,25 @@ func (m *metricVec) hashLabelValues(vals []string) (uint64, error) {
|
||||||
return h, nil
|
return h, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *metricVec) hashLabels(labels Labels) (uint64, error) {
|
func (m *metricVec) validateLabels(labels Labels) error {
|
||||||
if len(labels) != len(m.desc.variableLabels) {
|
if len(labels) != len(m.desc.variableLabels) {
|
||||||
return 0, errInconsistentCardinality
|
return errInconsistentCardinality
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for name, val := range labels {
|
||||||
|
if !utf8.ValidString(val) {
|
||||||
|
return errors.New(fmt.Sprintf("label %s: %q is not valid utf8", name, val))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *metricVec) hashLabels(labels Labels) (uint64, error) {
|
||||||
|
if err := m.validateLabels(labels); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
h := hashNew()
|
h := hashNew()
|
||||||
for _, label := range m.desc.variableLabels {
|
for _, label := range m.desc.variableLabels {
|
||||||
val, ok := labels[label]
|
val, ok := labels[label]
|
||||||
|
|
Loading…
Reference in New Issue