use std::fmt;
use std::str::FromStr;
use futures::StreamExt;
use futures_channel::mpsc::Sender;
use lazy_static::lazy_static;
use log::{debug, error, info};
use rumq_client::{
EventLoopError, MqttOptions, Notification, PacketIdentifier, Publish, Request, Suback,
Subscribe,
};
use serde_derive::{Deserialize, Serialize};
use super::*;
use crate::{AccountId, Addressable, AgentId, Authenticable, Error, SharedGroup};
#[cfg(feature = "queue-counter")]
use crate::queue_counter::QueueCounterHandle;
const DEFAULT_MQTT_REQUESTS_CHAN_SIZE: Option<usize> = Some(10_000);
lazy_static! {
static ref TOKIO: tokio::runtime::Runtime = {
let mut rt_builder = tokio::runtime::Builder::new();
rt_builder.enable_all().threaded_scheduler();
let thread_count = std::env::var("TOKIO_THREAD_COUNT").ok().map(|value| {
value
.parse::<usize>()
.expect("Error converting TOKIO_THREAD_COUNT variable into usize")
});
if let Some(value) = thread_count {
rt_builder.core_threads(value);
}
rt_builder.build().expect("Failed to start tokio runtime")
};
}
#[derive(Debug, Clone, Deserialize)]
pub struct AgentConfig {
uri: String,
clean_session: Option<bool>,
keep_alive_interval: Option<u64>,
reconnect_interval: Option<u64>,
outgoing_message_queue_size: Option<usize>,
incoming_message_queue_size: Option<usize>,
password: Option<String>,
max_message_size: Option<usize>,
#[serde(default = "default_mqtt_requests_chan_size")]
requests_channel_size: Option<usize>,
}
fn default_mqtt_requests_chan_size() -> Option<usize> {
DEFAULT_MQTT_REQUESTS_CHAN_SIZE
}
impl AgentConfig {
pub fn set_password(&mut self, value: &str) -> &mut Self {
self.password = Some(value.to_owned());
self
}
}
#[derive(Debug)]
pub struct AgentBuilder {
connection: Connection,
api_version: String,
}
impl AgentBuilder {
pub fn new(agent_id: AgentId, api_version: &str) -> Self {
Self {
connection: Connection::new(agent_id),
api_version: api_version.to_owned(),
}
}
pub fn connection_version(self, version: &str) -> Self {
let mut connection = self.connection;
connection.set_version(version);
Self { connection, ..self }
}
pub fn connection_mode(self, mode: ConnectionMode) -> Self {
let mut connection = self.connection;
connection.set_mode(mode);
Self { connection, ..self }
}
pub fn start(
self,
config: &AgentConfig,
) -> Result<(Agent, crossbeam_channel::Receiver<AgentNotification>), Error> {
self.start_with_runtime(config, TOKIO.handle().clone())
}
pub fn start_with_runtime(
self,
config: &AgentConfig,
rt_handle: tokio::runtime::Handle,
) -> Result<(Agent, crossbeam_channel::Receiver<AgentNotification>), Error> {
let options = Self::mqtt_options(&self.connection, &config)?;
let (mqtt_tx, mqtt_rx) = futures_channel::mpsc::channel::<Request>(
config
.requests_channel_size
.expect("requests_channel_size is not specified"),
);
let mut eventloop = rumq_client::eventloop(options, mqtt_rx);
let reconnect_interval = config.reconnect_interval.to_owned();
let (tx, rx) = crossbeam_channel::unbounded::<AgentNotification>();
#[cfg(feature = "queue-counter")]
let queue_counter = QueueCounterHandle::start();
#[cfg(feature = "queue-counter")]
let queue_counter_ = queue_counter.clone();
std::thread::Builder::new()
.name("svc-agent-notifications-loop".to_owned())
.spawn(move || {
let mut initial_connect = true;
loop {
let tx = tx.clone();
let connect_fut = eventloop.connect();
#[cfg(feature = "queue-counter")]
let queue_counter_ = queue_counter_.clone();
rt_handle.block_on(async {
match connect_fut.await {
Err(err) => error!("Error connecting to broker: {:?}", err),
Ok(mut stream) => {
if initial_connect {
info!("Doing initial connection");
initial_connect = false;
} else {
info!("Was connected before, reconnecting");
if let Err(e) = tx.send(AgentNotification::Reconnection) {
error!("Failed to notify about reconnection: {}", e);
}
}
while let Some(message) = stream.next().await {
if let Notification::Publish(ref content) = message {
info!("Incoming message = '{:?}'", content);
} else {
debug!("Incoming item = {:?}", message);
}
let mut msg: AgentNotification = message.into();
if let AgentNotification::Message(Ok(ref mut content), _) = msg
{
if let IncomingMessage::Request(req) = content {
let method = req.properties().method().to_owned();
req.properties_mut().set_method(&method);
}
#[cfg(feature = "queue-counter")]
queue_counter_.add_incoming_message(content);
}
if let Err(e) = tx.send(msg) {
error!("Failed to transmit message, reason = {}", e);
};
}
if let Err(e) = tx.send(AgentNotification::Disconnection) {
error!("Failed to notify about disconnection: {}", e);
}
}
}
});
match reconnect_interval {
Some(value) => std::thread::sleep(std::time::Duration::from_secs(value)),
None => break,
}
}
})
.map_err(|e| {
Error::new(&format!("Failed starting notifications loop thread, {}", e))
})?;
let agent = Agent::new(
self.connection.agent_id,
&self.api_version,
mqtt_tx,
#[cfg(feature = "queue-counter")]
queue_counter,
);
Ok((agent, rx))
}
fn mqtt_options(connection: &Connection, config: &AgentConfig) -> Result<MqttOptions, Error> {
let uri = config
.uri
.parse::<http::Uri>()
.map_err(|e| Error::new(&format!("error parsing MQTT connection URL, {}", e)))?;
let host = uri.host().ok_or_else(|| Error::new("missing MQTT host"))?;
let port = uri
.port_part()
.ok_or_else(|| Error::new("missing MQTT port"))?;
let username = format!("{}::{}", connection.version, connection.mode);
let password = config
.password
.to_owned()
.unwrap_or_else(|| String::from(""));
let mut opts = MqttOptions::new(connection.agent_id.to_string(), host, port.as_u16());
opts.set_credentials(username, password);
if let Some(value) = config.clean_session {
opts.set_clean_session(value);
}
if let Some(value) = config.keep_alive_interval {
opts.set_keep_alive(value as u16);
}
if let Some(value) = config.incoming_message_queue_size {
opts.set_notification_channel_capacity(value);
}
if let Some(value) = config.outgoing_message_queue_size {
opts.set_inflight(value);
}
if let Some(value) = config.max_message_size {
opts.set_max_packet_size(value);
};
Ok(opts)
}
}
#[derive(Clone, Debug)]
pub struct Address {
id: AgentId,
version: String,
}
impl Address {
pub fn new(id: AgentId, version: &str) -> Self {
Self {
id,
version: version.to_owned(),
}
}
pub fn id(&self) -> &AgentId {
&self.id
}
pub fn version(&self) -> &str {
&self.version
}
}
#[derive(Clone)]
pub struct Agent {
address: Address,
tx: Sender<Request>,
#[cfg(feature = "queue-counter")]
queue_counter: QueueCounterHandle,
}
impl Agent {
#[cfg(feature = "queue-counter")]
fn new(
id: AgentId,
api_version: &str,
tx: Sender<Request>,
queue_counter: QueueCounterHandle,
) -> Self {
Self {
address: Address::new(id, api_version),
tx,
queue_counter,
}
}
#[cfg(not(feature = "queue-counter"))]
fn new(id: AgentId, api_version: &str, tx: Sender<Request>) -> Self {
Self {
address: Address::new(id, api_version),
tx,
}
}
pub fn address(&self) -> &Address {
&self.address
}
pub fn id(&self) -> &AgentId {
&self.address.id()
}
pub fn publish<T: serde::Serialize>(
&mut self,
message: OutgoingMessage<T>,
) -> Result<(), Error> {
let dump = Box::new(message).into_dump(&self.address)?;
self.publish_dump(dump)
}
pub fn publish_publishable(
&mut self,
message: Box<dyn IntoPublishableMessage>,
) -> Result<(), Error> {
let dump = message.into_dump(&self.address)?;
self.publish_dump(dump)
}
fn publish_dump(&mut self, dump: PublishableMessage) -> Result<(), Error> {
#[cfg(feature = "queue-counter")]
self.queue_counter.add_outgoing_message(&dump);
let dump = match dump {
PublishableMessage::Event(dump) => dump,
PublishableMessage::Request(dump) => dump,
PublishableMessage::Response(dump) => dump,
};
info!(
"Outgoing message = '{}' sending to the topic = '{}'",
dump.payload(),
dump.topic(),
);
let publish = Publish::new(dump.topic(), dump.qos(), dump.payload());
self.tx.try_send(Request::Publish(publish)).map_err(|e| {
if e.is_full() {
error!(
"Rumq Requests channel reached maximum capacity, no space to publish, {:?}",
&e
)
}
Error::new(&format!("error publishing MQTT message, {}", &e))
})
}
pub fn subscribe<S>(
&mut self,
subscription: &S,
qos: QoS,
maybe_group: Option<&SharedGroup>,
) -> Result<(), Error>
where
S: SubscriptionTopic,
{
let mut topic = subscription.subscription_topic(self.id(), self.address.version())?;
if let Some(ref group) = maybe_group {
topic = format!("$share/{group}/{topic}", group = group, topic = topic);
};
self.tx
.try_send(Request::Subscribe(Subscribe::new(topic, qos)))
.map_err(|e| Error::new(&format!("error creating MQTT subscription, {}", e)))?;
Ok(())
}
#[cfg(feature = "queue-counter")]
pub fn get_queue_counter(&self) -> QueueCounterHandle {
self.queue_counter.clone()
}
}
#[derive(Debug, Clone)]
pub enum ConnectionMode {
Default,
Service,
Observer,
Bridge,
}
impl fmt::Display for ConnectionMode {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(
fmt,
"{}",
match self {
ConnectionMode::Default => "default",
ConnectionMode::Service => "service",
ConnectionMode::Observer => "observer",
ConnectionMode::Bridge => "bridge",
}
)
}
}
impl FromStr for ConnectionMode {
type Err = Error;
fn from_str(val: &str) -> Result<Self, Self::Err> {
match val {
"default" => Ok(ConnectionMode::Default),
"service" => Ok(ConnectionMode::Service),
"observer" => Ok(ConnectionMode::Observer),
"bridge" => Ok(ConnectionMode::Bridge),
_ => Err(Error::new(&format!(
"invalid value for the connection mode: {}",
val
))),
}
}
}
#[derive(Debug, Clone)]
pub struct Connection {
agent_id: AgentId,
version: String,
mode: ConnectionMode,
}
impl Connection {
fn new(agent_id: AgentId) -> Self {
Self {
agent_id,
version: String::from("v2"),
mode: ConnectionMode::Default,
}
}
fn set_version(&mut self, value: &str) -> &mut Self {
self.version = value.to_owned();
self
}
fn set_mode(&mut self, value: ConnectionMode) -> &mut Self {
self.mode = value;
self
}
pub fn agent_id(&self) -> &AgentId {
&self.agent_id
}
pub fn version(&self) -> &str {
&self.version
}
pub fn mode(&self) -> &ConnectionMode {
&self.mode
}
}
impl fmt::Display for Connection {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "{}/{}/{}", self.version, self.mode, self.agent_id,)
}
}
impl FromStr for Connection {
type Err = Error;
fn from_str(val: &str) -> Result<Self, Self::Err> {
match val.split('/').collect::<Vec<&str>>().as_slice() {
[version_str, mode_str, agent_id_str] => {
let version = (*version_str).to_string();
let mode = ConnectionMode::from_str(mode_str)?;
let agent_id = AgentId::from_str(agent_id_str)?;
Ok(Self {
version,
mode,
agent_id,
})
}
_ => Err(Error::new(&format!(
"invalid value for connection: {}",
val
))),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ConnectionProperties {
agent_id: AgentId,
#[serde(rename = "connection_version")]
version: String,
#[serde(rename = "connection_mode")]
mode: ConnectionMode,
}
impl ConnectionProperties {
pub(crate) fn to_connection(&self) -> Connection {
let mut connection = Connection::new(self.agent_id.clone());
connection.set_version(&self.version);
connection.set_mode(self.mode.clone());
connection
}
}
impl Authenticable for ConnectionProperties {
fn as_account_id(&self) -> &AccountId {
&self.agent_id.as_account_id()
}
}
impl Addressable for ConnectionProperties {
fn as_agent_id(&self) -> &AgentId {
&self.agent_id
}
}
#[derive(Debug)]
pub enum AgentNotification {
Message(Result<IncomingMessage<String>, String>, MessageData),
Reconnection,
Disconnection,
Puback(PacketIdentifier),
Pubrec(PacketIdentifier),
Pubcomp(PacketIdentifier),
Suback(Suback),
Unsuback(PacketIdentifier),
Abort(EventLoopError),
}
#[derive(Debug, Clone, PartialEq)]
pub struct MessageData {
pub dup: bool,
pub qos: QoS,
pub retain: bool,
pub topic: String,
pub pkid: Option<PacketIdentifier>,
}
impl From<Notification> for AgentNotification {
fn from(notification: Notification) -> Self {
match notification {
Notification::Publish(message) => {
let message_data = MessageData {
dup: message.dup,
qos: message.qos,
retain: message.retain,
topic: message.topic_name,
pkid: message.pkid,
};
let env_result =
serde_json::from_slice::<compat::IncomingEnvelope>(&message.payload)
.map_err(|err| format!("Failed to parse incoming envelope: {}", err))
.and_then(|env| match env.properties() {
compat::IncomingEnvelopeProperties::Request(_) => {
compat::into_request(env)
.map_err(|e| format!("Failed to convert into request: {}", e))
}
compat::IncomingEnvelopeProperties::Response(_) => {
compat::into_response(env)
.map_err(|e| format!("Failed to convert into response: {}", e))
}
compat::IncomingEnvelopeProperties::Event(_) => compat::into_event(env)
.map_err(|e| format!("Failed to convert into event: {}", e)),
});
Self::Message(env_result, message_data)
}
Notification::Puback(p) => Self::Puback(p),
Notification::Pubrec(p) => Self::Pubrec(p),
Notification::Pubcomp(p) => Self::Pubcomp(p),
Notification::Suback(s) => Self::Suback(s),
Notification::Unsuback(p) => Self::Unsuback(p),
Notification::Abort(err) => Self::Abort(err),
}
}
}