diff --git a/cmd/ledis-server/main.go b/cmd/ledis-server/main.go index 31e056f..f1c695b 100644 --- a/cmd/ledis-server/main.go +++ b/cmd/ledis-server/main.go @@ -88,17 +88,15 @@ func main() { syscall.SIGTERM, syscall.SIGQUIT) - go func() { - <-sc - - app.Close() - }() - if *usePprof { go func() { log.Println(http.ListenAndServe(fmt.Sprintf(":%d", *pprofPort), nil)) }() } - app.Run() + go app.Run() + + <-sc + + app.Close() } diff --git a/ledis/replication.go b/ledis/replication.go index 593388d..4f02259 100644 --- a/ledis/replication.go +++ b/ledis/replication.go @@ -202,7 +202,7 @@ func (l *Ledis) ReadLogsTo(startLogID uint64, w io.Writer) (n int, nextLogID uin } // try to read events, if no events read, try to wait the new event singal until timeout seconds -func (l *Ledis) ReadLogsToTimeout(startLogID uint64, w io.Writer, timeout int) (n int, nextLogID uint64, err error) { +func (l *Ledis) ReadLogsToTimeout(startLogID uint64, w io.Writer, timeout int, quitCh chan struct{}) (n int, nextLogID uint64, err error) { n, nextLogID, err = l.ReadLogsTo(startLogID, w) if err != nil { return @@ -213,6 +213,8 @@ func (l *Ledis) ReadLogsToTimeout(startLogID uint64, w io.Writer, timeout int) ( select { case <-l.r.WaitLog(): case <-time.After(time.Duration(timeout) * time.Second): + case <-quitCh: + return } return l.ReadLogsTo(startLogID, w) } diff --git a/server/app.go b/server/app.go index a021865..393da09 100644 --- a/server/app.go +++ b/server/app.go @@ -37,6 +37,11 @@ type App struct { slaveSyncAck chan uint64 snap *snapshotStore + + connWait sync.WaitGroup + + rcm sync.Mutex + rcs map[*respClient]struct{} } func netType(s string) string { @@ -64,6 +69,8 @@ func NewApp(cfg *config.Config) (*App, error) { app.slaves = make(map[string]*client) app.slaveSyncAck = make(chan uint64) + app.rcs = make(map[*respClient]struct{}) + var err error if app.info, err = newInfo(app); err != nil { @@ -129,6 +136,11 @@ func (app *App) Close() { app.httpListener.Close() } + app.closeAllRespClients() + + //wait all connection closed + app.connWait.Wait() + app.closeScript() app.m.Close() diff --git a/server/client_http.go b/server/client_http.go index d039533..4383673 100644 --- a/server/client_http.go +++ b/server/client_http.go @@ -40,15 +40,18 @@ type httpWriter struct { } func newClientHTTP(app *App, w http.ResponseWriter, r *http.Request) { + app.connWait.Add(1) + defer app.connWait.Done() + var err error c := new(httpClient) - c.client = newClient(app) err = c.makeRequest(app, r, w) if err != nil { w.Write([]byte(err.Error())) return } + c.client = newClient(app) c.perform() c.client.close() } diff --git a/server/client_resp.go b/server/client_resp.go index b538f1a..65a83b4 100644 --- a/server/client_resp.go +++ b/server/client_resp.go @@ -16,6 +16,7 @@ import ( ) var errReadRequest = errors.New("invalid request protocol") +var errClientQuit = errors.New("remote client quit") type respClient struct { *client @@ -24,18 +25,51 @@ type respClient struct { rb *bufio.Reader ar *arena.Arena + + activeQuit bool } type respWriter struct { buff *bufio.Writer } +func (app *App) addRespClient(c *respClient) { + app.rcm.Lock() + app.rcs[c] = struct{}{} + app.rcm.Unlock() +} + +func (app *App) delRespClient(c *respClient) { + app.rcm.Lock() + delete(app.rcs, c) + app.rcm.Unlock() +} + +func (app *App) closeAllRespClients() { + app.rcm.Lock() + + for c := range app.rcs { + c.conn.Close() + } + + app.rcm.Unlock() +} + +func (app *App) respClientNum() int { + app.rcm.Lock() + n := len(app.rcs) + app.rcm.Unlock() + return n +} + func newClientRESP(conn net.Conn, app *App) { c := new(respClient) c.client = newClient(app) c.conn = conn + c.activeQuit = false + if tcpConn, ok := conn.(*net.TCPConn); ok { tcpConn.SetReadBuffer(app.cfg.ConnReadBufferSize) tcpConn.SetWriteBuffer(app.cfg.ConnWriteBufferSize) @@ -49,17 +83,15 @@ func newClientRESP(conn net.Conn, app *App) { //maybe another config? c.ar = arena.NewArena(app.cfg.ConnReadBufferSize) + app.connWait.Add(1) + + app.addRespClient(c) + go c.run() } func (c *respClient) run() { - c.app.info.addClients(1) - defer func() { - c.client.close() - - c.app.info.addClients(-1) - if e := recover(); e != nil { buf := make([]byte, 4096) n := runtime.Stack(buf, false) @@ -68,21 +100,30 @@ func (c *respClient) run() { log.Fatal("client run panic %s:%v", buf, e) } - handleQuit := true - if c.conn != nil { - //if handle quit command before, conn is nil - handleQuit = false - c.conn.Close() - } + c.client.close() + + c.conn.Close() if c.tx != nil { c.tx.Rollback() c.tx = nil } - c.app.removeSlave(c.client, handleQuit) + c.app.removeSlave(c.client, c.activeQuit) + + c.app.delRespClient(c) + + c.app.connWait.Done() }() + select { + case <-c.app.quit: + //check app closed + return + default: + break + } + kc := time.Duration(c.app.cfg.ConnKeepaliveInterval) * time.Second for { if kc > 0 { @@ -91,16 +132,12 @@ func (c *respClient) run() { reqData, err := c.readRequest() if err == nil { - c.handleRequest(reqData) + err = c.handleRequest(reqData) } if err != nil { return } - - if c.conn == nil { - return - } } } @@ -108,7 +145,7 @@ func (c *respClient) readRequest() ([][]byte, error) { return ReadRequest(c.rb, c.ar) } -func (c *respClient) handleRequest(reqData [][]byte) { +func (c *respClient) handleRequest(reqData [][]byte) error { if len(reqData) == 0 { c.cmd = "" c.args = reqData[0:0] @@ -117,11 +154,11 @@ func (c *respClient) handleRequest(reqData [][]byte) { c.args = reqData[1:] } if c.cmd == "quit" { + c.activeQuit = true c.resp.writeStatus(OK) c.resp.flush() c.conn.Close() - c.conn = nil - return + return errClientQuit } c.perform() @@ -131,7 +168,7 @@ func (c *respClient) handleRequest(reqData [][]byte) { c.ar.Reset() - return + return nil } // response writer diff --git a/server/cmd_replication.go b/server/cmd_replication.go index bc26968..b910e51 100644 --- a/server/cmd_replication.go +++ b/server/cmd_replication.go @@ -131,7 +131,7 @@ func syncCommand(c *client) error { c.syncBuf.Write(dummyBuf) - if _, _, err := c.app.ldb.ReadLogsToTimeout(logId, &c.syncBuf, 30); err != nil { + if _, _, err := c.app.ldb.ReadLogsToTimeout(logId, &c.syncBuf, 30, c.app.quit); err != nil { return err } else { buf := c.syncBuf.Bytes() diff --git a/server/info.go b/server/info.go index 0e807c1..b06b084 100644 --- a/server/info.go +++ b/server/info.go @@ -9,7 +9,6 @@ import ( "runtime/debug" "strings" "sync" - "sync/atomic" "time" ) @@ -23,10 +22,6 @@ type info struct { ProceessId int } - Clients struct { - ConnectedClients int64 - } - Replication struct { PubLogNum sync2.AtomicInt64 PubLogAckNum sync2.AtomicInt64 @@ -47,10 +42,6 @@ func newInfo(app *App) (i *info, err error) { return i, nil } -func (i *info) addClients(delta int64) { - atomic.AddInt64(&i.Clients.ConnectedClients, delta) -} - func (i *info) Close() { } @@ -116,7 +107,7 @@ func (i *info) dumpServer(buf *bytes.Buffer) { infoPair{"readonly", i.app.cfg.Readonly}, infoPair{"goroutine_num", runtime.NumGoroutine()}, infoPair{"cgo_call_num", runtime.NumCgoCall()}, - infoPair{"client_num", i.Clients.ConnectedClients}, + infoPair{"resp_client_num", i.app.respClientNum()}, ) }