From 242f4890f5f6239217fa5c45bb6de3b4b5509d3b Mon Sep 17 00:00:00 2001 From: Xavier Coulon Date: Thu, 24 May 2018 10:09:29 +0200 Subject: [PATCH] Refactor with WaitGroup and check channel is open Signed-off-by: Xavier Coulon --- viper.go | 37 ++++++++++++++++++++++--------------- viper_test.go | 3 ++- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/viper.go b/viper.go index 148d5cc..72230f5 100644 --- a/viper.go +++ b/viper.go @@ -261,8 +261,8 @@ func (v *Viper) OnConfigChange(run func(in fsnotify.Event)) { func WatchConfig() { v.WatchConfig() } func (v *Viper) WatchConfig() { - wg := sync.WaitGroup{} - wg.Add(1) + initWG := sync.WaitGroup{} + initWG.Add(1) go func() { watcher, err := fsnotify.NewWatcher() if err != nil { @@ -272,7 +272,7 @@ func (v *Viper) WatchConfig() { // we have to watch the entire directory to pick up renames/atomic saves in a cross-platform way filename, err := v.getConfigFile() if err != nil { - log.Println("error:", err) + log.Printf("error: %v\n", err) return } @@ -280,12 +280,16 @@ func (v *Viper) WatchConfig() { configDir, _ := filepath.Split(configFile) realConfigFile, _ := filepath.EvalSymlinks(filename) - done := make(chan bool) + eventsWG := sync.WaitGroup{} + eventsWG.Add(1) go func() { - loop: for { select { - case event := <-watcher.Events: + case event, ok := <-watcher.Events: + if !ok { // 'Events' channel is closed + eventsWG.Done() + return + } currentConfigFile, _ := filepath.EvalSymlinks(filename) // we only care about the config file with the following cases: // 1 - if the config file was modified or created @@ -296,28 +300,31 @@ func (v *Viper) WatchConfig() { realConfigFile = currentConfigFile err := v.ReadInConfig() if err != nil { - log.Println("error reading file:", err.Error()) + log.Printf("error reading file: %v\n", err) } if v.onConfigChange != nil { v.onConfigChange(event) } } else if filepath.Clean(event.Name) == configFile && event.Op&fsnotify.Remove == fsnotify.Remove { - done <- true - break loop + eventsWG.Done() + return } - case err := <-watcher.Errors: - log.Printf("watcher error: %v\n", err) + case err, ok := <-watcher.Errors: + if ok { // 'Errors' channel is not closed + log.Printf("watcher error: %v\n", err) + } + eventsWG.Done() + return } } }() watcher.Add(configDir) - wg.Done() // done initalizing the watch in this go routine, so the parent routine can move on... - <-done // block until the watched file is removed... + initWG.Done() // done initalizing the watch in this go routine, so the parent routine can move on... + eventsWG.Wait() // now, wait for event loop to end in this go-routine... }() - // make sure that the go routine above fully started before returning - wg.Wait() + initWG.Wait() // make sure that the go routine above fully ended before returning } // SetConfigFile explicitly defines the path, name and extension of the config file. diff --git a/viper_test.go b/viper_test.go index 7e2140e..9961970 100644 --- a/viper_test.go +++ b/viper_test.go @@ -1419,9 +1419,11 @@ func newViperWithSymlinkedConfigFile(t *testing.T) (*Viper, string, string, func } func TestWatchFile(t *testing.T) { + t.Run("file content changed", func(t *testing.T) { // given a `config.yaml` file being watched v, configFile, cleanup := newViperWithConfigFile(t) + fmt.Printf("test config file: %s\n", configFile) defer cleanup() wg := sync.WaitGroup{} v.WatchConfig() @@ -1466,7 +1468,6 @@ func TestWatchFile(t *testing.T) { // then require.Nil(t, err) assert.Equal(t, "baz", v.Get("foo")) - }) }