mirror of https://github.com/tidwall/redcon.git
added context
This commit is contained in:
parent
bef3c6ddbd
commit
a0e14aa66e
51
redcon.go
51
redcon.go
|
@ -3,6 +3,7 @@ package redcon
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -23,6 +24,7 @@ var (
|
||||||
errDetached = errors.New("detached")
|
errDetached = errors.New("detached")
|
||||||
errIncompleteCommand = errors.New("incomplete command")
|
errIncompleteCommand = errors.New("incomplete command")
|
||||||
errTooMuchData = errors.New("too much data")
|
errTooMuchData = errors.New("too much data")
|
||||||
|
errContextDone = errors.New("context done")
|
||||||
)
|
)
|
||||||
|
|
||||||
type errProtocol struct {
|
type errProtocol struct {
|
||||||
|
@ -114,27 +116,31 @@ type Conn interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer returns a new Redcon server configured on "tcp" network net.
|
// NewServer returns a new Redcon server configured on "tcp" network net.
|
||||||
func NewServer(addr string,
|
func NewServer(
|
||||||
|
ctx context.Context,
|
||||||
|
addr string,
|
||||||
handler func(conn Conn, cmd Command),
|
handler func(conn Conn, cmd Command),
|
||||||
accept func(conn Conn) bool,
|
accept func(conn Conn) bool,
|
||||||
closed func(conn Conn, err error),
|
closed func(conn Conn, err error),
|
||||||
) *Server {
|
) *Server {
|
||||||
return NewServerNetwork("tcp", addr, handler, accept, closed)
|
return NewServerNetwork(ctx, "tcp", addr, handler, accept, closed)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServerTLS returns a new Redcon TLS server configured on "tcp" network net.
|
// NewServerTLS returns a new Redcon TLS server configured on "tcp" network net.
|
||||||
func NewServerTLS(addr string,
|
func NewServerTLS(ctx context.Context,
|
||||||
|
addr string,
|
||||||
handler func(conn Conn, cmd Command),
|
handler func(conn Conn, cmd Command),
|
||||||
accept func(conn Conn) bool,
|
accept func(conn Conn) bool,
|
||||||
closed func(conn Conn, err error),
|
closed func(conn Conn, err error),
|
||||||
config *tls.Config,
|
config *tls.Config,
|
||||||
) *TLSServer {
|
) *TLSServer {
|
||||||
return NewServerNetworkTLS("tcp", addr, handler, accept, closed, config)
|
return NewServerNetworkTLS(ctx, "tcp", addr, handler, accept, closed, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServerNetwork returns a new Redcon server. The network net must be
|
// NewServerNetwork returns a new Redcon server. The network net must be
|
||||||
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
|
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
|
||||||
func NewServerNetwork(
|
func NewServerNetwork(
|
||||||
|
ctx context.Context,
|
||||||
net, laddr string,
|
net, laddr string,
|
||||||
handler func(conn Conn, cmd Command),
|
handler func(conn Conn, cmd Command),
|
||||||
accept func(conn Conn) bool,
|
accept func(conn Conn) bool,
|
||||||
|
@ -144,6 +150,7 @@ func NewServerNetwork(
|
||||||
panic("handler is nil")
|
panic("handler is nil")
|
||||||
}
|
}
|
||||||
s := &Server{
|
s := &Server{
|
||||||
|
ctx: ctx,
|
||||||
net: net,
|
net: net,
|
||||||
laddr: laddr,
|
laddr: laddr,
|
||||||
handler: handler,
|
handler: handler,
|
||||||
|
@ -157,6 +164,7 @@ func NewServerNetwork(
|
||||||
// NewServerNetworkTLS returns a new TLS Redcon server. The network net must be
|
// NewServerNetworkTLS returns a new TLS Redcon server. The network net must be
|
||||||
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
|
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
|
||||||
func NewServerNetworkTLS(
|
func NewServerNetworkTLS(
|
||||||
|
ctx context.Context,
|
||||||
net, laddr string,
|
net, laddr string,
|
||||||
handler func(conn Conn, cmd Command),
|
handler func(conn Conn, cmd Command),
|
||||||
accept func(conn Conn) bool,
|
accept func(conn Conn) bool,
|
||||||
|
@ -167,6 +175,7 @@ func NewServerNetworkTLS(
|
||||||
panic("handler is nil")
|
panic("handler is nil")
|
||||||
}
|
}
|
||||||
s := Server{
|
s := Server{
|
||||||
|
ctx: ctx,
|
||||||
net: net,
|
net: net,
|
||||||
laddr: laddr,
|
laddr: laddr,
|
||||||
handler: handler,
|
handler: handler,
|
||||||
|
@ -241,50 +250,58 @@ func Serve(ln net.Listener,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenAndServe creates a new server and binds to addr configured on "tcp" network net.
|
// ListenAndServe creates a new server and binds to addr configured on "tcp" network net.
|
||||||
func ListenAndServe(addr string,
|
func ListenAndServe(
|
||||||
|
ctx context.Context,
|
||||||
|
addr string,
|
||||||
handler func(conn Conn, cmd Command),
|
handler func(conn Conn, cmd Command),
|
||||||
accept func(conn Conn) bool,
|
accept func(conn Conn) bool,
|
||||||
closed func(conn Conn, err error),
|
closed func(conn Conn, err error),
|
||||||
) error {
|
) error {
|
||||||
return ListenAndServeNetwork("tcp", addr, handler, accept, closed)
|
return ListenAndServeNetwork(ctx, "tcp", addr, handler, accept, closed)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenAndServeTLS creates a new TLS server and binds to addr configured on "tcp" network net.
|
// ListenAndServeTLS creates a new TLS server and binds to addr configured on "tcp" network net.
|
||||||
func ListenAndServeTLS(addr string,
|
func ListenAndServeTLS(
|
||||||
|
ctx context.Context,
|
||||||
|
addr string,
|
||||||
handler func(conn Conn, cmd Command),
|
handler func(conn Conn, cmd Command),
|
||||||
accept func(conn Conn) bool,
|
accept func(conn Conn) bool,
|
||||||
closed func(conn Conn, err error),
|
closed func(conn Conn, err error),
|
||||||
config *tls.Config,
|
config *tls.Config,
|
||||||
) error {
|
) error {
|
||||||
return ListenAndServeNetworkTLS("tcp", addr, handler, accept, closed, config)
|
return ListenAndServeNetworkTLS(ctx, "tcp", addr, handler, accept, closed, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenAndServeNetwork creates a new server and binds to addr. The network net must be
|
// ListenAndServeNetwork creates a new server and binds to addr. The network net must be
|
||||||
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
|
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
|
||||||
func ListenAndServeNetwork(
|
func ListenAndServeNetwork(
|
||||||
|
ctx context.Context,
|
||||||
net, laddr string,
|
net, laddr string,
|
||||||
handler func(conn Conn, cmd Command),
|
handler func(conn Conn, cmd Command),
|
||||||
accept func(conn Conn) bool,
|
accept func(conn Conn) bool,
|
||||||
closed func(conn Conn, err error),
|
closed func(conn Conn, err error),
|
||||||
) error {
|
) error {
|
||||||
return NewServerNetwork(net, laddr, handler, accept, closed).ListenAndServe()
|
return NewServerNetwork(ctx, net, laddr, handler, accept, closed).ListenAndServe()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenAndServeNetworkTLS creates a new TLS server and binds to addr. The network net must be
|
// ListenAndServeNetworkTLS creates a new TLS server and binds to addr. The network net must be
|
||||||
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
|
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
|
||||||
func ListenAndServeNetworkTLS(
|
func ListenAndServeNetworkTLS(
|
||||||
|
ctx context.Context,
|
||||||
net, laddr string,
|
net, laddr string,
|
||||||
handler func(conn Conn, cmd Command),
|
handler func(conn Conn, cmd Command),
|
||||||
accept func(conn Conn) bool,
|
accept func(conn Conn) bool,
|
||||||
closed func(conn Conn, err error),
|
closed func(conn Conn, err error),
|
||||||
config *tls.Config,
|
config *tls.Config,
|
||||||
) error {
|
) error {
|
||||||
return NewServerNetworkTLS(net, laddr, handler, accept, closed, config).ListenAndServe()
|
return NewServerNetworkTLS(ctx, net, laddr, handler, accept, closed, config).ListenAndServe()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenServeAndSignal serves incoming connections and passes nil or error
|
// ListenServeAndSignal serves incoming connections and passes nil or error
|
||||||
// when listening. signal can be nil.
|
// when listening. signal can be nil.
|
||||||
func (s *Server) ListenServeAndSignal(signal chan error) error {
|
func (s *Server) ListenServeAndSignal(signal chan error) error {
|
||||||
|
//var lc net.ListenConfig
|
||||||
|
//ln, err := lc.Listen(s.ctx, s.net, s.laddr)
|
||||||
ln, err := net.Listen(s.net, s.laddr)
|
ln, err := net.Listen(s.net, s.laddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if signal != nil {
|
if signal != nil {
|
||||||
|
@ -336,6 +353,14 @@ func serve(s *Server) error {
|
||||||
s.conns = nil
|
s.conns = nil
|
||||||
}()
|
}()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
s.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
lnconn, err := s.ln.Accept()
|
lnconn, err := s.ln.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -343,6 +368,11 @@ func serve(s *Server) error {
|
||||||
done := s.done
|
done := s.done
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
if done {
|
if done {
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
return errContextDone
|
||||||
|
default:
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if s.AcceptError != nil {
|
if s.AcceptError != nil {
|
||||||
|
@ -547,6 +577,7 @@ type Command struct {
|
||||||
|
|
||||||
// Server defines a server for clients for managing client connections.
|
// Server defines a server for clients for managing client connections.
|
||||||
type Server struct {
|
type Server struct {
|
||||||
|
ctx context.Context
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
net string
|
net string
|
||||||
laddr string
|
laddr string
|
||||||
|
|
|
@ -3,6 +3,7 @@ package redcon
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
@ -220,7 +221,8 @@ func TestServerUnix(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func testServerNetwork(t *testing.T, network, laddr string) {
|
func testServerNetwork(t *testing.T, network, laddr string) {
|
||||||
s := NewServerNetwork(network, laddr,
|
ctx := context.Background()
|
||||||
|
s := NewServerNetwork(ctx, network, laddr,
|
||||||
func(conn Conn, cmd Command) {
|
func(conn Conn, cmd Command) {
|
||||||
switch strings.ToLower(string(cmd.Args[0])) {
|
switch strings.ToLower(string(cmd.Args[0])) {
|
||||||
default:
|
default:
|
||||||
|
@ -261,7 +263,7 @@ func testServerNetwork(t *testing.T, network, laddr string) {
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(time.Second / 4)
|
time.Sleep(time.Second / 4)
|
||||||
if err := ListenAndServeNetwork(network, laddr, func(conn Conn, cmd Command) {}, nil, nil); err == nil {
|
if err := ListenAndServeNetwork(ctx, network, laddr, func(conn Conn, cmd Command) {}, nil, nil); err == nil {
|
||||||
panic("expected an error, should not be able to listen on the same port")
|
panic("expected an error, should not be able to listen on the same port")
|
||||||
}
|
}
|
||||||
time.Sleep(time.Second / 4)
|
time.Sleep(time.Second / 4)
|
||||||
|
@ -560,6 +562,7 @@ func TestParse(t *testing.T) {
|
||||||
func TestPubSub(t *testing.T) {
|
func TestPubSub(t *testing.T) {
|
||||||
addr := ":12346"
|
addr := ":12346"
|
||||||
done := make(chan bool)
|
done := make(chan bool)
|
||||||
|
ctx := context.Background()
|
||||||
go func() {
|
go func() {
|
||||||
var ps PubSub
|
var ps PubSub
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -593,7 +596,7 @@ func TestPubSub(t *testing.T) {
|
||||||
ps.Publish(channel, message)
|
ps.Publish(channel, message)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
panic(ListenAndServe(addr, func(conn Conn, cmd Command) {
|
panic(ListenAndServe(ctx, addr, func(conn Conn, cmd Command) {
|
||||||
switch strings.ToLower(string(cmd.Args[0])) {
|
switch strings.ToLower(string(cmd.Args[0])) {
|
||||||
default:
|
default:
|
||||||
conn.WriteError("ERR unknown command '" +
|
conn.WriteError("ERR unknown command '" +
|
||||||
|
@ -738,3 +741,21 @@ func TestPubSub(t *testing.T) {
|
||||||
// stop the timeout
|
// stop the timeout
|
||||||
final <- true
|
final <- true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestContextDone(t *testing.T) {
|
||||||
|
var err error
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
wg.Add(1)
|
||||||
|
s := NewServerNetwork(ctx, "tcp", ":12345", func(conn Conn, cmd Command) {}, nil, nil)
|
||||||
|
go func() {
|
||||||
|
err = s.ListenAndServe()
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
cancel()
|
||||||
|
wg.Wait()
|
||||||
|
if err != errContextDone {
|
||||||
|
t.Fatalf("expected %v but found %v", errContextDone, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue