fix #1106 UnmarshalKey looses some data if a subitem is overridden

This commit is contained in:
yushouqiu 2023-01-30 15:11:14 +08:00 committed by yushouqiu
parent c44f929787
commit cac6e7be53
2 changed files with 34 additions and 4 deletions

View File

@ -45,10 +45,10 @@ func TestNestedOverrides(t *testing.T) {
deepCheckValue(assert, v, overrideLayer, []string{"tom", "size"}, 4) deepCheckValue(assert, v, overrideLayer, []string{"tom", "size"}, 4)
// Case 4: key:value overridden by a map // Case 4: key:value overridden by a map
v = overrideDefault(assert, "tom.size", 4, "tom", map[string]any{"age": 10}) // "tom.size" is first given "4" as default value, then "tom" is overridden by map{"age":10} v = overrideDefault(assert, "tom.size", 4, "tom", map[string]any{"age": 10, "size": 4}) // "tom.size" is first given "4" as default value, then "tom" is overridden by map{"age":10}
assert.Equal(4, v.Get("tom.size")) // "tom.size" should still be reachable assert.Equal(4, v.Get("tom.size")) // "tom.size" should still be reachable
assert.Equal(10, v.Get("tom.age")) // new value should be there assert.Equal(10, v.Get("tom.age")) // new value should be there
deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10) // new value should be there deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10) // new value should be there
v = override(assert, "tom.size", 4, "tom", map[string]any{"age": 10}) v = override(assert, "tom.size", 4, "tom", map[string]any{"age": 10})
assert.Nil(v.Get("tom.size")) assert.Nil(v.Get("tom.size"))
assert.Equal(10, v.Get("tom.age")) assert.Equal(10, v.Get("tom.age"))

View File

@ -899,6 +899,13 @@ func GetViper() *Viper {
// Get returns an interface. For a specific value use one of the Get____ methods. // Get returns an interface. For a specific value use one of the Get____ methods.
func Get(key string) any { return v.Get(key) } func Get(key string) any { return v.Get(key) }
func isStringMapInterface(val any) bool {
vt := reflect.TypeOf(val)
return vt.Kind() == reflect.Map &&
vt.Key().Kind() == reflect.String &&
vt.Elem().Kind() == reflect.Interface
}
func (v *Viper) Get(key string) any { func (v *Viper) Get(key string) any {
lcaseKey := strings.ToLower(key) lcaseKey := strings.ToLower(key)
val := v.find(lcaseKey, true) val := v.find(lcaseKey, true)
@ -906,6 +913,29 @@ func (v *Viper) Get(key string) any {
return nil return nil
} }
// when section is partially overridden,
// make sure to return the complete map.
if isStringMapInterface(val) {
val := val.(map[string]interface{})
prefix := lcaseKey + v.keyDelim
keys := v.AllKeys()
for _, key := range keys {
if !strings.HasPrefix(key, prefix) {
continue
}
mk := strings.TrimPrefix(key, prefix)
mk = strings.Split(mk, v.keyDelim)[0]
if _, exists := val[mk]; exists {
continue
}
mv := v.Get(lcaseKey + v.keyDelim + mk)
if mv == nil {
continue
}
val[mk] = mv
}
}
if v.typeByDefValue { if v.typeByDefValue {
// TODO(bep) this branch isn't covered by a single test. // TODO(bep) this branch isn't covered by a single test.
valType := val valType := val