1use std::sync::Arc;
2use std::time::Duration;
3
4use rumqttc::v5::mqttbytes::QoS;
5use rumqttc::v5::{AsyncClient, MqttOptions};
6use serde::de::DeserializeOwned;
7use serde::Serialize;
8use tokio::sync::mpsc;
9use tokio::task::JoinHandle;
10
11use crate::envelope::{
12 decode_payload, decode_user_properties, encode_payload, encode_user_properties, Envelope,
13 Message, Metadata,
14};
15use crate::error::{MqttError, Result};
16use crate::qos::QosOverride;
17use crate::stream::{RawMessage, Subscription};
18use crate::topic::{TopicAddress, TopicBuilder};
19
20#[derive(Debug, Clone)]
22pub struct ServiceIdentity {
23 pub operator_id: String,
24 pub service: String,
25 pub instance_id: String,
26}
27
28pub enum ConnectionTarget {
30 #[cfg(feature = "embedded-broker")]
33 InProcess { link: crate::broker::BrokerLink },
34 Remote { host: String, port: u16 },
36}
37
38pub struct ClientConfig {
40 pub identity: ServiceIdentity,
41 pub target: ConnectionTarget,
42 pub client_id: Option<String>,
44 pub channel_capacity: usize,
46 pub keep_alive: Duration,
48}
49
50impl ClientConfig {
51 pub fn new(identity: ServiceIdentity, target: ConnectionTarget) -> Self {
52 Self {
53 identity,
54 target,
55 client_id: None,
56 channel_capacity: 128,
57 keep_alive: Duration::from_secs(30),
58 }
59 }
60}
61
62pub struct Client {
68 identity: ServiceIdentity,
69 inner: ClientInner,
70 raw_subscribers: Arc<tokio::sync::Mutex<Vec<FilteredSubscriber>>>,
71}
72
73struct FilteredSubscriber {
75 filter: String,
76 tx: mpsc::Sender<RawMessage>,
77}
78
79enum ClientInner {
80 #[cfg(feature = "embedded-broker")]
81 InProcess {
82 link_tx: Arc<tokio::sync::Mutex<rumqttd::local::LinkTx>>,
83 _recv_task: JoinHandle<()>,
84 },
85 Remote {
86 mqtt: AsyncClient,
87 _event_loop: JoinHandle<()>,
88 },
89}
90
91impl Client {
92 pub async fn connect(config: ClientConfig) -> Result<Self> {
94 let raw_subscribers: Arc<tokio::sync::Mutex<Vec<FilteredSubscriber>>> =
95 Arc::new(tokio::sync::Mutex::new(Vec::new()));
96
97 let inner = match config.target {
98 #[cfg(feature = "embedded-broker")]
99 ConnectionTarget::InProcess { link } => {
100 Self::connect_in_process(link, raw_subscribers.clone())
101 }
102 ConnectionTarget::Remote { host, port } => {
103 Self::connect_remote(
104 &config.identity,
105 config.client_id.as_deref(),
106 &host,
107 port,
108 config.channel_capacity,
109 config.keep_alive,
110 raw_subscribers.clone(),
111 )
112 .await?
113 }
114 };
115
116 Ok(Self { identity: config.identity, inner, raw_subscribers })
117 }
118
119 pub async fn publish<T: Serialize>(
124 &self,
125 event_type: &str,
126 payload: &T,
127 qos_override: QosOverride,
128 ) -> Result<()> {
129 let topic = TopicBuilder::new()
130 .operator(&self.identity.operator_id)
131 .service(&self.identity.service)
132 .instance(&self.identity.instance_id)
133 .event_type(event_type)
134 .build_publish()?;
135
136 self.publish_inner(&topic, event_type, payload, qos_override, false).await
137 }
138
139 pub async fn publish_to<T: Serialize>(
141 &self,
142 address: &TopicAddress,
143 payload: &T,
144 qos_override: QosOverride,
145 ) -> Result<()> {
146 let topic = address.to_topic_string();
147 self.publish_inner(&topic, &address.event_type, payload, qos_override, false).await
148 }
149
150 pub async fn publish_retained<T: Serialize>(
152 &self,
153 event_type: &str,
154 payload: &T,
155 ) -> Result<()> {
156 let topic = TopicBuilder::new()
157 .operator(&self.identity.operator_id)
158 .service(&self.identity.service)
159 .instance(&self.identity.instance_id)
160 .event_type(event_type)
161 .build_publish()?;
162
163 self.publish_inner(&topic, event_type, payload, QosOverride::Force(QoS::AtLeastOnce), true)
164 .await
165 }
166
167 pub async fn subscribe<T: DeserializeOwned + Send + 'static>(
171 &self,
172 filter: &str,
173 qos: QoS,
174 ) -> Result<Subscription<T>> {
175 self.subscribe_mqtt(filter, qos).await?;
176
177 let (raw_tx, mut raw_rx) = mpsc::channel::<RawMessage>(256);
178 let (typed_tx, typed_rx) = mpsc::channel::<Result<Message<T>>>(256);
179
180 {
181 let mut guard = self.raw_subscribers.lock().await;
182 guard.push(FilteredSubscriber { filter: filter.to_string(), tx: raw_tx });
183 }
184
185 tokio::spawn(async move {
186 while let Some(raw) = raw_rx.recv().await {
187 let address = match TopicAddress::parse(&raw.topic) {
188 Ok(a) => a,
189 Err(e) => {
190 tracing::debug!("skipping message on `{}`: {e}", raw.topic);
191 continue;
192 }
193 };
194
195 let meta = match decode_user_properties(&raw.user_properties) {
196 Ok(m) => m,
197 Err(e) => {
198 tracing::warn!("metadata decode failed on `{}`: {e}", raw.topic);
199 continue;
200 }
201 };
202
203 let payload: T = match decode_payload(&raw.payload, &raw.topic) {
204 Ok(p) => p,
205 Err(e) => {
206 tracing::warn!("payload decode failed on `{}`: {e}", raw.topic);
207 continue;
208 }
209 };
210
211 let msg = Message {
212 envelope: Envelope { meta, payload },
213 address,
214 qos: raw.qos,
215 retained: raw.retained,
216 };
217
218 if typed_tx.send(Ok(msg)).await.is_err() {
219 break;
220 }
221 }
222 });
223
224 Ok(Subscription::new(typed_rx))
225 }
226
227 pub async fn subscribe_raw(
229 &self,
230 filter: &str,
231 qos: QoS,
232 ) -> Result<mpsc::Receiver<RawMessage>> {
233 self.subscribe_mqtt(filter, qos).await?;
234
235 let (tx, rx) = mpsc::channel::<RawMessage>(256);
236 {
237 let mut guard = self.raw_subscribers.lock().await;
238 guard.push(FilteredSubscriber { filter: filter.to_string(), tx });
239 }
240 Ok(rx)
241 }
242
243 pub async fn disconnect(self) -> Result<()> {
245 match self.inner {
246 #[cfg(feature = "embedded-broker")]
247 ClientInner::InProcess { _recv_task, .. } => {
248 _recv_task.abort();
249 }
250 ClientInner::Remote { mqtt, _event_loop } => {
251 let _ = mqtt.disconnect().await;
252 _event_loop.abort();
253 }
254 }
255 Ok(())
256 }
257
258 #[cfg(feature = "embedded-broker")]
261 fn connect_in_process(
262 link: crate::broker::BrokerLink,
263 subs: Arc<tokio::sync::Mutex<Vec<FilteredSubscriber>>>,
264 ) -> ClientInner {
265 let link_tx = Arc::new(tokio::sync::Mutex::new(link.tx));
266 let mut link_rx = link.rx;
267
268 let recv_task = tokio::spawn(async move {
269 loop {
270 match link_rx.next().await {
271 Ok(Some(notification)) => {
272 if let rumqttd::Notification::Forward(fwd) = notification {
273 let raw = RawMessage {
274 topic: String::from_utf8_lossy(&fwd.publish.topic).to_string(),
275 payload: fwd.publish.payload.to_vec(),
276 qos: 0, retained: fwd.publish.retain,
278 user_properties: fwd
279 .properties
280 .as_ref()
281 .map(|p| {
282 p.user_properties
283 .iter()
284 .map(|(k, v)| (k.clone(), v.clone()))
285 .collect()
286 })
287 .unwrap_or_default(),
288 };
289
290 let guard = subs.lock().await;
291 for sub in guard.iter() {
292 if topic_matches_filter(&raw.topic, &sub.filter) {
293 let _ = sub.tx.try_send(raw.clone());
294 }
295 }
296 }
297 }
298 Ok(None) => {}
299 Err(e) => {
300 tracing::warn!("in-process link recv error: {e}");
301 tokio::time::sleep(Duration::from_millis(10)).await;
302 }
303 }
304 }
305 });
306
307 ClientInner::InProcess { link_tx, _recv_task: recv_task }
308 }
309
310 async fn connect_remote(
311 identity: &ServiceIdentity,
312 client_id: Option<&str>,
313 host: &str,
314 port: u16,
315 channel_capacity: usize,
316 keep_alive: Duration,
317 subs: Arc<tokio::sync::Mutex<Vec<FilteredSubscriber>>>,
318 ) -> Result<ClientInner> {
319 let id = client_id
320 .map(str::to_owned)
321 .unwrap_or_else(|| format!("{}-{}", identity.service, identity.instance_id));
322
323 let mut opts = MqttOptions::new(&id, host, port);
324 opts.set_keep_alive(keep_alive);
325
326 let (mqtt, mut event_loop) = AsyncClient::new(opts, channel_capacity);
327
328 let loop_handle = tokio::spawn(async move {
329 use rumqttc::v5::mqttbytes::v5::Packet;
330 loop {
331 match event_loop.poll().await {
332 Ok(rumqttc::v5::Event::Incoming(Packet::Publish(publish))) => {
333 let raw = RawMessage {
334 topic: String::from_utf8_lossy(&publish.topic).to_string(),
335 payload: publish.payload.to_vec(),
336 qos: match publish.qos {
337 QoS::AtMostOnce => 0,
338 QoS::AtLeastOnce => 1,
339 QoS::ExactlyOnce => 2,
340 },
341 retained: publish.retain,
342 user_properties: publish
343 .properties
344 .as_ref()
345 .map(|p| {
346 p.user_properties
347 .iter()
348 .map(|(k, v)| (k.clone(), v.clone()))
349 .collect()
350 })
351 .unwrap_or_default(),
352 };
353
354 let guard = subs.lock().await;
355 for sub in guard.iter() {
356 if topic_matches_filter(&raw.topic, &sub.filter) {
357 let _ = sub.tx.try_send(raw.clone());
358 }
359 }
360 }
361 Ok(_) => {}
362 Err(e) => {
363 tracing::warn!("MQTT event loop error: {e}");
364 tokio::time::sleep(Duration::from_secs(1)).await;
365 }
366 }
367 }
368 });
369
370 Ok(ClientInner::Remote { mqtt, _event_loop: loop_handle })
371 }
372
373 fn build_metadata(&self) -> Metadata {
376 Metadata {
377 timestamp: chrono::Utc::now(),
378 source_service: self.identity.service.clone(),
379 source_instance: self.identity.instance_id.clone(),
380 operator_id: self.identity.operator_id.clone(),
381 schema_version: 1,
382 correlation_id: None,
383 }
384 }
385
386 async fn publish_inner<T: Serialize>(
387 &self,
388 topic: &str,
389 event_type: &str,
390 payload: &T,
391 qos_override: QosOverride,
392 retain: bool,
393 ) -> Result<()> {
394 let bytes = encode_payload(payload)?;
395 let meta = self.build_metadata();
396 let user_props = encode_user_properties(&meta);
397
398 match &self.inner {
399 #[cfg(feature = "embedded-broker")]
400 ClientInner::InProcess { link_tx, .. } => {
401 use bytes::Bytes;
402 use rumqttd::protocol::{Packet, Publish, PublishProperties};
403
404 let publish = Publish::new(
405 Bytes::copy_from_slice(topic.as_bytes()),
406 Bytes::from(bytes),
407 retain,
408 );
409 let properties =
410 PublishProperties { user_properties: user_props, ..Default::default() };
411
412 let mut tx = link_tx.lock().await;
413 tx.send(Packet::Publish(publish, Some(properties)))
414 .await
415 .map_err(|e| MqttError::Publish(e.to_string()))?;
416 Ok(())
417 }
418 ClientInner::Remote { mqtt, .. } => {
419 let qos = qos_override.resolve(event_type);
420 let properties = rumqttc::v5::mqttbytes::v5::PublishProperties {
421 user_properties: user_props,
422 ..Default::default()
423 };
424 mqtt.publish_with_properties(topic, qos, retain, bytes, properties)
425 .await
426 .map_err(|e| MqttError::Publish(e.to_string()))?;
427 Ok(())
428 }
429 }
430 }
431
432 async fn subscribe_mqtt(&self, filter: &str, qos: QoS) -> Result<()> {
433 match &self.inner {
434 #[cfg(feature = "embedded-broker")]
435 ClientInner::InProcess { link_tx, .. } => {
436 let mut tx = link_tx.lock().await;
437 tx.subscribe(filter).map_err(|e| MqttError::Subscribe(e.to_string()))?;
438 Ok(())
439 }
440 ClientInner::Remote { mqtt, .. } => {
441 mqtt.subscribe(filter, qos)
442 .await
443 .map_err(|e| MqttError::Subscribe(e.to_string()))?;
444 Ok(())
445 }
446 }
447 }
448}
449
450fn topic_matches_filter(topic: &str, filter: &str) -> bool {
455 let mut topic_parts = topic.split('/');
456 let mut filter_parts = filter.split('/').peekable();
457
458 loop {
459 match (filter_parts.next(), topic_parts.next()) {
460 (Some("#"), _) => return true,
461 (Some("+"), Some(_)) => continue,
462 (Some(f), Some(t)) if f == t => continue,
463 (None, None) => return true,
464 _ => return false,
465 }
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472
473 #[test]
474 fn topic_filter_exact_match() {
475 assert!(topic_matches_filter("a/b/c", "a/b/c"));
476 assert!(!topic_matches_filter("a/b/c", "a/b/d"));
477 }
478
479 #[test]
480 fn topic_filter_single_level_wildcard() {
481 assert!(topic_matches_filter("a/b/c", "a/+/c"));
482 assert!(topic_matches_filter("a/x/c", "a/+/c"));
483 assert!(!topic_matches_filter("a/b/c/d", "a/+/c"));
484 }
485
486 #[test]
487 fn topic_filter_multi_level_wildcard() {
488 assert!(topic_matches_filter("a/b/c", "a/#"));
489 assert!(topic_matches_filter("a/b/c/d", "a/#"));
490 assert!(topic_matches_filter("a", "a/#"));
491 assert!(!topic_matches_filter("b/c", "a/#"));
492 }
493
494 #[test]
495 fn topic_filter_combined_wildcards() {
496 assert!(topic_matches_filter(
497 "styrene/op1/omegon/inst-a/events/turn.started",
498 "styrene/op1/omegon/+/events/#"
499 ));
500 assert!(!topic_matches_filter(
501 "styrene/op1/viz/inst-a/events/turn.started",
502 "styrene/op1/omegon/+/events/#"
503 ));
504 }
505
506 #[test]
507 fn topic_filter_length_mismatch() {
508 assert!(!topic_matches_filter("a/b", "a/b/c"));
509 assert!(!topic_matches_filter("a/b/c", "a/b"));
510 }
511}