From a7a5948b15ddd0aff0da29a4c4b0011c2968ccdb Mon Sep 17 00:00:00 2001 From: Harley Laue Date: Mon, 23 Apr 2018 10:24:03 -0700 Subject: [PATCH] Check for nil before binding pflag(s) * When passing nil to BindPFlag or BindPFlags, the value is set to a struct and passed as an interface. That struct never checks for the flag(set) being nil. Thus, it makes sense to check before it's set to the struct. * fixes #422 --- viper.go | 6 ++++++ viper_test.go | 16 ++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/viper.go b/viper.go index e9966ba..c166e9f 100644 --- a/viper.go +++ b/viper.go @@ -811,6 +811,9 @@ func (v *Viper) UnmarshalExact(rawVal interface{}) error { // name as the config key. func BindPFlags(flags *pflag.FlagSet) error { return v.BindPFlags(flags) } func (v *Viper) BindPFlags(flags *pflag.FlagSet) error { + if flags == nil { + return fmt.Errorf("FlagSet cannot be nil") + } return v.BindFlagValues(pflagValueSet{flags}) } @@ -822,6 +825,9 @@ func (v *Viper) BindPFlags(flags *pflag.FlagSet) error { // func BindPFlag(key string, flag *pflag.Flag) error { return v.BindPFlag(key, flag) } func (v *Viper) BindPFlag(key string, flag *pflag.Flag) error { + if flag == nil { + return fmt.Errorf("flag for %q is nil", key) + } return v.BindFlagValue(key, pflagValue{flag}) } diff --git a/viper_test.go b/viper_test.go index c93480e..443345e 100644 --- a/viper_test.go +++ b/viper_test.go @@ -577,6 +577,14 @@ func TestBindPFlagsStringSlice(t *testing.T) { } } +func TestBindPFlagsNil(t *testing.T) { + v := New() + err := v.BindPFlags(nil) + if err == nil { + t.Fatalf("expected error when passing nil to BindPFlags") + } +} + func TestBindPFlag(t *testing.T) { var testString = "testing" var testValue = newStringValue(testString, &testString) @@ -598,6 +606,14 @@ func TestBindPFlag(t *testing.T) { } +func TestBindPFlagNil(t *testing.T) { + v := New() + err := v.BindPFlag("any", nil) + if err == nil { + t.Fatalf("expected error when passing nil to BindPFlag") + } +} + func TestBoundCaseSensitivity(t *testing.T) { assert.Equal(t, "brown", Get("eyes"))