From c898f59d3345a8317651876020fe4eb7322a00aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dawid=20Urba=C5=84ski?= <2959818+UrbanskiDawid@users.noreply.github.com> Date: Sun, 5 Feb 2023 18:30:05 +0100 Subject: [PATCH] fix: time.Duration slice type conversion (#1498) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix DurationSlice * Fix typo --------- Co-authored-by: Márk Sági-Kazár --- viper.go | 13 +++++++++++++ viper_test.go | 22 ++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/viper.go b/viper.go index 06610fc..45a2d5c 100644 --- a/viper.go +++ b/viper.go @@ -928,6 +928,8 @@ func (v *Viper) Get(key string) interface{} { return cast.ToStringSlice(val) case []int: return cast.ToIntSlice(val) + case []time.Duration: + return cast.ToDurationSlice(val) } } @@ -1274,8 +1276,14 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} { s = strings.TrimSuffix(s, "]") res, _ := readAsCSV(s) return cast.ToIntSlice(res) + case "durationSlice": + s := strings.TrimPrefix(flag.ValueString(), "[") + s = strings.TrimSuffix(s, "]") + slice := strings.Split(s, ",") + return cast.ToDurationSlice(slice) case "stringToString": return stringToStringConv(flag.ValueString()) + default: return flag.ValueString() } @@ -1355,6 +1363,11 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} { return cast.ToIntSlice(res) case "stringToString": return stringToStringConv(flag.ValueString()) + case "durationSlice": + s := strings.TrimPrefix(flag.ValueString(), "[") + s = strings.TrimSuffix(s, "]") + slice := strings.Split(s, ",") + return cast.ToDurationSlice(slice) default: return flag.ValueString() } diff --git a/viper_test.go b/viper_test.go index 926ffc2..9363da8 100644 --- a/viper_test.go +++ b/viper_test.go @@ -1094,6 +1094,28 @@ func TestBindPFlagsStringArray(t *testing.T) { } } +func TestSliceFlagsReturnCorrectType(t *testing.T) { + flagSet := pflag.NewFlagSet("test", pflag.ContinueOnError) + flagSet.IntSlice("int", []int{1, 2}, "") + flagSet.StringSlice("str", []string{"3", "4"}, "") + flagSet.DurationSlice("duration", []time.Duration{5 * time.Second}, "") + + v := New() + v.BindPFlags(flagSet) + + all := v.AllSettings() + + if _, ok := all["int"].([]int); !ok { + t.Errorf("unexpected type %T expected []int", all["int"]) + } + if _, ok := all["str"].([]string); !ok { + t.Errorf("unexpected type %T expected []string", all["str"]) + } + if _, ok := all["duration"].([]time.Duration); !ok { + t.Errorf("unexpected type %T expected []time.Duration", all["duration"]) + } +} + //nolint:dupl func TestBindPFlagsIntSlice(t *testing.T) { tests := []struct {