diff --git a/util.go b/util.go index 72b47fb..2f974d8 100644 --- a/util.go +++ b/util.go @@ -19,8 +19,10 @@ import ( "path/filepath" "runtime" "strings" + "unicode" "github.com/BurntSushi/toml" + "github.com/spf13/cast" jww "github.com/spf13/jwalterweatherman" "gopkg.in/yaml.v2" ) @@ -139,3 +141,46 @@ func marshallConfigReader(in io.Reader, c map[string]interface{}, configType str insensitiviseMap(c) } + +func safeMul(a, b uint) uint { + c := a * b + if a > 1 && b > 1 && c/b != a { + return 0 + } + return c +} + +// parseSizeInBytes converts strings like 1GB or 12 mb into an unsigned integer number of bytes +func parseSizeInBytes(sizeStr string) uint { + sizeStr = strings.TrimSpace(sizeStr) + lastChar := len(sizeStr) - 1 + multiplier := uint(1) + + if lastChar > 0 { + if sizeStr[lastChar] == 'b' || sizeStr[lastChar] == 'B' { + if lastChar > 1 { + switch unicode.ToLower(rune(sizeStr[lastChar-1])) { + case 'k': + multiplier = 1 << 10 + sizeStr = strings.TrimSpace(sizeStr[:lastChar-1]) + case 'm': + multiplier = 1 << 20 + sizeStr = strings.TrimSpace(sizeStr[:lastChar-1]) + case 'g': + multiplier = 1 << 30 + sizeStr = strings.TrimSpace(sizeStr[:lastChar-1]) + default: + multiplier = 1 + sizeStr = strings.TrimSpace(sizeStr[:lastChar]) + } + } + } + } + + size := cast.ToInt(sizeStr) + if size < 0 { + size = 0 + } + + return safeMul(uint(size), multiplier) +} diff --git a/viper.go b/viper.go index 9b7489f..dde1cd6 100644 --- a/viper.go +++ b/viper.go @@ -335,6 +335,12 @@ func (v *viper) GetStringMapString(key string) map[string]string { return cast.ToStringMapString(v.Get(key)) } +func GetSizeInBytes(key string) uint { return v.GetSizeInBytes(key) } +func (v *viper) GetSizeInBytes(key string) uint { + sizeStr := cast.ToString(v.Get(key)) + return parseSizeInBytes(sizeStr) +} + // Takes a single key and marshals it into a Struct func MarshalKey(key string, rawVal interface{}) error { return v.MarshalKey(key, rawVal) } func (v *viper) MarshalKey(key string, rawVal interface{}) error { diff --git a/viper_test.go b/viper_test.go index a9ba70f..78b107b 100644 --- a/viper_test.go +++ b/viper_test.go @@ -376,3 +376,20 @@ func TestBoundCaseSensitivity(t *testing.T) { assert.Equal(t, "green", Get("eyes")) } + +func TestSizeInBytes(t *testing.T) { + input := map[string]uint{ + "": 0, + "b": 0, + "12 bytes": 0, + "200000000000gb": 0, + "12 b": 12, + "43 MB": 43 * (1 << 20), + "10mb": 10 * (1 << 20), + "1gb": 1 << 30, + } + + for str, expected := range input { + assert.Equal(t, expected, parseSizeInBytes(str), str) + } +}