mirror of https://github.com/spf13/viper.git
fix #1106 UnmarshalKey looses some data if a subitem is overridden
This commit is contained in:
parent
c44f929787
commit
cac6e7be53
|
@ -45,7 +45,7 @@ func TestNestedOverrides(t *testing.T) {
|
||||||
deepCheckValue(assert, v, overrideLayer, []string{"tom", "size"}, 4)
|
deepCheckValue(assert, v, overrideLayer, []string{"tom", "size"}, 4)
|
||||||
|
|
||||||
// Case 4: key:value overridden by a map
|
// Case 4: key:value overridden by a map
|
||||||
v = overrideDefault(assert, "tom.size", 4, "tom", map[string]any{"age": 10}) // "tom.size" is first given "4" as default value, then "tom" is overridden by map{"age":10}
|
v = overrideDefault(assert, "tom.size", 4, "tom", map[string]any{"age": 10, "size": 4}) // "tom.size" is first given "4" as default value, then "tom" is overridden by map{"age":10}
|
||||||
assert.Equal(4, v.Get("tom.size")) // "tom.size" should still be reachable
|
assert.Equal(4, v.Get("tom.size")) // "tom.size" should still be reachable
|
||||||
assert.Equal(10, v.Get("tom.age")) // new value should be there
|
assert.Equal(10, v.Get("tom.age")) // new value should be there
|
||||||
deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10) // new value should be there
|
deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10) // new value should be there
|
||||||
|
|
30
viper.go
30
viper.go
|
@ -899,6 +899,13 @@ func GetViper() *Viper {
|
||||||
// Get returns an interface. For a specific value use one of the Get____ methods.
|
// Get returns an interface. For a specific value use one of the Get____ methods.
|
||||||
func Get(key string) any { return v.Get(key) }
|
func Get(key string) any { return v.Get(key) }
|
||||||
|
|
||||||
|
func isStringMapInterface(val any) bool {
|
||||||
|
vt := reflect.TypeOf(val)
|
||||||
|
return vt.Kind() == reflect.Map &&
|
||||||
|
vt.Key().Kind() == reflect.String &&
|
||||||
|
vt.Elem().Kind() == reflect.Interface
|
||||||
|
}
|
||||||
|
|
||||||
func (v *Viper) Get(key string) any {
|
func (v *Viper) Get(key string) any {
|
||||||
lcaseKey := strings.ToLower(key)
|
lcaseKey := strings.ToLower(key)
|
||||||
val := v.find(lcaseKey, true)
|
val := v.find(lcaseKey, true)
|
||||||
|
@ -906,6 +913,29 @@ func (v *Viper) Get(key string) any {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// when section is partially overridden,
|
||||||
|
// make sure to return the complete map.
|
||||||
|
if isStringMapInterface(val) {
|
||||||
|
val := val.(map[string]interface{})
|
||||||
|
prefix := lcaseKey + v.keyDelim
|
||||||
|
keys := v.AllKeys()
|
||||||
|
for _, key := range keys {
|
||||||
|
if !strings.HasPrefix(key, prefix) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
mk := strings.TrimPrefix(key, prefix)
|
||||||
|
mk = strings.Split(mk, v.keyDelim)[0]
|
||||||
|
if _, exists := val[mk]; exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
mv := v.Get(lcaseKey + v.keyDelim + mk)
|
||||||
|
if mv == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
val[mk] = mv
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if v.typeByDefValue {
|
if v.typeByDefValue {
|
||||||
// TODO(bep) this branch isn't covered by a single test.
|
// TODO(bep) this branch isn't covered by a single test.
|
||||||
valType := val
|
valType := val
|
||||||
|
|
Loading…
Reference in New Issue