From 9471b43256b579c601d23430b9bf84d0e189138a Mon Sep 17 00:00:00 2001 From: Josh Baker Date: Tue, 7 Nov 2017 06:49:33 -0700 Subject: [PATCH] wip --- evio.go | 180 ++++++++++++++++--- evio_loop.go | 321 ++++++++++++++++++++-------------- evio_net.go | 8 +- evio_translate.go | 14 +- examples/redis-server/main.go | 86 ++++----- internal/internal_test.go | 34 ++++ internal/timeoutqueue.go | 70 ++++++++ 7 files changed, 489 insertions(+), 224 deletions(-) create mode 100644 internal/internal_test.go create mode 100644 internal/timeoutqueue.go diff --git a/evio.go b/evio.go index f8a3bc9..b36d0f1 100644 --- a/evio.go +++ b/evio.go @@ -32,40 +32,48 @@ type Options struct { TCPKeepAlive time.Duration } -// Addr represents the connection's remote and local addresses. -type Addr struct { - // Index is the index of server address that was passed to the Serve call. - Index int - // Local is the connection's local socket address. - Local net.Addr - // Local is the connection's remote peer address. - Remote net.Addr +// Conn represents a connection context which provides information +// about the connection. +type Conn struct { + // Closing is true when the connection is about to close. Expect a Closed + // event to fire soon. + Closing bool + // AddrIndex is the index of server address that was passed to the Serve call. + AddrIndex int + // LocalAddr is the connection's local socket address. + LocalAddr net.Addr + // RemoteAddr is the connection's remote peer address. + RemoteAddr net.Addr } -// Context represents a server context which provides information about the +// Server represents a server context which provides information about the // running server and has control functions for managing some state -type Context struct { - Addrs []net.Addr - Wake func(id int) bool - Attach func(v interface{}) error +type Server struct { + // The addrs parameter is an array of listening addresses that align + // with the addr strings passed to the Serve function. + Addrs []net.Addr + // Wake is a goroutine-safe function that triggers a Data event + // (with a nil `in` parameter) for the specified id. + Wake func(id int) bool + // Dial makes a connection to an external server and returns a new + // connection id. The new connection is added to the event loop and + // is managed exactly the same way as all the other connections. + Dial func(addr string, timeout time.Duration) (id int, err error) } // Events represents the server events for the Serve call. // Each event has an Action return value that is used manage the state // of the connection and server. type Events struct { - // Serving fires when the server can accept connections. - // The wake parameter is a goroutine-safe function that triggers - // a Data event (with a nil `in` parameter) for the specified id. - // The addrs parameter is an array of listening addresses that align - // with the addr strings passed to the Serve function. - Serving func(c Context) (action Action) + // Serving fires when the server can accept connections. The context + // parameter has various utilities that may help with managing the + // event loop. + Serving func(s Server) (action Action) // Opened fires when a new connection has opened. // The addr parameter is the connection's local and remote addresses. // Use the out return value to write data to the connection. // The opts return value is used to set connection options. - Opened func(id int, addr Addr) (out []byte, opts Options, action Action) - Attached func(id int, v interface{}) (out []byte, opts Options, action Action) + Opened func(id int, c Conn) (out []byte, opts Options, action Action) // Closed fires when a connection has closed. // The err parameter is the last known connection error, usually nil. Closed func(id int, err error) (action Action) @@ -116,14 +124,11 @@ func Serve(events Events, addr ...string) error { }() var stdlib bool for _, addr := range addr { - ln := listener{network: "tcp", addr: addr} - if strings.Contains(addr, "://") { - ln.network = strings.Split(addr, "://")[0] - ln.addr = strings.Split(addr, "://")[1] - } - if strings.HasSuffix(ln.network, "-net") { + var ln listener + var stdlibt bool + ln.network, ln.addr, stdlibt = parseAddr(addr) + if stdlibt { stdlib = true - ln.network = ln.network[:len(ln.network)-4] } if ln.network == "unix" { os.RemoveAll(ln.addr) @@ -181,3 +186,122 @@ type listener struct { addr string naddr net.Addr } + +func parseAddr(addr string) (network, address string, stdlib bool) { + network = "tcp" + address = addr + if strings.Contains(address, "://") { + network = strings.Split(address, "://")[0] + address = strings.Split(address, "://")[1] + } + if strings.HasSuffix(network, "-net") { + stdlib = true + network = network[:len(network)-4] + } + return +} + +// // type timeoutHeap []timeoutHeapItem + +// // func (h timeoutHeap) Len() int { return len(h) } +// // func (h timeoutHeap) Less(i, j int) bool { return h[i].timeout < h[j].timeout } +// // func (h timeoutHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } +// // func (h *timeoutHeap) Push(x interface{}) { +// // *h = append(*h, x.(timeoutHeapItem)) +// // } +// // func (h *timeoutHeap) Pop() interface{} { +// // old := *h +// // n := len(old) +// // x := old[n-1] +// // *h = old[0 : n-1] +// // return x +// // } + +// type timeoutQueue struct { +// h *timeoutHeap +// } + +// func newTimeoutQueue() *timeoutQueue { +// q := &timeoutQueue{&timeoutHeap{}} +// heap.Init(q.h) +// return q +// } +// func (q *timeoutQueue) len() int { +// return q.h.Len() +// } +// func (q *timeoutQueue) push(id int, timeout int64) { +// heap.Push(q.h, timeoutHeapItem{id: id, timeout: timeout}) +// } +// func (q *timeoutQueue) peek() (id int, timeout int64) { +// if q.len() > 0 { +// id = (*(q.h))[0].id +// timeout = (*(q.h))[0].timeout +// } +// return +// } +// func (q *timeoutQueue) pop() (id int, timeout int64) { +// if q.len() > 0 { +// item := q.h.Pop().(timeoutHeapItem) +// id = item.id +// timeout = item.timeout +// } +// return +// } + +// // func init() { +// // rand.Seed(time.Now().UnixNano()) +// // q := newTimeoutQueue() +// // for i := 0; i < 1000; i++ { +// // q.push(i, rand.Int63()%9000) +// // } +// // for q.len() > 0 { +// // id, timeout := q.pop() +// // fmt.Printf("%05d %05d\n", id, timeout) +// // } +// // } + +// type timeoutHeapItem struct { +// id int +// timeout int64 +// } +// type timeoutHeap []timeoutHeapItem + +// func (h timeoutHeap) Len() int { return len(h) } +// func (h timeoutHeap) Less(i, j int) bool { return h[i].timeout < h[j].timeout } +// func (h timeoutHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } + +// func (h *timeoutHeap) Push(x interface{}) { +// // Push and Pop use pointer receivers because they modify the slice's length, +// // not just its contents. +// *h = append(*h, x.(timeoutHeapItem)) +// } + +// func (h *timeoutHeap) Pop() interface{} { +// old := *h +// n := len(old) +// x := old[n-1] +// *h = old[0 : n-1] +// return x +// } + +// // This example inserts several ints into an IntHeap, checks the minimum, +// // and removes them in order of priority. +// func init() { +// q := newTimeoutQueue() +// rand.Seed(time.Now().UnixNano()) +// // h := &timeoutHeap{} +// // heap.Init(h) + +// for i := 10; i < 20; i++ { +// //heap.Push(h, timeoutHeapItem{i, rand.Int63() % 10}) +// q.push(i, rand.Int63()%10) +// } +// _, timeout := q.peek() +// fmt.Printf("minimum: %d\n", timeout) +// for q.len() > 0 { +// //v := heap.Pop(h).(timeoutHeapItem) +// _, timeout = q.pop() +// fmt.Printf("%d ", timeout) +// } +// fmt.Printf("\n") +// } diff --git a/evio_loop.go b/evio_loop.go index 0270577..486e82c 100644 --- a/evio_loop.go +++ b/evio_loop.go @@ -54,19 +54,22 @@ func (ln *listener) system() error { } type unixConn struct { - id, fd, p int - outbuf []byte - outpos int - action Action - opts Options - raddr net.Addr - laddr net.Addr - err error - wake bool - writeon bool - detached bool - attaching bool - closed bool + id, fd int + outbuf []byte + outpos int + action Action + opts Options + timeout time.Time + err error + wake bool + writeon bool + detached bool + closed bool + opening bool +} + +func (c *unixConn) Timeout() time.Time { + return c.timeout } func (c *unixConn) Read(p []byte) (n int, err error) { @@ -113,7 +116,6 @@ func (c *unixConn) Close() error { c.closed = true return err } - func serve(events Events, lns []*listener) error { p, err := internal.MakePoll() if err != nil { @@ -130,8 +132,10 @@ func serve(events Events, lns []*listener) error { unlock := func() { mu.Unlock() } fdconn := make(map[int]*unixConn) idconn := make(map[int]*unixConn) + + timeoutqueue := internal.NewTimeoutQueue() var id int - ctx := Context{ + ctx := Server{ Wake: func(id int) bool { var ok = true var err error @@ -141,7 +145,7 @@ func serve(events Events, lns []*listener) error { ok = false } else if !c.wake { c.wake = true - err = internal.AddWrite(c.p, c.fd, &c.writeon) + err = internal.AddWrite(p, c.fd, &c.writeon) } unlock() if err != nil { @@ -149,69 +153,77 @@ func serve(events Events, lns []*listener) error { } return ok }, - Attach: func(v interface{}) error { - var fd int + Dial: func(addr string, timeout time.Duration) (int, error) { + network, address, _ := parseAddr(addr) + var taddr net.Addr var err error - switch v := v.(type) { + switch network { default: - return errors.New("invalid type") - case *net.TCPConn: - f, err := v.File() + return 0, errors.New("invalid network") + case "unix": + case "tcp", "tcp4", "tcp6": + taddr, err = net.ResolveTCPAddr(network, address) if err != nil { - return err + return 0, err + } + } + var fd int + var sa syscall.Sockaddr + switch taddr := taddr.(type) { + case *net.UnixAddr: + sa = &syscall.SockaddrUnix{Name: taddr.Name} + case *net.TCPAddr: + if len(taddr.IP) == 4 { + var sa4 syscall.SockaddrInet4 + copy(sa4.Addr[:], taddr.IP[:]) + sa4.Port = taddr.Port + sa = &sa4 + fd, err = syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, 0) + } else if len(taddr.IP) == 16 { + var sa6 syscall.SockaddrInet6 + copy(sa6.Addr[:], taddr.IP[:]) + sa6.Port = taddr.Port + sa = &sa6 + fd, err = syscall.Socket(syscall.AF_INET6, syscall.SOCK_STREAM, 0) + } else { + return 0, errors.New("invalid network") } - fd = int(f.Fd()) - case *net.UnixConn: - f, err := v.File() - if err != nil { - return err - } - fd = int(f.Fd()) - case *os.File: - fd = int(v.Fd()) - case int: - fd = v - case uintptr: - fd = int(v) } - err = syscall.SetNonblock(fd, true) if err != nil { - println(456) - return err + return 0, err + } + if err := syscall.SetNonblock(fd, true); err != nil { + syscall.Close(fd) + return 0, err + } + err = syscall.Connect(fd, sa) + if err != nil && err != syscall.EINPROGRESS { + syscall.Close(fd) + return 0, err } lock() err = internal.AddRead(p, fd) if err != nil { unlock() - return err + syscall.Close(fd) + return 0, err } id++ - c := &unixConn{id: id, fd: fd, p: p} - c.attaching = true - fdconn[fd] = c - idconn[id] = c - if events.Attached != nil { - unlock() - out, opts, action := events.Attached(id, v) - lock() - if opts.TCPKeepAlive > 0 { - internal.SetKeepAlive(fd, int(c.opts.TCPKeepAlive/time.Second)) - } - c.action = action - if len(out) > 0 { - c.outbuf = append(c.outbuf, out...) - } - } - // if len(c.outbuf) > 0 || c.action != None { + c := &unixConn{id: id, fd: fd, opening: true} err = internal.AddWrite(p, fd, &c.writeon) if err != nil { unlock() - panic(err) + syscall.Close(fd) + return 0, err + } + fdconn[fd] = c + idconn[id] = c + if timeout != 0 { + c.timeout = time.Now().Add(timeout) + timeoutqueue.Push(c) } - // } unlock() - // println("---") - return nil + return id, nil }, } ctx.Addrs = make([]net.Addr, len(lns)) @@ -227,47 +239,109 @@ func serve(events Events, lns []*listener) error { defer func() { lock() type fdid struct { - fd, id int - opts Options + fd, id int + opening bool } var fdids []fdid for fd, c := range fdconn { - fdids = append(fdids, fdid{fd, c.id, c.opts}) + fdids = append(fdids, fdid{fd, c.id, c.opening}) } sort.Slice(fdids, func(i, j int) bool { return fdids[j].id < fdids[i].id }) for _, fdid := range fdids { syscall.Close(fdid.fd) + if fdid.opening { + if events.Opened != nil { + laddr := getlocaladdr(fdid.fd) + raddr := getremoteaddr(fdid.fd) + unlock() + events.Opened(fdid.id, Conn{ + Closing: true, + AddrIndex: -1, + LocalAddr: laddr, + RemoteAddr: raddr, + }) + lock() + } + } if events.Closed != nil { unlock() events.Closed(fdid.id, nil) lock() } } + syscall.Close(p) unlock() }() var packet [0xFFFF]byte var evs = internal.MakeEvents(64) - var lastTicker time.Time - var tickerDelay time.Duration - if events.Tick == nil { - tickerDelay = time.Hour - } + nextTicker := time.Now() for { - pn, err := internal.Wait(p, evs, tickerDelay) + delay := nextTicker.Sub(time.Now()) + if delay < 0 { + delay = 0 + } else if delay > time.Second/4 { + delay = time.Second / 4 + } + pn, err := internal.Wait(p, evs, delay) if err != nil && err != syscall.EINTR { return err } if events.Tick != nil { - now := time.Now() - if now.Sub(lastTicker) > tickerDelay { + remain := nextTicker.Sub(time.Now()) + if remain < 0 { + var tickerDelay time.Duration var action Action - tickerDelay, action = events.Tick() - if action == Shutdown { - return nil + if events.Tick != nil { + tickerDelay, action = events.Tick() + if action == Shutdown { + return nil + } + } else { + tickerDelay = time.Hour } - lastTicker = now + nextTicker = time.Now().Add(tickerDelay + remain) + } + } + // check timeouts + if timeoutqueue.Len() > 0 { + var count int + now := time.Now() + for { + v := timeoutqueue.Peek() + if v == nil { + break + } + c := v.(*unixConn) + if now.After(v.Timeout()) { + timeoutqueue.Pop() + if _, ok := idconn[c.id]; ok { + delete(idconn, c.id) + delete(fdconn, c.fd) + syscall.Close(c.fd) + if events.Opened != nil { + laddr := getlocaladdr(c.fd) + raddr := getremoteaddr(c.fd) + events.Opened(c.id, Conn{ + Closing: true, + AddrIndex: -1, + LocalAddr: laddr, + RemoteAddr: raddr, + }) + } + if events.Closed != nil { + events.Closed(c.id, syscall.ETIMEDOUT) + } + count++ + } + } else { + break + } + } + if count > 0 { + // invalidate the current events and wait for more + continue } } lock() @@ -280,39 +354,25 @@ func serve(events Events, lns []*listener) error { var ln *listener var lnidx int var fd = internal.GetFD(evs, i) - var sa syscall.Sockaddr for lnidx, ln = range lns { if fd == ln.fd { goto accept } } + ln = nil c = fdconn[fd] if c == nil { syscall.Close(fd) goto next } - if c.attaching { - println(fd) - goto next - // opt, err := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_ERROR) - // if opt != 0 { + if c.opening { - // } - // fmt.Printf(">> %v %v\n", opt, err) - switch evs.([]syscall.Kevent_t)[i].Filter { - case syscall.EVFILT_WRITE: - println(123) - goto write - case syscall.EVFILT_READ: - println(456) - goto read - default: - goto next - } + lnidx = -1 + goto opened } goto read accept: - nfd, sa, err = syscall.Accept(fd) + nfd, _, err = syscall.Accept(fd) if err != nil { goto next } @@ -323,26 +383,32 @@ func serve(events Events, lns []*listener) error { goto fail } id++ - c = &unixConn{id: id, fd: nfd, p: p} + c = &unixConn{id: id, fd: nfd} fdconn[nfd] = c idconn[id] = c - c.laddr = getlocaladdr(fd, ln.ln) - c.raddr = getaddr(sa, ln.ln) + goto opened + opened: if events.Opened != nil { + laddr := getlocaladdr(fd) + raddr := getremoteaddr(fd) unlock() - out, c.opts, c.action = events.Opened(c.id, Addr{lnidx, c.laddr, c.raddr}) + out, c.opts, c.action = events.Opened(c.id, Conn{ + AddrIndex: lnidx, + LocalAddr: laddr, + RemoteAddr: raddr, + }) lock() if c.opts.TCPKeepAlive > 0 { - if _, ok := ln.ln.(*net.TCPListener); ok { - if err = internal.SetKeepAlive(c.fd, int(c.opts.TCPKeepAlive/time.Second)); err != nil { - goto fail - } - } + internal.SetKeepAlive(c.fd, int(c.opts.TCPKeepAlive/time.Second)) } if len(out) > 0 { c.outbuf = append(c.outbuf, out...) } } + if c.opening { + c.opening = false + goto next + } goto write read: if c.action != None { @@ -394,12 +460,11 @@ func serve(events Events, lns []*listener) error { } } if n == 0 || err != nil { - println("C") if c.action == Shutdown { goto close } if err == syscall.EAGAIN { - if err = internal.AddWrite(c.p, c.fd, &c.writeon); err != nil { + if err = internal.AddWrite(p, c.fd, &c.writeon); err != nil { goto fail } goto next @@ -418,7 +483,7 @@ func serve(events Events, lns []*listener) error { } if len(c.outbuf)-c.outpos == 0 { if !c.wake { - if err = internal.DelWrite(c.p, c.fd, &c.writeon); err != nil { + if err = internal.DelWrite(p, c.fd, &c.writeon); err != nil { goto fail } } @@ -426,7 +491,7 @@ func serve(events Events, lns []*listener) error { goto close } } else { - if err = internal.AddWrite(c.p, c.fd, &c.writeon); err != nil { + if err = internal.AddWrite(p, c.fd, &c.writeon); err != nil { goto fail } } @@ -434,6 +499,7 @@ func serve(events Events, lns []*listener) error { close: delete(fdconn, c.fd) delete(idconn, c.id) + //delete(idtimeout, c.id) if c.action == Detach { if events.Detached != nil { c.detached = true @@ -476,30 +542,23 @@ func serve(events Events, lns []*listener) error { } } -func getlocaladdr(fd int, ln net.Listener) net.Addr { +func getlocaladdr(fd int) net.Addr { sa, _ := syscall.Getsockname(fd) - return getaddr(sa, ln) + return getaddr(sa) } - -func getaddr(sa syscall.Sockaddr, ln net.Listener) net.Addr { - switch ln.(type) { - case *net.UnixListener: - return ln.Addr() - case *net.TCPListener: - var addr net.TCPAddr - switch sa := sa.(type) { - case *syscall.SockaddrInet4: - addr.IP = net.IP(sa.Addr[:]) - addr.Port = sa.Port - return &addr - case *syscall.SockaddrInet6: - addr.IP = net.IP(sa.Addr[:]) - addr.Port = sa.Port - if sa.ZoneId != 0 { - addr.Zone = strconv.FormatInt(int64(sa.ZoneId), 10) - } - return &addr - } +func getremoteaddr(fd int) net.Addr { + sa, _ := syscall.Getpeername(fd) + return getaddr(sa) +} +func getaddr(sa syscall.Sockaddr) net.Addr { + switch sa := sa.(type) { + default: + return nil + case *syscall.SockaddrInet4: + return &net.TCPAddr{IP: net.IP(sa.Addr[:]), Port: sa.Port} + case *syscall.SockaddrInet6: + return &net.TCPAddr{IP: net.IP(sa.Addr[:]), Port: sa.Port, Zone: strconv.FormatInt(int64(sa.ZoneId), 10)} + case *syscall.SockaddrUnix: + return &net.UnixAddr{Net: "unix", Name: sa.Name} } - return nil } diff --git a/evio_net.go b/evio_net.go index 4409709..419944b 100644 --- a/evio_net.go +++ b/evio_net.go @@ -68,7 +68,7 @@ func servenet(events Events, lns []*listener) error { var cmu sync.Mutex var idconn = make(map[int]*netConn) var done bool - ctx := Context{ + ctx := Server{ Wake: func(id int) bool { cmu.Lock() c := idconn[id] @@ -166,7 +166,11 @@ func servenet(events Events, lns []*listener) error { var action Action mu.Lock() if !done { - out, opts, action = events.Opened(id, Addr{lnidx, conn.LocalAddr(), conn.RemoteAddr()}) + out, opts, action = events.Opened(id, Conn{ + AddrIndex: lnidx, + LocalAddr: conn.LocalAddr(), + RemoteAddr: conn.RemoteAddr(), + }) } mu.Unlock() if opts.TCPKeepAlive > 0 { diff --git a/evio_translate.go b/evio_translate.go index bc7f156..f95aa9f 100644 --- a/evio_translate.go +++ b/evio_translate.go @@ -70,11 +70,11 @@ func NopConn(rw io.ReadWriter) net.Conn { // that performs the actual translation. func Translate( events Events, - should func(id int, addr Addr) bool, + should func(id int, conn Conn) bool, translate func(id int, rd io.ReadWriter) io.ReadWriter, ) Events { tevents := events - var ctx Context + var ctx Server var mu sync.Mutex idc := make(map[int]*tconn) get := func(id int) *tconn { @@ -180,23 +180,23 @@ func Translate( c.mu.Unlock() return err } - tevents.Serving = func(ctxin Context) (action Action) { + tevents.Serving = func(ctxin Server) (action Action) { ctx = ctxin if events.Serving != nil { action = events.Serving(ctx) } return } - tevents.Opened = func(id int, addr Addr) (out []byte, opts Options, action Action) { - if should != nil && !should(id, addr) { + tevents.Opened = func(id int, conn Conn) (out []byte, opts Options, action Action) { + if should != nil && !should(id, conn) { if events.Opened != nil { - out, opts, action = events.Opened(id, addr) + out, opts, action = events.Opened(id, conn) } return } c := create(id) if events.Opened != nil { - out, opts, c.action = events.Opened(id, addr) + out, opts, c.action = events.Opened(id, conn) if len(out) > 0 { c.write(1, out) out = nil diff --git a/examples/redis-server/main.go b/examples/redis-server/main.go index fce2071..b045236 100644 --- a/examples/redis-server/main.go +++ b/examples/redis-server/main.go @@ -8,9 +8,8 @@ import ( "flag" "fmt" "log" - "net" + "strconv" "strings" - "syscall" "time" "github.com/tidwall/evio" @@ -23,61 +22,49 @@ type conn struct { wget bool } -func Dial(network, addr string) (fd int, err error) { - taddr, err := net.ResolveTCPAddr(network, addr) - if err != nil { - return 0, err - } - fd, err = syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, 0) - if err != nil { - return 0, err - } - if err := syscall.SetNonblock(fd, true); err != nil { - syscall.Close(fd) - return 0, err - } - var sa syscall.SockaddrInet4 - copy(sa.Addr[:], taddr.IP[:]) - sa.Port = taddr.Port - err = syscall.Connect(fd, &sa) - if err != nil && err != syscall.EINPROGRESS { - syscall.Close(fd) - return 0, err - } - return fd, nil -} - func main() { var port int var unixsocket string - var ctx evio.Context + var srv evio.Server flag.IntVar(&port, "port", 6380, "server port") flag.StringVar(&unixsocket, "unixsocket", "socket", "unix socket") flag.Parse() var conns = make(map[int]*conn) var keys = make(map[string]string) var events evio.Events - events.Serving = func(ctxin evio.Context) (action evio.Action) { - ctx = ctxin + events.Serving = func(srvin evio.Server) (action evio.Action) { + srv = srvin log.Printf("redis server started on port %d", port) if unixsocket != "" { log.Printf("redis server started at %s", unixsocket) } return } - events.Attached = func(id int, v interface{}) (out []byte, opts evio.Options, action evio.Action) { - conns[id] = &conn{wget: true} - println("attached", id) - out = []byte("GET / HTTP/1.0\r\n\r\n") + wgetids := make(map[int]time.Time) + events.Opened = func(id int, cn evio.Conn) (out []byte, opts evio.Options, action evio.Action) { + c := &conn{} + if !wgetids[id].IsZero() { + delete(wgetids, id) + c.wget = true + } + conns[id] = c + println("opened", id, c.wget) + if c.wget { + out = []byte("GET / HTTP/1.0\r\n\r\n") + } return } - events.Opened = func(id int, addr evio.Addr) (out []byte, opts evio.Options, action evio.Action) { - println("opened", id) - conns[id] = &conn{} + events.Tick = func() (delay time.Duration, action evio.Action) { + now := time.Now() + for id, t := range wgetids { + if now.Sub(t) > time.Second { + srv.Wake(id) + } + } + delay = time.Second return } events.Closed = func(id int, err error) (action evio.Action) { - fmt.Printf("closed %d %v\n", id, err) delete(conns, id) return @@ -86,6 +73,7 @@ func main() { c := conns[id] if c.wget { println(string(in)) + action = evio.Close return } data := c.is.Begin(in) @@ -109,33 +97,19 @@ func main() { default: out = redcon.AppendError(out, "ERR unknown command '"+string(args[0])+"'") case "WGET": - if len(args) != 2 { + if len(args) != 3 { out = redcon.AppendError(out, "ERR wrong number of arguments for '"+string(args[0])+"' command") } else { start := time.Now() - fd, err := Dial("tcp", string(args[1])) + n, _ := strconv.ParseInt(string(args[2]), 10, 63) + cid, err := srv.Dial("tcp://"+string(args[1]), time.Duration(n)*time.Second) if err != nil { out = redcon.AppendError(out, err.Error()) } else { - time.Since(start) + wgetids[cid] = time.Now() + println(cid, time.Since(start).String()) out = redcon.AppendOK(out) - ctx.Attach(fd) } - // conn, err := net.Dial("tcp", string(args[1])) - // if err != nil { - // out = redcon.AppendError(out, err.Error()) - // } else { - // println(time.Since(start).String()) - // f, err := conn.(*net.TCPConn).File() - // if err != nil { - // conn.Close() - // out = redcon.AppendError(out, err.Error()) - // } else { - // out = redcon.AppendOK(out) - - // ctx.Attach(f.Fd()) - // } - // } } case "PING": if len(args) > 2 { diff --git a/internal/internal_test.go b/internal/internal_test.go new file mode 100644 index 0000000..bc911bc --- /dev/null +++ b/internal/internal_test.go @@ -0,0 +1,34 @@ +package internal + +import ( + "fmt" + "testing" + "time" +) + +type queueItem struct { + timeout time.Time +} + +func (item *queueItem) Timeout() time.Time { + return item.timeout +} + +func TestQueue(t *testing.T) { + q := NewTimeoutQueue() + item := &queueItem{timeout: time.Unix(0, 5)} + q.Push(item) + q.Push(&queueItem{timeout: time.Unix(0, 3)}) + q.Push(&queueItem{timeout: time.Unix(0, 20)}) + q.Push(&queueItem{timeout: time.Unix(0, 13)}) + var out string + for q.Len() > 0 { + pitem := q.Peek() + item := q.Pop() + out += fmt.Sprintf("(%v:%v) ", pitem.Timeout().UnixNano(), item.Timeout().UnixNano()) + } + exp := "(3:3) (5:5) (13:13) (20:20) " + if out != exp { + t.Fatalf("expected '%v', got '%v'", exp, out) + } +} diff --git a/internal/timeoutqueue.go b/internal/timeoutqueue.go new file mode 100644 index 0000000..49ae1d5 --- /dev/null +++ b/internal/timeoutqueue.go @@ -0,0 +1,70 @@ +package internal + +import ( + "container/heap" + "time" +) + +// TimeoutQueueItem is an item for TimeoutQueue +type TimeoutQueueItem interface { + Timeout() time.Time +} + +type timeoutPriorityQueue []TimeoutQueueItem + +func (pq timeoutPriorityQueue) Len() int { return len(pq) } + +func (pq timeoutPriorityQueue) Less(i, j int) bool { + return pq[i].Timeout().Before(pq[j].Timeout()) +} + +func (pq timeoutPriorityQueue) Swap(i, j int) { + pq[i], pq[j] = pq[j], pq[i] +} + +func (pq *timeoutPriorityQueue) Push(x interface{}) { + *pq = append(*pq, x.(TimeoutQueueItem)) +} + +func (pq *timeoutPriorityQueue) Pop() interface{} { + old := *pq + n := len(old) + item := old[n-1] + *pq = old[0 : n-1] + return item +} + +// TimeoutQueue is a priority queue ordere be ascending time.Time. +type TimeoutQueue struct { + pq timeoutPriorityQueue +} + +// NewTimeoutQueue returns a new TimeoutQueue. +func NewTimeoutQueue() *TimeoutQueue { + q := &TimeoutQueue{} + heap.Init(&q.pq) + return q +} + +// Push adds a new item. +func (q *TimeoutQueue) Push(x TimeoutQueueItem) { + heap.Push(&q.pq, x) +} + +// Pop removes and returns the items with the smallest value. +func (q *TimeoutQueue) Pop() TimeoutQueueItem { + return heap.Pop(&q.pq).(TimeoutQueueItem) +} + +// Peek returns the items with the smallest value, but does not remove it. +func (q *TimeoutQueue) Peek() TimeoutQueueItem { + if q.Len() > 0 { + return q.pq[0] + } + return nil +} + +// Len returns the number of items in the queue +func (q *TimeoutQueue) Len() int { + return q.pq.Len() +}