Fix race condition between writing and prompting.

Introduce a new operation.isPrompting field which is true from when a
prompt has been written until data is returned to the caller.
When Write is called on the wrapWriter to write to stdout or stderr,
check if we are currently prompting the user for input and if so
clean up the prompt and write the data before redrawing the
prompt at its new location after the written data.

Previously terminal.IsReading() was used for this, but this had various
race conditions and it was not correct to check this field to make
prompt and buffer redrawing decisions. In turn, I removed all the
isReading code also. The old isReading() check was actually checking
if the terminal goroutine was actively waiting for more input.
This commit is contained in:
Thomas O'Dowd 2023-03-03 18:01:34 +09:00
parent 586d8eebeb
commit 62ab2cfd17
3 changed files with 56 additions and 23 deletions

View File

@ -27,6 +27,8 @@ type Operation struct {
errchan chan error errchan chan error
w io.Writer w io.Writer
isPrompting bool // true when prompt written and waiting for input
history *opHistory history *opHistory
*opSearch *opSearch
*opCompleter *opCompleter
@ -39,29 +41,43 @@ func (o *Operation) SetBuffer(what string) {
} }
type wrapWriter struct { type wrapWriter struct {
r *Operation o *Operation
t *Terminal
target io.Writer target io.Writer
} }
func (w *wrapWriter) Write(b []byte) (int, error) { func (w *wrapWriter) Write(b []byte) (int, error) {
if !w.t.IsReading() { return w.o.write(w.target, b)
return w.target.Write(b) }
func (o *Operation) write(target io.Writer, b []byte) (int, error) {
o.m.Lock()
defer o.m.Unlock()
if !o.isPrompting {
return target.Write(b)
} }
var ( var (
n int n int
err error err error
) )
w.r.buf.Refresh(func() { o.buf.Refresh(func() {
n, err = w.target.Write(b) n, err = target.Write(b)
// Adjust the prompt start position by b
rout := runes.ColorFilter([]rune(string(b[:])))
sp := SplitByLine(rout, []rune{}, o.buf.ppos, o.buf.width, 1)
if len(sp) > 1 {
o.buf.ppos = len(sp[len(sp)-1])
} else {
o.buf.ppos += len(rout)
}
}) })
if w.r.IsSearchMode() { if o.IsSearchMode() {
w.r.SearchRefresh(-1) o.SearchRefresh(-1)
} }
if w.r.IsInCompleteMode() { if o.IsInCompleteMode() {
w.r.CompleteRefresh() o.CompleteRefresh()
} }
return n, err return n, err
} }
@ -352,7 +368,7 @@ func (o *Operation) ioloop() {
} else if o.IsInCompleteMode() { } else if o.IsInCompleteMode() {
if !keepInCompleteMode { if !keepInCompleteMode {
o.ExitCompleteMode(false) o.ExitCompleteMode(false)
o.Refresh() o.refresh()
} else { } else {
o.buf.Refresh(nil) o.buf.Refresh(nil)
o.CompleteRefresh() o.CompleteRefresh()
@ -367,11 +383,11 @@ func (o *Operation) ioloop() {
} }
func (o *Operation) Stderr() io.Writer { func (o *Operation) Stderr() io.Writer {
return &wrapWriter{target: o.GetConfig().Stderr, r: o, t: o.t} return &wrapWriter{target: o.GetConfig().Stderr, o: o}
} }
func (o *Operation) Stdout() io.Writer { func (o *Operation) Stdout() io.Writer {
return &wrapWriter{target: o.GetConfig().Stdout, r: o, t: o.t} return &wrapWriter{target: o.GetConfig().Stdout, o: o}
} }
func (o *Operation) String() (string, error) { func (o *Operation) String() (string, error) {
@ -388,6 +404,11 @@ func (o *Operation) Runes() ([]rune, error) {
listener.OnChange(nil, 0, 0) listener.OnChange(nil, 0, 0)
} }
// Before writing the prompt and starting to read, get a lock
// so we don't race with wrapWriter trying to write and refresh.
o.m.Lock()
o.isPrompting = true
// Query cursor position before printing the prompt as there // Query cursor position before printing the prompt as there
// maybe existing text on the same line that ideally we don't // maybe existing text on the same line that ideally we don't
// want to overwrite and cause prompt to jump left. Note that // want to overwrite and cause prompt to jump left. Note that
@ -396,6 +417,16 @@ func (o *Operation) Runes() ([]rune, error) {
o.buf.Print() // print prompt & buffer contents o.buf.Print() // print prompt & buffer contents
o.t.KickRead() o.t.KickRead()
// Prompt written safely, unlock until read completes and then
// lock again to unset.
o.m.Unlock()
defer func() {
o.m.Lock()
o.isPrompting = false
o.buf.SetOffset("1;1")
o.m.Unlock()
}()
select { select {
case r := <-o.outchan: case r := <-o.outchan:
return r, nil return r, nil
@ -508,7 +539,13 @@ func (o *Operation) SaveHistory(content string) error {
} }
func (o *Operation) Refresh() { func (o *Operation) Refresh() {
if o.t.IsReading() { o.m.Lock()
defer o.m.Unlock()
o.refresh()
}
func (o *Operation) refresh() {
if o.isPrompting {
o.buf.Refresh(nil) o.buf.Refresh(nil)
} }
} }

View File

@ -17,7 +17,6 @@ type Terminal struct {
stopChan chan struct{} stopChan chan struct{}
kickChan chan struct{} kickChan chan struct{}
wg sync.WaitGroup wg sync.WaitGroup
isReading int32
sleeping int32 sleeping int32
sizeChan chan string sizeChan chan string
@ -104,10 +103,6 @@ func (t *Terminal) ReadRune() rune {
return ch return ch
} }
func (t *Terminal) IsReading() bool {
return atomic.LoadInt32(&t.isReading) == 1
}
func (t *Terminal) KickRead() { func (t *Terminal) KickRead() {
select { select {
case t.kickChan <- struct{}{}: case t.kickChan <- struct{}{}:
@ -132,10 +127,8 @@ func (t *Terminal) ioloop() {
buf := bufio.NewReader(t.getStdin()) buf := bufio.NewReader(t.getStdin())
for { for {
if !expectNextChar { if !expectNextChar {
atomic.StoreInt32(&t.isReading, 0)
select { select {
case <-t.kickChan: case <-t.kickChan:
atomic.StoreInt32(&t.isReading, 1)
case <-t.stopChan: case <-t.stopChan:
return return
} }
@ -210,7 +203,6 @@ func (t *Terminal) ioloop() {
t.outchan <- r t.outchan <- r
} }
} }
} }
func (t *Terminal) Bell() { func (t *Terminal) Bell() {

View File

@ -224,7 +224,11 @@ func SplitByLine(prompt, rs []rune, offset, screenWidth, nextWidth int) [][]rune
currentWidth := offset currentWidth := offset
for i, r := range prs { for i, r := range prs {
w := runes.Width(r) w := runes.Width(r)
if currentWidth + w > screenWidth { if r == '\n' {
ret = append(ret, prs[si:i+1])
si = i + 1
currentWidth = 0
} else if currentWidth + w > screenWidth {
ret = append(ret, prs[si:i]) ret = append(ret, prs[si:i])
si = i si = i
currentWidth = 0 currentWidth = 0