mirror of https://github.com/spf13/viper.git
Refactor with WaitGroup and check channel is open
Signed-off-by: Xavier Coulon <xcoulon@redhat.com>
This commit is contained in:
parent
e0f7631cf3
commit
242f4890f5
35
viper.go
35
viper.go
|
@ -261,8 +261,8 @@ func (v *Viper) OnConfigChange(run func(in fsnotify.Event)) {
|
||||||
|
|
||||||
func WatchConfig() { v.WatchConfig() }
|
func WatchConfig() { v.WatchConfig() }
|
||||||
func (v *Viper) WatchConfig() {
|
func (v *Viper) WatchConfig() {
|
||||||
wg := sync.WaitGroup{}
|
initWG := sync.WaitGroup{}
|
||||||
wg.Add(1)
|
initWG.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
watcher, err := fsnotify.NewWatcher()
|
watcher, err := fsnotify.NewWatcher()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -272,7 +272,7 @@ func (v *Viper) WatchConfig() {
|
||||||
// we have to watch the entire directory to pick up renames/atomic saves in a cross-platform way
|
// we have to watch the entire directory to pick up renames/atomic saves in a cross-platform way
|
||||||
filename, err := v.getConfigFile()
|
filename, err := v.getConfigFile()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("error:", err)
|
log.Printf("error: %v\n", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -280,12 +280,16 @@ func (v *Viper) WatchConfig() {
|
||||||
configDir, _ := filepath.Split(configFile)
|
configDir, _ := filepath.Split(configFile)
|
||||||
realConfigFile, _ := filepath.EvalSymlinks(filename)
|
realConfigFile, _ := filepath.EvalSymlinks(filename)
|
||||||
|
|
||||||
done := make(chan bool)
|
eventsWG := sync.WaitGroup{}
|
||||||
|
eventsWG.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
loop:
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case event := <-watcher.Events:
|
case event, ok := <-watcher.Events:
|
||||||
|
if !ok { // 'Events' channel is closed
|
||||||
|
eventsWG.Done()
|
||||||
|
return
|
||||||
|
}
|
||||||
currentConfigFile, _ := filepath.EvalSymlinks(filename)
|
currentConfigFile, _ := filepath.EvalSymlinks(filename)
|
||||||
// we only care about the config file with the following cases:
|
// we only care about the config file with the following cases:
|
||||||
// 1 - if the config file was modified or created
|
// 1 - if the config file was modified or created
|
||||||
|
@ -296,28 +300,31 @@ func (v *Viper) WatchConfig() {
|
||||||
realConfigFile = currentConfigFile
|
realConfigFile = currentConfigFile
|
||||||
err := v.ReadInConfig()
|
err := v.ReadInConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("error reading file:", err.Error())
|
log.Printf("error reading file: %v\n", err)
|
||||||
}
|
}
|
||||||
if v.onConfigChange != nil {
|
if v.onConfigChange != nil {
|
||||||
v.onConfigChange(event)
|
v.onConfigChange(event)
|
||||||
}
|
}
|
||||||
} else if filepath.Clean(event.Name) == configFile &&
|
} else if filepath.Clean(event.Name) == configFile &&
|
||||||
event.Op&fsnotify.Remove == fsnotify.Remove {
|
event.Op&fsnotify.Remove == fsnotify.Remove {
|
||||||
done <- true
|
eventsWG.Done()
|
||||||
break loop
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
case err := <-watcher.Errors:
|
case err, ok := <-watcher.Errors:
|
||||||
|
if ok { // 'Errors' channel is not closed
|
||||||
log.Printf("watcher error: %v\n", err)
|
log.Printf("watcher error: %v\n", err)
|
||||||
}
|
}
|
||||||
|
eventsWG.Done()
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
watcher.Add(configDir)
|
watcher.Add(configDir)
|
||||||
wg.Done() // done initalizing the watch in this go routine, so the parent routine can move on...
|
initWG.Done() // done initalizing the watch in this go routine, so the parent routine can move on...
|
||||||
<-done // block until the watched file is removed...
|
eventsWG.Wait() // now, wait for event loop to end in this go-routine...
|
||||||
}()
|
}()
|
||||||
// make sure that the go routine above fully started before returning
|
initWG.Wait() // make sure that the go routine above fully ended before returning
|
||||||
wg.Wait()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetConfigFile explicitly defines the path, name and extension of the config file.
|
// SetConfigFile explicitly defines the path, name and extension of the config file.
|
||||||
|
|
|
@ -1419,9 +1419,11 @@ func newViperWithSymlinkedConfigFile(t *testing.T) (*Viper, string, string, func
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWatchFile(t *testing.T) {
|
func TestWatchFile(t *testing.T) {
|
||||||
|
|
||||||
t.Run("file content changed", func(t *testing.T) {
|
t.Run("file content changed", func(t *testing.T) {
|
||||||
// given a `config.yaml` file being watched
|
// given a `config.yaml` file being watched
|
||||||
v, configFile, cleanup := newViperWithConfigFile(t)
|
v, configFile, cleanup := newViperWithConfigFile(t)
|
||||||
|
fmt.Printf("test config file: %s\n", configFile)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
wg := sync.WaitGroup{}
|
wg := sync.WaitGroup{}
|
||||||
v.WatchConfig()
|
v.WatchConfig()
|
||||||
|
@ -1466,7 +1468,6 @@ func TestWatchFile(t *testing.T) {
|
||||||
// then
|
// then
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Equal(t, "baz", v.Get("foo"))
|
assert.Equal(t, "baz", v.Get("foo"))
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue