a simple rpc frame, not full implemented

This commit is contained in:
siddontang 2014-02-20 13:55:45 +08:00
parent 097357a9f9
commit eeed48eee2
5 changed files with 558 additions and 0 deletions

181
rpc/client.go Normal file
View File

@ -0,0 +1,181 @@
package rpc
import (
"container/list"
"fmt"
"reflect"
"sync"
)
type Client struct {
sync.Mutex
addr string
maxIdleConns int
conns *list.List
}
func NewClient(addr string, maxIdleConns int) *Client {
RegisterType(RpcError{})
c := new(Client)
c.addr = addr
c.maxIdleConns = maxIdleConns
c.conns = list.New()
return c
}
func (c *Client) Close() error {
c.Lock()
for {
if c.conns.Len() > 0 {
v := c.conns.Front()
co := v.Value.(*conn)
co.Close()
c.conns.Remove(v)
} else {
break
}
}
c.Unlock()
return nil
}
func (c *Client) MakeRpc(rpcName string, fptr interface{}) (err error) {
defer func() {
if e := recover(); e != nil {
err = fmt.Errorf("make rpc error")
}
}()
fn := reflect.ValueOf(fptr).Elem()
nOut := fn.Type().NumOut()
if nOut == 0 || fn.Type().Out(nOut-1).Kind() != reflect.Interface {
err = fmt.Errorf("%s return final output param must be error interface", rpcName)
return
}
_, b := fn.Type().Out(nOut - 1).MethodByName("Error")
if !b {
err = fmt.Errorf("%s return final output param must be error interface", rpcName)
return
}
f := func(in []reflect.Value) []reflect.Value {
return c.call(fn, rpcName, in)
}
v := reflect.MakeFunc(fn.Type(), f)
fn.Set(v)
return
}
func (c *Client) call(fn reflect.Value, name string, in []reflect.Value) []reflect.Value {
inArgs := make([]interface{}, len(in))
for i := 0; i < len(in); i++ {
inArgs[i] = in[i].Interface()
}
data, err := encodeData(name, inArgs)
if err != nil {
return c.returnCallError(fn, err)
}
var co *conn
var buf []byte
for i := 0; i < 3; i++ {
if co, err = c.popConn(); err != nil {
continue
}
buf, err = co.Call(data)
if err == nil {
c.pushConn(co)
break
} else {
co.Close()
}
}
if err != nil {
return c.returnCallError(fn, err)
}
n, out, e := decodeData(buf)
if e != nil {
return c.returnCallError(fn, e)
}
if n != name {
return c.returnCallError(fn, fmt.Errorf("rpc name %s != %s", n, name))
}
last := out[len(out)-1]
if last != nil {
if err, ok := last.(error); ok {
return c.returnCallError(fn, err)
} else {
return c.returnCallError(fn, fmt.Errorf("rpc final return type %T must be error", last))
}
}
outValues := make([]reflect.Value, len(out))
for i := 0; i < len(out); i++ {
if out[i] == nil {
outValues[i] = reflect.Zero(fn.Type().Out(i))
} else {
outValues[i] = reflect.ValueOf(out[i])
}
}
return outValues
}
func (c *Client) returnCallError(fn reflect.Value, err error) []reflect.Value {
println("return call error", err.Error())
nOut := fn.Type().NumOut()
out := make([]reflect.Value, nOut)
for i := 0; i < nOut-1; i++ {
out[i] = reflect.Zero(fn.Type().Out(i))
}
out[nOut-1] = reflect.ValueOf(&err).Elem()
return out
}
func (c *Client) popConn() (*conn, error) {
c.Lock()
if c.conns.Len() > 0 {
v := c.conns.Front()
c.conns.Remove(v)
c.Unlock()
return v.Value.(*conn), nil
}
c.Unlock()
return newConn(c.addr)
}
func (c *Client) pushConn(co *conn) error {
c.Lock()
if c.conns.Len() >= c.maxIdleConns {
c.Unlock()
co.Close()
return nil
} else {
c.conns.PushBack(co)
}
c.Unlock()
return nil
}

63
rpc/codec.go Normal file
View File

@ -0,0 +1,63 @@
package rpc
import (
"bytes"
"encoding/gob"
"fmt"
)
type RpcError struct {
Message string
}
func (r RpcError) Error() string {
return r.Message
}
func RegisterType(value interface{}) (err error) {
defer func() {
if e := recover(); e != nil {
err = fmt.Errorf("register error")
}
}()
gob.Register(value)
return
}
type rpcData struct {
Name string
Args []interface{}
}
func encodeData(name string, args []interface{}) ([]byte, error) {
d := rpcData{}
d.Name = name
d.Args = args
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
if err := enc.Encode(d); err != nil {
return nil, err
} else {
return buf.Bytes(), nil
}
}
func decodeData(data []byte) (name string, args []interface{}, err error) {
var d rpcData
var buf = bytes.NewBuffer(data)
dec := gob.NewDecoder(buf)
if err = dec.Decode(&d); err != nil {
return
}
name = d.Name
args = d.Args
return
}

78
rpc/conn.go Normal file
View File

