diff --git a/controller/endpoint/endpoint.go b/controller/endpoint/endpoint.go index 6b3d2128..3b2ed27d 100644 --- a/controller/endpoint/endpoint.go +++ b/controller/endpoint/endpoint.go @@ -16,6 +16,7 @@ const ( HTTP = EndpointProtocol("http") // HTTP Disque = EndpointProtocol("disque") // Disque GRPC = EndpointProtocol("grpc") // GRPC + Redis = EndpointProtocol("redis") // Redis ) // Endpoint represents an endpoint. @@ -34,6 +35,11 @@ type Endpoint struct { Replicate int } } + Redis struct { + Host string + Port int + Channel string + } } type EndpointConn interface { @@ -93,6 +99,8 @@ func (epc *EndpointManager) Send(endpoint, val string) error { conn = newDisqueEndpointConn(ep) case GRPC: conn = newGRPCEndpointConn(ep) + case Redis: + conn = newRedisEndpointConn(ep) } epc.conns[endpoint] = conn } @@ -114,17 +122,22 @@ func parseEndpoint(s string) (Endpoint, error) { endpoint.Protocol = Disque case strings.HasPrefix(s, "grpc:"): endpoint.Protocol = GRPC + case strings.HasPrefix(s, "redis:"): + endpoint.Protocol = Redis } + s = s[strings.Index(s, ":")+1:] if !strings.HasPrefix(s, "//") { return endpoint, errors.New("missing the two slashes") } + sqp := strings.Split(s[2:], "?") sp := strings.Split(sqp[0], "/") s = sp[0] if s == "" { return endpoint, errors.New("missing host") } + if endpoint.Protocol == GRPC { dp := strings.Split(s, ":") switch len(dp) { @@ -142,6 +155,33 @@ func parseEndpoint(s string) (Endpoint, error) { endpoint.GRPC.Port = int(n) } } + + if endpoint.Protocol == Redis { + dp := strings.Split(s, ":") + switch len(dp) { + default: + return endpoint, errors.New("invalid redis url") + case 1: + endpoint.Redis.Host = dp[0] + endpoint.Redis.Port = 6379 + case 2: + endpoint.Redis.Host = dp[0] + n, err := strconv.ParseUint(dp[1], 10, 16) + if err != nil { + return endpoint, errors.New("invalid redis url port") + } + endpoint.Redis.Port = int(n) + } + + if len(sp) > 1 { + var err error + endpoint.Redis.Channel, err = url.QueryUnescape(sp[1]) + if err != nil { + return endpoint, errors.New("invalid redis channel name") + } + } + } + if endpoint.Protocol == Disque { dp := strings.Split(s, ":") switch len(dp) { @@ -187,7 +227,7 @@ func parseEndpoint(s string) (Endpoint, error) { if endpoint.Disque.QueueName == "" { return endpoint, errors.New("missing disque queue name") } - } + return endpoint, nil } diff --git a/controller/endpoint/redis.go b/controller/endpoint/redis.go new file mode 100644 index 00000000..1343ca1d --- /dev/null +++ b/controller/endpoint/redis.go @@ -0,0 +1,105 @@ +package endpoint + +import ( + "bufio" + "errors" + "fmt" + "net" + "sync" + "time" +) + +const ( + redisExpiresAfter = time.Second * 30 +) + +type RedisEndpointConn struct { + mu sync.Mutex + ep Endpoint + ex bool + t time.Time + conn net.Conn + rd *bufio.Reader +} + +func newRedisEndpointConn(ep Endpoint) *RedisEndpointConn { + return &RedisEndpointConn{ + ep: ep, + t: time.Now(), + } +} + +func (conn *RedisEndpointConn) Expired() bool { + conn.mu.Lock() + defer conn.mu.Unlock() + if !conn.ex { + if time.Now().Sub(conn.t) > redisExpiresAfter { + if conn.conn != nil { + conn.close() + } + conn.ex = true + } + } + return conn.ex +} + +func (conn *RedisEndpointConn) close() { + if conn.conn != nil { + conn.conn.Close() + conn.conn = nil + } + conn.rd = nil +} + +func (conn *RedisEndpointConn) Send(msg string) error { + conn.mu.Lock() + defer conn.mu.Unlock() + + if conn.ex { + return errors.New("expired") + } + + conn.t = time.Now() + if conn.conn == nil { + addr := fmt.Sprintf("%s:%d", conn.ep.Redis.Host, conn.ep.Redis.Port) + var err error + conn.conn, err = net.Dial("tcp", addr) + if err != nil { + return err + } + conn.rd = bufio.NewReader(conn.conn) + } + + var args []string + args = append(args, "PUBLISH", conn.ep.Redis.Channel, msg) + cmd := buildRedisCommand(args) + + if _, err := conn.conn.Write(cmd); err != nil { + conn.close() + return err + } + + c, err := conn.rd.ReadByte() + if err != nil { + conn.close() + return err + } + + if c != ':' { + conn.close() + return errors.New("invalid redis reply") + } + + ln, err := conn.rd.ReadBytes('\n') + if err != nil { + conn.close() + return err + } + + if string(ln[0:1]) != "1" { + conn.close() + return errors.New("invalid redis reply") + } + + return nil +}