diff --git a/example/clone.go b/example/clone.go index 0c76182..155fd99 100644 --- a/example/clone.go +++ b/example/clone.go @@ -20,6 +20,15 @@ func main() { switch strings.ToLower(args[0]) { default: conn.WriteError("ERR unknown command '" + args[0] + "'") + case "hijack": + hconn := conn.Hijack() + log.Printf("connection is hijacked") + go func() { + defer hconn.Close() + hconn.WriteString("OK") + hconn.Flush() + }() + return case "ping": conn.WriteString("PONG") case "quit": diff --git a/redcon.go b/redcon.go index e836dcf..d8cfb17 100644 --- a/redcon.go +++ b/redcon.go @@ -41,12 +41,43 @@ type Conn interface { Context() interface{} // SetContext sets a user-defined context SetContext(v interface{}) + // Hijack return an unmanaged connection. Useful for operations like PubSub. + // + // hconn := conn.Hijack() + // go func(){ + // defer hconn.Close() + // cmd, err := hconn.ReadCommand() + // if err != nil{ + // fmt.Printf("read failed: %v\n", err) + // return + // } + // fmt.Printf("received command: %v", cmd) + // hconn.WriteString("OK") + // if err := hconn.Flush(); err != nil{ + // fmt.Printf("write failed: %v\n", err) + // return + // } + // }() + Hijack() HijackedConn +} + +// HijackConn represents an unmanaged connection. +type HijackedConn interface { + // Conn is the original connection + Conn + // ReadCommand reads the next client command. + ReadCommand() ([]string, error) + // ReadCommandBytes reads the next client command as bytes. + ReadCommandBytes() ([][]byte, error) + // Flush flushes any writes to the network. + Flush() error } var ( errUnbalancedQuotes = &errProtocol{"unbalanced quotes in request"} errInvalidBulkLength = &errProtocol{"invalid bulk length"} errInvalidMultiBulkLength = &errProtocol{"invalid multibulk length"} + errHijacked = errors.New("hijacked") ) const ( @@ -174,7 +205,7 @@ func (s *Server) ListenServeAndSignal(signal chan error) error { newWriter(tcpc), newReader(tcpc, nil), tcpc.RemoteAddr().String(), - nil, + nil, false, } s.mu.Lock() if len(s.rdpool) > 0 { @@ -233,7 +264,9 @@ func handle( closed func(conn Conn, err error)) { var err error defer func() { - c.conn.Close() + if err != errHijacked { + c.conn.Close() + } func() { s.mu.Lock() defer s.mu.Unlock() @@ -244,16 +277,21 @@ func handle( } closed(c, err) } - if len(s.rdpool) < defaultPoolSize && len(c.rd.buf) < defaultBufLen { - s.rdpool = append(s.rdpool, c.rd.buf) - } - if len(s.wrpool) < defaultPoolSize && cap(c.wr.b) < defaultBufLen { - s.wrpool = append(s.wrpool, c.wr.b[:0]) + if err != errHijacked { + if len(s.rdpool) < defaultPoolSize && len(c.rd.buf) < defaultBufLen { + s.rdpool = append(s.rdpool, c.rd.buf) + } + if len(s.wrpool) < defaultPoolSize && cap(c.wr.b) < defaultBufLen { + s.wrpool = append(s.wrpool, c.wr.b[:0]) + } } }() }() err = func() error { for { + if c.hj { + return errHijacked + } cmds, err := c.rd.ReadCommands() if err != nil { if err, ok := err.(*errProtocol); ok { @@ -306,6 +344,7 @@ type conn struct { rd *reader addr string ctx interface{} + hj bool } func (c *conn) Close() error { @@ -345,6 +384,51 @@ func (c *conn) RemoteAddr() string { } func (c *conn) SetReadBuffer(bytes int) { } +func (c *conn) Hijack() HijackedConn { + c.hj = true + return &hijackedConn{conn: c} +} + +type hijackedConn struct { + *conn + cmds [][][]byte +} + +func (hjc *hijackedConn) Flush() error { + return hjc.conn.wr.Flush() +} + +func (hjc *hijackedConn) ReadCommandBytes() ([][]byte, error) { + if len(hjc.cmds) > 0 { + args := hjc.cmds[0] + hjc.cmds = hjc.cmds[1:] + for i, arg := range args { + nb := make([]byte, len(arg)) + copy(nb, arg) + args[i] = nb + } + return args, nil + } + cmds, err := hjc.rd.ReadCommands() + if err != nil { + return nil, err + } + hjc.cmds = cmds + return hjc.ReadCommandBytes() +} + +func (hjc *hijackedConn) ReadCommand() ([]string, error) { + if len(hjc.cmds) > 0 { + args := hjc.cmds[0] + hjc.cmds = hjc.cmds[1:] + nargs := make([]string, len(args)) + for i, arg := range args { + nargs[i] = string(arg) + } + return nargs, nil + } + return hjc.ReadCommand() +} // Reader represents a RESP command reader. type reader struct { diff --git a/redcon_test.go b/redcon_test.go index 11a1446..e8fcb71 100644 --- a/redcon_test.go +++ b/redcon_test.go @@ -200,6 +200,12 @@ func TestRandomCommands(t *testing.T) { fmt.Printf("%d commands in %s - %.0f ops/sec\n", cnt, dur, float64(cnt)/(float64(dur)/float64(time.Second))) } } +func testHijack(t *testing.T, conn HijackedConn) { + conn.WriteString("HIJACKED") + if err := conn.Flush(); err != nil { + t.Fatal(err) + } +} func TestServer(t *testing.T) { s := NewServer(":12345", @@ -213,6 +219,8 @@ func TestServer(t *testing.T) { case "quit": conn.WriteString("OK") conn.Close() + case "hijack": + go testHijack(t, conn.Hijack()) case "int": conn.WriteInt(100) case "bulk": @@ -257,7 +265,16 @@ func TestServer(t *testing.T) { t.Fatalf("expected an error") } }() + done := make(chan bool) + signal := make(chan error) go func() { + defer func() { + done <- true + }() + err := <-signal + if err != nil { + t.Fatal(err) + } c, err := net.Dial("tcp", ":12345") if err != nil { t.Fatal(err) @@ -321,16 +338,19 @@ func TestServer(t *testing.T) { if res != "-ERR error\r\n" { t.Fatal("expecting array, got '%v'", res) } + res, err = do("HIJACK\r\n") + if err != nil { + t.Fatal(err) + } + if res != "+HIJACKED\r\n" { + t.Fatal("expecting string, got '%v'", res) + } }() - signal := make(chan error) go func() { err := s.ListenServeAndSignal(signal) if err != nil { t.Fatal(err) } }() - err := <-signal - if err != nil { - t.Fatal(err) - } + <-done }