diff --git a/readline.go b/readline.go index 14af1dd..1e67f51 100644 --- a/readline.go +++ b/readline.go @@ -87,7 +87,7 @@ func (c *Config) Init() error { } c.inited = true if c.Stdin == nil { - c.Stdin = Stdin + c.Stdin = NewCancelableStdin(Stdin) } if c.Stdout == nil { c.Stdout = Stdout diff --git a/std.go b/std.go index 28439f9..622e66e 100644 --- a/std.go +++ b/std.go @@ -7,7 +7,7 @@ import ( ) var ( - Stdin io.ReadCloser = NewCancelableStdin() + Stdin io.ReadCloser = os.Stdin Stdout io.WriteCloser = os.Stdout Stderr io.WriteCloser = os.Stderr ) @@ -66,6 +66,7 @@ func Line(prompt string) (string, error) { } type CancelableStdin struct { + r io.Reader mutex sync.Mutex stop chan struct{} notify chan struct{} @@ -75,11 +76,13 @@ type CancelableStdin struct { ioloopFired bool } -func NewCancelableStdin() *CancelableStdin { +func NewCancelableStdin(r io.Reader) *CancelableStdin { c := &CancelableStdin{ + r: r, notify: make(chan struct{}), stop: make(chan struct{}), } + go c.ioloop() return c } @@ -88,7 +91,7 @@ loop: for { select { case <-c.notify: - c.read, c.err = os.Stdin.Read(c.data) + c.read, c.err = c.r.Read(c.data) c.notify <- struct{}{} case <-c.stop: break loop @@ -99,10 +102,6 @@ loop: func (c *CancelableStdin) Read(b []byte) (n int, err error) { c.mutex.Lock() defer c.mutex.Unlock() - if !c.ioloopFired { - c.ioloopFired = true - go c.ioloop() - } c.data = b c.notify <- struct{}{}