topic_stream/
topic_stream.rs1use async_broadcast::{Receiver, SendError, Sender};
2use dashmap::DashMap;
3use futures::future::select_all;
4use std::{collections::HashSet, hash::Hash, sync::Arc};
5
6#[derive(Debug, Clone)]
13pub struct TopicStream<T: Eq + Hash + Clone, M: Clone> {
14 subscribers: Arc<DashMap<T, Sender<M>>>,
16 buffer_size: usize,
18}
19
20impl<T: Eq + Hash + Clone, M: Clone> TopicStream<T, M> {
21 pub fn new(buffer_size: usize) -> Self {
29 Self {
30 subscribers: Arc::new(DashMap::new()),
31 buffer_size,
32 }
33 }
34
35 pub fn subscribe(&self, topics: &[T]) -> MultiTopicReceiver<T, M> {
44 let mut receiver = MultiTopicReceiver::new(Arc::clone(&self.subscribers), self.buffer_size);
45 receiver.subscribe(topics);
46
47 receiver
48 }
49
50 pub async fn publish(&self, topic: &T, message: M) -> Result<(), SendError<M>> {
61 if let Some(sender) = self.subscribers.get(topic) {
62 sender.broadcast(message).await?;
63 };
64
65 Ok(())
66 }
67}
68
69#[derive(Debug)]
75pub struct MultiTopicReceiver<T: Eq + Hash + Clone, M: Clone> {
76 subscribers: Arc<DashMap<T, Sender<M>>>,
78 receivers: Vec<Receiver<M>>,
80 subscribed_topics: HashSet<T>,
82 buffer_size: usize,
84}
85
86impl<T: Eq + Hash + Clone, M: Clone> MultiTopicReceiver<T, M> {
87 pub fn new(subscribers: Arc<DashMap<T, Sender<M>>>, buffer_size: usize) -> Self {
96 Self {
97 subscribers,
98 receivers: Vec::new(),
99 subscribed_topics: HashSet::new(),
100 buffer_size,
101 }
102 }
103
104 pub fn subscribe(&mut self, topics: &[T]) {
110 self.receivers.extend(
111 topics
112 .iter()
113 .filter(|topic| self.subscribed_topics.insert((*topic).clone()))
114 .map(|topic| {
115 let topic = topic.clone();
116 let (sender, _receiver) = async_broadcast::broadcast(self.buffer_size);
117
118 self.subscribers
119 .entry(topic)
120 .or_insert_with(|| sender)
121 .new_receiver()
122 }),
123 );
124 }
125
126 pub async fn recv(&mut self) -> Option<M> {
128 self.receivers.retain(|r| !r.is_closed());
129
130 if self.receivers.is_empty() {
131 return None;
132 }
133
134 let futures = self
135 .receivers
136 .iter_mut()
137 .map(|receiver| Box::pin(receiver.recv()))
138 .collect::<Vec<_>>();
139
140 let (result, _index, _remaining) = select_all(futures).await;
141
142 result.ok() }
144}
145
146impl<T: Eq + Hash + Clone, M: Clone> Drop for MultiTopicReceiver<T, M> {
147 fn drop(&mut self) {
148 let mut to_remove = Vec::new();
149
150 for topic in &self.subscribed_topics {
151 if let Some(sender) = self.subscribers.get(topic) {
152 if sender.receiver_count() <= 1 {
153 to_remove.push(topic.clone());
154 }
155 }
156 }
157
158 to_remove.into_iter().for_each(|topic| {
159 self.subscribers.remove(&topic);
160 });
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167 use std::hash::Hash;
168
169 #[derive(Debug, Clone, Hash, Eq, PartialEq)]
170 struct Topic(String);
171
172 #[derive(Debug, Clone, Eq, PartialEq)]
173 struct Message(String);
174
175 #[tokio::test]
176 async fn test_subscribe_and_publish_single_subscriber() {
177 let publisher = TopicStream::<Topic, Message>::new(2);
178 let topic = Topic("test_topic".to_string());
179
180 let mut receiver = publisher.subscribe(&[topic.clone()]);
182
183 let message = Message("Hello, Subscriber!".to_string());
185 publisher.publish(&topic, message.clone()).await.unwrap();
186
187 let received_message = receiver.recv().await.unwrap();
189 assert_eq!(received_message, message);
190 }
191
192 #[tokio::test]
193 async fn test_subscribe_multiple_subscribers() {
194 let publisher = TopicStream::<Topic, Message>::new(2);
195 let topic = Topic("test_topic".to_string());
196
197 let mut receiver1 = publisher.subscribe(&[topic.clone()]);
199 let mut receiver2 = publisher.subscribe(&[topic.clone()]);
201
202 let message = Message("Hello, Subscribers!".to_string());
204 publisher.publish(&topic, message.clone()).await.unwrap();
205
206 let received_message1 = receiver1.recv().await.unwrap();
208 assert_eq!(received_message1, message);
209
210 let received_message2 = receiver2.recv().await.unwrap();
212 assert_eq!(received_message2, message);
213 }
214
215 #[tokio::test]
216 async fn test_publish_to_unsubscribed_topic() {
217 let publisher = TopicStream::<Topic, Message>::new(2);
218 let topic = Topic("test_topic".to_string());
219
220 let mut receiver = publisher.subscribe(&[Topic("invalid_topic".to_string())]);
222
223 let message = Message("Hello, World!".to_string());
225 publisher.publish(&topic, message.clone()).await.unwrap();
226
227 let timeout = tokio::time::sleep(tokio::time::Duration::from_secs(1));
233 tokio::select! {
234 _ = timeout => {
235 }
237 _ = receiver.recv() => {
238 panic!("Unexpected message received after timeout");
239 }
240 }
241 }
242
243 #[tokio::test]
244 async fn test_multiple_messages_for_single_subscriber() {
245 let publisher = TopicStream::<Topic, Message>::new(2);
246 let topic = Topic("test_topic".to_string());
247
248 let mut receiver = publisher.subscribe(&[topic.clone()]);
250
251 let message1 = Message("Message 1".to_string());
253 let message2 = Message("Message 2".to_string());
254 publisher.publish(&topic, message1.clone()).await.unwrap();
255 publisher.publish(&topic, message2.clone()).await.unwrap();
256
257 let received_message1 = receiver.recv().await.unwrap();
259 assert_eq!(received_message1, message1);
260
261 let received_message2 = receiver.recv().await.unwrap();
263 assert_eq!(received_message2, message2);
264 }
265
266 #[tokio::test]
267 async fn test_multiple_publishers() {
268 let publisher = TopicStream::<Topic, Message>::new(2);
269 let topic = Topic("test_topic".to_string());
270
271 let mut receiver = publisher.subscribe(&[topic.clone()]);
273
274 let message1 = Message("Message from Publisher 1".to_string());
276 publisher.publish(&topic, message1.clone()).await.unwrap();
277
278 let message2 = Message("Message from Publisher 2".to_string());
280 publisher.publish(&topic, message2.clone()).await.unwrap();
281
282 let received_message1 = receiver.recv().await.unwrap();
284 assert_eq!(received_message1, message1);
285
286 let received_message2 = receiver.recv().await.unwrap();
288 assert_eq!(received_message2, message2);
289 }
290
291 #[tokio::test]
292 async fn test_subscribe_to_different_topics() {
293 let publisher = TopicStream::<Topic, Message>::new(2);
294 let topic1 = Topic("test_topic_1".to_string());
295 let topic2 = Topic("test_topic_2".to_string());
296
297 let mut receiver1 = publisher.subscribe(&[topic1.clone()]);
299
300 let message1 = Message("Hello, Topic 1".to_string());
302 publisher.publish(&topic1, message1.clone()).await.unwrap();
303
304 let received_message1 = receiver1.recv().await.unwrap();
306 assert_eq!(received_message1, message1);
307
308 let mut receiver2 = publisher.subscribe(&[topic2.clone()]);
310
311 let message2 = Message("Hello, Topic 2".to_string());
313 publisher.publish(&topic2, message2.clone()).await.unwrap();
314
315 let received_message2 = receiver2.recv().await.unwrap();
317 assert_eq!(received_message2, message2);
318 }
319
320 #[tokio::test]
321 async fn test_single_receiver_multiple_topics() {
322 let publisher = TopicStream::<Topic, Message>::new(2);
323
324 let topic1 = Topic("test_topic_1".to_string());
326 let topic2 = Topic("test_topic_2".to_string());
327 let topic3 = Topic("test_topic_3".to_string());
328
329 let mut receiver = publisher.subscribe(&[topic1.clone(), topic2.clone(), topic3.clone()]);
331
332 let message1 = Message("Message for Topic 1".to_string());
334 let message2 = Message("Message for Topic 2".to_string());
335 let message3 = Message("Message for Topic 3".to_string());
336
337 publisher.publish(&topic1, message1.clone()).await.unwrap();
338 publisher.publish(&topic2, message2.clone()).await.unwrap();
339 publisher.publish(&topic3, message3.clone()).await.unwrap();
340
341 let received_message1 = receiver.recv().await.unwrap();
343 assert_eq!(received_message1, message1);
344
345 let received_message2 = receiver.recv().await.unwrap();
346 assert_eq!(received_message2, message2);
347
348 let received_message3 = receiver.recv().await.unwrap();
349 assert_eq!(received_message3, message3);
350 }
351}