topic_stream/
topic_stream.rs1use async_broadcast::{Receiver, SendError, Sender};
2use dashmap::DashMap;
3use futures::future::select_all;
4use std::{collections::HashSet, hash::Hash, sync::Arc};
5
6#[derive(Debug, Clone)]
13pub struct TopicStream<T: Eq + Hash + Clone, M: Clone> {
14 subscribers: Arc<DashMap<T, Sender<M>>>,
16}
17
18impl<T: Eq + Hash + Clone, M: Clone> TopicStream<T, M> {
19 pub fn new(capacity: usize) -> Self {
27 Self {
28 subscribers: Arc::new(DashMap::with_capacity(capacity)),
29 }
30 }
31
32 pub fn subscribe(&self, topics: &[T]) -> MultiTopicReceiver<T, M> {
41 let mut receiver = MultiTopicReceiver::new(Arc::clone(&self.subscribers));
42 receiver.subscribe(topics);
43
44 receiver
45 }
46
47 pub async fn publish(&self, topic: &T, message: M) -> Result<(), SendError<M>> {
58 if let Some(sender) = self.subscribers.get(topic) {
59 sender.broadcast(message).await?;
60 };
61
62 Ok(())
63 }
64}
65
66#[derive(Debug)]
72pub struct MultiTopicReceiver<T: Eq + Hash + Clone, M: Clone> {
73 subscribers: Arc<DashMap<T, Sender<M>>>,
75 receivers: Vec<Receiver<M>>,
77 subscribed_topics: HashSet<T>,
79}
80
81impl<T: Eq + Hash + Clone, M: Clone> MultiTopicReceiver<T, M> {
82 pub fn new(subscribers: Arc<DashMap<T, Sender<M>>>) -> Self {
90 Self {
91 subscribers,
92 receivers: Vec::new(),
93 subscribed_topics: HashSet::new(),
94 }
95 }
96
97 pub fn subscribe(&mut self, topics: &[T]) {
103 self.receivers.extend(
104 topics
105 .iter()
106 .filter(|topic| self.subscribed_topics.insert((*topic).clone()))
107 .map(|topic| {
108 let topic = topic.clone();
109 let (sender, _receiver) =
110 async_broadcast::broadcast(self.subscribers.capacity());
111
112 self.subscribers
113 .entry(topic)
114 .or_insert_with(|| sender)
115 .new_receiver()
116 }),
117 );
118 }
119
120 pub async fn recv(&mut self) -> Option<M> {
122 self.receivers.retain(|r| !r.is_closed());
123
124 if self.receivers.is_empty() {
125 return None;
126 }
127
128 let futures = self
129 .receivers
130 .iter_mut()
131 .map(|receiver| Box::pin(receiver.recv()))
132 .collect::<Vec<_>>();
133
134 let (result, _index, _remaining) = select_all(futures).await;
135
136 result.ok() }
138}
139
140impl<T: Eq + Hash + Clone, M: Clone> Drop for MultiTopicReceiver<T, M> {
141 fn drop(&mut self) {
142 let mut to_remove = Vec::new();
143
144 for topic in &self.subscribed_topics {
145 if let Some(sender) = self.subscribers.get(topic) {
146 if sender.receiver_count() <= 1 {
147 to_remove.push(topic.clone());
148 }
149 }
150 }
151
152 to_remove.into_iter().for_each(|topic| {
153 self.subscribers.remove(&topic);
154 });
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161 use std::hash::Hash;
162
163 #[derive(Debug, Clone, Hash, Eq, PartialEq)]
164 struct Topic(String);
165
166 #[derive(Debug, Clone, Eq, PartialEq)]
167 struct Message(String);
168
169 #[tokio::test]
170 async fn test_subscribe_and_publish_single_subscriber() {
171 let publisher = TopicStream::<Topic, Message>::new(2);
172 let topic = Topic("test_topic".to_string());
173
174 let mut receiver = publisher.subscribe(&[topic.clone()]);
176
177 let message = Message("Hello, Subscriber!".to_string());
179 publisher.publish(&topic, message.clone()).await.unwrap();
180
181 let received_message = receiver.recv().await.unwrap();
183 assert_eq!(received_message, message);
184 }
185
186 #[tokio::test]
187 async fn test_subscribe_multiple_subscribers() {
188 let publisher = TopicStream::<Topic, Message>::new(2);
189 let topic = Topic("test_topic".to_string());
190
191 let mut receiver1 = publisher.subscribe(&[topic.clone()]);
193 let mut receiver2 = publisher.subscribe(&[topic.clone()]);
195
196 let message = Message("Hello, Subscribers!".to_string());
198 publisher.publish(&topic, message.clone()).await.unwrap();
199
200 let received_message1 = receiver1.recv().await.unwrap();
202 assert_eq!(received_message1, message);
203
204 let received_message2 = receiver2.recv().await.unwrap();
206 assert_eq!(received_message2, message);
207 }
208
209 #[tokio::test]
210 async fn test_publish_to_unsubscribed_topic() {
211 let publisher = TopicStream::<Topic, Message>::new(2);
212 let topic = Topic("test_topic".to_string());
213
214 let mut receiver = publisher.subscribe(&[Topic("invalid_topic".to_string())]);
216
217 let message = Message("Hello, World!".to_string());
219 publisher.publish(&topic, message.clone()).await.unwrap();
220
221 let timeout = tokio::time::sleep(tokio::time::Duration::from_secs(1));
227 tokio::select! {
228 _ = timeout => {
229 }
231 _ = receiver.recv() => {
232 panic!("Unexpected message received after timeout");
233 }
234 }
235 }
236
237 #[tokio::test]
238 async fn test_multiple_messages_for_single_subscriber() {
239 let publisher = TopicStream::<Topic, Message>::new(2);
240 let topic = Topic("test_topic".to_string());
241
242 let mut receiver = publisher.subscribe(&[topic.clone()]);
244
245 let message1 = Message("Message 1".to_string());
247 let message2 = Message("Message 2".to_string());
248 publisher.publish(&topic, message1.clone()).await.unwrap();
249 publisher.publish(&topic, message2.clone()).await.unwrap();
250
251 let received_message1 = receiver.recv().await.unwrap();
253 assert_eq!(received_message1, message1);
254
255 let received_message2 = receiver.recv().await.unwrap();
257 assert_eq!(received_message2, message2);
258 }
259
260 #[tokio::test]
261 async fn test_multiple_publishers() {
262 let publisher = TopicStream::<Topic, Message>::new(2);
263 let topic = Topic("test_topic".to_string());
264
265 let mut receiver = publisher.subscribe(&[topic.clone()]);
267
268 let message1 = Message("Message from Publisher 1".to_string());
270 publisher.publish(&topic, message1.clone()).await.unwrap();
271
272 let message2 = Message("Message from Publisher 2".to_string());
274 publisher.publish(&topic, message2.clone()).await.unwrap();
275
276 let received_message1 = receiver.recv().await.unwrap();
278 assert_eq!(received_message1, message1);
279
280 let received_message2 = receiver.recv().await.unwrap();
282 assert_eq!(received_message2, message2);
283 }
284
285 #[tokio::test]
286 async fn test_subscribe_to_different_topics() {
287 let publisher = TopicStream::<Topic, Message>::new(2);
288 let topic1 = Topic("test_topic_1".to_string());
289 let topic2 = Topic("test_topic_2".to_string());
290
291 let mut receiver1 = publisher.subscribe(&[topic1.clone()]);
293
294 let message1 = Message("Hello, Topic 1".to_string());
296 publisher.publish(&topic1, message1.clone()).await.unwrap();
297
298 let received_message1 = receiver1.recv().await.unwrap();
300 assert_eq!(received_message1, message1);
301
302 let mut receiver2 = publisher.subscribe(&[topic2.clone()]);
304
305 let message2 = Message("Hello, Topic 2".to_string());
307 publisher.publish(&topic2, message2.clone()).await.unwrap();
308
309 let received_message2 = receiver2.recv().await.unwrap();
311 assert_eq!(received_message2, message2);
312 }
313
314 #[tokio::test]
315 async fn test_single_receiver_multiple_topics() {
316 let publisher = TopicStream::<Topic, Message>::new(2);
317
318 let topic1 = Topic("test_topic_1".to_string());
320 let topic2 = Topic("test_topic_2".to_string());
321 let topic3 = Topic("test_topic_3".to_string());
322
323 let mut receiver = publisher.subscribe(&[topic1.clone(), topic2.clone(), topic3.clone()]);
325
326 let message1 = Message("Message for Topic 1".to_string());
328 let message2 = Message("Message for Topic 2".to_string());
329 let message3 = Message("Message for Topic 3".to_string());
330
331 publisher.publish(&topic1, message1.clone()).await.unwrap();
332 publisher.publish(&topic2, message2.clone()).await.unwrap();
333 publisher.publish(&topic3, message3.clone()).await.unwrap();
334
335 let received_message1 = receiver.recv().await.unwrap();
337 assert_eq!(received_message1, message1);
338
339 let received_message2 = receiver.recv().await.unwrap();
340 assert_eq!(received_message2, message2);
341
342 let received_message3 = receiver.recv().await.unwrap();
343 assert_eq!(received_message3, message3);
344 }
345}