From a0e14aa66e12ff5b5746b074e9ac7c2d0af1ee72 Mon Sep 17 00:00:00 2001 From: Nathan Hack Date: Fri, 19 Nov 2021 14:25:45 -0500 Subject: [PATCH] added context --- redcon.go | 51 ++++++++++++++++++++++++++++++++++++++++---------- redcon_test.go | 27 +++++++++++++++++++++++--- 2 files changed, 65 insertions(+), 13 deletions(-) diff --git a/redcon.go b/redcon.go index 46b7d65..80614ee 100644 --- a/redcon.go +++ b/redcon.go @@ -3,6 +3,7 @@ package redcon import ( "bufio" + "context" "crypto/tls" "errors" "fmt" @@ -23,6 +24,7 @@ var ( errDetached = errors.New("detached") errIncompleteCommand = errors.New("incomplete command") errTooMuchData = errors.New("too much data") + errContextDone = errors.New("context done") ) type errProtocol struct { @@ -114,27 +116,31 @@ type Conn interface { } // NewServer returns a new Redcon server configured on "tcp" network net. -func NewServer(addr string, +func NewServer( + ctx context.Context, + addr string, handler func(conn Conn, cmd Command), accept func(conn Conn) bool, closed func(conn Conn, err error), ) *Server { - return NewServerNetwork("tcp", addr, handler, accept, closed) + return NewServerNetwork(ctx, "tcp", addr, handler, accept, closed) } // NewServerTLS returns a new Redcon TLS server configured on "tcp" network net. -func NewServerTLS(addr string, +func NewServerTLS(ctx context.Context, + addr string, handler func(conn Conn, cmd Command), accept func(conn Conn) bool, closed func(conn Conn, err error), config *tls.Config, ) *TLSServer { - return NewServerNetworkTLS("tcp", addr, handler, accept, closed, config) + return NewServerNetworkTLS(ctx, "tcp", addr, handler, accept, closed, config) } // NewServerNetwork returns a new Redcon server. The network net must be // a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket" func NewServerNetwork( + ctx context.Context, net, laddr string, handler func(conn Conn, cmd Command), accept func(conn Conn) bool, @@ -144,6 +150,7 @@ func NewServerNetwork( panic("handler is nil") } s := &Server{ + ctx: ctx, net: net, laddr: laddr, handler: handler, @@ -157,6 +164,7 @@ func NewServerNetwork( // NewServerNetworkTLS returns a new TLS Redcon server. The network net must be // a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket" func NewServerNetworkTLS( + ctx context.Context, net, laddr string, handler func(conn Conn, cmd Command), accept func(conn Conn) bool, @@ -167,6 +175,7 @@ func NewServerNetworkTLS( panic("handler is nil") } s := Server{ + ctx: ctx, net: net, laddr: laddr, handler: handler, @@ -241,50 +250,58 @@ func Serve(ln net.Listener, } // ListenAndServe creates a new server and binds to addr configured on "tcp" network net. -func ListenAndServe(addr string, +func ListenAndServe( + ctx context.Context, + addr string, handler func(conn Conn, cmd Command), accept func(conn Conn) bool, closed func(conn Conn, err error), ) error { - return ListenAndServeNetwork("tcp", addr, handler, accept, closed) + return ListenAndServeNetwork(ctx, "tcp", addr, handler, accept, closed) } // ListenAndServeTLS creates a new TLS server and binds to addr configured on "tcp" network net. -func ListenAndServeTLS(addr string, +func ListenAndServeTLS( + ctx context.Context, + addr string, handler func(conn Conn, cmd Command), accept func(conn Conn) bool, closed func(conn Conn, err error), config *tls.Config, ) error { - return ListenAndServeNetworkTLS("tcp", addr, handler, accept, closed, config) + return ListenAndServeNetworkTLS(ctx, "tcp", addr, handler, accept, closed, config) } // ListenAndServeNetwork creates a new server and binds to addr. The network net must be // a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket" func ListenAndServeNetwork( + ctx context.Context, net, laddr string, handler func(conn Conn, cmd Command), accept func(conn Conn) bool, closed func(conn Conn, err error), ) error { - return NewServerNetwork(net, laddr, handler, accept, closed).ListenAndServe() + return NewServerNetwork(ctx, net, laddr, handler, accept, closed).ListenAndServe() } // ListenAndServeNetworkTLS creates a new TLS server and binds to addr. The network net must be // a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket" func ListenAndServeNetworkTLS( + ctx context.Context, net, laddr string, handler func(conn Conn, cmd Command), accept func(conn Conn) bool, closed func(conn Conn, err error), config *tls.Config, ) error { - return NewServerNetworkTLS(net, laddr, handler, accept, closed, config).ListenAndServe() + return NewServerNetworkTLS(ctx, net, laddr, handler, accept, closed, config).ListenAndServe() } // ListenServeAndSignal serves incoming connections and passes nil or error // when listening. signal can be nil. func (s *Server) ListenServeAndSignal(signal chan error) error { + //var lc net.ListenConfig + //ln, err := lc.Listen(s.ctx, s.net, s.laddr) ln, err := net.Listen(s.net, s.laddr) if err != nil { if signal != nil { @@ -336,6 +353,14 @@ func serve(s *Server) error { s.conns = nil }() }() + + go func() { + select { + case <-s.ctx.Done(): + s.Close() + } + }() + for { lnconn, err := s.ln.Accept() if err != nil { @@ -343,6 +368,11 @@ func serve(s *Server) error { done := s.done s.mu.Unlock() if done { + select { + case <-s.ctx.Done(): + return errContextDone + default: + } return nil } if s.AcceptError != nil { @@ -547,6 +577,7 @@ type Command struct { // Server defines a server for clients for managing client connections. type Server struct { + ctx context.Context mu sync.Mutex net string laddr string diff --git a/redcon_test.go b/redcon_test.go index 757320e..33cdece 100644 --- a/redcon_test.go +++ b/redcon_test.go @@ -3,6 +3,7 @@ package redcon import ( "bufio" "bytes" + "context" "fmt" "io" "log" @@ -220,7 +221,8 @@ func TestServerUnix(t *testing.T) { } func testServerNetwork(t *testing.T, network, laddr string) { - s := NewServerNetwork(network, laddr, + ctx := context.Background() + s := NewServerNetwork(ctx, network, laddr, func(conn Conn, cmd Command) { switch strings.ToLower(string(cmd.Args[0])) { default: @@ -261,7 +263,7 @@ func testServerNetwork(t *testing.T, network, laddr string) { } go func() { time.Sleep(time.Second / 4) - if err := ListenAndServeNetwork(network, laddr, func(conn Conn, cmd Command) {}, nil, nil); err == nil { + if err := ListenAndServeNetwork(ctx, network, laddr, func(conn Conn, cmd Command) {}, nil, nil); err == nil { panic("expected an error, should not be able to listen on the same port") } time.Sleep(time.Second / 4) @@ -560,6 +562,7 @@ func TestParse(t *testing.T) { func TestPubSub(t *testing.T) { addr := ":12346" done := make(chan bool) + ctx := context.Background() go func() { var ps PubSub go func() { @@ -593,7 +596,7 @@ func TestPubSub(t *testing.T) { ps.Publish(channel, message) } }() - panic(ListenAndServe(addr, func(conn Conn, cmd Command) { + panic(ListenAndServe(ctx, addr, func(conn Conn, cmd Command) { switch strings.ToLower(string(cmd.Args[0])) { default: conn.WriteError("ERR unknown command '" + @@ -738,3 +741,21 @@ func TestPubSub(t *testing.T) { // stop the timeout final <- true } + +func TestContextDone(t *testing.T) { + var err error + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(1) + s := NewServerNetwork(ctx, "tcp", ":12345", func(conn Conn, cmd Command) {}, nil, nil) + go func() { + err = s.ListenAndServe() + wg.Done() + }() + time.Sleep(1 * time.Second) + cancel() + wg.Wait() + if err != errContextDone { + t.Fatalf("expected %v but found %v", errContextDone, err) + } +}