diff --git a/viper.go b/viper.go index e9966ba..c166e9f 100644 --- a/viper.go +++ b/viper.go @@ -811,6 +811,9 @@ func (v *Viper) UnmarshalExact(rawVal interface{}) error { // name as the config key. func BindPFlags(flags *pflag.FlagSet) error { return v.BindPFlags(flags) } func (v *Viper) BindPFlags(flags *pflag.FlagSet) error { + if flags == nil { + return fmt.Errorf("FlagSet cannot be nil") + } return v.BindFlagValues(pflagValueSet{flags}) } @@ -822,6 +825,9 @@ func (v *Viper) BindPFlags(flags *pflag.FlagSet) error { // func BindPFlag(key string, flag *pflag.Flag) error { return v.BindPFlag(key, flag) } func (v *Viper) BindPFlag(key string, flag *pflag.Flag) error { + if flag == nil { + return fmt.Errorf("flag for %q is nil", key) + } return v.BindFlagValue(key, pflagValue{flag}) } diff --git a/viper_test.go b/viper_test.go index c93480e..443345e 100644 --- a/viper_test.go +++ b/viper_test.go @@ -577,6 +577,14 @@ func TestBindPFlagsStringSlice(t *testing.T) { } } +func TestBindPFlagsNil(t *testing.T) { + v := New() + err := v.BindPFlags(nil) + if err == nil { + t.Fatalf("expected error when passing nil to BindPFlags") + } +} + func TestBindPFlag(t *testing.T) { var testString = "testing" var testValue = newStringValue(testString, &testString) @@ -598,6 +606,14 @@ func TestBindPFlag(t *testing.T) { } +func TestBindPFlagNil(t *testing.T) { + v := New() + err := v.BindPFlag("any", nil) + if err == nil { + t.Fatalf("expected error when passing nil to BindPFlag") + } +} + func TestBoundCaseSensitivity(t *testing.T) { assert.Equal(t, "brown", Get("eyes"))