mirror of https://github.com/tidwall/redcon.git
added context
This commit is contained in:
parent
bef3c6ddbd
commit
a0e14aa66e
51
redcon.go
51
redcon.go
|
@ -3,6 +3,7 @@ package redcon
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -23,6 +24,7 @@ var (
|
|||
errDetached = errors.New("detached")
|
||||
errIncompleteCommand = errors.New("incomplete command")
|
||||
errTooMuchData = errors.New("too much data")
|
||||
errContextDone = errors.New("context done")
|
||||
)
|
||||
|
||||
type errProtocol struct {
|
||||
|
@ -114,27 +116,31 @@ type Conn interface {
|
|||
}
|
||||
|
||||
// NewServer returns a new Redcon server configured on "tcp" network net.
|
||||
func NewServer(addr string,
|
||||
func NewServer(
|
||||
ctx context.Context,
|
||||
addr string,
|
||||
handler func(conn Conn, cmd Command),
|
||||
accept func(conn Conn) bool,
|
||||
closed func(conn Conn, err error),
|
||||
) *Server {
|
||||
return NewServerNetwork("tcp", addr, handler, accept, closed)
|
||||
return NewServerNetwork(ctx, "tcp", addr, handler, accept, closed)
|
||||
}
|
||||
|
||||
// NewServerTLS returns a new Redcon TLS server configured on "tcp" network net.
|
||||
func NewServerTLS(addr string,
|
||||
func NewServerTLS(ctx context.Context,
|
||||
addr string,
|
||||
handler func(conn Conn, cmd Command),
|
||||
accept func(conn Conn) bool,
|
||||
closed func(conn Conn, err error),
|
||||
config *tls.Config,
|
||||
) *TLSServer {
|
||||
return NewServerNetworkTLS("tcp", addr, handler, accept, closed, config)
|
||||
return NewServerNetworkTLS(ctx, "tcp", addr, handler, accept, closed, config)
|
||||
}
|
||||
|
||||
// NewServerNetwork returns a new Redcon server. The network net must be
|
||||
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
|
||||
func NewServerNetwork(
|
||||
ctx context.Context,
|
||||
net, laddr string,
|
||||
handler func(conn Conn, cmd Command),
|
||||
accept func(conn Conn) bool,
|
||||
|
@ -144,6 +150,7 @@ func NewServerNetwork(
|
|||
panic("handler is nil")
|
||||
}
|
||||
s := &Server{
|
||||
ctx: ctx,
|
||||
net: net,
|
||||
laddr: laddr,
|
||||
handler: handler,
|
||||
|
@ -157,6 +164,7 @@ func NewServerNetwork(
|
|||
// NewServerNetworkTLS returns a new TLS Redcon server. The network net must be
|
||||
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
|
||||
func NewServerNetworkTLS(
|
||||
ctx context.Context,
|
||||
net, laddr string,
|
||||
handler func(conn Conn, cmd Command),
|
||||
accept func(conn Conn) bool,
|
||||
|
@ -167,6 +175,7 @@ func NewServerNetworkTLS(
|
|||
panic("handler is nil")
|
||||
}
|
||||
s := Server{
|
||||
ctx: ctx,
|
||||
net: net,
|
||||
laddr: laddr,
|
||||
handler: handler,
|
||||
|
@ -241,50 +250,58 @@ func Serve(ln net.Listener,
|
|||
}
|
||||
|
||||
// ListenAndServe creates a new server and binds to addr configured on "tcp" network net.
|
||||
func ListenAndServe(addr string,
|
||||
func ListenAndServe(
|
||||
ctx context.Context,
|
||||
addr string,
|
||||
handler func(conn Conn, cmd Command),
|
||||
accept func(conn Conn) bool,
|
||||
closed func(conn Conn, err error),
|
||||
) error {
|
||||
return ListenAndServeNetwork("tcp", addr, handler, accept, closed)
|
||||
return ListenAndServeNetwork(ctx, "tcp", addr, handler, accept, closed)
|
||||
}
|
||||
|
||||
// ListenAndServeTLS creates a new TLS server and binds to addr configured on "tcp" network net.
|
||||
func ListenAndServeTLS(addr string,
|
||||
func ListenAndServeTLS(
|
||||
ctx context.Context,
|
||||
addr string,
|
||||
handler func(conn Conn, cmd Command),
|
||||
accept func(conn Conn) bool,
|
||||
closed func(conn Conn, err error),
|
||||
config *tls.Config,
|
||||
) error {
|
||||
return ListenAndServeNetworkTLS("tcp", addr, handler, accept, closed, config)
|
||||
return ListenAndServeNetworkTLS(ctx, "tcp", addr, handler, accept, closed, config)
|
||||
}
|
||||
|
||||
// ListenAndServeNetwork creates a new server and binds to addr. The network net must be
|
||||
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
|
||||
func ListenAndServeNetwork(
|
||||
ctx context.Context,
|
||||
net, laddr string,
|
||||
handler func(conn Conn, cmd Command),
|
||||
accept func(conn Conn) bool,
|
||||
closed func(conn Conn, err error),
|
||||
) error {
|
||||
return NewServerNetwork(net, laddr, handler, accept, closed).ListenAndServe()
|
||||
return NewServerNetwork(ctx, net, laddr, handler, accept, closed).ListenAndServe()
|
||||
}
|
||||
|
||||
// ListenAndServeNetworkTLS creates a new TLS server and binds to addr. The network net must be
|
||||
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
|
||||
func ListenAndServeNetworkTLS(
|
||||
ctx context.Context,
|
||||
net, laddr string,
|
||||
handler func(conn Conn, cmd Command),
|
||||
accept func(conn Conn) bool,
|
||||
closed func(conn Conn, err error),
|
||||
config *tls.Config,
|
||||
) error {
|
||||
return NewServerNetworkTLS(net, laddr, handler, accept, closed, config).ListenAndServe()
|
||||
return NewServerNetworkTLS(ctx, net, laddr, handler, accept, closed, config).ListenAndServe()
|
||||
}
|
||||
|
||||
// ListenServeAndSignal serves incoming connections and passes nil or error
|
||||
// when listening. signal can be nil.
|
||||
func (s *Server) ListenServeAndSignal(signal chan error) error {
|
||||
//var lc net.ListenConfig
|
||||
//ln, err := lc.Listen(s.ctx, s.net, s.laddr)
|
||||
ln, err := net.Listen(s.net, s.laddr)
|
||||
if err != nil {
|
||||
if signal != nil {
|
||||
|
@ -336,6 +353,14 @@ func serve(s *Server) error {
|
|||
s.conns = nil
|
||||
}()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
s.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
lnconn, err := s.ln.Accept()
|
||||
if err != nil {
|
||||
|
@ -343,6 +368,11 @@ func serve(s *Server) error {
|
|||
done := s.done
|
||||
s.mu.Unlock()
|
||||
if done {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return errContextDone
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if s.AcceptError != nil {
|
||||
|
@ -547,6 +577,7 @@ type Command struct {
|
|||
|
||||
// Server defines a server for clients for managing client connections.
|
||||
type Server struct {
|
||||
ctx context.Context
|
||||
mu sync.Mutex
|
||||
net string
|
||||
laddr string
|
||||
|
|
|
@ -3,6 +3,7 @@ package redcon
|
|||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
|
@ -220,7 +221,8 @@ func TestServerUnix(t *testing.T) {
|
|||
}
|
||||
|
||||
func testServerNetwork(t *testing.T, network, laddr string) {
|
||||
s := NewServerNetwork(network, laddr,
|
||||
ctx := context.Background()
|
||||
s := NewServerNetwork(ctx, network, laddr,
|
||||
func(conn Conn, cmd Command) {
|
||||
switch strings.ToLower(string(cmd.Args[0])) {
|
||||
default:
|
||||
|
@ -261,7 +263,7 @@ func testServerNetwork(t *testing.T, network, laddr string) {
|
|||
}
|
||||
go func() {
|
||||
time.Sleep(time.Second / 4)
|
||||
if err := ListenAndServeNetwork(network, laddr, func(conn Conn, cmd Command) {}, nil, nil); err == nil {
|
||||
if err := ListenAndServeNetwork(ctx, network, laddr, func(conn Conn, cmd Command) {}, nil, nil); err == nil {
|
||||
panic("expected an error, should not be able to listen on the same port")
|
||||
}
|
||||
time.Sleep(time.Second / 4)
|
||||
|
@ -560,6 +562,7 @@ func TestParse(t *testing.T) {
|
|||
func TestPubSub(t *testing.T) {
|
||||
addr := ":12346"
|
||||
done := make(chan bool)
|
||||
ctx := context.Background()
|
||||
go func() {
|
||||
var ps PubSub
|
||||
go func() {
|
||||
|
@ -593,7 +596,7 @@ func TestPubSub(t *testing.T) {
|
|||
ps.Publish(channel, message)
|
||||
}
|
||||
}()
|
||||
panic(ListenAndServe(addr, func(conn Conn, cmd Command) {
|
||||
panic(ListenAndServe(ctx, addr, func(conn Conn, cmd Command) {
|
||||
switch strings.ToLower(string(cmd.Args[0])) {
|
||||
default:
|
||||
conn.WriteError("ERR unknown command '" +
|
||||
|
@ -738,3 +741,21 @@ func TestPubSub(t *testing.T) {
|
|||
// stop the timeout
|
||||
final <- true
|
||||
}
|
||||
|
||||
func TestContextDone(t *testing.T) {
|
||||
var err error
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
s := NewServerNetwork(ctx, "tcp", ":12345", func(conn Conn, cmd Command) {}, nil, nil)
|
||||
go func() {
|
||||
err = s.ListenAndServe()
|
||||
wg.Done()
|
||||
}()
|
||||
time.Sleep(1 * time.Second)
|
||||
cancel()
|
||||
wg.Wait()
|
||||
if err != errContextDone {
|
||||
t.Fatalf("expected %v but found %v", errContextDone, err)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue