mirror of https://github.com/tidwall/redcon.git
added Hijack func for detached connections
This commit is contained in:
parent
8d2bcf15cc
commit
08e1ceff58
|
@ -20,6 +20,15 @@ func main() {
|
||||||
switch strings.ToLower(args[0]) {
|
switch strings.ToLower(args[0]) {
|
||||||
default:
|
default:
|
||||||
conn.WriteError("ERR unknown command '" + args[0] + "'")
|
conn.WriteError("ERR unknown command '" + args[0] + "'")
|
||||||
|
case "hijack":
|
||||||
|
hconn := conn.Hijack()
|
||||||
|
log.Printf("connection is hijacked")
|
||||||
|
go func() {
|
||||||
|
defer hconn.Close()
|
||||||
|
hconn.WriteString("OK")
|
||||||
|
hconn.Flush()
|
||||||
|
}()
|
||||||
|
return
|
||||||
case "ping":
|
case "ping":
|
||||||
conn.WriteString("PONG")
|
conn.WriteString("PONG")
|
||||||
case "quit":
|
case "quit":
|
||||||
|
|
98
redcon.go
98
redcon.go
|
@ -41,12 +41,43 @@ type Conn interface {
|
||||||
Context() interface{}
|
Context() interface{}
|
||||||
// SetContext sets a user-defined context
|
// SetContext sets a user-defined context
|
||||||
SetContext(v interface{})
|
SetContext(v interface{})
|
||||||
|
// Hijack return an unmanaged connection. Useful for operations like PubSub.
|
||||||
|
//
|
||||||
|
// hconn := conn.Hijack()
|
||||||
|
// go func(){
|
||||||
|
// defer hconn.Close()
|
||||||
|
// cmd, err := hconn.ReadCommand()
|
||||||
|
// if err != nil{
|
||||||
|
// fmt.Printf("read failed: %v\n", err)
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
// fmt.Printf("received command: %v", cmd)
|
||||||
|
// hconn.WriteString("OK")
|
||||||
|
// if err := hconn.Flush(); err != nil{
|
||||||
|
// fmt.Printf("write failed: %v\n", err)
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
// }()
|
||||||
|
Hijack() HijackedConn
|
||||||
|
}
|
||||||
|
|
||||||
|
// HijackConn represents an unmanaged connection.
|
||||||
|
type HijackedConn interface {
|
||||||
|
// Conn is the original connection
|
||||||
|
Conn
|
||||||
|
// ReadCommand reads the next client command.
|
||||||
|
ReadCommand() ([]string, error)
|
||||||
|
// ReadCommandBytes reads the next client command as bytes.
|
||||||
|
ReadCommandBytes() ([][]byte, error)
|
||||||
|
// Flush flushes any writes to the network.
|
||||||
|
Flush() error
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errUnbalancedQuotes = &errProtocol{"unbalanced quotes in request"}
|
errUnbalancedQuotes = &errProtocol{"unbalanced quotes in request"}
|
||||||
errInvalidBulkLength = &errProtocol{"invalid bulk length"}
|
errInvalidBulkLength = &errProtocol{"invalid bulk length"}
|
||||||
errInvalidMultiBulkLength = &errProtocol{"invalid multibulk length"}
|
errInvalidMultiBulkLength = &errProtocol{"invalid multibulk length"}
|
||||||
|
errHijacked = errors.New("hijacked")
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -174,7 +205,7 @@ func (s *Server) ListenServeAndSignal(signal chan error) error {
|
||||||
newWriter(tcpc),
|
newWriter(tcpc),
|
||||||
newReader(tcpc, nil),
|
newReader(tcpc, nil),
|
||||||
tcpc.RemoteAddr().String(),
|
tcpc.RemoteAddr().String(),
|
||||||
nil,
|
nil, false,
|
||||||
}
|
}
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
if len(s.rdpool) > 0 {
|
if len(s.rdpool) > 0 {
|
||||||
|
@ -233,7 +264,9 @@ func handle(
|
||||||
closed func(conn Conn, err error)) {
|
closed func(conn Conn, err error)) {
|
||||||
var err error
|
var err error
|
||||||
defer func() {
|
defer func() {
|
||||||
c.conn.Close()
|
if err != errHijacked {
|
||||||
|
c.conn.Close()
|
||||||
|
}
|
||||||
func() {
|
func() {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
@ -244,16 +277,21 @@ func handle(
|
||||||
}
|
}
|
||||||
closed(c, err)
|
closed(c, err)
|
||||||
}
|
}
|
||||||
if len(s.rdpool) < defaultPoolSize && len(c.rd.buf) < defaultBufLen {
|
if err != errHijacked {
|
||||||
s.rdpool = append(s.rdpool, c.rd.buf)
|
if len(s.rdpool) < defaultPoolSize && len(c.rd.buf) < defaultBufLen {
|
||||||
}
|
s.rdpool = append(s.rdpool, c.rd.buf)
|
||||||
if len(s.wrpool) < defaultPoolSize && cap(c.wr.b) < defaultBufLen {
|
}
|
||||||
s.wrpool = append(s.wrpool, c.wr.b[:0])
|
if len(s.wrpool) < defaultPoolSize && cap(c.wr.b) < defaultBufLen {
|
||||||
|
s.wrpool = append(s.wrpool, c.wr.b[:0])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}()
|
}()
|
||||||
err = func() error {
|
err = func() error {
|
||||||
for {
|
for {
|
||||||
|
if c.hj {
|
||||||
|
return errHijacked
|
||||||
|
}
|
||||||
cmds, err := c.rd.ReadCommands()
|
cmds, err := c.rd.ReadCommands()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err, ok := err.(*errProtocol); ok {
|
if err, ok := err.(*errProtocol); ok {
|
||||||
|
@ -306,6 +344,7 @@ type conn struct {
|
||||||
rd *reader
|
rd *reader
|
||||||
addr string
|
addr string
|
||||||
ctx interface{}
|
ctx interface{}
|
||||||
|
hj bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) Close() error {
|
func (c *conn) Close() error {
|
||||||
|
@ -345,6 +384,51 @@ func (c *conn) RemoteAddr() string {
|
||||||
}
|
}
|
||||||
func (c *conn) SetReadBuffer(bytes int) {
|
func (c *conn) SetReadBuffer(bytes int) {
|
||||||
}
|
}
|
||||||
|
func (c *conn) Hijack() HijackedConn {
|
||||||
|
c.hj = true
|
||||||
|
return &hijackedConn{conn: c}
|
||||||
|
}
|
||||||
|
|
||||||
|
type hijackedConn struct {
|
||||||
|
*conn
|
||||||
|
cmds [][][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hjc *hijackedConn) Flush() error {
|
||||||
|
return hjc.conn.wr.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hjc *hijackedConn) ReadCommandBytes() ([][]byte, error) {
|
||||||
|
if len(hjc.cmds) > 0 {
|
||||||
|
args := hjc.cmds[0]
|
||||||
|
hjc.cmds = hjc.cmds[1:]
|
||||||
|
for i, arg := range args {
|
||||||
|
nb := make([]byte, len(arg))
|
||||||
|
copy(nb, arg)
|
||||||
|
args[i] = nb
|
||||||
|
}
|
||||||
|
return args, nil
|
||||||
|
}
|
||||||
|
cmds, err := hjc.rd.ReadCommands()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
hjc.cmds = cmds
|
||||||
|
return hjc.ReadCommandBytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hjc *hijackedConn) ReadCommand() ([]string, error) {
|
||||||
|
if len(hjc.cmds) > 0 {
|
||||||
|
args := hjc.cmds[0]
|
||||||
|
hjc.cmds = hjc.cmds[1:]
|
||||||
|
nargs := make([]string, len(args))
|
||||||
|
for i, arg := range args {
|
||||||
|
nargs[i] = string(arg)
|
||||||
|
}
|
||||||
|
return nargs, nil
|
||||||
|
}
|
||||||
|
return hjc.ReadCommand()
|
||||||
|
}
|
||||||
|
|
||||||
// Reader represents a RESP command reader.
|
// Reader represents a RESP command reader.
|
||||||
type reader struct {
|
type reader struct {
|
||||||
|
|
|
@ -200,6 +200,12 @@ func TestRandomCommands(t *testing.T) {
|
||||||
fmt.Printf("%d commands in %s - %.0f ops/sec\n", cnt, dur, float64(cnt)/(float64(dur)/float64(time.Second)))
|
fmt.Printf("%d commands in %s - %.0f ops/sec\n", cnt, dur, float64(cnt)/(float64(dur)/float64(time.Second)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
func testHijack(t *testing.T, conn HijackedConn) {
|
||||||
|
conn.WriteString("HIJACKED")
|
||||||
|
if err := conn.Flush(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestServer(t *testing.T) {
|
func TestServer(t *testing.T) {
|
||||||
s := NewServer(":12345",
|
s := NewServer(":12345",
|
||||||
|
@ -213,6 +219,8 @@ func TestServer(t *testing.T) {
|
||||||
case "quit":
|
case "quit":
|
||||||
conn.WriteString("OK")
|
conn.WriteString("OK")
|
||||||
conn.Close()
|
conn.Close()
|
||||||
|
case "hijack":
|
||||||
|
go testHijack(t, conn.Hijack())
|
||||||
case "int":
|
case "int":
|
||||||
conn.WriteInt(100)
|
conn.WriteInt(100)
|
||||||
case "bulk":
|
case "bulk":
|
||||||
|
@ -257,7 +265,16 @@ func TestServer(t *testing.T) {
|
||||||
t.Fatalf("expected an error")
|
t.Fatalf("expected an error")
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
done := make(chan bool)
|
||||||
|
signal := make(chan error)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
done <- true
|
||||||
|
}()
|
||||||
|
err := <-signal
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
c, err := net.Dial("tcp", ":12345")
|
c, err := net.Dial("tcp", ":12345")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -321,16 +338,19 @@ func TestServer(t *testing.T) {
|
||||||
if res != "-ERR error\r\n" {
|
if res != "-ERR error\r\n" {
|
||||||
t.Fatal("expecting array, got '%v'", res)
|
t.Fatal("expecting array, got '%v'", res)
|
||||||
}
|
}
|
||||||
|
res, err = do("HIJACK\r\n")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if res != "+HIJACKED\r\n" {
|
||||||
|
t.Fatal("expecting string, got '%v'", res)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
signal := make(chan error)
|
|
||||||
go func() {
|
go func() {
|
||||||
err := s.ListenServeAndSignal(signal)
|
err := s.ListenServeAndSignal(signal)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
err := <-signal
|
<-done
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue