diff --git a/redcon.go b/redcon.go index 4830730..0d71f31 100644 --- a/redcon.go +++ b/redcon.go @@ -90,17 +90,29 @@ type Conn interface { PeekPipeline() []Command } -// NewServer returns a new Redcon server. +// NewServer returns a new Redcon server configured on "tcp" network net. func NewServer(addr string, handler func(conn Conn, cmd Command), accept func(conn Conn) bool, closed func(conn Conn, err error), +) *Server { + return NewServerNetwork("tcp", addr, handler, accept, closed) +} + +// NewServerNetworkType returns a new Redcon server. The network net must be +// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket" +func NewServerNetwork( + net, laddr string, + handler func(conn Conn, cmd Command), + accept func(conn Conn) bool, + closed func(conn Conn, err error), ) *Server { if handler == nil { panic("handler is nil") } s := &Server{ - addr: addr, + net: net, + laddr: laddr, handler: handler, accept: accept, closed: closed, @@ -126,20 +138,30 @@ func (s *Server) ListenAndServe() error { return s.ListenServeAndSignal(nil) } -// ListenAndServe creates a new server and binds to addr. +// ListenAndServe creates a new server and binds to addr configured on "tcp" network net. func ListenAndServe(addr string, handler func(conn Conn, cmd Command), accept func(conn Conn) bool, closed func(conn Conn, err error), ) error { - return NewServer(addr, handler, accept, closed).ListenAndServe() + return ListenAndServeNetwork("tcp", addr, handler, accept, closed) +} + +// ListenAndServe creates a new server and binds to addr. The network net must be +// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket" +func ListenAndServeNetwork( + net, laddr string, + handler func(conn Conn, cmd Command), + accept func(conn Conn) bool, + closed func(conn Conn, err error), +) error { + return NewServerNetwork(net, laddr, handler, accept, closed).ListenAndServe() } // ListenServeAndSignal serves incoming connections and passes nil or error // when listening. signal can be nil. func (s *Server) ListenServeAndSignal(signal chan error) error { - var addr = s.addr - ln, err := net.Listen("tcp", addr) + ln, err := net.Listen(s.net, s.laddr) if err != nil { if signal != nil { signal <- err @@ -149,9 +171,8 @@ func (s *Server) ListenServeAndSignal(signal chan error) error { if signal != nil { signal <- nil } - tln := ln.(*net.TCPListener) s.mu.Lock() - s.ln = tln + s.ln = ln s.mu.Unlock() defer func() { ln.Close() @@ -165,7 +186,7 @@ func (s *Server) ListenServeAndSignal(signal chan error) error { }() }() for { - tcpc, err := tln.AcceptTCP() + lnconn, err := ln.Accept() if err != nil { s.mu.Lock() done := s.done @@ -175,8 +196,8 @@ func (s *Server) ListenServeAndSignal(signal chan error) error { } return err } - c := &conn{conn: tcpc, addr: tcpc.RemoteAddr().String(), - wr: NewWriter(tcpc), rd: NewReader(tcpc)} + c := &conn{conn: lnconn, addr: lnconn.RemoteAddr().String(), + wr: NewWriter(lnconn), rd: NewReader(lnconn)} s.mu.Lock() s.conns[c] = true s.mu.Unlock() @@ -253,7 +274,7 @@ func handle(s *Server, c *conn) { // conn represents a client connection type conn struct { - conn *net.TCPConn + conn net.Conn wr *Writer rd *Reader addr string @@ -361,12 +382,13 @@ type Command struct { // Server defines a server for clients for managing client connections. type Server struct { mu sync.Mutex - addr string + net string + laddr string handler func(conn Conn, cmd Command) accept func(conn Conn) bool closed func(conn Conn, err error) conns map[*conn]bool - ln *net.TCPListener + ln net.Listener done bool } diff --git a/redcon_test.go b/redcon_test.go index 97029dd..6c4e250 100644 --- a/redcon_test.go +++ b/redcon_test.go @@ -7,6 +7,7 @@ import ( "log" "math/rand" "net" + "os" "strconv" "strings" "testing" @@ -207,9 +208,16 @@ func testDetached(t *testing.T, conn DetachedConn) { t.Fatal(err) } } +func TestServerTCP(t *testing.T) { + testServerNetwork(t, "tcp", ":12345") +} +func TestServerUnix(t *testing.T) { + defer os.RemoveAll("unix.net") + testServerNetwork(t, "unix", "unix.net") +} -func TestServer(t *testing.T) { - s := NewServer(":12345", +func testServerNetwork(t *testing.T, network, laddr string) { + s := NewServerNetwork(network, laddr, func(conn Conn, cmd Command) { switch strings.ToLower(string(cmd.Args[0])) { default: @@ -249,8 +257,12 @@ func TestServer(t *testing.T) { t.Fatalf("expected an error, should not be able to close before serving") } go func() { + + if network == "unix" { + os.RemoveAll(laddr) + } time.Sleep(time.Second / 4) - if err := ListenAndServe(":12345", func(conn Conn, cmd Command) {}, nil, nil); err == nil { + if err := ListenAndServeNetwork(network, laddr, func(conn Conn, cmd Command) {}, nil, nil); err == nil { t.Fatalf("expected an error, should not be able to listen on the same port") } time.Sleep(time.Second / 4) @@ -274,7 +286,7 @@ func TestServer(t *testing.T) { if err != nil { t.Fatal(err) } - c, err := net.Dial("tcp", ":12345") + c, err := net.Dial(network, laddr) if err != nil { t.Fatal(err) }