package viper

import (
	"fmt"
	"strings"
	"testing"

	"github.com/spf13/cast"
	"github.com/stretchr/testify/assert"
)

type layer int

const (
	defaultLayer layer = iota + 1
	overrideLayer
)

func TestNestedOverrides(t *testing.T) {
	assert := assert.New(t)
	var v *Viper

	// Case 0: value overridden by a value
	overrideDefault(assert, "tom", 10, "tom", 20) // "tom" is first given 10 as default value, then overridden by 20
	override(assert, "tom", 10, "tom", 20)        // "tom" is first given value 10, then overridden by 20
	overrideDefault(assert, "tom.age", 10, "tom.age", 20)
	override(assert, "tom.age", 10, "tom.age", 20)
	overrideDefault(assert, "sawyer.tom.age", 10, "sawyer.tom.age", 20)
	override(assert, "sawyer.tom.age", 10, "sawyer.tom.age", 20)

	// Case 1: key:value overridden by a value
	v = overrideDefault(assert, "tom.age", 10, "tom", "boy") // "tom.age" is first given 10 as default value, then "tom" is overridden by "boy"
	assert.Nil(v.Get("tom.age"))                             // "tom.age" should not exist anymore
	v = override(assert, "tom.age", 10, "tom", "boy")
	assert.Nil(v.Get("tom.age"))

	// Case 2: value overridden by a key:value
	overrideDefault(assert, "tom", "boy", "tom.age", 10) // "tom" is first given "boy" as default value, then "tom" is overridden by map{"age":10}
	override(assert, "tom.age", 10, "tom", "boy")

	// Case 3: key:value overridden by a key:value
	v = overrideDefault(assert, "tom.size", 4, "tom.age", 10)
	assert.Equal(4, v.Get("tom.size")) // value should still be reachable
	v = override(assert, "tom.size", 4, "tom.age", 10)
	assert.Equal(4, v.Get("tom.size"))
	deepCheckValue(assert, v, overrideLayer, []string{"tom", "size"}, 4)

	// Case 4: key:value overridden by a map
	v = overrideDefault(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10}) // "tom.size" is first given "4" as default value, then "tom" is overridden by map{"age":10}
	assert.Equal(4, v.Get("tom.size"))                                                   // "tom.size" should still be reachable
	assert.Equal(10, v.Get("tom.age"))                                                   // new value should be there
	deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10)                 // new value should be there
	v = override(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10})
	assert.Nil(v.Get("tom.size"))
	assert.Equal(10, v.Get("tom.age"))
	deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10)

	// Case 5: array overridden by a value
	overrideDefault(assert, "tom", []int{10, 20}, "tom", 30)
	override(assert, "tom", []int{10, 20}, "tom", 30)
	overrideDefault(assert, "tom.age", []int{10, 20}, "tom.age", 30)
	override(assert, "tom.age", []int{10, 20}, "tom.age", 30)

	// Case 6: array overridden by an array
	overrideDefault(assert, "tom", []int{10, 20}, "tom", []int{30, 40})
	override(assert, "tom", []int{10, 20}, "tom", []int{30, 40})
	overrideDefault(assert, "tom.age", []int{10, 20}, "tom.age", []int{30, 40})
	v = override(assert, "tom.age", []int{10, 20}, "tom.age", []int{30, 40})
	// explicit array merge:
	s, ok := v.Get("tom.age").([]int)
	if assert.True(ok, "tom[\"age\"] is not a slice") {
		v.Set("tom.age", append(s, []int{50, 60}...))
		assert.Equal([]int{30, 40, 50, 60}, v.Get("tom.age"))
		deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, []int{30, 40, 50, 60})
	}
}

func overrideDefault(assert *assert.Assertions, firstPath string, firstValue interface{}, secondPath string, secondValue interface{}) *Viper {
	return overrideFromLayer(defaultLayer, assert, firstPath, firstValue, secondPath, secondValue)
}

func override(assert *assert.Assertions, firstPath string, firstValue interface{}, secondPath string, secondValue interface{}) *Viper {
	return overrideFromLayer(overrideLayer, assert, firstPath, firstValue, secondPath, secondValue)
}

// overrideFromLayer performs the sequential override and low-level checks.
//
// First assignment is made on layer l for path firstPath with value firstValue,
// the second one on the override layer (i.e., with the Set() function)
// for path secondPath with value secondValue.
//
// firstPath and secondPath can include an arbitrary number of dots to indicate
// a nested element.
//
// After each assignment, the value is checked, retrieved both by its full path
// and by its key sequence (successive maps).
func overrideFromLayer(l layer, assert *assert.Assertions, firstPath string, firstValue interface{}, secondPath string, secondValue interface{}) *Viper {
	v := New()
	firstKeys := strings.Split(firstPath, v.keyDelim)
	if assert == nil ||
		len(firstKeys) == 0 || len(firstKeys[0]) == 0 {
		return v
	}

	// Set and check first value
	switch l {
	case defaultLayer:
		v.SetDefault(firstPath, firstValue)
	case overrideLayer:
		v.Set(firstPath, firstValue)
	default:
		return v
	}
	assert.Equal(firstValue, v.Get(firstPath))
	deepCheckValue(assert, v, l, firstKeys, firstValue)

	// Override and check new value
	secondKeys := strings.Split(secondPath, v.keyDelim)
	if len(secondKeys) == 0 || len(secondKeys[0]) == 0 {
		return v
	}
	v.Set(secondPath, secondValue)
	assert.Equal(secondValue, v.Get(secondPath))
	deepCheckValue(assert, v, overrideLayer, secondKeys, secondValue)

	return v
}

// deepCheckValue checks that all given keys correspond to a valid path in the
// configuration map of the given layer, and that the final value equals the one given
func deepCheckValue(assert *assert.Assertions, v *Viper, l layer, keys []string, value interface{}) {
	if assert == nil || v == nil ||
		len(keys) == 0 || len(keys[0]) == 0 {
		return
	}

	// init
	var val interface{}
	var ms string
	switch l {
	case defaultLayer:
		val = v.defaults
		ms = "v.defaults"
	case overrideLayer:
		val = v.override
		ms = "v.override"
	}

	// loop through map
	var m map[string]interface{}
	err := false
	for _, k := range keys {
		if val == nil {
			assert.Fail(fmt.Sprintf("%s is not a map[string]interface{}", ms))
			return
		}

		// deep scan of the map to get the final value
		switch val.(type) {
		case map[interface{}]interface{}:
			m = cast.ToStringMap(val)
		case map[string]interface{}:
			m = val.(map[string]interface{})
		default:
			assert.Fail(fmt.Sprintf("%s is not a map[string]interface{}", ms))
			return
		}
		ms = ms + "[\"" + k + "\"]"
		val = m[k]
	}
	if !err {
		assert.Equal(value, val)
	}
}