added Hijack func for detached connections

This commit is contained in:
Josh Baker 2016-09-13 11:32:47 -07:00
parent 8d2bcf15cc
commit 08e1ceff58
3 changed files with 125 additions and 12 deletions

View File

@ -20,6 +20,15 @@ func main() {
switch strings.ToLower(args[0]) { switch strings.ToLower(args[0]) {
default: default:
conn.WriteError("ERR unknown command '" + args[0] + "'") conn.WriteError("ERR unknown command '" + args[0] + "'")
case "hijack":
hconn := conn.Hijack()
log.Printf("connection is hijacked")
go func() {
defer hconn.Close()
hconn.WriteString("OK")
hconn.Flush()
}()
return
case "ping": case "ping":
conn.WriteString("PONG") conn.WriteString("PONG")
case "quit": case "quit":

View File

@ -41,12 +41,43 @@ type Conn interface {
Context() interface{} Context() interface{}
// SetContext sets a user-defined context // SetContext sets a user-defined context
SetContext(v interface{}) SetContext(v interface{})
// Hijack return an unmanaged connection. Useful for operations like PubSub.
//
// hconn := conn.Hijack()
// go func(){
// defer hconn.Close()
// cmd, err := hconn.ReadCommand()
// if err != nil{
// fmt.Printf("read failed: %v\n", err)
// return
// }
// fmt.Printf("received command: %v", cmd)
// hconn.WriteString("OK")
// if err := hconn.Flush(); err != nil{
// fmt.Printf("write failed: %v\n", err)
// return
// }
// }()
Hijack() HijackedConn
}
// HijackConn represents an unmanaged connection.
type HijackedConn interface {
// Conn is the original connection
Conn
// ReadCommand reads the next client command.
ReadCommand() ([]string, error)
// ReadCommandBytes reads the next client command as bytes.
ReadCommandBytes() ([][]byte, error)
// Flush flushes any writes to the network.
Flush() error
} }
var ( var (
errUnbalancedQuotes = &errProtocol{"unbalanced quotes in request"} errUnbalancedQuotes = &errProtocol{"unbalanced quotes in request"}
errInvalidBulkLength = &errProtocol{"invalid bulk length"} errInvalidBulkLength = &errProtocol{"invalid bulk length"}
errInvalidMultiBulkLength = &errProtocol{"invalid multibulk length"} errInvalidMultiBulkLength = &errProtocol{"invalid multibulk length"}
errHijacked = errors.New("hijacked")
) )
const ( const (
@ -174,7 +205,7 @@ func (s *Server) ListenServeAndSignal(signal chan error) error {
newWriter(tcpc), newWriter(tcpc),
newReader(tcpc, nil), newReader(tcpc, nil),
tcpc.RemoteAddr().String(), tcpc.RemoteAddr().String(),
nil, nil, false,
} }
s.mu.Lock() s.mu.Lock()
if len(s.rdpool) > 0 { if len(s.rdpool) > 0 {
@ -233,7 +264,9 @@ func handle(
closed func(conn Conn, err error)) { closed func(conn Conn, err error)) {
var err error var err error
defer func() { defer func() {
c.conn.Close() if err != errHijacked {
c.conn.Close()
}
func() { func() {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@ -244,16 +277,21 @@ func handle(
} }
closed(c, err) closed(c, err)
} }
if len(s.rdpool) < defaultPoolSize && len(c.rd.buf) < defaultBufLen { if err != errHijacked {
s.rdpool = append(s.rdpool, c.rd.buf) if len(s.rdpool) < defaultPoolSize && len(c.rd.buf) < defaultBufLen {
} s.rdpool = append(s.rdpool, c.rd.buf)
if len(s.wrpool) < defaultPoolSize && cap(c.wr.b) < defaultBufLen { }
s.wrpool = append(s.wrpool, c.wr.b[:0]) if len(s.wrpool) < defaultPoolSize && cap(c.wr.b) < defaultBufLen {
s.wrpool = append(s.wrpool, c.wr.b[:0])
}
} }
}() }()
}() }()
err = func() error { err = func() error {
for { for {
if c.hj {
return errHijacked
}
cmds, err := c.rd.ReadCommands() cmds, err := c.rd.ReadCommands()
if err != nil { if err != nil {
if err, ok := err.(*errProtocol); ok { if err, ok := err.(*errProtocol); ok {
@ -306,6 +344,7 @@ type conn struct {
rd *reader rd *reader
addr string addr string
ctx interface{} ctx interface{}
hj bool
} }
func (c *conn) Close() error { func (c *conn) Close() error {
@ -345,6 +384,51 @@ func (c *conn) RemoteAddr() string {
} }
func (c *conn) SetReadBuffer(bytes int) { func (c *conn) SetReadBuffer(bytes int) {
} }
func (c *conn) Hijack() HijackedConn {
c.hj = true
return &hijackedConn{conn: c}
}
type hijackedConn struct {
*conn
cmds [][][]byte
}
func (hjc *hijackedConn) Flush() error {
return hjc.conn.wr.Flush()
}
func (hjc *hijackedConn) ReadCommandBytes() ([][]byte, error) {
if len(hjc.cmds) > 0 {
args := hjc.cmds[0]
hjc.cmds = hjc.cmds[1:]
for i, arg := range args {
nb := make([]byte, len(arg))
copy(nb, arg)
args[i] = nb
}
return args, nil
}
cmds, err := hjc.rd.ReadCommands()
if err != nil {
return nil, err
}
hjc.cmds = cmds
return hjc.ReadCommandBytes()
}
func (hjc *hijackedConn) ReadCommand() ([]string, error) {
if len(hjc.cmds) > 0 {
args := hjc.cmds[0]
hjc.cmds = hjc.cmds[1:]
nargs := make([]string, len(args))
for i, arg := range args {
nargs[i] = string(arg)
}
return nargs, nil
}
return hjc.ReadCommand()
}
// Reader represents a RESP command reader. // Reader represents a RESP command reader.
type reader struct { type reader struct {

View File

@ -200,6 +200,12 @@ func TestRandomCommands(t *testing.T) {
fmt.Printf("%d commands in %s - %.0f ops/sec\n", cnt, dur, float64(cnt)/(float64(dur)/float64(time.Second))) fmt.Printf("%d commands in %s - %.0f ops/sec\n", cnt, dur, float64(cnt)/(float64(dur)/float64(time.Second)))
} }
} }
func testHijack(t *testing.T, conn HijackedConn) {
conn.WriteString("HIJACKED")
if err := conn.Flush(); err != nil {
t.Fatal(err)
}
}
func TestServer(t *testing.T) { func TestServer(t *testing.T) {
s := NewServer(":12345", s := NewServer(":12345",
@ -213,6 +219,8 @@ func TestServer(t *testing.T) {
case "quit": case "quit":
conn.WriteString("OK") conn.WriteString("OK")
conn.Close() conn.Close()
case "hijack":
go testHijack(t, conn.Hijack())
case "int": case "int":
conn.WriteInt(100) conn.WriteInt(100)
case "bulk": case "bulk":
@ -257,7 +265,16 @@ func TestServer(t *testing.T) {
t.Fatalf("expected an error") t.Fatalf("expected an error")
} }
}() }()
done := make(chan bool)
signal := make(chan error)
go func() { go func() {
defer func() {
done <- true
}()
err := <-signal
if err != nil {
t.Fatal(err)
}
c, err := net.Dial("tcp", ":12345") c, err := net.Dial("tcp", ":12345")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -321,16 +338,19 @@ func TestServer(t *testing.T) {
if res != "-ERR error\r\n" { if res != "-ERR error\r\n" {
t.Fatal("expecting array, got '%v'", res) t.Fatal("expecting array, got '%v'", res)
} }
res, err = do("HIJACK\r\n")
if err != nil {
t.Fatal(err)
}
if res != "+HIJACKED\r\n" {
t.Fatal("expecting string, got '%v'", res)
}
}() }()
signal := make(chan error)
go func() { go func() {
err := s.ListenServeAndSignal(signal) err := s.ListenServeAndSignal(signal)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
}() }()
err := <-signal <-done
if err != nil {
t.Fatal(err)
}
} }