added Hijack func for detached connections

This commit is contained in:
Josh Baker 2016-09-13 11:32:47 -07:00
parent 8d2bcf15cc
commit 08e1ceff58
3 changed files with 125 additions and 12 deletions

View File

@ -20,6 +20,15 @@ func main() {
switch strings.ToLower(args[0]) {
default:
conn.WriteError("ERR unknown command '" + args[0] + "'")
case "hijack":
hconn := conn.Hijack()
log.Printf("connection is hijacked")
go func() {
defer hconn.Close()
hconn.WriteString("OK")
hconn.Flush()
}()
return
case "ping":
conn.WriteString("PONG")
case "quit":

View File

@ -41,12 +41,43 @@ type Conn interface {
Context() interface{}
// SetContext sets a user-defined context
SetContext(v interface{})
// Hijack return an unmanaged connection. Useful for operations like PubSub.
//
// hconn := conn.Hijack()
// go func(){
// defer hconn.Close()
// cmd, err := hconn.ReadCommand()
// if err != nil{
// fmt.Printf("read failed: %v\n", err)
// return
// }
// fmt.Printf("received command: %v", cmd)
// hconn.WriteString("OK")
// if err := hconn.Flush(); err != nil{
// fmt.Printf("write failed: %v\n", err)
// return
// }
// }()
Hijack() HijackedConn
}
// HijackConn represents an unmanaged connection.
type HijackedConn interface {
// Conn is the original connection
Conn
// ReadCommand reads the next client command.
ReadCommand() ([]string, error)
// ReadCommandBytes reads the next client command as bytes.
ReadCommandBytes() ([][]byte, error)
// Flush flushes any writes to the network.
Flush() error
}
var (
errUnbalancedQuotes = &errProtocol{"unbalanced quotes in request"}
errInvalidBulkLength = &errProtocol{"invalid bulk length"}
errInvalidMultiBulkLength = &errProtocol{"invalid multibulk length"}
errHijacked = errors.New("hijacked")
)
const (
@ -174,7 +205,7 @@ func (s *Server) ListenServeAndSignal(signal chan error) error {
newWriter(tcpc),
newReader(tcpc, nil),
tcpc.RemoteAddr().String(),
nil,
nil, false,
}
s.mu.Lock()
if len(s.rdpool) > 0 {
@ -233,7 +264,9 @@ func handle(
closed func(conn Conn, err error)) {
var err error
defer func() {
c.conn.Close()
if err != errHijacked {
c.conn.Close()
}
func() {
s.mu.Lock()
defer s.mu.Unlock()
@ -244,16 +277,21 @@ func handle(
}
closed(c, err)
}
if len(s.rdpool) < defaultPoolSize && len(c.rd.buf) < defaultBufLen {
s.rdpool = append(s.rdpool, c.rd.buf)
}
if len(s.wrpool) < defaultPoolSize && cap(c.wr.b) < defaultBufLen {
s.wrpool = append(s.wrpool, c.wr.b[:0])
if err != errHijacked {
if len(s.rdpool) < defaultPoolSize && len(c.rd.buf) < defaultBufLen {
s.rdpool = append(s.rdpool, c.rd.buf)
}
if len(s.wrpool) < defaultPoolSize && cap(c.wr.b) < defaultBufLen {
s.wrpool = append(s.wrpool, c.wr.b[:0])
}
}
}()
}()
err = func() error {
for {
if c.hj {
return errHijacked
}
cmds, err := c.rd.ReadCommands()
if err != nil {
if err, ok := err.(*errProtocol); ok {
@ -306,6 +344,7 @@ type conn struct {
rd *reader
addr string
ctx interface{}
hj bool
}
func (c *conn) Close() error {
@ -345,6 +384,51 @@ func (c *conn) RemoteAddr() string {
}
func (c *conn) SetReadBuffer(bytes int) {
}
func (c *conn) Hijack() HijackedConn {
c.hj = true
return &hijackedConn{conn: c}
}
type hijackedConn struct {
*conn
cmds [][][]byte
}
func (hjc *hijackedConn) Flush() error {
return hjc.conn.wr.Flush()
}
func (hjc *hijackedConn) ReadCommandBytes() ([][]byte, error) {
if len(hjc.cmds) > 0 {
args := hjc.cmds[0]
hjc.cmds = hjc.cmds[1:]
for i, arg := range args {
nb := make([]byte, len(arg))
copy(nb, arg)
args[i] = nb
}
return args, nil
}
cmds, err := hjc.rd.ReadCommands()
if err != nil {
return nil, err
}
hjc.cmds = cmds
return hjc.ReadCommandBytes()
}
func (hjc *hijackedConn) ReadCommand() ([]string, error) {
if len(hjc.cmds) > 0 {
args := hjc.cmds[0]
hjc.cmds = hjc.cmds[1:]
nargs := make([]string, len(args))
for i, arg := range args {
nargs[i] = string(arg)
}
return nargs, nil
}
return hjc.ReadCommand()
}
// Reader represents a RESP command reader.
type reader struct {

View File

@ -200,6 +200,12 @@ func TestRandomCommands(t *testing.T) {
fmt.Printf("%d commands in %s - %.0f ops/sec\n", cnt, dur, float64(cnt)/(float64(dur)/float64(time.Second)))
}
}
func testHijack(t *testing.T, conn HijackedConn) {
conn.WriteString("HIJACKED")
if err := conn.Flush(); err != nil {
t.Fatal(err)
}
}
func TestServer(t *testing.T) {
s := NewServer(":12345",
@ -213,6 +219,8 @@ func TestServer(t *testing.T) {
case "quit":
conn.WriteString("OK")
conn.Close()
case "hijack":
go testHijack(t, conn.Hijack())
case "int":
conn.WriteInt(100)
case "bulk":
@ -257,7 +265,16 @@ func TestServer(t *testing.T) {
t.Fatalf("expected an error")
}
}()
done := make(chan bool)
signal := make(chan error)
go func() {
defer func() {
done <- true
}()
err := <-signal
if err != nil {
t.Fatal(err)
}
c, err := net.Dial("tcp", ":12345")
if err != nil {
t.Fatal(err)
@ -321,16 +338,19 @@ func TestServer(t *testing.T) {
if res != "-ERR error\r\n" {
t.Fatal("expecting array, got '%v'", res)
}
res, err = do("HIJACK\r\n")
if err != nil {
t.Fatal(err)
}
if res != "+HIJACKED\r\n" {
t.Fatal("expecting string, got '%v'", res)
}
}()
signal := make(chan error)
go func() {
err := s.ListenServeAndSignal(signal)
if err != nil {
t.Fatal(err)
}
}()
err := <-signal
if err != nil {
t.Fatal(err)
}
<-done
}