diff --git a/remote.go b/remote.go index d20c037..2d1847c 100644 --- a/remote.go +++ b/remote.go @@ -21,9 +21,11 @@ const ( T_ISTTY_REPORT T_RAW T_ERAW // exit raw + T_EOF ) type RemoteSvr struct { + eof int32 closed int32 width int32 reciveChan chan struct{} @@ -31,6 +33,7 @@ type RemoteSvr struct { conn net.Conn isTerminal bool funcWidthChan func() + stopChan chan struct{} dataBufM sync.Mutex dataBuf bytes.Buffer @@ -59,6 +62,7 @@ func NewRemoteSvr(conn net.Conn) (*RemoteSvr, error) { conn: conn, writeChan: make(chan *writeCtx), reciveChan: make(chan struct{}), + stopChan: make(chan struct{}), } buf := bufio.NewReader(rs.conn) @@ -113,16 +117,35 @@ func (r *RemoteSvr) IsTerminal() bool { return r.isTerminal } +func (r *RemoteSvr) checkEOF() error { + if atomic.LoadInt32(&r.eof) == 1 { + return io.EOF + } + return nil +} + func (r *RemoteSvr) Read(b []byte) (int, error) { r.dataBufM.Lock() n, err := r.dataBuf.Read(b) r.dataBufM.Unlock() + if n == 0 { + if err := r.checkEOF(); err != nil { + return 0, err + } + } + if n == 0 && err == io.EOF { <-r.reciveChan r.dataBufM.Lock() n, err = r.dataBuf.Read(b) r.dataBufM.Unlock() } + if n == 0 { + if err := r.checkEOF(); err != nil { + return 0, err + } + } + return n, err } @@ -151,19 +174,24 @@ func (r *RemoteSvr) ExitRawMode() error { func (r *RemoteSvr) writeLoop() { defer r.Close() +loop: for { - ctx, ok := <-r.writeChan - if !ok { - break + select { + case ctx, ok := <-r.writeChan: + if !ok { + break + } + n, err := ctx.msg.WriteTo(r.conn) + ctx.reply <- &writeReply{n, err} + case <-r.stopChan: + break loop } - n, err := ctx.msg.WriteTo(r.conn) - ctx.reply <- &writeReply{n, err} } } func (r *RemoteSvr) Close() { if atomic.CompareAndSwapInt32(&r.closed, 0, 1) { - close(r.writeChan) + close(r.stopChan) r.conn.Close() } } @@ -176,6 +204,12 @@ func (r *RemoteSvr) readLoop(buf *bufio.Reader) { break } switch m.Type { + case T_EOF: + atomic.StoreInt32(&r.eof, 1) + select { + case r.reciveChan <- struct{}{}: + default: + } case T_DATA: r.dataBufM.Lock() r.dataBuf.Write(m.Data) @@ -350,6 +384,7 @@ func (r *RemoteCli) Serve() error { for { n, _ := io.Copy(r, os.Stdin) if n == 0 { + r.writeMsg(NewMessage(T_EOF, nil)) break } }