diff --git a/pkg/endpoint/endpoint.go b/pkg/endpoint/endpoint.go index 6c60ec65..238a558b 100644 --- a/pkg/endpoint/endpoint.go +++ b/pkg/endpoint/endpoint.go @@ -76,11 +76,14 @@ type Endpoint struct { DeliveryMode uint8 } MQTT struct { - Host string - Port int - QueueName string - Qos byte - Retained bool + Host string + Port int + QueueName string + Qos byte + Retained bool + CACertFile string + CertFile string + KeyFile string } SQS struct { @@ -406,6 +409,12 @@ func parseEndpoint(s string) (Endpoint, error) { if n == 1 { endpoint.MQTT.Retained = true } + case "cacert": + endpoint.MQTT.CACertFile = val[0] + case "cert": + endpoint.MQTT.CertFile = val[0] + case "key": + endpoint.MQTT.KeyFile = val[0] } } } @@ -469,7 +478,7 @@ func parseEndpoint(s string) (Endpoint, error) { // Basic AMQP connection strings in HOOKS interface // amqp://guest:guest@localhost:5672//?params=value - // or amqp://guest:guest@localhost:5672///?params=value + // or amqp://guest:guest@localhost:5672///?params=value // // Default params are: // @@ -487,15 +496,15 @@ func parseEndpoint(s string) (Endpoint, error) { endpoint.AMQP.Durable = true endpoint.AMQP.DeliveryMode = amqp.Transient - // Fix incase of namespace, e.g. example.com/namespace/queue - // but not example.com/queue/ - with an endslash. - if len(sp) > 2 && len(sp[2]) > 0 { - endpoint.AMQP.URI = endpoint.AMQP.URI + "/" + sp[1] - sp = append([]string{endpoint.AMQP.URI}, sp[2:]...) - } - - // Bind queue name with no namespace - if len(sp) > 1 { + // Fix incase of namespace, e.g. example.com/namespace/queue + // but not example.com/queue/ - with an endslash. + if len(sp) > 2 && len(sp[2]) > 0 { + endpoint.AMQP.URI = endpoint.AMQP.URI + "/" + sp[1] + sp = append([]string{endpoint.AMQP.URI}, sp[2:]...) + } + + // Bind queue name with no namespace + if len(sp) > 1 { var err error endpoint.AMQP.QueueName, err = url.QueryUnescape(sp[1]) if err != nil { diff --git a/pkg/endpoint/mqtt.go b/pkg/endpoint/mqtt.go index 1dab8beb..eea4682f 100644 --- a/pkg/endpoint/mqtt.go +++ b/pkg/endpoint/mqtt.go @@ -1,7 +1,10 @@ package endpoint import ( + "crypto/tls" + "crypto/x509" "fmt" + "io/ioutil" "sync" "time" @@ -56,7 +59,31 @@ func (conn *MQTTConn) Send(msg string) error { if conn.conn == nil { uri := fmt.Sprintf("tcp://%s:%d", conn.ep.MQTT.Host, conn.ep.MQTT.Port) - ops := paho.NewClientOptions().SetClientID("tile38").AddBroker(uri) + ops := paho.NewClientOptions() + if conn.ep.MQTT.CertFile != "" || conn.ep.MQTT.KeyFile != "" || + conn.ep.MQTT.CACertFile != "" { + var config tls.Config + if conn.ep.MQTT.CertFile != "" || conn.ep.MQTT.KeyFile != "" { + cert, err := tls.LoadX509KeyPair(conn.ep.MQTT.CertFile, + conn.ep.MQTT.KeyFile) + if err != nil { + return err + } + config.Certificates = append(config.Certificates, cert) + } + if conn.ep.MQTT.CACertFile != "" { + // Load CA cert + caCert, err := ioutil.ReadFile(conn.ep.MQTT.CACertFile) + if err != nil { + return err + } + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + config.RootCAs = caCertPool + } + ops = ops.SetTLSConfig(&config) + } + ops = ops.SetClientID("tile38").AddBroker(uri) c := paho.NewClient(ops) if token := c.Connect(); token.Wait() && token.Error() != nil { @@ -66,7 +93,8 @@ func (conn *MQTTConn) Send(msg string) error { conn.conn = c } - t := conn.conn.Publish(conn.ep.MQTT.QueueName, conn.ep.MQTT.Qos, conn.ep.MQTT.Retained, msg) + t := conn.conn.Publish(conn.ep.MQTT.QueueName, conn.ep.MQTT.Qos, + conn.ep.MQTT.Retained, msg) t.Wait() if t.Error() != nil {