added context

This commit is contained in:
Nathan Hack 2021-11-19 14:25:45 -05:00
parent bef3c6ddbd
commit a0e14aa66e
2 changed files with 65 additions and 13 deletions

View File

@ -3,6 +3,7 @@ package redcon
import ( import (
"bufio" "bufio"
"context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
@ -23,6 +24,7 @@ var (
errDetached = errors.New("detached") errDetached = errors.New("detached")
errIncompleteCommand = errors.New("incomplete command") errIncompleteCommand = errors.New("incomplete command")
errTooMuchData = errors.New("too much data") errTooMuchData = errors.New("too much data")
errContextDone = errors.New("context done")
) )
type errProtocol struct { type errProtocol struct {
@ -114,27 +116,31 @@ type Conn interface {
} }
// NewServer returns a new Redcon server configured on "tcp" network net. // NewServer returns a new Redcon server configured on "tcp" network net.
func NewServer(addr string, func NewServer(
ctx context.Context,
addr string,
handler func(conn Conn, cmd Command), handler func(conn Conn, cmd Command),
accept func(conn Conn) bool, accept func(conn Conn) bool,
closed func(conn Conn, err error), closed func(conn Conn, err error),
) *Server { ) *Server {
return NewServerNetwork("tcp", addr, handler, accept, closed) return NewServerNetwork(ctx, "tcp", addr, handler, accept, closed)
} }
// NewServerTLS returns a new Redcon TLS server configured on "tcp" network net. // NewServerTLS returns a new Redcon TLS server configured on "tcp" network net.
func NewServerTLS(addr string, func NewServerTLS(ctx context.Context,
addr string,
handler func(conn Conn, cmd Command), handler func(conn Conn, cmd Command),
accept func(conn Conn) bool, accept func(conn Conn) bool,
closed func(conn Conn, err error), closed func(conn Conn, err error),
config *tls.Config, config *tls.Config,
) *TLSServer { ) *TLSServer {
return NewServerNetworkTLS("tcp", addr, handler, accept, closed, config) return NewServerNetworkTLS(ctx, "tcp", addr, handler, accept, closed, config)
} }
// NewServerNetwork returns a new Redcon server. The network net must be // NewServerNetwork returns a new Redcon server. The network net must be
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket" // a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
func NewServerNetwork( func NewServerNetwork(
ctx context.Context,
net, laddr string, net, laddr string,
handler func(conn Conn, cmd Command), handler func(conn Conn, cmd Command),
accept func(conn Conn) bool, accept func(conn Conn) bool,
@ -144,6 +150,7 @@ func NewServerNetwork(
panic("handler is nil") panic("handler is nil")
} }
s := &Server{ s := &Server{
ctx: ctx,
net: net, net: net,
laddr: laddr, laddr: laddr,
handler: handler, handler: handler,
@ -157,6 +164,7 @@ func NewServerNetwork(
// NewServerNetworkTLS returns a new TLS Redcon server. The network net must be // NewServerNetworkTLS returns a new TLS Redcon server. The network net must be
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket" // a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
func NewServerNetworkTLS( func NewServerNetworkTLS(
ctx context.Context,
net, laddr string, net, laddr string,
handler func(conn Conn, cmd Command), handler func(conn Conn, cmd Command),
accept func(conn Conn) bool, accept func(conn Conn) bool,
@ -167,6 +175,7 @@ func NewServerNetworkTLS(
panic("handler is nil") panic("handler is nil")
} }
s := Server{ s := Server{
ctx: ctx,
net: net, net: net,
laddr: laddr, laddr: laddr,
handler: handler, handler: handler,
@ -241,50 +250,58 @@ func Serve(ln net.Listener,
} }
// ListenAndServe creates a new server and binds to addr configured on "tcp" network net. // ListenAndServe creates a new server and binds to addr configured on "tcp" network net.
func ListenAndServe(addr string, func ListenAndServe(
ctx context.Context,
addr string,
handler func(conn Conn, cmd Command), handler func(conn Conn, cmd Command),
accept func(conn Conn) bool, accept func(conn Conn) bool,
closed func(conn Conn, err error), closed func(conn Conn, err error),
) error { ) error {
return ListenAndServeNetwork("tcp", addr, handler, accept, closed) return ListenAndServeNetwork(ctx, "tcp", addr, handler, accept, closed)
} }
// ListenAndServeTLS creates a new TLS server and binds to addr configured on "tcp" network net. // ListenAndServeTLS creates a new TLS server and binds to addr configured on "tcp" network net.
func ListenAndServeTLS(addr string, func ListenAndServeTLS(
ctx context.Context,
addr string,
handler func(conn Conn, cmd Command), handler func(conn Conn, cmd Command),
accept func(conn Conn) bool, accept func(conn Conn) bool,
closed func(conn Conn, err error), closed func(conn Conn, err error),
config *tls.Config, config *tls.Config,
) error { ) error {
return ListenAndServeNetworkTLS("tcp", addr, handler, accept, closed, config) return ListenAndServeNetworkTLS(ctx, "tcp", addr, handler, accept, closed, config)
} }
// ListenAndServeNetwork creates a new server and binds to addr. The network net must be // ListenAndServeNetwork creates a new server and binds to addr. The network net must be
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket" // a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
func ListenAndServeNetwork( func ListenAndServeNetwork(
ctx context.Context,
net, laddr string, net, laddr string,
handler func(conn Conn, cmd Command), handler func(conn Conn, cmd Command),
accept func(conn Conn) bool, accept func(conn Conn) bool,
closed func(conn Conn, err error), closed func(conn Conn, err error),
) error { ) error {
return NewServerNetwork(net, laddr, handler, accept, closed).ListenAndServe() return NewServerNetwork(ctx, net, laddr, handler, accept, closed).ListenAndServe()
} }
// ListenAndServeNetworkTLS creates a new TLS server and binds to addr. The network net must be // ListenAndServeNetworkTLS creates a new TLS server and binds to addr. The network net must be
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket" // a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
func ListenAndServeNetworkTLS( func ListenAndServeNetworkTLS(
ctx context.Context,
net, laddr string, net, laddr string,
handler func(conn Conn, cmd Command), handler func(conn Conn, cmd Command),
accept func(conn Conn) bool, accept func(conn Conn) bool,
closed func(conn Conn, err error), closed func(conn Conn, err error),
config *tls.Config, config *tls.Config,
) error { ) error {
return NewServerNetworkTLS(net, laddr, handler, accept, closed, config).ListenAndServe() return NewServerNetworkTLS(ctx, net, laddr, handler, accept, closed, config).ListenAndServe()
} }
// ListenServeAndSignal serves incoming connections and passes nil or error // ListenServeAndSignal serves incoming connections and passes nil or error
// when listening. signal can be nil. // when listening. signal can be nil.
func (s *Server) ListenServeAndSignal(signal chan error) error { func (s *Server) ListenServeAndSignal(signal chan error) error {
//var lc net.ListenConfig
//ln, err := lc.Listen(s.ctx, s.net, s.laddr)
ln, err := net.Listen(s.net, s.laddr) ln, err := net.Listen(s.net, s.laddr)
if err != nil { if err != nil {
if signal != nil { if signal != nil {
@ -336,6 +353,14 @@ func serve(s *Server) error {
s.conns = nil s.conns = nil
}() }()
}() }()
go func() {
select {
case <-s.ctx.Done():
s.Close()
}
}()
for { for {
lnconn, err := s.ln.Accept() lnconn, err := s.ln.Accept()
if err != nil { if err != nil {
@ -343,6 +368,11 @@ func serve(s *Server) error {
done := s.done done := s.done
s.mu.Unlock() s.mu.Unlock()
if done { if done {
select {
case <-s.ctx.Done():
return errContextDone
default:
}
return nil return nil
} }
if s.AcceptError != nil { if s.AcceptError != nil {
@ -547,6 +577,7 @@ type Command struct {
// Server defines a server for clients for managing client connections. // Server defines a server for clients for managing client connections.
type Server struct { type Server struct {
ctx context.Context
mu sync.Mutex mu sync.Mutex
net string net string
laddr string laddr string

View File

@ -3,6 +3,7 @@ package redcon
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"fmt" "fmt"
"io" "io"
"log" "log"
@ -220,7 +221,8 @@ func TestServerUnix(t *testing.T) {
} }
func testServerNetwork(t *testing.T, network, laddr string) { func testServerNetwork(t *testing.T, network, laddr string) {
s := NewServerNetwork(network, laddr, ctx := context.Background()
s := NewServerNetwork(ctx, network, laddr,
func(conn Conn, cmd Command) { func(conn Conn, cmd Command) {
switch strings.ToLower(string(cmd.Args[0])) { switch strings.ToLower(string(cmd.Args[0])) {
default: default:
@ -261,7 +263,7 @@ func testServerNetwork(t *testing.T, network, laddr string) {
} }
go func() { go func() {
time.Sleep(time.Second / 4) time.Sleep(time.Second / 4)
if err := ListenAndServeNetwork(network, laddr, func(conn Conn, cmd Command) {}, nil, nil); err == nil { if err := ListenAndServeNetwork(ctx, network, laddr, func(conn Conn, cmd Command) {}, nil, nil); err == nil {
panic("expected an error, should not be able to listen on the same port") panic("expected an error, should not be able to listen on the same port")
} }
time.Sleep(time.Second / 4) time.Sleep(time.Second / 4)
@ -560,6 +562,7 @@ func TestParse(t *testing.T) {
func TestPubSub(t *testing.T) { func TestPubSub(t *testing.T) {
addr := ":12346" addr := ":12346"
done := make(chan bool) done := make(chan bool)
ctx := context.Background()
go func() { go func() {
var ps PubSub var ps PubSub
go func() { go func() {
@ -593,7 +596,7 @@ func TestPubSub(t *testing.T) {
ps.Publish(channel, message) ps.Publish(channel, message)
} }
}() }()
panic(ListenAndServe(addr, func(conn Conn, cmd Command) { panic(ListenAndServe(ctx, addr, func(conn Conn, cmd Command) {
switch strings.ToLower(string(cmd.Args[0])) { switch strings.ToLower(string(cmd.Args[0])) {
default: default:
conn.WriteError("ERR unknown command '" + conn.WriteError("ERR unknown command '" +
@ -738,3 +741,21 @@ func TestPubSub(t *testing.T) {
// stop the timeout // stop the timeout
final <- true final <- true
} }
func TestContextDone(t *testing.T) {
var err error
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(1)
s := NewServerNetwork(ctx, "tcp", ":12345", func(conn Conn, cmd Command) {}, nil, nil)
go func() {
err = s.ListenAndServe()
wg.Done()
}()
time.Sleep(1 * time.Second)
cancel()
wg.Wait()
if err != errContextDone {
t.Fatalf("expected %v but found %v", errContextDone, err)
}
}