@ -0,0 +1,78 @@
package rpc
import (
"encoding/binary"
"fmt"
"io"
"net"
)
type conn struct {
co net.Conn
}
func newConn(addr string) (*conn, error) {
c, err := net.Dial("tcp", addr)
if err != nil {
return nil, err
}
co := new(conn)
co.co = c
return co, nil
}
func (c *conn) Close() error {
return c.co.Close()
}
func (c *conn) Call(data []byte) ([]byte, error) {
if err := c.WriteMessage(data); err != nil {
return nil, err
}
if buf, err := c.ReadMessage(); err != nil {
return nil, err
} else {
return buf, nil
}
}
func (c *conn) WriteMessage(data []byte) error {
buf := make([]byte, 4+len(data))
binary.LittleEndian.PutUint32(buf[0:4], uint32(len(data)))
copy(buf[4:], data)
n, err := c.co.Write(buf)
if err != nil {
c.Close()
return err
} else if n != len(buf) {
c.Close()
return fmt.Errorf("write %d less than %d", n, len(buf))
}
return nil
}
func (c *conn) ReadMessage() ([]byte, error) {
l := make([]byte, 4)
_, err := io.ReadFull(c.co, l)
if err != nil {
c.Close()
return nil, err
}
length := binary.LittleEndian.Uint32(l)
data := make([]byte, length)
_, err = io.ReadFull(c.co, data)
if err != nil {
c.Close()
return nil, err
} else {
return data, nil
}
}

63
rpc/rpc_test.go Normal file
View File

@ -0,0 +1,63 @@
package rpc
import (
"errors"
"sync"
"testing"
)
var testServerOnce sync.Once
var testClientOnce sync.Once
var testServer *Server
var testClient *Client
func newTestServer() *Server {
f := func() {
testServer = NewServer("127.0.0.1:11182")
go testServer.Start()
}
testServerOnce.Do(f)
return testServer
}
func newTestClient() *Client {
f := func() {
testClient = NewClient("127.0.0.1:11182", 10)
}
testClientOnce.Do(f)
return testClient
}
func OnlineRpc(id int) (int, string, error) {
return id * 10, "abc", errors.New("hello world")
}
func TestRpc(t *testing.T) {
defer func() {
e := recover()
if s, ok := e.(string); ok {
println(s)
}
if err, ok := e.(error); ok {
println(err.Error())
}
}()
s := newTestServer()
s.Register("online_rpc", OnlineRpc)
c := newTestClient()
var r func(int) (int, string, error)
if err := c.MakeRpc("online_rpc", &r); err != nil {
t.Fatal(err)
}
r(10)
}

173
rpc/server.go Normal file
View File

@ -0,0 +1,173 @@
package rpc
import (
"fmt"
"net"
"reflect"
"sync"
)
type Server struct {
sync.Mutex
addr string
funcs map[string]reflect.Value
listener net.Listener
running bool
}
func NewServer(addr string) *Server {
RegisterType(RpcError{})
s := new(Server)
s.addr = addr
s.funcs = make(map[string]reflect.Value)
return s
}
func (s *Server) Start() error {
var err error
s.listener, err = net.Listen("tcp", s.addr)
if err != nil {
return err
}
s.running = true
for s.running {
conn, err := s.listener.Accept()
if err != nil {
continue
}
go s.onConn(conn)
}
return nil
}
func (s *Server) Stop() error {
s.running = false
if s.listener != nil {
s.listener.Close()
}
return nil
}
func (s *Server) Register(name string, f interface{}) (err error) {
defer func() {
if e := recover(); e != nil {
err = fmt.Errorf("%s is not callable", name)
}
}()
v := reflect.ValueOf(f)
//to check f is function
v.Type().NumIn()
nOut := v.Type().NumOut()
if nOut == 0 || v.Type().Out(nOut-1).Kind() != reflect.Interface {
err = fmt.Errorf("%s return final output param must be error interface", name)
return
}
_, b := v.Type().Out(nOut - 1).MethodByName("Error")
if !b {
err = fmt.Errorf("%s return final output param must be error interface", name)
return
}
s.Lock()
if _, ok := s.funcs[name]; ok {
err = fmt.Errorf("%s has registered", name)
s.Unlock()
return
}
s.funcs[name] = v
s.Unlock()
return
}
func (s *Server) onConn(co net.Conn) {
println("onconn")
c := new(conn)
c.co = co
defer func() {
if e := recover(); e != nil {
//later log
if err, ok := e.(error); ok {
println("recover", err.Error())
}
}
c.Close()
}()
for {
data, err := c.ReadMessage()
if err != nil {
println("read error ", err.Error())
return
}
data, err = s.handle(data)
if err != nil {
println("handle error ", err.Error())
return
}
err = c.WriteMessage(data)
if err != nil {
println("write error ", err.Error())
return
}
}
}
func (s *Server) handle(data []byte) ([]byte, error) {
name, args, err := decodeData(data)
if err != nil {
return nil, err
}
s.Lock()
f, ok := s.funcs[name]
s.Unlock()
if !ok {
return nil, fmt.Errorf("rpc %s not registered", name)
}
inValues := make([]reflect.Value, len(args))
for i := 0; i < len(args); i++ {
if args[i] == nil {
inValues[i] = reflect.Zero(f.Type().In(i))
} else {
inValues[i] = reflect.ValueOf(args[i])
}
}
out := f.Call(inValues)
outArgs := make([]interface{}, len(out))
for i := 0; i < len(outArgs); i++ {
outArgs[i] = out[i].Interface()
}
p := out[len(out)-1].Interface()
if p != nil {
if e, ok := p.(error); ok {
outArgs[len(out)-1] = RpcError{e.Error()}
} else {
return nil, fmt.Errorf("final param must be error")
}
}
return encodeData(name, outArgs)
}