a simple rpc frame, not full implemented
This commit is contained in:
parent
097357a9f9
commit
eeed48eee2
|
@ -0,0 +1,181 @@
|
|||
package rpc
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
sync.Mutex
|
||||
|
||||
addr string
|
||||
|
||||
maxIdleConns int
|
||||
|
||||
conns *list.List
|
||||
}
|
||||
|
||||
func NewClient(addr string, maxIdleConns int) *Client {
|
||||
RegisterType(RpcError{})
|
||||
|
||||
c := new(Client)
|
||||
c.addr = addr
|
||||
|
||||
c.maxIdleConns = maxIdleConns
|
||||
|
||||
c.conns = list.New()
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
c.Lock()
|
||||
|
||||
for {
|
||||
if c.conns.Len() > 0 {
|
||||
v := c.conns.Front()
|
||||
|
||||
co := v.Value.(*conn)
|
||||
co.Close()
|
||||
c.conns.Remove(v)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
c.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) MakeRpc(rpcName string, fptr interface{}) (err error) {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
err = fmt.Errorf("make rpc error")
|
||||
}
|
||||
}()
|
||||
|
||||
fn := reflect.ValueOf(fptr).Elem()
|
||||
|
||||
nOut := fn.Type().NumOut()
|
||||
if nOut == 0 || fn.Type().Out(nOut-1).Kind() != reflect.Interface {
|
||||
err = fmt.Errorf("%s return final output param must be error interface", rpcName)
|
||||
return
|
||||
}
|
||||
|
||||
_, b := fn.Type().Out(nOut - 1).MethodByName("Error")
|
||||
if !b {
|
||||
err = fmt.Errorf("%s return final output param must be error interface", rpcName)
|
||||
return
|
||||
}
|
||||
|
||||
f := func(in []reflect.Value) []reflect.Value {
|
||||
return c.call(fn, rpcName, in)
|
||||
}
|
||||
|
||||
v := reflect.MakeFunc(fn.Type(), f)
|
||||
fn.Set(v)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Client) call(fn reflect.Value, name string, in []reflect.Value) []reflect.Value {
|
||||
inArgs := make([]interface{}, len(in))
|
||||
for i := 0; i < len(in); i++ {
|
||||
inArgs[i] = in[i].Interface()
|
||||
}
|
||||
|
||||
data, err := encodeData(name, inArgs)
|
||||
if err != nil {
|
||||
return c.returnCallError(fn, err)
|
||||
}
|
||||
|
||||
var co *conn
|
||||
var buf []byte
|
||||
for i := 0; i < 3; i++ {
|
||||
if co, err = c.popConn(); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
buf, err = co.Call(data)
|
||||
if err == nil {
|
||||
c.pushConn(co)
|
||||
break
|
||||
} else {
|
||||
co.Close()
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return c.returnCallError(fn, err)
|
||||
}
|
||||
|
||||
n, out, e := decodeData(buf)
|
||||
if e != nil {
|
||||
return c.returnCallError(fn, e)
|
||||
}
|
||||
|
||||
if n != name {
|
||||
return c.returnCallError(fn, fmt.Errorf("rpc name %s != %s", n, name))
|
||||
}
|
||||
|
||||
last := out[len(out)-1]
|
||||
if last != nil {
|
||||
if err, ok := last.(error); ok {
|
||||
return c.returnCallError(fn, err)
|
||||
} else {
|
||||
return c.returnCallError(fn, fmt.Errorf("rpc final return type %T must be error", last))
|
||||
}
|
||||
}
|
||||
|
||||
outValues := make([]reflect.Value, len(out))
|
||||
for i := 0; i < len(out); i++ {
|
||||
if out[i] == nil {
|
||||
outValues[i] = reflect.Zero(fn.Type().Out(i))
|
||||
} else {
|
||||
outValues[i] = reflect.ValueOf(out[i])
|
||||
}
|
||||
}
|
||||
|
||||
return outValues
|
||||
}
|
||||
|
||||
func (c *Client) returnCallError(fn reflect.Value, err error) []reflect.Value {
|
||||
println("return call error", err.Error())
|
||||
|
||||
nOut := fn.Type().NumOut()
|
||||
out := make([]reflect.Value, nOut)
|
||||
for i := 0; i < nOut-1; i++ {
|
||||
out[i] = reflect.Zero(fn.Type().Out(i))
|
||||
}
|
||||
|
||||
out[nOut-1] = reflect.ValueOf(&err).Elem()
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *Client) popConn() (*conn, error) {
|
||||
c.Lock()
|
||||
if c.conns.Len() > 0 {
|
||||
v := c.conns.Front()
|
||||
c.conns.Remove(v)
|
||||
c.Unlock()
|
||||
|
||||
return v.Value.(*conn), nil
|
||||
}
|
||||
c.Unlock()
|
||||
return newConn(c.addr)
|
||||
}
|
||||
|
||||
func (c *Client) pushConn(co *conn) error {
|
||||
c.Lock()
|
||||
if c.conns.Len() >= c.maxIdleConns {
|
||||
c.Unlock()
|
||||
co.Close()
|
||||
return nil
|
||||
} else {
|
||||
c.conns.PushBack(co)
|
||||
}
|
||||
c.Unlock()
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
package rpc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type RpcError struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
func (r RpcError) Error() string {
|
||||
return r.Message
|
||||
}
|
||||
|
||||
func RegisterType(value interface{}) (err error) {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
err = fmt.Errorf("register error")
|
||||
}
|
||||
}()
|
||||
gob.Register(value)
|
||||
return
|
||||
}
|
||||
|
||||
type rpcData struct {
|
||||
Name string
|
||||
Args []interface{}
|
||||
}
|
||||
|
||||
func encodeData(name string, args []interface{}) ([]byte, error) {
|
||||
d := rpcData{}
|
||||
d.Name = name
|
||||
|
||||
d.Args = args
|
||||
|
||||
var buf bytes.Buffer
|
||||
enc := gob.NewEncoder(&buf)
|
||||
|
||||
if err := enc.Encode(d); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
}
|
||||
|
||||
func decodeData(data []byte) (name string, args []interface{}, err error) {
|
||||
var d rpcData
|
||||
|
||||
var buf = bytes.NewBuffer(data)
|
||||
|
||||
dec := gob.NewDecoder(buf)
|
||||
|
||||
if err = dec.Decode(&d); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
name = d.Name
|
||||
args = d.Args
|
||||
|
||||
return
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
package rpc
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
type conn struct {
|
||||
co net.Conn
|
||||
}
|
||||
|
||||
func newConn(addr string) (*conn, error) {
|
||||
c, err := net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
co := new(conn)
|
||||
co.co = c
|
||||
return co, nil
|
||||
}
|
||||
|
||||
func (c *conn) Close() error {
|
||||
return c.co.Close()
|
||||
}
|
||||
|
||||
func (c *conn) Call(data []byte) ([]byte, error) {
|
||||
if err := c.WriteMessage(data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if buf, err := c.ReadMessage(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return buf, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *conn) WriteMessage(data []byte) error {
|
||||
buf := make([]byte, 4+len(data))
|
||||
|
||||
binary.LittleEndian.PutUint32(buf[0:4], uint32(len(data)))
|
||||
|
||||
copy(buf[4:], data)
|
||||
|
||||
n, err := c.co.Write(buf)
|
||||
if err != nil {
|
||||
c.Close()
|
||||
return err
|
||||
} else if n != len(buf) {
|
||||
c.Close()
|
||||
return fmt.Errorf("write %d less than %d", n, len(buf))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) ReadMessage() ([]byte, error) {
|
||||
l := make([]byte, 4)
|
||||
|
||||
_, err := io.ReadFull(c.co, l)
|
||||
if err != nil {
|
||||
c.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
length := binary.LittleEndian.Uint32(l)
|
||||
|
||||
data := make([]byte, length)
|
||||
_, err = io.ReadFull(c.co, data)
|
||||
if err != nil {
|
||||
c.Close()
|
||||
return nil, err
|
||||
} else {
|
||||
return data, nil
|
||||
}
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
package rpc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var testServerOnce sync.Once
|
||||
var testClientOnce sync.Once
|
||||
|
||||
var testServer *Server
|
||||
var testClient *Client
|
||||
|
||||
func newTestServer() *Server {
|
||||
f := func() {
|
||||
testServer = NewServer("127.0.0.1:11182")
|
||||
go testServer.Start()
|
||||
}
|
||||
|
||||
testServerOnce.Do(f)
|
||||
|
||||
return testServer
|
||||
}
|
||||
|
||||
func newTestClient() *Client {
|
||||
f := func() {
|
||||
testClient = NewClient("127.0.0.1:11182", 10)
|
||||
}
|
||||
|
||||
testClientOnce.Do(f)
|
||||
|
||||
return testClient
|
||||
}
|
||||
|
||||
func OnlineRpc(id int) (int, string, error) {
|
||||
return id * 10, "abc", errors.New("hello world")
|
||||
}
|
||||
|
||||
func TestRpc(t *testing.T) {
|
||||
defer func() {
|
||||
e := recover()
|
||||
if s, ok := e.(string); ok {
|
||||
println(s)
|
||||
}
|
||||
|
||||
if err, ok := e.(error); ok {
|
||||
println(err.Error())
|
||||
}
|
||||
}()
|
||||
s := newTestServer()
|
||||
|
||||
s.Register("online_rpc", OnlineRpc)
|
||||
|
||||
c := newTestClient()
|
||||
|
||||
var r func(int) (int, string, error)
|
||||
if err := c.MakeRpc("online_rpc", &r); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r(10)
|
||||
}
|
|
@ -0,0 +1,173 @@
|
|||
package rpc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
sync.Mutex
|
||||
|
||||
addr string
|
||||
funcs map[string]reflect.Value
|
||||
|
||||
listener net.Listener
|
||||
running bool
|
||||
}
|
||||
|
||||
func NewServer(addr string) *Server {
|
||||
RegisterType(RpcError{})
|
||||
|
||||
s := new(Server)
|
||||
s.addr = addr
|
||||
|
||||
s.funcs = make(map[string]reflect.Value)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Server) Start() error {
|
||||
var err error
|
||||
s.listener, err = net.Listen("tcp", s.addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.running = true
|
||||
|
||||
for s.running {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
go s.onConn(conn)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) Stop() error {
|
||||
s.running = false
|
||||
|
||||
if s.listener != nil {
|
||||
s.listener.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) Register(name string, f interface{}) (err error) {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
err = fmt.Errorf("%s is not callable", name)
|
||||
}
|
||||
}()
|
||||
|
||||
v := reflect.ValueOf(f)
|
||||
|
||||
//to check f is function
|
||||
v.Type().NumIn()
|
||||
|
||||
nOut := v.Type().NumOut()
|
||||
if nOut == 0 || v.Type().Out(nOut-1).Kind() != reflect.Interface {
|
||||
err = fmt.Errorf("%s return final output param must be error interface", name)
|
||||
return
|
||||
}
|
||||
|
||||
_, b := v.Type().Out(nOut - 1).MethodByName("Error")
|
||||
if !b {
|
||||
err = fmt.Errorf("%s return final output param must be error interface", name)
|
||||
return
|
||||
}
|
||||
|
||||
s.Lock()
|
||||
if _, ok := s.funcs[name]; ok {
|
||||
err = fmt.Errorf("%s has registered", name)
|
||||
s.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
s.funcs[name] = v
|
||||
s.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Server) onConn(co net.Conn) {
|
||||
println("onconn")
|
||||
c := new(conn)
|
||||
c.co = co
|
||||
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
//later log
|
||||
if err, ok := e.(error); ok {
|
||||
println("recover", err.Error())
|
||||
}
|
||||
}
|
||||
c.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
data, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
println("read error ", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
data, err = s.handle(data)
|
||||
if err != nil {
|
||||
println("handle error ", err.Error())
|
||||
return
|
||||
}
|
||||
err = c.WriteMessage(data)
|
||||
if err != nil {
|
||||
println("write error ", err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handle(data []byte) ([]byte, error) {
|
||||
name, args, err := decodeData(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.Lock()
|
||||
f, ok := s.funcs[name]
|
||||
s.Unlock()
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("rpc %s not registered", name)
|
||||
}
|
||||
|
||||
inValues := make([]reflect.Value, len(args))
|
||||
|
||||
for i := 0; i < len(args); i++ {
|
||||
if args[i] == nil {
|
||||
inValues[i] = reflect.Zero(f.Type().In(i))
|
||||
} else {
|
||||
inValues[i] = reflect.ValueOf(args[i])
|
||||
}
|
||||
}
|
||||
|
||||
out := f.Call(inValues)
|
||||
|
||||
outArgs := make([]interface{}, len(out))
|
||||
for i := 0; i < len(outArgs); i++ {
|
||||
outArgs[i] = out[i].Interface()
|
||||
}
|
||||
|
||||
p := out[len(out)-1].Interface()
|
||||
if p != nil {
|
||||
if e, ok := p.(error); ok {
|
||||
outArgs[len(out)-1] = RpcError{e.Error()}
|
||||
} else {
|
||||
return nil, fmt.Errorf("final param must be error")
|
||||
}
|
||||
}
|
||||
|
||||
return encodeData(name, outArgs)
|
||||
}
|
Loading…
Reference in New Issue