added context

This commit is contained in:
Nathan Hack 2021-11-19 14:25:45 -05:00
parent bef3c6ddbd
commit a0e14aa66e
2 changed files with 65 additions and 13 deletions

View File

@ -3,6 +3,7 @@ package redcon
import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
@ -23,6 +24,7 @@ var (
errDetached = errors.New("detached")
errIncompleteCommand = errors.New("incomplete command")
errTooMuchData = errors.New("too much data")
errContextDone = errors.New("context done")
)
type errProtocol struct {
@ -114,27 +116,31 @@ type Conn interface {
}
// NewServer returns a new Redcon server configured on "tcp" network net.
func NewServer(addr string,
func NewServer(
ctx context.Context,
addr string,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
closed func(conn Conn, err error),
) *Server {
return NewServerNetwork("tcp", addr, handler, accept, closed)
return NewServerNetwork(ctx, "tcp", addr, handler, accept, closed)
}
// NewServerTLS returns a new Redcon TLS server configured on "tcp" network net.
func NewServerTLS(addr string,
func NewServerTLS(ctx context.Context,
addr string,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
closed func(conn Conn, err error),
config *tls.Config,
) *TLSServer {
return NewServerNetworkTLS("tcp", addr, handler, accept, closed, config)
return NewServerNetworkTLS(ctx, "tcp", addr, handler, accept, closed, config)
}
// NewServerNetwork returns a new Redcon server. The network net must be
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
func NewServerNetwork(
ctx context.Context,
net, laddr string,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
@ -144,6 +150,7 @@ func NewServerNetwork(
panic("handler is nil")
}
s := &Server{
ctx: ctx,
net: net,
laddr: laddr,
handler: handler,
@ -157,6 +164,7 @@ func NewServerNetwork(
// NewServerNetworkTLS returns a new TLS Redcon server. The network net must be
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
func NewServerNetworkTLS(
ctx context.Context,
net, laddr string,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
@ -167,6 +175,7 @@ func NewServerNetworkTLS(
panic("handler is nil")
}
s := Server{
ctx: ctx,
net: net,
laddr: laddr,
handler: handler,
@ -241,50 +250,58 @@ func Serve(ln net.Listener,
}
// ListenAndServe creates a new server and binds to addr configured on "tcp" network net.
func ListenAndServe(addr string,
func ListenAndServe(
ctx context.Context,
addr string,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
closed func(conn Conn, err error),
) error {
return ListenAndServeNetwork("tcp", addr, handler, accept, closed)
return ListenAndServeNetwork(ctx, "tcp", addr, handler, accept, closed)
}
// ListenAndServeTLS creates a new TLS server and binds to addr configured on "tcp" network net.
func ListenAndServeTLS(addr string,
func ListenAndServeTLS(
ctx context.Context,
addr string,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
closed func(conn Conn, err error),
config *tls.Config,
) error {
return ListenAndServeNetworkTLS("tcp", addr, handler, accept, closed, config)
return ListenAndServeNetworkTLS(ctx, "tcp", addr, handler, accept, closed, config)
}
// ListenAndServeNetwork creates a new server and binds to addr. The network net must be
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
func ListenAndServeNetwork(
ctx context.Context,
net, laddr string,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
closed func(conn Conn, err error),
) error {
return NewServerNetwork(net, laddr, handler, accept, closed).ListenAndServe()
return NewServerNetwork(ctx, net, laddr, handler, accept, closed).ListenAndServe()
}
// ListenAndServeNetworkTLS creates a new TLS server and binds to addr. The network net must be
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
func ListenAndServeNetworkTLS(
ctx context.Context,
net, laddr string,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
closed func(conn Conn, err error),
config *tls.Config,
) error {
return NewServerNetworkTLS(net, laddr, handler, accept, closed, config).ListenAndServe()
return NewServerNetworkTLS(ctx, net, laddr, handler, accept, closed, config).ListenAndServe()
}
// ListenServeAndSignal serves incoming connections and passes nil or error
// when listening. signal can be nil.
func (s *Server) ListenServeAndSignal(signal chan error) error {
//var lc net.ListenConfig
//ln, err := lc.Listen(s.ctx, s.net, s.laddr)
ln, err := net.Listen(s.net, s.laddr)
if err != nil {
if signal != nil {
@ -336,6 +353,14 @@ func serve(s *Server) error {
s.conns = nil
}()
}()
go func() {
select {
case <-s.ctx.Done():
s.Close()
}
}()
for {
lnconn, err := s.ln.Accept()
if err != nil {
@ -343,6 +368,11 @@ func serve(s *Server) error {
done := s.done
s.mu.Unlock()
if done {
select {
case <-s.ctx.Done():
return errContextDone
default:
}
return nil
}
if s.AcceptError != nil {
@ -547,6 +577,7 @@ type Command struct {
// Server defines a server for clients for managing client connections.
type Server struct {
ctx context.Context
mu sync.Mutex
net string
laddr string

View File

@ -3,6 +3,7 @@ package redcon
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"log"
@ -220,7 +221,8 @@ func TestServerUnix(t *testing.T) {
}
func testServerNetwork(t *testing.T, network, laddr string) {
s := NewServerNetwork(network, laddr,
ctx := context.Background()
s := NewServerNetwork(ctx, network, laddr,
func(conn Conn, cmd Command) {
switch strings.ToLower(string(cmd.Args[0])) {
default:
@ -261,7 +263,7 @@ func testServerNetwork(t *testing.T, network, laddr string) {
}
go func() {
time.Sleep(time.Second / 4)
if err := ListenAndServeNetwork(network, laddr, func(conn Conn, cmd Command) {}, nil, nil); err == nil {
if err := ListenAndServeNetwork(ctx, network, laddr, func(conn Conn, cmd Command) {}, nil, nil); err == nil {
panic("expected an error, should not be able to listen on the same port")
}
time.Sleep(time.Second / 4)
@ -560,6 +562,7 @@ func TestParse(t *testing.T) {
func TestPubSub(t *testing.T) {
addr := ":12346"
done := make(chan bool)
ctx := context.Background()
go func() {
var ps PubSub
go func() {
@ -593,7 +596,7 @@ func TestPubSub(t *testing.T) {
ps.Publish(channel, message)
}
}()
panic(ListenAndServe(addr, func(conn Conn, cmd Command) {
panic(ListenAndServe(ctx, addr, func(conn Conn, cmd Command) {
switch strings.ToLower(string(cmd.Args[0])) {
default:
conn.WriteError("ERR unknown command '" +
@ -738,3 +741,21 @@ func TestPubSub(t *testing.T) {
// stop the timeout
final <- true
}
func TestContextDone(t *testing.T) {
var err error
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(1)
s := NewServerNetwork(ctx, "tcp", ":12345", func(conn Conn, cmd Command) {}, nil, nil)
go func() {
err = s.ListenAndServe()
wg.Done()
}()
time.Sleep(1 * time.Second)
cancel()
wg.Wait()
if err != errContextDone {
t.Fatalf("expected %v but found %v", errContextDone, err)
}
}