1use crossbeam_channel::{Receiver, Sender, bounded};
173use lazy_static::lazy_static;
174use std::any::Any;
175use std::collections::HashMap;
176use std::sync::{Arc, Mutex};
177use std::time::Duration;
178use uuid::Uuid;
179
180lazy_static! {
182 static ref PUBSUB: Arc<PubSub> = Arc::new(PubSub::new());
183}
184
185#[derive(Clone)]
186pub struct TopicConfig {
187 queue_depth: usize,
188 overwrite: bool,
189}
190
191impl TopicConfig {
192 pub fn new(queue_depth: usize, overwrite: bool) -> Self {
193 TopicConfig {
194 queue_depth,
195 overwrite,
196 }
197 }
198}
199
200#[derive(Clone)]
201struct MessageWrapper {
202 data: Arc<dyn Any + Send + Sync>,
203}
204
205#[derive(Clone)]
206struct ChannelPair {
207 sender: Sender<MessageWrapper>,
208 receiver: Receiver<MessageWrapper>,
209 config: TopicConfig,
210 subscriber_id: String,
211}
212
213impl ChannelPair {
214 fn new(
215 sender: Sender<MessageWrapper>,
216 receiver: Receiver<MessageWrapper>,
217 config: TopicConfig,
218 subscriber_id: String,
219 ) -> Self {
220 ChannelPair {
221 sender,
222 receiver,
223 config,
224 subscriber_id,
225 }
226 }
227}
228
229struct TopicData {
230 #[allow(dead_code)]
231 name: String,
232 channel_pairs: Vec<ChannelPair>,
233}
234
235struct SubscriberData {
236 topic: String,
237 #[allow(dead_code)]
238 receiver: Receiver<MessageWrapper>,
239 #[allow(dead_code)]
240 callback: Option<Arc<dyn Fn(&dyn Any) + Send + Sync>>,
241}
242
243#[derive(Clone)]
244pub struct ManualReceiver<T: 'static> {
245 receiver: Receiver<MessageWrapper>,
246 subscriber_id: String,
247 pubsub: Arc<PubSub>,
248 _marker: std::marker::PhantomData<T>,
249}
250
251impl<T: Clone + 'static> ManualReceiver<T> {
252 pub fn try_recv(&self) -> Option<T> {
253 let msg = self.receiver.try_recv().ok();
254
255 match msg {
256 Some(msg) => {
257 if let Some(data) = msg.downcast::<T>() {
258 return Some(data.to_owned());
259 }
260 None
261 }
262 None => None,
263 }
264 }
265
266 pub fn recv(&self) -> Option<T> {
267 self.recv_timeout(None)
268 }
269
270 pub fn recv_timeout(&self, timeout_ms: Option<u64>) -> Option<T> {
271 let msg = match timeout_ms {
272 Some(ms) => self.receiver.recv_timeout(Duration::from_millis(ms)).ok(),
273 None => self.receiver.recv().ok(),
274 };
275
276 match msg {
277 Some(msg) => {
278 if let Some(data) = msg.downcast::<T>() {
279 return Some(data.to_owned());
280 }
281 None
282 }
283 None => None,
284 }
285 }
286
287 pub fn unsubscribe(self) {
288 self.pubsub.unsubscribe(&self.subscriber_id);
289 }
290}
291
292impl MessageWrapper {
293 fn new<T: Send + Sync + Clone + 'static>(data: T) -> Self {
294 MessageWrapper {
295 data: Arc::new(data),
296 }
297 }
298
299 fn downcast<T: 'static>(&self) -> Option<&T> {
300 self.data.downcast_ref::<T>()
301 }
302}
303
304pub struct PubSub {
305 topics: Mutex<Vec<TopicData>>,
306 topic_map: Mutex<HashMap<String, usize>>,
307 subscribers: Mutex<HashMap<String, SubscriberData>>,
308}
309
310impl PubSub {
311 fn new() -> Self {
312 PubSub {
313 topics: Mutex::new(Vec::new()),
314 topic_map: Mutex::new(HashMap::new()),
315 subscribers: Mutex::new(HashMap::new()),
316 }
317 }
318
319 pub fn instance() -> Arc<PubSub> {
320 PUBSUB.clone()
321 }
322
323 pub fn create_publisher(&self, topic: &str) -> usize {
324 let mut topic_map = self.topic_map.lock().unwrap();
325
326 if let Some(&index) = topic_map.get(topic) {
327 return index;
328 }
329
330 let mut topics = self.topics.lock().unwrap();
331 let new_index = topics.len();
332
333 topics.push(TopicData {
334 name: topic.to_string(),
335 channel_pairs: Vec::new(),
336 });
337
338 topic_map.insert(topic.to_string(), new_index);
339
340 new_index
341 }
342
343 pub fn subscribe_manual<T: Send + Sync + Clone + 'static>(
344 &self,
345 topic: &str,
346 config: TopicConfig,
347 ) -> ManualReceiver<T>
348 where
349 T: 'static,
350 {
351 let subscriber_id = Uuid::new_v4().to_string();
352 let (tx, rx) = bounded(config.queue_depth);
353 let topic_str = topic.to_string();
354
355 let topic_index = self.create_publisher(topic);
356
357 {
358 let mut topics = self.topics.lock().unwrap();
359 topics[topic_index].channel_pairs.push(ChannelPair::new(
360 tx,
361 rx.clone(),
362 config.clone(),
363 subscriber_id.clone(),
364 ));
365 }
366
367 {
368 self.subscribers.lock().unwrap().insert(
369 subscriber_id.clone(),
370 SubscriberData {
371 topic: topic_str.clone(),
372 receiver: rx.clone(),
373 callback: None,
374 },
375 );
376 }
377
378 ManualReceiver {
379 receiver: rx,
380 subscriber_id,
381 pubsub: PubSub::instance(),
382 _marker: std::marker::PhantomData,
383 }
384 }
385
386 pub fn subscribe<T, F>(&self, topic: &str, config: TopicConfig, callback: F) -> String
387 where
388 T: Send + Sync + Clone + 'static,
389 F: Fn(&T) + Send + Sync + 'static,
390 {
391 let subscriber_id = Uuid::new_v4().to_string();
392 let (tx, rx) = bounded(config.queue_depth);
393 let topic_str = topic.to_string();
394
395 let topic_index = self.create_publisher(topic);
396
397 {
398 let mut topics = self.topics.lock().unwrap();
399 topics[topic_index].channel_pairs.push(ChannelPair::new(
400 tx,
401 rx.clone(),
402 config.clone(),
403 subscriber_id.clone(),
404 ));
405 }
406
407 let callback_wrapper: Arc<dyn Fn(&dyn Any) + Send + Sync> =
408 Arc::new(move |data: &dyn Any| {
409 if let Some(t) = data.downcast_ref::<T>() {
410 callback(t);
411 }
412 });
413
414 {
415 self.subscribers.lock().unwrap().insert(
416 subscriber_id.clone(),
417 SubscriberData {
418 topic: topic_str.clone(),
419 receiver: rx.clone(),
420 callback: Some(callback_wrapper.clone()),
421 },
422 );
423 }
424
425 let rx_clone = rx.clone();
426 let callback_for_thread = callback_wrapper.clone();
427 std::thread::spawn(move || {
428 while let Ok(msg) = rx_clone.recv() {
429 if let Some(data) = msg.downcast::<T>() {
430 callback_for_thread(data);
431 }
432 }
433 });
434
435 subscriber_id
436 }
437
438 pub fn try_publish<T: Send + Sync + Clone + 'static>(&self, topic_id: usize, message: T) {
439 let msg = MessageWrapper::new(message);
440
441 let channel_pairs = {
442 let topics = self.topics.lock().unwrap();
443
444 if topic_id >= topics.len() {
445 return;
446 }
447
448 if topics[topic_id].channel_pairs.is_empty() {
449 return;
450 }
451
452 topics[topic_id].channel_pairs.clone()
453 };
454
455 for pair in channel_pairs.iter() {
456 if pair.config.overwrite {
457 while pair.sender.is_full() {
458 let _ = pair.receiver.try_recv();
459 }
460 }
461
462 let _ = pair.sender.try_send(msg.clone());
463 }
464 }
465
466 pub fn publish<T: Send + Sync + Clone + 'static>(&self, topic_id: usize, message: T) {
467 self.publish_with_timeout(topic_id, message, None);
468 }
469
470 pub fn publish_with_timeout<T: Send + Sync + Clone + 'static>(
471 &self,
472 topic_id: usize,
473 message: T,
474 max_wait_ms: Option<u64>,
475 ) {
476 let msg = MessageWrapper::new(message);
477
478 let channel_pairs = {
479 let topics = self.topics.lock().unwrap();
480
481 if topic_id >= topics.len() {
482 return;
483 }
484
485 if topics[topic_id].channel_pairs.is_empty() {
486 return;
487 }
488
489 topics[topic_id].channel_pairs.clone()
490 };
491
492 for pair in channel_pairs.iter() {
493 if pair.config.overwrite {
494 while pair.sender.is_full() {
495 let _ = pair.receiver.try_recv();
496 }
497 let _ = pair.sender.try_send(msg.clone());
498 } else {
499 match max_wait_ms {
500 Some(ms) => {
501 let _ = pair
502 .sender
503 .send_timeout(msg.clone(), Duration::from_millis(ms));
504 }
505 None => {
506 let _ = pair.sender.send(msg.clone());
507 }
508 }
509 }
510 }
511 }
512
513 pub fn unsubscribe(&self, subscriber_id: &str) {
514 let topic_opt = {
515 let mut subscribers = self.subscribers.lock().unwrap();
516 if let Some(data) = subscribers.remove(subscriber_id) {
517 Some(data.topic)
518 } else {
519 None
520 }
521 };
522
523 if let Some(topic) = topic_opt {
524 let topic_index_opt = {
525 let topic_map = self.topic_map.lock().unwrap();
526 topic_map.get(&topic).cloned()
527 };
528
529 if let Some(topic_index) = topic_index_opt {
530 let mut topics = self.topics.lock().unwrap();
531 if let Some(topic_data) = topics.get_mut(topic_index) {
532 topic_data
533 .channel_pairs
534 .retain(|pair| pair.subscriber_id != subscriber_id);
535
536 if topic_data.channel_pairs.is_empty() {
537 let mut topic_map = self.topic_map.lock().unwrap();
538 topic_map.remove(&topic);
539 }
540 }
541 }
542 }
543 }
544}