mirror of https://github.com/tidwall/redcon.git
added Hijack func for detached connections
This commit is contained in:
parent
8d2bcf15cc
commit
08e1ceff58
|
@ -20,6 +20,15 @@ func main() {
|
|||
switch strings.ToLower(args[0]) {
|
||||
default:
|
||||
conn.WriteError("ERR unknown command '" + args[0] + "'")
|
||||
case "hijack":
|
||||
hconn := conn.Hijack()
|
||||
log.Printf("connection is hijacked")
|
||||
go func() {
|
||||
defer hconn.Close()
|
||||
hconn.WriteString("OK")
|
||||
hconn.Flush()
|
||||
}()
|
||||
return
|
||||
case "ping":
|
||||
conn.WriteString("PONG")
|
||||
case "quit":
|
||||
|
|
98
redcon.go
98
redcon.go
|
@ -41,12 +41,43 @@ type Conn interface {
|
|||
Context() interface{}
|
||||
// SetContext sets a user-defined context
|
||||
SetContext(v interface{})
|
||||
// Hijack return an unmanaged connection. Useful for operations like PubSub.
|
||||
//
|
||||
// hconn := conn.Hijack()
|
||||
// go func(){
|
||||
// defer hconn.Close()
|
||||
// cmd, err := hconn.ReadCommand()
|
||||
// if err != nil{
|
||||
// fmt.Printf("read failed: %v\n", err)
|
||||
// return
|
||||
// }
|
||||
// fmt.Printf("received command: %v", cmd)
|
||||
// hconn.WriteString("OK")
|
||||
// if err := hconn.Flush(); err != nil{
|
||||
// fmt.Printf("write failed: %v\n", err)
|
||||
// return
|
||||
// }
|
||||
// }()
|
||||
Hijack() HijackedConn
|
||||
}
|
||||
|
||||
// HijackConn represents an unmanaged connection.
|
||||
type HijackedConn interface {
|
||||
// Conn is the original connection
|
||||
Conn
|
||||
// ReadCommand reads the next client command.
|
||||
ReadCommand() ([]string, error)
|
||||
// ReadCommandBytes reads the next client command as bytes.
|
||||
ReadCommandBytes() ([][]byte, error)
|
||||
// Flush flushes any writes to the network.
|
||||
Flush() error
|
||||
}
|
||||
|
||||
var (
|
||||
errUnbalancedQuotes = &errProtocol{"unbalanced quotes in request"}
|
||||
errInvalidBulkLength = &errProtocol{"invalid bulk length"}
|
||||
errInvalidMultiBulkLength = &errProtocol{"invalid multibulk length"}
|
||||
errHijacked = errors.New("hijacked")
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -174,7 +205,7 @@ func (s *Server) ListenServeAndSignal(signal chan error) error {
|
|||
newWriter(tcpc),
|
||||
newReader(tcpc, nil),
|
||||
tcpc.RemoteAddr().String(),
|
||||
nil,
|
||||
nil, false,
|
||||
}
|
||||
s.mu.Lock()
|
||||
if len(s.rdpool) > 0 {
|
||||
|
@ -233,7 +264,9 @@ func handle(
|
|||
closed func(conn Conn, err error)) {
|
||||
var err error
|
||||
defer func() {
|
||||
c.conn.Close()
|
||||
if err != errHijacked {
|
||||
c.conn.Close()
|
||||
}
|
||||
func() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
@ -244,16 +277,21 @@ func handle(
|
|||
}
|
||||
closed(c, err)
|
||||
}
|
||||
if len(s.rdpool) < defaultPoolSize && len(c.rd.buf) < defaultBufLen {
|
||||
s.rdpool = append(s.rdpool, c.rd.buf)
|
||||
}
|
||||
if len(s.wrpool) < defaultPoolSize && cap(c.wr.b) < defaultBufLen {
|
||||
s.wrpool = append(s.wrpool, c.wr.b[:0])
|
||||
if err != errHijacked {
|
||||
if len(s.rdpool) < defaultPoolSize && len(c.rd.buf) < defaultBufLen {
|
||||
s.rdpool = append(s.rdpool, c.rd.buf)
|
||||
}
|
||||
if len(s.wrpool) < defaultPoolSize && cap(c.wr.b) < defaultBufLen {
|
||||
s.wrpool = append(s.wrpool, c.wr.b[:0])
|
||||
}
|
||||
}
|
||||
}()
|
||||
}()
|
||||
err = func() error {
|
||||
for {
|
||||
if c.hj {
|
||||
return errHijacked
|
||||
}
|
||||
cmds, err := c.rd.ReadCommands()
|
||||
if err != nil {
|
||||
if err, ok := err.(*errProtocol); ok {
|
||||
|
@ -306,6 +344,7 @@ type conn struct {
|
|||
rd *reader
|
||||
addr string
|
||||
ctx interface{}
|
||||
hj bool
|
||||
}
|
||||
|
||||
func (c *conn) Close() error {
|
||||
|
@ -345,6 +384,51 @@ func (c *conn) RemoteAddr() string {
|
|||
}
|
||||
func (c *conn) SetReadBuffer(bytes int) {
|
||||
}
|
||||
func (c *conn) Hijack() HijackedConn {
|
||||
c.hj = true
|
||||
return &hijackedConn{conn: c}
|
||||
}
|
||||
|
||||
type hijackedConn struct {
|
||||
*conn
|
||||
cmds [][][]byte
|
||||
}
|
||||
|
||||
func (hjc *hijackedConn) Flush() error {
|
||||
return hjc.conn.wr.Flush()
|
||||
}
|
||||
|
||||
func (hjc *hijackedConn) ReadCommandBytes() ([][]byte, error) {
|
||||
if len(hjc.cmds) > 0 {
|
||||
args := hjc.cmds[0]
|
||||
hjc.cmds = hjc.cmds[1:]
|
||||
for i, arg := range args {
|
||||
nb := make([]byte, len(arg))
|
||||
copy(nb, arg)
|
||||
args[i] = nb
|
||||
}
|
||||
return args, nil
|
||||
}
|
||||
cmds, err := hjc.rd.ReadCommands()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hjc.cmds = cmds
|
||||
return hjc.ReadCommandBytes()
|
||||
}
|
||||
|
||||
func (hjc *hijackedConn) ReadCommand() ([]string, error) {
|
||||
if len(hjc.cmds) > 0 {
|
||||
args := hjc.cmds[0]
|
||||
hjc.cmds = hjc.cmds[1:]
|
||||
nargs := make([]string, len(args))
|
||||
for i, arg := range args {
|
||||
nargs[i] = string(arg)
|
||||
}
|
||||
return nargs, nil
|
||||
}
|
||||
return hjc.ReadCommand()
|
||||
}
|
||||
|
||||
// Reader represents a RESP command reader.
|
||||
type reader struct {
|
||||
|
|
|
@ -200,6 +200,12 @@ func TestRandomCommands(t *testing.T) {
|
|||
fmt.Printf("%d commands in %s - %.0f ops/sec\n", cnt, dur, float64(cnt)/(float64(dur)/float64(time.Second)))
|
||||
}
|
||||
}
|
||||
func testHijack(t *testing.T, conn HijackedConn) {
|
||||
conn.WriteString("HIJACKED")
|
||||
if err := conn.Flush(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer(t *testing.T) {
|
||||
s := NewServer(":12345",
|
||||
|
@ -213,6 +219,8 @@ func TestServer(t *testing.T) {
|
|||
case "quit":
|
||||
conn.WriteString("OK")
|
||||
conn.Close()
|
||||
case "hijack":
|
||||
go testHijack(t, conn.Hijack())
|
||||
case "int":
|
||||
conn.WriteInt(100)
|
||||
case "bulk":
|
||||
|
@ -257,7 +265,16 @@ func TestServer(t *testing.T) {
|
|||
t.Fatalf("expected an error")
|
||||
}
|
||||
}()
|
||||
done := make(chan bool)
|
||||
signal := make(chan error)
|
||||
go func() {
|
||||
defer func() {
|
||||
done <- true
|
||||
}()
|
||||
err := <-signal
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
c, err := net.Dial("tcp", ":12345")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -321,16 +338,19 @@ func TestServer(t *testing.T) {
|
|||
if res != "-ERR error\r\n" {
|
||||
t.Fatal("expecting array, got '%v'", res)
|
||||
}
|
||||
res, err = do("HIJACK\r\n")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if res != "+HIJACKED\r\n" {
|
||||
t.Fatal("expecting string, got '%v'", res)
|
||||
}
|
||||
}()
|
||||
signal := make(chan error)
|
||||
go func() {
|
||||
err := s.ListenServeAndSignal(signal)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
err := <-signal
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
<-done
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue