mirror of https://github.com/spf13/viper.git
feat: add a new API StopWatching() to stop the groutine created by WatchConfig() manually
This commit is contained in:
parent
e40c5633a5
commit
59e479f6b8
30
viper.go
30
viper.go
|
@ -216,7 +216,8 @@ type Viper struct {
|
|||
aliases map[string]string
|
||||
typeByDefValue bool
|
||||
|
||||
onConfigChange func(fsnotify.Event)
|
||||
onConfigChange func(fsnotify.Event)
|
||||
stopWatchingFunc func()
|
||||
|
||||
logger Logger
|
||||
|
||||
|
@ -432,13 +433,10 @@ func (v *Viper) OnConfigChange(run func(in fsnotify.Event)) {
|
|||
}
|
||||
|
||||
// WatchConfig starts watching a config file for changes.
|
||||
// The function returned for stop watching manually.
|
||||
func WatchConfig() func() { return v.WatchConfig() }
|
||||
func WatchConfig() { v.WatchConfig() }
|
||||
|
||||
// WatchConfig starts watching a config file for changes.
|
||||
// The function returned for stop watching manually.
|
||||
func (v *Viper) WatchConfig() func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
func (v *Viper) WatchConfig() {
|
||||
initWG := sync.WaitGroup{}
|
||||
initWG.Add(1)
|
||||
go func() {
|
||||
|
@ -460,6 +458,10 @@ func (v *Viper) WatchConfig() func() {
|
|||
configDir, _ := filepath.Split(configFile)
|
||||
realConfigFile, _ := filepath.EvalSymlinks(filename)
|
||||
|
||||
// init the stopWatchingFunc
|
||||
watchingCtx, cancel := context.WithCancel(context.Background())
|
||||
v.stopWatchingFunc = cancel
|
||||
|
||||
eventsWG := sync.WaitGroup{}
|
||||
eventsWG.Add(1)
|
||||
go func() {
|
||||
|
@ -496,8 +498,9 @@ func (v *Viper) WatchConfig() func() {
|
|||
}
|
||||
eventsWG.Done()
|
||||
return
|
||||
case <-ctx.Done(): // cancel function called
|
||||
watcher.Close()
|
||||
case <-watchingCtx.Done(): // StopWatching function called
|
||||
eventsWG.Done()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
@ -506,7 +509,16 @@ func (v *Viper) WatchConfig() func() {
|
|||
eventsWG.Wait() // now, wait for event loop to end in this go-routine...
|
||||
}()
|
||||
initWG.Wait() // make sure that the go routine above fully ended before returning
|
||||
return cancel
|
||||
}
|
||||
|
||||
// StopWatching stop watching a config file for changes.
|
||||
func StopWatching() { v.StopWatching() }
|
||||
|
||||
// StopWatching stop watching a config file for changes.
|
||||
func (v *Viper) StopWatching() {
|
||||
if v.stopWatchingFunc != nil {
|
||||
v.stopWatchingFunc()
|
||||
}
|
||||
}
|
||||
|
||||
// SetConfigFile explicitly defines the path, name and extension of the config file.
|
||||
|
|
|
@ -2545,6 +2545,46 @@ func TestWatchFile(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestStopWatching(t *testing.T) {
|
||||
t.Run(
|
||||
"file content changed after stop watching", func(t *testing.T) {
|
||||
// given a `config.yaml` file being watched
|
||||
v, configFile, cleanup := newViperWithConfigFile(t)
|
||||
defer cleanup()
|
||||
_, err := os.Stat(configFile)
|
||||
require.NoError(t, err)
|
||||
t.Logf("test config file: %s\n", configFile)
|
||||
|
||||
v.WatchConfig()
|
||||
v.StopWatching()
|
||||
|
||||
// overwriting the file after StopWatching called
|
||||
err = ioutil.WriteFile(configFile, []byte("foo: baz\n"), 0o640)
|
||||
time.Sleep(time.Second) // wait for file changed event
|
||||
// then the config value should not be changed
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, "bar", v.Get("foo"))
|
||||
|
||||
// watch again
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
var wgDoneOnce sync.Once // OnConfigChange is called twice on Windows
|
||||
v.OnConfigChange(
|
||||
func(in fsnotify.Event) {
|
||||
t.Logf("config file changed again")
|
||||
wgDoneOnce.Do(func() { wg.Done() })
|
||||
},
|
||||
)
|
||||
v.WatchConfig()
|
||||
// overwriting the file after StopWatching and Watch again
|
||||
err = ioutil.WriteFile(configFile, []byte("foo: qux\n"), 0o640)
|
||||
wg.Wait()
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, "qux", v.Get("foo"))
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func TestUnmarshal_DotSeparatorBackwardCompatibility(t *testing.T) {
|
||||
flags := pflag.NewFlagSet("test", pflag.ContinueOnError)
|
||||
flags.String("foo.bar", "cobra_flag", "")
|
||||
|
|
Loading…
Reference in New Issue