diff --git a/rpc/client.go b/rpc/client.go index 70025d1..11f8264 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -10,17 +10,19 @@ import ( type Client struct { sync.Mutex - addr string + network string + addr string maxIdleConns int conns *list.List } -func NewClient(addr string, maxIdleConns int) *Client { +func NewClient(network, addr string, maxIdleConns int) *Client { RegisterType(RpcError{}) c := new(Client) + c.network = network c.addr = addr c.maxIdleConns = maxIdleConns @@ -162,7 +164,7 @@ func (c *Client) popConn() (*conn, error) { return v.Value.(*conn), nil } c.Unlock() - return newConn(c.addr) + return newConn(c.network, c.addr) } func (c *Client) pushConn(co *conn) error { diff --git a/rpc/conn.go b/rpc/conn.go index 888fc1c..eff4a2b 100644 --- a/rpc/conn.go +++ b/rpc/conn.go @@ -11,8 +11,8 @@ type conn struct { co net.Conn } -func newConn(addr string) (*conn, error) { - c, err := net.Dial("tcp", addr) +func newConn(network, addr string) (*conn, error) { + c, err := net.Dial(network, addr) if err != nil { return nil, err } diff --git a/rpc/rpc_test.go b/rpc/rpc_test.go index 4390d70..fa9a138 100644 --- a/rpc/rpc_test.go +++ b/rpc/rpc_test.go @@ -14,7 +14,7 @@ var testClient *Client func newTestServer() *Server { f := func() { - testServer = NewServer("127.0.0.1:11182") + testServer = NewServer("tcp", "127.0.0.1:11182") go testServer.Start() } @@ -25,7 +25,7 @@ func newTestServer() *Server { func newTestClient() *Client { f := func() { - testClient = NewClient("127.0.0.1:11182", 10) + testClient = NewClient("tcp", "127.0.0.1:11182", 10) } testClientOnce.Do(f) diff --git a/rpc/server.go b/rpc/server.go index de070f1..98cd561 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -10,17 +10,19 @@ import ( type Server struct { sync.Mutex - addr string - funcs map[string]reflect.Value + network string + addr string + funcs map[string]reflect.Value listener net.Listener running bool } -func NewServer(addr string) *Server { +func NewServer(network, addr string) *Server { RegisterType(RpcError{}) s := new(Server) + s.network = network s.addr = addr s.funcs = make(map[string]reflect.Value) @@ -30,7 +32,7 @@ func NewServer(addr string) *Server { func (s *Server) Start() error { var err error - s.listener, err = net.Listen("tcp", s.addr) + s.listener, err = net.Listen(s.network, s.addr) if err != nil { return err }