support for unix domain sockets, fixes #6

This commit is contained in:
Josh Baker 2016-10-18 15:17:35 -07:00
parent bd8eb49594
commit 89a9dbebb2
2 changed files with 52 additions and 18 deletions

View File

@ -90,17 +90,29 @@ type Conn interface {
PeekPipeline() []Command PeekPipeline() []Command
} }
// NewServer returns a new Redcon server. // NewServer returns a new Redcon server configured on "tcp" network net.
func NewServer(addr string, func NewServer(addr string,
handler func(conn Conn, cmd Command), handler func(conn Conn, cmd Command),
accept func(conn Conn) bool, accept func(conn Conn) bool,
closed func(conn Conn, err error), closed func(conn Conn, err error),
) *Server {
return NewServerNetwork("tcp", addr, handler, accept, closed)
}
// NewServerNetworkType returns a new Redcon server. The network net must be
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
func NewServerNetwork(
net, laddr string,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
closed func(conn Conn, err error),
) *Server { ) *Server {
if handler == nil { if handler == nil {
panic("handler is nil") panic("handler is nil")
} }
s := &Server{ s := &Server{
addr: addr, net: net,
laddr: laddr,
handler: handler, handler: handler,
accept: accept, accept: accept,
closed: closed, closed: closed,
@ -126,20 +138,30 @@ func (s *Server) ListenAndServe() error {
return s.ListenServeAndSignal(nil) return s.ListenServeAndSignal(nil)
} }
// ListenAndServe creates a new server and binds to addr. // ListenAndServe creates a new server and binds to addr configured on "tcp" network net.
func ListenAndServe(addr string, func ListenAndServe(addr string,
handler func(conn Conn, cmd Command), handler func(conn Conn, cmd Command),
accept func(conn Conn) bool, accept func(conn Conn) bool,
closed func(conn Conn, err error), closed func(conn Conn, err error),
) error { ) error {
return NewServer(addr, handler, accept, closed).ListenAndServe() return ListenAndServeNetwork("tcp", addr, handler, accept, closed)
}
// ListenAndServe creates a new server and binds to addr. The network net must be
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
func ListenAndServeNetwork(
net, laddr string,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
closed func(conn Conn, err error),
) error {
return NewServerNetwork(net, laddr, handler, accept, closed).ListenAndServe()
} }
// ListenServeAndSignal serves incoming connections and passes nil or error // ListenServeAndSignal serves incoming connections and passes nil or error
// when listening. signal can be nil. // when listening. signal can be nil.
func (s *Server) ListenServeAndSignal(signal chan error) error { func (s *Server) ListenServeAndSignal(signal chan error) error {
var addr = s.addr ln, err := net.Listen(s.net, s.laddr)
ln, err := net.Listen("tcp", addr)
if err != nil { if err != nil {
if signal != nil { if signal != nil {
signal <- err signal <- err
@ -149,9 +171,8 @@ func (s *Server) ListenServeAndSignal(signal chan error) error {
if signal != nil { if signal != nil {
signal <- nil signal <- nil
} }
tln := ln.(*net.TCPListener)
s.mu.Lock() s.mu.Lock()
s.ln = tln s.ln = ln
s.mu.Unlock() s.mu.Unlock()
defer func() { defer func() {
ln.Close() ln.Close()
@ -165,7 +186,7 @@ func (s *Server) ListenServeAndSignal(signal chan error) error {
}() }()
}() }()
for { for {
tcpc, err := tln.AcceptTCP() lnconn, err := ln.Accept()
if err != nil { if err != nil {
s.mu.Lock() s.mu.Lock()
done := s.done done := s.done
@ -175,8 +196,8 @@ func (s *Server) ListenServeAndSignal(signal chan error) error {
} }
return err return err
} }
c := &conn{conn: tcpc, addr: tcpc.RemoteAddr().String(), c := &conn{conn: lnconn, addr: lnconn.RemoteAddr().String(),
wr: NewWriter(tcpc), rd: NewReader(tcpc)} wr: NewWriter(lnconn), rd: NewReader(lnconn)}
s.mu.Lock() s.mu.Lock()
s.conns[c] = true s.conns[c] = true
s.mu.Unlock() s.mu.Unlock()
@ -253,7 +274,7 @@ func handle(s *Server, c *conn) {
// conn represents a client connection // conn represents a client connection
type conn struct { type conn struct {
conn *net.TCPConn conn net.Conn
wr *Writer wr *Writer
rd *Reader rd *Reader
addr string addr string
@ -361,12 +382,13 @@ type Command struct {
// Server defines a server for clients for managing client connections. // Server defines a server for clients for managing client connections.
type Server struct { type Server struct {
mu sync.Mutex mu sync.Mutex
addr string net string
laddr string
handler func(conn Conn, cmd Command) handler func(conn Conn, cmd Command)
accept func(conn Conn) bool accept func(conn Conn) bool
closed func(conn Conn, err error) closed func(conn Conn, err error)
conns map[*conn]bool conns map[*conn]bool
ln *net.TCPListener ln net.Listener
done bool done bool
} }

View File

@ -7,6 +7,7 @@ import (
"log" "log"
"math/rand" "math/rand"
"net" "net"
"os"
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
@ -207,9 +208,16 @@ func testDetached(t *testing.T, conn DetachedConn) {
t.Fatal(err) t.Fatal(err)
} }
} }
func TestServerTCP(t *testing.T) {
testServerNetwork(t, "tcp", ":12345")
}
func TestServerUnix(t *testing.T) {
defer os.RemoveAll("unix.net")
testServerNetwork(t, "unix", "unix.net")
}
func TestServer(t *testing.T) { func testServerNetwork(t *testing.T, network, laddr string) {
s := NewServer(":12345", s := NewServerNetwork(network, laddr,
func(conn Conn, cmd Command) { func(conn Conn, cmd Command) {
switch strings.ToLower(string(cmd.Args[0])) { switch strings.ToLower(string(cmd.Args[0])) {
default: default:
@ -249,8 +257,12 @@ func TestServer(t *testing.T) {
t.Fatalf("expected an error, should not be able to close before serving") t.Fatalf("expected an error, should not be able to close before serving")
} }
go func() { go func() {
if network == "unix" {
os.RemoveAll(laddr)
}
time.Sleep(time.Second / 4) time.Sleep(time.Second / 4)
if err := ListenAndServe(":12345", func(conn Conn, cmd Command) {}, nil, nil); err == nil { if err := ListenAndServeNetwork(network, laddr, func(conn Conn, cmd Command) {}, nil, nil); err == nil {
t.Fatalf("expected an error, should not be able to listen on the same port") t.Fatalf("expected an error, should not be able to listen on the same port")
} }
time.Sleep(time.Second / 4) time.Sleep(time.Second / 4)
@ -274,7 +286,7 @@ func TestServer(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
c, err := net.Dial("tcp", ":12345") c, err := net.Dial(network, laddr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }