Skip to main content

styrene_mqtt/
client.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use rumqttc::v5::mqttbytes::QoS;
5use rumqttc::v5::{AsyncClient, MqttOptions};
6use serde::de::DeserializeOwned;
7use serde::Serialize;
8use tokio::sync::mpsc;
9use tokio::task::JoinHandle;
10
11use crate::envelope::{
12    decode_payload, decode_user_properties, encode_payload, encode_user_properties, Envelope,
13    Message, Metadata,
14};
15use crate::error::{MqttError, Result};
16use crate::qos::QosOverride;
17use crate::stream::{RawMessage, Subscription};
18use crate::topic::{TopicAddress, TopicBuilder};
19
20/// Identity of this client on the Aether fabric.
21#[derive(Debug, Clone)]
22pub struct ServiceIdentity {
23    pub operator_id: String,
24    pub service: String,
25    pub instance_id: String,
26}
27
28/// How to connect to the MQTT broker.
29pub enum ConnectionTarget {
30    /// Connect via an in-process rumqttd link (Tier 1).
31    /// Requires the `embedded-broker` feature.
32    #[cfg(feature = "embedded-broker")]
33    InProcess { link: crate::broker::BrokerLink },
34    /// Connect to a remote MQTT 5.0 broker via TCP.
35    Remote { host: String, port: u16 },
36}
37
38/// Configuration for creating a [`Client`].
39pub struct ClientConfig {
40    pub identity: ServiceIdentity,
41    pub target: ConnectionTarget,
42    /// MQTT client ID. Defaults to `"{service}-{instance_id}"`.
43    pub client_id: Option<String>,
44    /// Channel capacity for the internal event loop. Default: 128.
45    pub channel_capacity: usize,
46    /// Keep-alive interval (remote only). Default: 30s.
47    pub keep_alive: Duration,
48}
49
50impl ClientConfig {
51    pub fn new(identity: ServiceIdentity, target: ConnectionTarget) -> Self {
52        Self {
53            identity,
54            target,
55            client_id: None,
56            channel_capacity: 128,
57            keep_alive: Duration::from_secs(30),
58        }
59    }
60}
61
62/// High-level Aether MQTT 5.0 client.
63///
64/// Publishes typed events and subscribes to topic patterns with automatic
65/// deserialization. Supports both in-process (embedded broker) and remote
66/// (TCP) connections.
67pub struct Client {
68    identity: ServiceIdentity,
69    inner: ClientInner,
70    raw_subscribers: Arc<tokio::sync::Mutex<Vec<FilteredSubscriber>>>,
71}
72
73/// A subscriber with its topic filter for fan-out matching.
74struct FilteredSubscriber {
75    filter: String,
76    tx: mpsc::Sender<RawMessage>,
77}
78
79enum ClientInner {
80    #[cfg(feature = "embedded-broker")]
81    InProcess {
82        link_tx: Arc<tokio::sync::Mutex<rumqttd::local::LinkTx>>,
83        _recv_task: JoinHandle<()>,
84    },
85    Remote {
86        mqtt: AsyncClient,
87        _event_loop: JoinHandle<()>,
88    },
89}
90
91impl Client {
92    /// Connect to the broker and start the internal event loop.
93    pub async fn connect(config: ClientConfig) -> Result<Self> {
94        let raw_subscribers: Arc<tokio::sync::Mutex<Vec<FilteredSubscriber>>> =
95            Arc::new(tokio::sync::Mutex::new(Vec::new()));
96
97        let inner = match config.target {
98            #[cfg(feature = "embedded-broker")]
99            ConnectionTarget::InProcess { link } => {
100                Self::connect_in_process(link, raw_subscribers.clone())
101            }
102            ConnectionTarget::Remote { host, port } => {
103                Self::connect_remote(
104                    &config.identity,
105                    config.client_id.as_deref(),
106                    &host,
107                    port,
108                    config.channel_capacity,
109                    config.keep_alive,
110                    raw_subscribers.clone(),
111                )
112                .await?
113            }
114        };
115
116        Ok(Self { identity: config.identity, inner, raw_subscribers })
117    }
118
119    /// Publish a typed event.
120    ///
121    /// Topic is built from the client's identity + `event_type`.
122    /// QoS is determined by policy unless overridden.
123    pub async fn publish<T: Serialize>(
124        &self,
125        event_type: &str,
126        payload: &T,
127        qos_override: QosOverride,
128    ) -> Result<()> {
129        let topic = TopicBuilder::new()
130            .operator(&self.identity.operator_id)
131            .service(&self.identity.service)
132            .instance(&self.identity.instance_id)
133            .event_type(event_type)
134            .build_publish()?;
135
136        self.publish_inner(&topic, event_type, payload, qos_override, false).await
137    }
138
139    /// Publish to an explicit topic address (for relay/proxy scenarios).
140    pub async fn publish_to<T: Serialize>(
141        &self,
142        address: &TopicAddress,
143        payload: &T,
144        qos_override: QosOverride,
145    ) -> Result<()> {
146        let topic = address.to_topic_string();
147        self.publish_inner(&topic, &address.event_type, payload, qos_override, false).await
148    }
149
150    /// Publish a retained message (late-join state snapshot).
151    pub async fn publish_retained<T: Serialize>(
152        &self,
153        event_type: &str,
154        payload: &T,
155    ) -> Result<()> {
156        let topic = TopicBuilder::new()
157            .operator(&self.identity.operator_id)
158            .service(&self.identity.service)
159            .instance(&self.identity.instance_id)
160            .event_type(event_type)
161            .build_publish()?;
162
163        self.publish_inner(&topic, event_type, payload, QosOverride::Force(QoS::AtLeastOnce), true)
164            .await
165    }
166
167    /// Subscribe to a topic filter and return a typed stream.
168    ///
169    /// Messages that fail deserialization are logged and skipped.
170    pub async fn subscribe<T: DeserializeOwned + Send + 'static>(
171        &self,
172        filter: &str,
173        qos: QoS,
174    ) -> Result<Subscription<T>> {
175        self.subscribe_mqtt(filter, qos).await?;
176
177        let (raw_tx, mut raw_rx) = mpsc::channel::<RawMessage>(256);
178        let (typed_tx, typed_rx) = mpsc::channel::<Result<Message<T>>>(256);
179
180        {
181            let mut guard = self.raw_subscribers.lock().await;
182            guard.push(FilteredSubscriber { filter: filter.to_string(), tx: raw_tx });
183        }
184
185        tokio::spawn(async move {
186            while let Some(raw) = raw_rx.recv().await {
187                let address = match TopicAddress::parse(&raw.topic) {
188                    Ok(a) => a,
189                    Err(e) => {
190                        tracing::debug!("skipping message on `{}`: {e}", raw.topic);
191                        continue;
192                    }
193                };
194
195                let meta = match decode_user_properties(&raw.user_properties) {
196                    Ok(m) => m,
197                    Err(e) => {
198                        tracing::warn!("metadata decode failed on `{}`: {e}", raw.topic);
199                        continue;
200                    }
201                };
202
203                let payload: T = match decode_payload(&raw.payload, &raw.topic) {
204                    Ok(p) => p,
205                    Err(e) => {
206                        tracing::warn!("payload decode failed on `{}`: {e}", raw.topic);
207                        continue;
208                    }
209                };
210
211                let msg = Message {
212                    envelope: Envelope { meta, payload },
213                    address,
214                    qos: raw.qos,
215                    retained: raw.retained,
216                };
217
218                if typed_tx.send(Ok(msg)).await.is_err() {
219                    break;
220                }
221            }
222        });
223
224        Ok(Subscription::new(typed_rx))
225    }
226
227    /// Subscribe with raw MQTT messages (no deserialization).
228    pub async fn subscribe_raw(
229        &self,
230        filter: &str,
231        qos: QoS,
232    ) -> Result<mpsc::Receiver<RawMessage>> {
233        self.subscribe_mqtt(filter, qos).await?;
234
235        let (tx, rx) = mpsc::channel::<RawMessage>(256);
236        {
237            let mut guard = self.raw_subscribers.lock().await;
238            guard.push(FilteredSubscriber { filter: filter.to_string(), tx });
239        }
240        Ok(rx)
241    }
242
243    /// Disconnect from the broker.
244    pub async fn disconnect(self) -> Result<()> {
245        match self.inner {
246            #[cfg(feature = "embedded-broker")]
247            ClientInner::InProcess { _recv_task, .. } => {
248                _recv_task.abort();
249            }
250            ClientInner::Remote { mqtt, _event_loop } => {
251                let _ = mqtt.disconnect().await;
252                _event_loop.abort();
253            }
254        }
255        Ok(())
256    }
257
258    // ── Connection constructors ─────────────────────────────────────────
259
260    #[cfg(feature = "embedded-broker")]
261    fn connect_in_process(
262        link: crate::broker::BrokerLink,
263        subs: Arc<tokio::sync::Mutex<Vec<FilteredSubscriber>>>,
264    ) -> ClientInner {
265        let link_tx = Arc::new(tokio::sync::Mutex::new(link.tx));
266        let mut link_rx = link.rx;
267
268        let recv_task = tokio::spawn(async move {
269            loop {
270                match link_rx.next().await {
271                    Ok(Some(notification)) => {
272                        if let rumqttd::Notification::Forward(fwd) = notification {
273                            let raw = RawMessage {
274                                topic: String::from_utf8_lossy(&fwd.publish.topic).to_string(),
275                                payload: fwd.publish.payload.to_vec(),
276                                qos: 0, // In-process links are QoS 0
277                                retained: fwd.publish.retain,
278                                user_properties: fwd
279                                    .properties
280                                    .as_ref()
281                                    .map(|p| {
282                                        p.user_properties
283                                            .iter()
284                                            .map(|(k, v)| (k.clone(), v.clone()))
285                                            .collect()
286                                    })
287                                    .unwrap_or_default(),
288                            };
289
290                            let guard = subs.lock().await;
291                            for sub in guard.iter() {
292                                if topic_matches_filter(&raw.topic, &sub.filter) {
293                                    let _ = sub.tx.try_send(raw.clone());
294                                }
295                            }
296                        }
297                    }
298                    Ok(None) => {}
299                    Err(e) => {
300                        tracing::warn!("in-process link recv error: {e}");
301                        tokio::time::sleep(Duration::from_millis(10)).await;
302                    }
303                }
304            }
305        });
306
307        ClientInner::InProcess { link_tx, _recv_task: recv_task }
308    }
309
310    async fn connect_remote(
311        identity: &ServiceIdentity,
312        client_id: Option<&str>,
313        host: &str,
314        port: u16,
315        channel_capacity: usize,
316        keep_alive: Duration,
317        subs: Arc<tokio::sync::Mutex<Vec<FilteredSubscriber>>>,
318    ) -> Result<ClientInner> {
319        let id = client_id
320            .map(str::to_owned)
321            .unwrap_or_else(|| format!("{}-{}", identity.service, identity.instance_id));
322
323        let mut opts = MqttOptions::new(&id, host, port);
324        opts.set_keep_alive(keep_alive);
325
326        let (mqtt, mut event_loop) = AsyncClient::new(opts, channel_capacity);
327
328        let loop_handle = tokio::spawn(async move {
329            use rumqttc::v5::mqttbytes::v5::Packet;
330            loop {
331                match event_loop.poll().await {
332                    Ok(rumqttc::v5::Event::Incoming(Packet::Publish(publish))) => {
333                        let raw = RawMessage {
334                            topic: String::from_utf8_lossy(&publish.topic).to_string(),
335                            payload: publish.payload.to_vec(),
336                            qos: match publish.qos {
337                                QoS::AtMostOnce => 0,
338                                QoS::AtLeastOnce => 1,
339                                QoS::ExactlyOnce => 2,
340                            },
341                            retained: publish.retain,
342                            user_properties: publish
343                                .properties
344                                .as_ref()
345                                .map(|p| {
346                                    p.user_properties
347                                        .iter()
348                                        .map(|(k, v)| (k.clone(), v.clone()))
349                                        .collect()
350                                })
351                                .unwrap_or_default(),
352                        };
353
354                        let guard = subs.lock().await;
355                        for sub in guard.iter() {
356                            if topic_matches_filter(&raw.topic, &sub.filter) {
357                                let _ = sub.tx.try_send(raw.clone());
358                            }
359                        }
360                    }
361                    Ok(_) => {}
362                    Err(e) => {
363                        tracing::warn!("MQTT event loop error: {e}");
364                        tokio::time::sleep(Duration::from_secs(1)).await;
365                    }
366                }
367            }
368        });
369
370        Ok(ClientInner::Remote { mqtt, _event_loop: loop_handle })
371    }
372
373    // ── Internal helpers ────────────────────────────────────────────────
374
375    fn build_metadata(&self) -> Metadata {
376        Metadata {
377            timestamp: chrono::Utc::now(),
378            source_service: self.identity.service.clone(),
379            source_instance: self.identity.instance_id.clone(),
380            operator_id: self.identity.operator_id.clone(),
381            schema_version: 1,
382            correlation_id: None,
383        }
384    }
385
386    async fn publish_inner<T: Serialize>(
387        &self,
388        topic: &str,
389        event_type: &str,
390        payload: &T,
391        qos_override: QosOverride,
392        retain: bool,
393    ) -> Result<()> {
394        let bytes = encode_payload(payload)?;
395        let meta = self.build_metadata();
396        let user_props = encode_user_properties(&meta);
397
398        match &self.inner {
399            #[cfg(feature = "embedded-broker")]
400            ClientInner::InProcess { link_tx, .. } => {
401                use bytes::Bytes;
402                use rumqttd::protocol::{Packet, Publish, PublishProperties};
403
404                let publish = Publish::new(
405                    Bytes::copy_from_slice(topic.as_bytes()),
406                    Bytes::from(bytes),
407                    retain,
408                );
409                let properties =
410                    PublishProperties { user_properties: user_props, ..Default::default() };
411
412                let mut tx = link_tx.lock().await;
413                tx.send(Packet::Publish(publish, Some(properties)))
414                    .await
415                    .map_err(|e| MqttError::Publish(e.to_string()))?;
416                Ok(())
417            }
418            ClientInner::Remote { mqtt, .. } => {
419                let qos = qos_override.resolve(event_type);
420                let properties = rumqttc::v5::mqttbytes::v5::PublishProperties {
421                    user_properties: user_props,
422                    ..Default::default()
423                };
424                mqtt.publish_with_properties(topic, qos, retain, bytes, properties)
425                    .await
426                    .map_err(|e| MqttError::Publish(e.to_string()))?;
427                Ok(())
428            }
429        }
430    }
431
432    async fn subscribe_mqtt(&self, filter: &str, qos: QoS) -> Result<()> {
433        match &self.inner {
434            #[cfg(feature = "embedded-broker")]
435            ClientInner::InProcess { link_tx, .. } => {
436                let mut tx = link_tx.lock().await;
437                tx.subscribe(filter).map_err(|e| MqttError::Subscribe(e.to_string()))?;
438                Ok(())
439            }
440            ClientInner::Remote { mqtt, .. } => {
441                mqtt.subscribe(filter, qos)
442                    .await
443                    .map_err(|e| MqttError::Subscribe(e.to_string()))?;
444                Ok(())
445            }
446        }
447    }
448}
449
450/// MQTT topic filter matching (MQTT 5.0 §4.7).
451///
452/// - `+` matches exactly one topic level
453/// - `#` matches zero or more trailing levels (must be last segment)
454fn topic_matches_filter(topic: &str, filter: &str) -> bool {
455    let mut topic_parts = topic.split('/');
456    let mut filter_parts = filter.split('/').peekable();
457
458    loop {
459        match (filter_parts.next(), topic_parts.next()) {
460            (Some("#"), _) => return true,
461            (Some("+"), Some(_)) => continue,
462            (Some(f), Some(t)) if f == t => continue,
463            (None, None) => return true,
464            _ => return false,
465        }
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use super::*;
472
473    #[test]
474    fn topic_filter_exact_match() {
475        assert!(topic_matches_filter("a/b/c", "a/b/c"));
476        assert!(!topic_matches_filter("a/b/c", "a/b/d"));
477    }
478
479    #[test]
480    fn topic_filter_single_level_wildcard() {
481        assert!(topic_matches_filter("a/b/c", "a/+/c"));
482        assert!(topic_matches_filter("a/x/c", "a/+/c"));
483        assert!(!topic_matches_filter("a/b/c/d", "a/+/c"));
484    }
485
486    #[test]
487    fn topic_filter_multi_level_wildcard() {
488        assert!(topic_matches_filter("a/b/c", "a/#"));
489        assert!(topic_matches_filter("a/b/c/d", "a/#"));
490        assert!(topic_matches_filter("a", "a/#"));
491        assert!(!topic_matches_filter("b/c", "a/#"));
492    }
493
494    #[test]
495    fn topic_filter_combined_wildcards() {
496        assert!(topic_matches_filter(
497            "styrene/op1/omegon/inst-a/events/turn.started",
498            "styrene/op1/omegon/+/events/#"
499        ));
500        assert!(!topic_matches_filter(
501            "styrene/op1/viz/inst-a/events/turn.started",
502            "styrene/op1/omegon/+/events/#"
503        ));
504    }
505
506    #[test]
507    fn topic_filter_length_mismatch() {
508        assert!(!topic_matches_filter("a/b", "a/b/c"));
509        assert!(!topic_matches_filter("a/b/c", "a/b"));
510    }
511}