From eeed48eee238d0f0ad254dce4ced142425c3353e Mon Sep 17 00:00:00 2001 From: siddontang Date: Thu, 20 Feb 2014 13:55:45 +0800 Subject: [PATCH] a simple rpc frame, not full implemented --- rpc/client.go | 181 ++++++++++++++++++++++++++++++++++++++++++++++++ rpc/codec.go | 63 +++++++++++++++++ rpc/conn.go | 78 +++++++++++++++++++++ rpc/rpc_test.go | 63 +++++++++++++++++ rpc/server.go | 173 +++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 558 insertions(+) create mode 100644 rpc/client.go create mode 100644 rpc/codec.go create mode 100644 rpc/conn.go create mode 100644 rpc/rpc_test.go create mode 100644 rpc/server.go diff --git a/rpc/client.go b/rpc/client.go new file mode 100644 index 0000000..0a4b544 --- /dev/null +++ b/rpc/client.go @@ -0,0 +1,181 @@ +package rpc + +import ( + "container/list" + "fmt" + "reflect" + "sync" +) + +type Client struct { + sync.Mutex + + addr string + + maxIdleConns int + + conns *list.List +} + +func NewClient(addr string, maxIdleConns int) *Client { + RegisterType(RpcError{}) + + c := new(Client) + c.addr = addr + + c.maxIdleConns = maxIdleConns + + c.conns = list.New() + + return c +} + +func (c *Client) Close() error { + c.Lock() + + for { + if c.conns.Len() > 0 { + v := c.conns.Front() + + co := v.Value.(*conn) + co.Close() + c.conns.Remove(v) + } else { + break + } + } + + c.Unlock() + return nil +} + +func (c *Client) MakeRpc(rpcName string, fptr interface{}) (err error) { + defer func() { + if e := recover(); e != nil { + err = fmt.Errorf("make rpc error") + } + }() + + fn := reflect.ValueOf(fptr).Elem() + + nOut := fn.Type().NumOut() + if nOut == 0 || fn.Type().Out(nOut-1).Kind() != reflect.Interface { + err = fmt.Errorf("%s return final output param must be error interface", rpcName) + return + } + + _, b := fn.Type().Out(nOut - 1).MethodByName("Error") + if !b { + err = fmt.Errorf("%s return final output param must be error interface", rpcName) + return + } + + f := func(in []reflect.Value) []reflect.Value { + return c.call(fn, rpcName, in) + } + + v := reflect.MakeFunc(fn.Type(), f) + fn.Set(v) + + return +} + +func (c *Client) call(fn reflect.Value, name string, in []reflect.Value) []reflect.Value { + inArgs := make([]interface{}, len(in)) + for i := 0; i < len(in); i++ { + inArgs[i] = in[i].Interface() + } + + data, err := encodeData(name, inArgs) + if err != nil { + return c.returnCallError(fn, err) + } + + var co *conn + var buf []byte + for i := 0; i < 3; i++ { + if co, err = c.popConn(); err != nil { + continue + } + + buf, err = co.Call(data) + if err == nil { + c.pushConn(co) + break + } else { + co.Close() + } + } + + if err != nil { + return c.returnCallError(fn, err) + } + + n, out, e := decodeData(buf) + if e != nil { + return c.returnCallError(fn, e) + } + + if n != name { + return c.returnCallError(fn, fmt.Errorf("rpc name %s != %s", n, name)) + } + + last := out[len(out)-1] + if last != nil { + if err, ok := last.(error); ok { + return c.returnCallError(fn, err) + } else { + return c.returnCallError(fn, fmt.Errorf("rpc final return type %T must be error", last)) + } + } + + outValues := make([]reflect.Value, len(out)) + for i := 0; i < len(out); i++ { + if out[i] == nil { + outValues[i] = reflect.Zero(fn.Type().Out(i)) + } else { + outValues[i] = reflect.ValueOf(out[i]) + } + } + + return outValues +} + +func (c *Client) returnCallError(fn reflect.Value, err error) []reflect.Value { + println("return call error", err.Error()) + + nOut := fn.Type().NumOut() + out := make([]reflect.Value, nOut) + for i := 0; i < nOut-1; i++ { + out[i] = reflect.Zero(fn.Type().Out(i)) + } + + out[nOut-1] = reflect.ValueOf(&err).Elem() + return out +} + +func (c *Client) popConn() (*conn, error) { + c.Lock() + if c.conns.Len() > 0 { + v := c.conns.Front() + c.conns.Remove(v) + c.Unlock() + + return v.Value.(*conn), nil + } + c.Unlock() + return newConn(c.addr) +} + +func (c *Client) pushConn(co *conn) error { + c.Lock() + if c.conns.Len() >= c.maxIdleConns { + c.Unlock() + co.Close() + return nil + } else { + c.conns.PushBack(co) + } + c.Unlock() + return nil +} diff --git a/rpc/codec.go b/rpc/codec.go new file mode 100644 index 0000000..68caf26 --- /dev/null +++ b/rpc/codec.go @@ -0,0 +1,63 @@ +package rpc + +import ( + "bytes" + "encoding/gob" + "fmt" +) + +type RpcError struct { + Message string +} + +func (r RpcError) Error() string { + return r.Message +} + +func RegisterType(value interface{}) (err error) { + defer func() { + if e := recover(); e != nil { + err = fmt.Errorf("register error") + } + }() + gob.Register(value) + return +} + +type rpcData struct { + Name string + Args []interface{} +} + +func encodeData(name string, args []interface{}) ([]byte, error) { + d := rpcData{} + d.Name = name + + d.Args = args + + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + + if err := enc.Encode(d); err != nil { + return nil, err + } else { + return buf.Bytes(), nil + } +} + +func decodeData(data []byte) (name string, args []interface{}, err error) { + var d rpcData + + var buf = bytes.NewBuffer(data) + + dec := gob.NewDecoder(buf) + + if err = dec.Decode(&d); err != nil { + return + } + + name = d.Name + args = d.Args + + return +} diff --git a/rpc/conn.go b/rpc/conn.go new file mode 100644 index 0000000..888fc1c --- /dev/null +++ b/rpc/conn.go @@ -0,0 +1,78 @@ +package rpc + +import ( + "encoding/binary" + "fmt" + "io" + "net" +) + +type conn struct { + co net.Conn +} + +func newConn(addr string) (*conn, error) { + c, err := net.Dial("tcp", addr) + if err != nil { + return nil, err + } + + co := new(conn) + co.co = c + return co, nil +} + +func (c *conn) Close() error { + return c.co.Close() +} + +func (c *conn) Call(data []byte) ([]byte, error) { + if err := c.WriteMessage(data); err != nil { + return nil, err + } + + if buf, err := c.ReadMessage(); err != nil { + return nil, err + } else { + return buf, nil + } +} + +func (c *conn) WriteMessage(data []byte) error { + buf := make([]byte, 4+len(data)) + + binary.LittleEndian.PutUint32(buf[0:4], uint32(len(data))) + + copy(buf[4:], data) + + n, err := c.co.Write(buf) + if err != nil { + c.Close() + return err + } else if n != len(buf) { + c.Close() + return fmt.Errorf("write %d less than %d", n, len(buf)) + } + return nil +} + +func (c *conn) ReadMessage() ([]byte, error) { + l := make([]byte, 4) + + _, err := io.ReadFull(c.co, l) + if err != nil { + c.Close() + return nil, err + } + + length := binary.LittleEndian.Uint32(l) + + data := make([]byte, length) + _, err = io.ReadFull(c.co, data) + if err != nil { + c.Close() + return nil, err + } else { + return data, nil + } +} diff --git a/rpc/rpc_test.go b/rpc/rpc_test.go new file mode 100644 index 0000000..6a4f024 --- /dev/null +++ b/rpc/rpc_test.go @@ -0,0 +1,63 @@ +package rpc + +import ( + "errors" + "sync" + "testing" +) + +var testServerOnce sync.Once +var testClientOnce sync.Once + +var testServer *Server +var testClient *Client + +func newTestServer() *Server { + f := func() { + testServer = NewServer("127.0.0.1:11182") + go testServer.Start() + } + + testServerOnce.Do(f) + + return testServer +} + +func newTestClient() *Client { + f := func() { + testClient = NewClient("127.0.0.1:11182", 10) + } + + testClientOnce.Do(f) + + return testClient +} + +func OnlineRpc(id int) (int, string, error) { + return id * 10, "abc", errors.New("hello world") +} + +func TestRpc(t *testing.T) { + defer func() { + e := recover() + if s, ok := e.(string); ok { + println(s) + } + + if err, ok := e.(error); ok { + println(err.Error()) + } + }() + s := newTestServer() + + s.Register("online_rpc", OnlineRpc) + + c := newTestClient() + + var r func(int) (int, string, error) + if err := c.MakeRpc("online_rpc", &r); err != nil { + t.Fatal(err) + } + + r(10) +} diff --git a/rpc/server.go b/rpc/server.go new file mode 100644 index 0000000..7893760 --- /dev/null +++ b/rpc/server.go @@ -0,0 +1,173 @@ +package rpc + +import ( + "fmt" + "net" + "reflect" + "sync" +) + +type Server struct { + sync.Mutex + + addr string + funcs map[string]reflect.Value + + listener net.Listener + running bool +} + +func NewServer(addr string) *Server { + RegisterType(RpcError{}) + + s := new(Server) + s.addr = addr + + s.funcs = make(map[string]reflect.Value) + + return s +} + +func (s *Server) Start() error { + var err error + s.listener, err = net.Listen("tcp", s.addr) + if err != nil { + return err + } + + s.running = true + + for s.running { + conn, err := s.listener.Accept() + if err != nil { + continue + } + + go s.onConn(conn) + } + + return nil +} + +func (s *Server) Stop() error { + s.running = false + + if s.listener != nil { + s.listener.Close() + } + + return nil +} + +func (s *Server) Register(name string, f interface{}) (err error) { + defer func() { + if e := recover(); e != nil { + err = fmt.Errorf("%s is not callable", name) + } + }() + + v := reflect.ValueOf(f) + + //to check f is function + v.Type().NumIn() + + nOut := v.Type().NumOut() + if nOut == 0 || v.Type().Out(nOut-1).Kind() != reflect.Interface { + err = fmt.Errorf("%s return final output param must be error interface", name) + return + } + + _, b := v.Type().Out(nOut - 1).MethodByName("Error") + if !b { + err = fmt.Errorf("%s return final output param must be error interface", name) + return + } + + s.Lock() + if _, ok := s.funcs[name]; ok { + err = fmt.Errorf("%s has registered", name) + s.Unlock() + return + } + + s.funcs[name] = v + s.Unlock() + return +} + +func (s *Server) onConn(co net.Conn) { + println("onconn") + c := new(conn) + c.co = co + + defer func() { + if e := recover(); e != nil { + //later log + if err, ok := e.(error); ok { + println("recover", err.Error()) + } + } + c.Close() + }() + + for { + data, err := c.ReadMessage() + if err != nil { + println("read error ", err.Error()) + return + } + + data, err = s.handle(data) + if err != nil { + println("handle error ", err.Error()) + return + } + err = c.WriteMessage(data) + if err != nil { + println("write error ", err.Error()) + return + } + } +} + +func (s *Server) handle(data []byte) ([]byte, error) { + name, args, err := decodeData(data) + if err != nil { + return nil, err + } + + s.Lock() + f, ok := s.funcs[name] + s.Unlock() + if !ok { + return nil, fmt.Errorf("rpc %s not registered", name) + } + + inValues := make([]reflect.Value, len(args)) + + for i := 0; i < len(args); i++ { + if args[i] == nil { + inValues[i] = reflect.Zero(f.Type().In(i)) + } else { + inValues[i] = reflect.ValueOf(args[i]) + } + } + + out := f.Call(inValues) + + outArgs := make([]interface{}, len(out)) + for i := 0; i < len(outArgs); i++ { + outArgs[i] = out[i].Interface() + } + + p := out[len(out)-1].Interface() + if p != nil { + if e, ok := p.(error); ok { + outArgs[len(out)-1] = RpcError{e.Error()} + } else { + return nil, fmt.Errorf("final param must be error") + } + } + + return encodeData(name, outArgs) +}