rust_pubsub/
lib.rs

1//! # Rust PubSub
2//!
3//! A thread-safe, in-memory publish-subscribe library for Rust, designed for efficient and flexible
4//! inter-thread communication. It supports both manual message receiving and callback-based subscriptions,
5//! with configurable queue depth and overwrite behavior.
6//!
7//! ## Features
8//!
9//! - **Thread-Safe**: Uses `crossbeam-channel` for safe message passing between threads.
10//! - **Flexible Subscriptions**: Supports manual message receiving and callback-based subscriptions.
11//! - **Configurable Channels**: Allows customization of queue depth and overwrite behavior per subscription.
12//! - **Type-Safe Messaging**: Supports any `Send + Sync + Clone + 'static` type for messages.
13//! - **Timeout Support**: Provides blocking, non-blocking, and timeout-based message receiving and publishing.
14//!
15//! ## 功能(中文)
16//!
17//! - **线程安全**:使用 `crossbeam-channel` 实现线程间安全消息传递。
18//! - **灵活订阅**:支持手动消息接收和基于回调的订阅模式。
19//! - **可配置通道**:允许为每个订阅配置队列深度和覆盖行为。
20//! - **类型安全消息**:支持任何满足 `Send + Sync + Clone + 'static` 的消息类型。
21//! - **超时支持**:提供阻塞、非阻塞和带超时的消息接收与发布。
22//!
23//! ## Usage Examples
24//!
25//! ### 1. Manual Subscription with Non-Blocking Receive
26//!
27//! ```rust
28//! use rust_pubsub::{PubSub, TopicConfig};
29//!
30//! let pubsub = PubSub::instance();
31//! let topic = "test_topic";
32//! let config = TopicConfig::new(10, false); // Queue depth 10, no overwrite
33//!
34//! // Create publisher
35//! let topic_id = pubsub.create_publisher(topic);
36//!
37//! // Subscribe manually
38//! let receiver = pubsub.subscribe_manual::<String>(topic, config);
39//!
40//! // Publish a message
41//! pubsub.publish(topic_id, "Hello, World!".to_string());
42//!
43//! // Try to receive the message
44//! if let Some(msg) = receiver.try_recv() {
45//!     println!("Received: {}", msg);
46//! }
47//! ```
48//!
49//! ### 2. Callback-Based Subscription
50//!
51//! ```rust
52//! use rust_pubsub::{PubSub, TopicConfig};
53//!
54//! let pubsub = PubSub::instance();
55//! let topic = "callback_topic";
56//! let config = TopicConfig::new(5, true); // Queue depth 5, overwrite enabled
57//!
58//! // Create publisher
59//! let topic_id = pubsub.create_publisher(topic);
60//!
61//! // Subscribe with a callback
62//! let subscriber_id = pubsub.subscribe::<String, _>(topic, config, |msg: &String| {
63//!     println!("Callback received: {}", msg);
64//! });
65//!
66//! // Publish a message
67//! pubsub.publish(topic_id, "Callback message".to_string());
68//!
69//! // Wait briefly to ensure callback executes
70//! std::thread::sleep(std::time::Duration::from_millis(100));
71//!
72//! // Unsubscribe
73//! pubsub.unsubscribe(&subscriber_id);
74//! ```
75//!
76//! ### 3. Publishing with Timeout
77//!
78//! ```rust
79//! use rust_pubsub::{PubSub, TopicConfig};
80//!
81//! let pubsub = PubSub::instance();
82//! let topic = "timeout_topic";
83//! let config = TopicConfig::new(1, false); // Queue depth 1, no overwrite
84//!
85//! // Create publisher
86//! let topic_id = pubsub.create_publisher(topic);
87//!
88//! // Subscribe manually
89//! let receiver = pubsub.subscribe_manual::<i32>(topic, config);
90//!
91//! // Publish with timeout (100ms)
92//! pubsub.publish_with_timeout(topic_id, 42, Some(100));
93//!
94//! // Receive with timeout (100ms)
95//! if let Some(msg) = receiver.recv_timeout(Some(100)) {
96//!     println!("Received: {}", msg);
97//! }
98//! ```
99//!
100//! ### 4. Overwrite Mode with Full Queue
101//!
102//! ```rust
103//! use rust_pubsub::{PubSub, TopicConfig};
104//!
105//! let pubsub = PubSub::instance();
106//! let topic = "overwrite_topic";
107//! let config = TopicConfig::new(2, true); // Queue depth 2, overwrite enabled
108//!
109//! // Create publisher
110//! let topic_id = pubsub.create_publisher(topic);
111//!
112//! // Subscribe manually
113//! let receiver = pubsub.subscribe_manual::<String>( کمپل, config);
114//!
115//! // Publish multiple messages to fill queue
116//! pubsub.publish(topic_id, "Message 1".to_string());
117//! pubsub.publish(topic_id, "Message 2".to_string());
118//! pubsub.publish(topic_id, "Message 3".to_string()); // Overwrites oldest
119//!
120//! // Receive messages
121//! while let Some(msg) = receiver.try_recv() {
122//!     println!("Received: {}", msg);
123//! }
124//! ```
125//!
126//! ### 5. Multiple Subscribers
127//!
128//! ```rust
129//! use rust_pubsub::{PubSub, TopicConfig};
130//!
131//! let pubsub = PubSub::instance();
132//! let topic = "multi_subscriber_topic";
133//! let config = TopicConfig::new(10, false);
134//!
135//! // Create publisher
136//! let topic_id = pubsub.create_publisher(topic);
137//!
138//! // Subscribe multiple times
139//! let receiver1 = pubsub.subscribe_manual::<String>(topic, config.clone());
140//! let receiver2 = pubsub.subscribe_manual::<String>(topic, config.clone());
141//!
142//! // Publish a message
143//! pubsub.publish(topic_id, "Broadcast message".to_string());
144//!
145//! // Receive from both subscribers
146//! println!("Receiver 1: {:?}", receiver1.try_recv());
147//! println!("Receiver 2: {:?}", receiver2.try_recv());
148//! ```
149//!
150//! ## Installation
151//!
152//! Add the following to your `Cargo.toml`:
153//!
154//! ```toml
155//! [dependencies]
156//! rust-pubsub = "0.1.0"
157//! ```
158//!
159//! ## 安装(中文)
160//!
161//! 在您的 `Cargo.toml` 中添加以下内容:
162//!
163//! ```toml
164//! [dependencies]
165//! rust-pubsub = "0.1.0"
166//! ```
167//!
168//! ## License
169//!
170//! Licensed under either of Apache License, Version 2.0 or MIT license at your option.
171
172use crossbeam_channel::{Receiver, Sender, bounded};
173use lazy_static::lazy_static;
174use std::any::Any;
175use std::collections::HashMap;
176use std::sync::{Arc, Mutex};
177use std::time::Duration;
178use uuid::Uuid;
179
180// Your original code follows here, unchanged
181lazy_static! {
182    static ref PUBSUB: Arc<PubSub> = Arc::new(PubSub::new());
183}
184
185#[derive(Clone)]
186pub struct TopicConfig {
187    queue_depth: usize,
188    overwrite: bool,
189}
190
191impl TopicConfig {
192    pub fn new(queue_depth: usize, overwrite: bool) -> Self {
193        TopicConfig {
194            queue_depth,
195            overwrite,
196        }
197    }
198}
199
200#[derive(Clone)]
201struct MessageWrapper {
202    data: Arc<dyn Any + Send + Sync>,
203}
204
205#[derive(Clone)]
206struct ChannelPair {
207    sender: Sender<MessageWrapper>,
208    receiver: Receiver<MessageWrapper>,
209    config: TopicConfig,
210    subscriber_id: String,
211}
212
213impl ChannelPair {
214    fn new(
215        sender: Sender<MessageWrapper>,
216        receiver: Receiver<MessageWrapper>,
217        config: TopicConfig,
218        subscriber_id: String,
219    ) -> Self {
220        ChannelPair {
221            sender,
222            receiver,
223            config,
224            subscriber_id,
225        }
226    }
227}
228
229struct TopicData {
230    #[allow(dead_code)]
231    name: String,
232    channel_pairs: Vec<ChannelPair>,
233}
234
235struct SubscriberData {
236    topic: String,
237    #[allow(dead_code)]
238    receiver: Receiver<MessageWrapper>,
239    #[allow(dead_code)]
240    callback: Option<Arc<dyn Fn(&dyn Any) + Send + Sync>>,
241}
242
243#[derive(Clone)]
244pub struct ManualReceiver<T: 'static> {
245    receiver: Receiver<MessageWrapper>,
246    subscriber_id: String,
247    pubsub: Arc<PubSub>,
248    _marker: std::marker::PhantomData<T>,
249}
250
251impl<T: Clone + 'static> ManualReceiver<T> {
252    pub fn try_recv(&self) -> Option<T> {
253        let msg = self.receiver.try_recv().ok();
254
255        match msg {
256            Some(msg) => {
257                if let Some(data) = msg.downcast::<T>() {
258                    return Some(data.to_owned());
259                }
260                None
261            }
262            None => None,
263        }
264    }
265
266    pub fn recv(&self) -> Option<T> {
267        self.recv_timeout(None)
268    }
269
270    pub fn recv_timeout(&self, timeout_ms: Option<u64>) -> Option<T> {
271        let msg = match timeout_ms {
272            Some(ms) => self.receiver.recv_timeout(Duration::from_millis(ms)).ok(),
273            None => self.receiver.recv().ok(),
274        };
275
276        match msg {
277            Some(msg) => {
278                if let Some(data) = msg.downcast::<T>() {
279                    return Some(data.to_owned());
280                }
281                None
282            }
283            None => None,
284        }
285    }
286
287    pub fn unsubscribe(self) {
288        self.pubsub.unsubscribe(&self.subscriber_id);
289    }
290}
291
292impl MessageWrapper {
293    fn new<T: Send + Sync + Clone + 'static>(data: T) -> Self {
294        MessageWrapper {
295            data: Arc::new(data),
296        }
297    }
298
299    fn downcast<T: 'static>(&self) -> Option<&T> {
300        self.data.downcast_ref::<T>()
301    }
302}
303
304pub struct PubSub {
305    topics: Mutex<Vec<TopicData>>,
306    topic_map: Mutex<HashMap<String, usize>>,
307    subscribers: Mutex<HashMap<String, SubscriberData>>,
308}
309
310impl PubSub {
311    fn new() -> Self {
312        PubSub {
313            topics: Mutex::new(Vec::new()),
314            topic_map: Mutex::new(HashMap::new()),
315            subscribers: Mutex::new(HashMap::new()),
316        }
317    }
318
319    pub fn instance() -> Arc<PubSub> {
320        PUBSUB.clone()
321    }
322
323    pub fn create_publisher(&self, topic: &str) -> usize {
324        let mut topic_map = self.topic_map.lock().unwrap();
325
326        if let Some(&index) = topic_map.get(topic) {
327            return index;
328        }
329
330        let mut topics = self.topics.lock().unwrap();
331        let new_index = topics.len();
332
333        topics.push(TopicData {
334            name: topic.to_string(),
335            channel_pairs: Vec::new(),
336        });
337
338        topic_map.insert(topic.to_string(), new_index);
339
340        new_index
341    }
342
343    pub fn subscribe_manual<T: Send + Sync + Clone + 'static>(
344        &self,
345        topic: &str,
346        config: TopicConfig,
347    ) -> ManualReceiver<T>
348    where
349        T: 'static,
350    {
351        let subscriber_id = Uuid::new_v4().to_string();
352        let (tx, rx) = bounded(config.queue_depth);
353        let topic_str = topic.to_string();
354
355        let topic_index = self.create_publisher(topic);
356
357        {
358            let mut topics = self.topics.lock().unwrap();
359            topics[topic_index].channel_pairs.push(ChannelPair::new(
360                tx,
361                rx.clone(),
362                config.clone(),
363                subscriber_id.clone(),
364            ));
365        }
366
367        {
368            self.subscribers.lock().unwrap().insert(
369                subscriber_id.clone(),
370                SubscriberData {
371                    topic: topic_str.clone(),
372                    receiver: rx.clone(),
373                    callback: None,
374                },
375            );
376        }
377
378        ManualReceiver {
379            receiver: rx,
380            subscriber_id,
381            pubsub: PubSub::instance(),
382            _marker: std::marker::PhantomData,
383        }
384    }
385
386    pub fn subscribe<T, F>(&self, topic: &str, config: TopicConfig, callback: F) -> String
387    where
388        T: Send + Sync + Clone + 'static,
389        F: Fn(&T) + Send + Sync + 'static,
390    {
391        let subscriber_id = Uuid::new_v4().to_string();
392        let (tx, rx) = bounded(config.queue_depth);
393        let topic_str = topic.to_string();
394
395        let topic_index = self.create_publisher(topic);
396
397        {
398            let mut topics = self.topics.lock().unwrap();
399            topics[topic_index].channel_pairs.push(ChannelPair::new(
400                tx,
401                rx.clone(),
402                config.clone(),
403                subscriber_id.clone(),
404            ));
405        }
406
407        let callback_wrapper: Arc<dyn Fn(&dyn Any) + Send + Sync> =
408            Arc::new(move |data: &dyn Any| {
409                if let Some(t) = data.downcast_ref::<T>() {
410                    callback(t);
411                }
412            });
413
414        {
415            self.subscribers.lock().unwrap().insert(
416                subscriber_id.clone(),
417                SubscriberData {
418                    topic: topic_str.clone(),
419                    receiver: rx.clone(),
420                    callback: Some(callback_wrapper.clone()),
421                },
422            );
423        }
424
425        let rx_clone = rx.clone();
426        let callback_for_thread = callback_wrapper.clone();
427        std::thread::spawn(move || {
428            while let Ok(msg) = rx_clone.recv() {
429                if let Some(data) = msg.downcast::<T>() {
430                    callback_for_thread(data);
431                }
432            }
433        });
434
435        subscriber_id
436    }
437
438    pub fn try_publish<T: Send + Sync + Clone + 'static>(&self, topic_id: usize, message: T) {
439        let msg = MessageWrapper::new(message);
440
441        let channel_pairs = {
442            let topics = self.topics.lock().unwrap();
443
444            if topic_id >= topics.len() {
445                return;
446            }
447
448            if topics[topic_id].channel_pairs.is_empty() {
449                return;
450            }
451
452            topics[topic_id].channel_pairs.clone()
453        };
454
455        for pair in channel_pairs.iter() {
456            if pair.config.overwrite {
457                while pair.sender.is_full() {
458                    let _ = pair.receiver.try_recv();
459                }
460            }
461
462            let _ = pair.sender.try_send(msg.clone());
463        }
464    }
465
466    pub fn publish<T: Send + Sync + Clone + 'static>(&self, topic_id: usize, message: T) {
467        self.publish_with_timeout(topic_id, message, None);
468    }
469
470    pub fn publish_with_timeout<T: Send + Sync + Clone + 'static>(
471        &self,
472        topic_id: usize,
473        message: T,
474        max_wait_ms: Option<u64>,
475    ) {
476        let msg = MessageWrapper::new(message);
477
478        let channel_pairs = {
479            let topics = self.topics.lock().unwrap();
480
481            if topic_id >= topics.len() {
482                return;
483            }
484
485            if topics[topic_id].channel_pairs.is_empty() {
486                return;
487            }
488
489            topics[topic_id].channel_pairs.clone()
490        };
491
492        for pair in channel_pairs.iter() {
493            if pair.config.overwrite {
494                while pair.sender.is_full() {
495                    let _ = pair.receiver.try_recv();
496                }
497                let _ = pair.sender.try_send(msg.clone());
498            } else {
499                match max_wait_ms {
500                    Some(ms) => {
501                        let _ = pair
502                            .sender
503                            .send_timeout(msg.clone(), Duration::from_millis(ms));
504                    }
505                    None => {
506                        let _ = pair.sender.send(msg.clone());
507                    }
508                }
509            }
510        }
511    }
512
513    pub fn unsubscribe(&self, subscriber_id: &str) {
514        let topic_opt = {
515            let mut subscribers = self.subscribers.lock().unwrap();
516            if let Some(data) = subscribers.remove(subscriber_id) {
517                Some(data.topic)
518            } else {
519                None
520            }
521        };
522
523        if let Some(topic) = topic_opt {
524            let topic_index_opt = {
525                let topic_map = self.topic_map.lock().unwrap();
526                topic_map.get(&topic).cloned()
527            };
528
529            if let Some(topic_index) = topic_index_opt {
530                let mut topics = self.topics.lock().unwrap();
531                if let Some(topic_data) = topics.get_mut(topic_index) {
532                    topic_data
533                        .channel_pairs
534                        .retain(|pair| pair.subscriber_id != subscriber_id);
535
536                    if topic_data.channel_pairs.is_empty() {
537                        let mut topic_map = self.topic_map.lock().unwrap();
538                        topic_map.remove(&topic);
539                    }
540                }
541            }
542        }
543    }
544}