topic_stream/
topic_stream.rs

1use async_broadcast::{Receiver, SendError, Sender};
2use dashmap::DashMap;
3use futures::future::select_all;
4use std::{collections::HashSet, hash::Hash, sync::Arc};
5
6/// A topic-based publish-subscribe stream that allows multiple subscribers
7/// to listen to messages associated with specific topics.
8///
9/// # Type Parameters
10/// - T: The type representing a topic. Must be hashable, comparable, and clonable.
11/// - M: The message type that will be published and received. Must be clonable.
12#[derive(Debug, Clone)]
13pub struct TopicStream<T: Eq + Hash + Clone, M: Clone> {
14    /// Stores the active subscribers for each topic.
15    subscribers: Arc<DashMap<T, Sender<M>>>,
16}
17
18impl<T: Eq + Hash + Clone, M: Clone> TopicStream<T, M> {
19    /// Creates a new TopicStream instance with the specified capacity.
20    ///
21    /// # Arguments
22    /// - capacity: The maximum number of messages each topic can hold in its buffer.
23    ///
24    /// # Returns
25    /// A new TopicStream instance.
26    pub fn new(capacity: usize) -> Self {
27        Self {
28            subscribers: Arc::new(DashMap::with_capacity(capacity)),
29        }
30    }
31
32    /// Subscribes to a list of topics and returns a MultiTopicReceiver
33    /// that can receive messages from them.
34    ///
35    /// # Arguments
36    /// - topics: A slice of topics to subscribe to.
37    ///
38    /// # Returns
39    /// A MultiTopicReceiver that listens to the specified topics.
40    pub fn subscribe(&self, topics: &[T]) -> MultiTopicReceiver<T, M> {
41        let mut receiver = MultiTopicReceiver::new(Arc::clone(&self.subscribers));
42        receiver.subscribe(topics);
43
44        receiver
45    }
46
47    /// Publishes a message to a specific topic. If the topic has no subscribers,
48    /// the message is ignored.
49    ///
50    /// # Arguments
51    /// - topic: The topic to publish the message to.
52    /// - message: The message to send.
53    ///
54    /// # Returns
55    /// - Ok(()): If the message was successfully sent or there were no subscribers.
56    /// - Err(SendError<M>): If there was an error sending the message.
57    pub async fn publish(&self, topic: &T, message: M) -> Result<(), SendError<M>> {
58        if let Some(sender) = self.subscribers.get(topic) {
59            sender.broadcast(message).await?;
60        };
61
62        Ok(())
63    }
64}
65
66/// A multi-topic receiver that listens to messages from multiple topics.
67///
68/// # Type Parameters
69/// - T: The type representing a topic.
70/// - M: The message type being received.
71#[derive(Debug)]
72pub struct MultiTopicReceiver<T: Eq + Hash + Clone, M: Clone> {
73    /// A reference to the associated TopicStream.
74    subscribers: Arc<DashMap<T, Sender<M>>>,
75    /// The list of active message receivers for the subscribed topics.
76    receivers: Vec<Receiver<M>>,
77    /// Tracks the topics this receiver is currently subscribed to.
78    subscribed_topics: HashSet<T>,
79}
80
81impl<T: Eq + Hash + Clone, M: Clone> MultiTopicReceiver<T, M> {
82    /// Creates a new MultiTopicReceiver for the given TopicStream.
83    ///
84    /// # Arguments
85    /// - subscribers: A reference to the DashMap containing the active subscribers.
86    ///
87    /// # Returns
88    /// A new MultiTopicReceiver instance.
89    pub fn new(subscribers: Arc<DashMap<T, Sender<M>>>) -> Self {
90        Self {
91            subscribers,
92            receivers: Vec::new(),
93            subscribed_topics: HashSet::new(),
94        }
95    }
96
97    /// Subscribes to the given list of topics. If already subscribed to a topic,
98    /// it is ignored.
99    ///
100    /// # Arguments
101    /// - topics: A slice of topics to subscribe to.
102    pub fn subscribe(&mut self, topics: &[T]) {
103        self.receivers.extend(
104            topics
105                .iter()
106                .filter(|topic| self.subscribed_topics.insert((*topic).clone()))
107                .map(|topic| {
108                    let topic = topic.clone();
109                    let (sender, _receiver) =
110                        async_broadcast::broadcast(self.subscribers.capacity());
111
112                    self.subscribers
113                        .entry(topic)
114                        .or_insert_with(|| sender)
115                        .new_receiver()
116                }),
117        );
118    }
119
120    /// An Option<M> containing the received message, or None if all receivers are closed.
121    pub async fn recv(&mut self) -> Option<M> {
122        self.receivers.retain(|r| !r.is_closed());
123
124        if self.receivers.is_empty() {
125            return None;
126        }
127
128        let futures = self
129            .receivers
130            .iter_mut()
131            .map(|receiver| Box::pin(receiver.recv()))
132            .collect::<Vec<_>>();
133
134        let (result, _index, _remaining) = select_all(futures).await;
135
136        result.ok() // If a message is received, return it; otherwise, return None.
137    }
138}
139
140impl<T: Eq + Hash + Clone, M: Clone> Drop for MultiTopicReceiver<T, M> {
141    fn drop(&mut self) {
142        let mut to_remove = Vec::new();
143
144        for topic in &self.subscribed_topics {
145            if let Some(sender) = self.subscribers.get(topic) {
146                if sender.receiver_count() <= 1 {
147                    to_remove.push(topic.clone());
148                }
149            }
150        }
151
152        to_remove.into_iter().for_each(|topic| {
153            self.subscribers.remove(&topic);
154        });
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161    use std::hash::Hash;
162
163    #[derive(Debug, Clone, Hash, Eq, PartialEq)]
164    struct Topic(String);
165
166    #[derive(Debug, Clone, Eq, PartialEq)]
167    struct Message(String);
168
169    #[tokio::test]
170    async fn test_subscribe_and_publish_single_subscriber() {
171        let publisher = TopicStream::<Topic, Message>::new(2);
172        let topic = Topic("test_topic".to_string());
173
174        // Subscriber subscribes to the topic
175        let mut receiver = publisher.subscribe(&[topic.clone()]);
176
177        // Publisher sends a message to the topic
178        let message = Message("Hello, Subscriber!".to_string());
179        publisher.publish(&topic, message.clone()).await.unwrap();
180
181        // Subscriber receives the message
182        let received_message = receiver.recv().await.unwrap();
183        assert_eq!(received_message, message);
184    }
185
186    #[tokio::test]
187    async fn test_subscribe_multiple_subscribers() {
188        let publisher = TopicStream::<Topic, Message>::new(2);
189        let topic = Topic("test_topic".to_string());
190
191        // Subscriber 1 subscribes to the topic
192        let mut receiver1 = publisher.subscribe(&[topic.clone()]);
193        // Subscriber 2 subscribes to the topic
194        let mut receiver2 = publisher.subscribe(&[topic.clone()]);
195
196        // Publisher sends a message to the topic
197        let message = Message("Hello, Subscribers!".to_string());
198        publisher.publish(&topic, message.clone()).await.unwrap();
199
200        // Subscriber 1 receives the message
201        let received_message1 = receiver1.recv().await.unwrap();
202        assert_eq!(received_message1, message);
203
204        // Subscriber 2 receives the message
205        let received_message2 = receiver2.recv().await.unwrap();
206        assert_eq!(received_message2, message);
207    }
208
209    #[tokio::test]
210    async fn test_publish_to_unsubscribed_topic() {
211        let publisher = TopicStream::<Topic, Message>::new(2);
212        let topic = Topic("test_topic".to_string());
213
214        // Subscriber subscribes to a non-existent topic
215        let mut receiver = publisher.subscribe(&[Topic("invalid_topic".to_string())]);
216
217        // Publisher sends a message to the topic with no subscribers
218        let message = Message("Hello, World!".to_string());
219        publisher.publish(&topic, message.clone()).await.unwrap();
220
221        // No subscribers, so nothing to receive
222        // Here we assume that nothing crashes or any side effects occur.
223        // Test should pass as no message should be received
224
225        // Use a timeout to ensure the test completes
226        let timeout = tokio::time::sleep(tokio::time::Duration::from_secs(1));
227        tokio::select! {
228            _ = timeout => {
229                // Timeout reached, test completes
230            }
231            _ = receiver.recv() => {
232                panic!("Unexpected message received after timeout");
233            }
234        }
235    }
236
237    #[tokio::test]
238    async fn test_multiple_messages_for_single_subscriber() {
239        let publisher = TopicStream::<Topic, Message>::new(2);
240        let topic = Topic("test_topic".to_string());
241
242        // Subscriber subscribes to the topic
243        let mut receiver = publisher.subscribe(&[topic.clone()]);
244
245        // Publisher sends multiple messages
246        let message1 = Message("Message 1".to_string());
247        let message2 = Message("Message 2".to_string());
248        publisher.publish(&topic, message1.clone()).await.unwrap();
249        publisher.publish(&topic, message2.clone()).await.unwrap();
250
251        // Subscriber receives the first message
252        let received_message1 = receiver.recv().await.unwrap();
253        assert_eq!(received_message1, message1);
254
255        // Subscriber receives the second message
256        let received_message2 = receiver.recv().await.unwrap();
257        assert_eq!(received_message2, message2);
258    }
259
260    #[tokio::test]
261    async fn test_multiple_publishers() {
262        let publisher = TopicStream::<Topic, Message>::new(2);
263        let topic = Topic("test_topic".to_string());
264
265        // Subscriber subscribes to the topic
266        let mut receiver = publisher.subscribe(&[topic.clone()]);
267
268        // Publisher 1 sends a message
269        let message1 = Message("Message from Publisher 1".to_string());
270        publisher.publish(&topic, message1.clone()).await.unwrap();
271
272        // Publisher 2 sends a message
273        let message2 = Message("Message from Publisher 2".to_string());
274        publisher.publish(&topic, message2.clone()).await.unwrap();
275
276        // Subscriber receives the first message
277        let received_message1 = receiver.recv().await.unwrap();
278        assert_eq!(received_message1, message1);
279
280        // Subscriber receives the second message
281        let received_message2 = receiver.recv().await.unwrap();
282        assert_eq!(received_message2, message2);
283    }
284
285    #[tokio::test]
286    async fn test_subscribe_to_different_topics() {
287        let publisher = TopicStream::<Topic, Message>::new(2);
288        let topic1 = Topic("test_topic_1".to_string());
289        let topic2 = Topic("test_topic_2".to_string());
290
291        // Subscriber subscribes to topic 1
292        let mut receiver1 = publisher.subscribe(&[topic1.clone()]);
293
294        // Publisher sends a message to topic 1
295        let message1 = Message("Hello, Topic 1".to_string());
296        publisher.publish(&topic1, message1.clone()).await.unwrap();
297
298        // Subscriber 1 receives the message for topic 1
299        let received_message1 = receiver1.recv().await.unwrap();
300        assert_eq!(received_message1, message1);
301
302        // Subscriber subscribes to topic 2
303        let mut receiver2 = publisher.subscribe(&[topic2.clone()]);
304
305        // Publisher sends a message to topic 2
306        let message2 = Message("Hello, Topic 2".to_string());
307        publisher.publish(&topic2, message2.clone()).await.unwrap();
308
309        // Subscriber 2 receives the message for topic 2
310        let received_message2 = receiver2.recv().await.unwrap();
311        assert_eq!(received_message2, message2);
312    }
313
314    #[tokio::test]
315    async fn test_single_receiver_multiple_topics() {
316        let publisher = TopicStream::<Topic, Message>::new(2);
317
318        // Define multiple topics
319        let topic1 = Topic("test_topic_1".to_string());
320        let topic2 = Topic("test_topic_2".to_string());
321        let topic3 = Topic("test_topic_3".to_string());
322
323        // Subscriber subscribes to multiple topics
324        let mut receiver = publisher.subscribe(&[topic1.clone(), topic2.clone(), topic3.clone()]);
325
326        // Publisher sends messages to each topic
327        let message1 = Message("Message for Topic 1".to_string());
328        let message2 = Message("Message for Topic 2".to_string());
329        let message3 = Message("Message for Topic 3".to_string());
330
331        publisher.publish(&topic1, message1.clone()).await.unwrap();
332        publisher.publish(&topic2, message2.clone()).await.unwrap();
333        publisher.publish(&topic3, message3.clone()).await.unwrap();
334
335        // Subscriber should receive the messages in the order they were published
336        let received_message1 = receiver.recv().await.unwrap();
337        assert_eq!(received_message1, message1);
338
339        let received_message2 = receiver.recv().await.unwrap();
340        assert_eq!(received_message2, message2);
341
342        let received_message3 = receiver.recv().await.unwrap();
343        assert_eq!(received_message3, message3);
344    }
345}