rusher_pubsub/
memory.rs

1use std::{
2    collections::{HashMap, HashSet},
3    mem,
4    sync::Arc,
5};
6
7use async_stream::stream;
8use futures::{stream::BoxStream, StreamExt};
9use serde::{de::DeserializeOwned, Serialize};
10use tokio::sync::{
11    broadcast::{
12        self,
13        error::{RecvError, TryRecvError},
14    },
15    mpsc, Mutex,
16};
17
18use crate::{BoxError, Broker, Connection};
19
20#[derive(Debug, Clone)]
21pub struct MemoryBroker {
22    broadcast: broadcast::Sender<(String, Vec<u8>)>,
23    subscribers: Arc<Mutex<HashMap<String, usize>>>,
24}
25
26impl Default for MemoryBroker {
27    fn default() -> Self {
28        Self::with_capacity(1000)
29    }
30}
31
32impl MemoryBroker {
33    pub fn with_capacity(capacity: usize) -> Self {
34        let (sender, _) = broadcast::channel(capacity);
35        Self {
36            broadcast: sender,
37            subscribers: Default::default(),
38        }
39    }
40}
41
42impl Broker for MemoryBroker {
43    type Conn = MemoryConnection;
44
45    async fn connect(&self) -> Result<MemoryConnection, BoxError> {
46        let sender = self.broadcast.clone();
47        let receiver = sender.subscribe();
48        let (events_tx, mut events_rx) = mpsc::unbounded_channel::<ConnectionEvent>();
49
50        tokio::spawn({
51            let subscribers = self.subscribers.clone();
52            async move {
53                while let Some(event) = events_rx.recv().await {
54                    match event {
55                        ConnectionEvent::Subscribe(channel) => {
56                            subscribers
57                                .lock()
58                                .await
59                                .entry(channel)
60                                .and_modify(|count| *count += 1)
61                                .or_insert(1);
62                        }
63                        ConnectionEvent::Unsubscribe(channel) => {
64                            subscribers
65                                .lock()
66                                .await
67                                .entry(channel)
68                                .and_modify(|count| *count -= 1)
69                                .or_default();
70                        }
71                    }
72                }
73            }
74        });
75
76        Ok(MemoryConnection {
77            sender,
78            receiver,
79            events: events_tx,
80            subs: HashSet::new(),
81            user_id: None,
82        })
83    }
84
85    async fn subscribers_count(&self, channel: &str) -> usize {
86        self.subscribers
87            .lock()
88            .await
89            .get(channel)
90            .copied()
91            .unwrap_or(0)
92    }
93
94    async fn subscriptions(&self) -> HashSet<(String, usize)> {
95        self.subscribers
96            .lock()
97            .await
98            .iter()
99            .map(|(channel, count)| (channel.clone(), *count))
100            .filter(|(_, count)| *count > 0)
101            .collect()
102    }
103
104    async fn publish(&self, channel: &str, msg: impl Serialize) -> Result<(), BoxError> {
105        self.broadcast
106            .send((channel.to_owned(), serde_json::to_vec(&msg)?))?;
107        Ok(())
108    }
109
110    fn all_messages<T: DeserializeOwned + Send + 'static>(&self) -> BoxStream<'static, T> {
111        let mut msgs = self.broadcast.clone().subscribe();
112        stream! {
113            loop {
114                match msgs.try_recv() {
115                    Ok((_, msg)) => {
116                        if let Ok(msg) = serde_json::from_slice(&msg) {
117                            yield msg
118                        }
119                    }
120                    Err(TryRecvError::Lagged(_)) => continue,
121                    Err(_) => break,
122                }
123            }
124        }
125        .boxed()
126    }
127}
128
129#[derive(Debug, Clone)]
130enum ConnectionEvent {
131    Subscribe(String),
132    Unsubscribe(String),
133}
134
135#[derive(Debug)]
136pub struct MemoryConnection {
137    sender: broadcast::Sender<(String, Vec<u8>)>,
138    receiver: broadcast::Receiver<(String, Vec<u8>)>,
139    events: mpsc::UnboundedSender<ConnectionEvent>,
140    subs: HashSet<String>,
141    user_id: Option<String>,
142}
143
144impl Drop for MemoryConnection {
145    fn drop(&mut self) {
146        for channel in mem::take(&mut self.subs).into_iter() {
147            self.events
148                .send(ConnectionEvent::Unsubscribe(channel.to_owned()))
149                .ok();
150        }
151    }
152}
153
154impl Connection for MemoryConnection {
155    async fn authenticate(&mut self, user_id: &str, _data: impl Serialize) -> Result<(), BoxError> {
156        match self.user_id.as_mut() {
157            Some(current_user_id) if current_user_id != user_id => {
158                Err("Connection already authenticated".into())
159            }
160            Some(current_user_id) => {
161                *current_user_id = user_id.to_owned();
162                Ok(())
163            }
164            None => {
165                self.user_id = Some(user_id.to_string());
166                Ok(())
167            }
168        }
169    }
170
171    async fn publish(&mut self, channel: &str, msg: impl Serialize) -> Result<(), BoxError> {
172        self.sender
173            .send((channel.to_owned(), serde_json::to_vec(&msg)?))?;
174        Ok(())
175    }
176
177    async fn subscribe(&mut self, channel: &str) -> Result<(), BoxError> {
178        if self.subs.insert(channel.to_owned()) {
179            self.events
180                .send(ConnectionEvent::Subscribe(channel.to_owned()))?;
181        }
182        Ok(())
183    }
184
185    async fn unsubscribe(&mut self, channel: &str) -> Result<(), BoxError> {
186        if self.subs.remove(channel) {
187            self.events
188                .send(ConnectionEvent::Unsubscribe(channel.to_owned()))?;
189        }
190        Ok(())
191    }
192
193    async fn recv<T: DeserializeOwned>(&mut self) -> Result<T, BoxError> {
194        loop {
195            match self.receiver.recv().await {
196                Ok((channel, msg)) => match serde_json::from_slice(&msg) {
197                    Ok(msg) if self.subs.contains(&channel) => return Ok(msg),
198                    _ => continue,
199                },
200                Err(RecvError::Lagged(_)) => continue,
201                Err(err) => return Err(err.into()),
202            }
203        }
204    }
205
206    async fn try_recv<T: DeserializeOwned>(&mut self) -> Result<Option<T>, BoxError> {
207        loop {
208            match self.receiver.try_recv() {
209                Ok((channel, msg)) => match serde_json::from_slice(&msg) {
210                    Ok(msg) if self.subs.contains(&channel) => return Ok(Some(msg)),
211                    _ => return Ok(None),
212                },
213                Err(TryRecvError::Empty) => return Ok(None),
214                Err(TryRecvError::Lagged(_)) => continue,
215                Err(err) => return Err(err.into()),
216            }
217        }
218    }
219}
220
221#[cfg(test)]
222mod test {
223    use tokio::time;
224
225    use super::*;
226
227    #[tokio::test]
228    async fn test_pubsub() {
229        let broker = MemoryBroker::default();
230        let mut conn1 = broker.connect().await.unwrap();
231        let mut conn2 = broker.connect().await.unwrap();
232        let mut conn3 = broker.connect().await.unwrap();
233
234        conn1.subscribe("channel_all").await.unwrap();
235        conn2.subscribe("channel_all").await.unwrap();
236        conn3.subscribe("channel_all").await.unwrap();
237
238        conn2.subscribe("channel2").await.unwrap();
239
240        conn3.subscribe("channel3").await.unwrap();
241
242        conn1.publish("channel_all", "1").await.unwrap();
243        conn2.publish("channel_all", "2").await.unwrap();
244        conn3.publish("channel_all", "3").await.unwrap();
245
246        conn1.publish("channel2", "only 2").await.unwrap();
247        conn1.publish("channel3", "only 3").await.unwrap();
248
249        assert_eq!("1", conn1.recv::<String>().await.unwrap());
250        assert_eq!("2", conn1.recv::<String>().await.unwrap());
251        assert_eq!("3", conn1.recv::<String>().await.unwrap());
252
253        assert_eq!("1", conn2.recv::<String>().await.unwrap());
254        assert_eq!("2", conn2.recv::<String>().await.unwrap());
255        assert_eq!("3", conn2.recv::<String>().await.unwrap());
256        assert_eq!("only 2", conn2.recv::<String>().await.unwrap());
257
258        assert_eq!("1", conn3.recv::<String>().await.unwrap());
259        assert_eq!("2", conn3.recv::<String>().await.unwrap());
260        assert_eq!("3", conn3.recv::<String>().await.unwrap());
261        assert_eq!("only 3", conn3.recv::<String>().await.unwrap());
262    }
263
264    #[tokio::test]
265    async fn test_unsubsribe() {
266        let broker = MemoryBroker::default();
267        let mut conn1 = broker.connect().await.unwrap();
268        let mut conn2 = broker.connect().await.unwrap();
269
270        conn1.subscribe("channel").await.unwrap();
271        conn2.subscribe("channel").await.unwrap();
272
273        conn1.publish("channel", "1").await.unwrap();
274        assert_eq!("1", conn1.recv::<String>().await.unwrap());
275        assert_eq!("1", conn2.recv::<String>().await.unwrap());
276
277        conn1.unsubscribe("channel").await.unwrap();
278
279        conn2.publish("channel", "3").await.unwrap();
280
281        assert_eq!("3", conn2.recv::<String>().await.unwrap());
282        assert_eq!(None, conn1.try_recv::<String>().await.unwrap());
283    }
284
285    #[tokio::test]
286    async fn test_broker_subscribers_count() {
287        let mut interval = time::interval(time::Duration::from_millis(1));
288        let broker = MemoryBroker::default();
289        let mut conn1 = broker.connect().await.unwrap();
290        let mut conn2 = broker.connect().await.unwrap();
291
292        conn1.subscribe("channel1").await.unwrap();
293        conn1.subscribe("channel2").await.unwrap();
294        conn2.subscribe("channel1").await.unwrap();
295        interval.tick().await;
296
297        assert_eq!(0, broker.subscribers_count("channel0").await);
298        assert_eq!(2, broker.subscribers_count("channel1").await);
299        assert_eq!(1, broker.subscribers_count("channel2").await);
300
301        conn1.unsubscribe("channel1").await.unwrap();
302        interval.tick().await;
303
304        assert_eq!(1, broker.subscribers_count("channel1").await);
305    }
306
307    #[tokio::test]
308    async fn test_subscriptions() {
309        let mut interval = time::interval(time::Duration::from_millis(1));
310        let broker = MemoryBroker::default();
311        let mut conn1 = broker.connect().await.unwrap();
312        let mut conn2 = broker.connect().await.unwrap();
313
314        conn1.subscribe("channel1").await.unwrap();
315        conn1.subscribe("channel2").await.unwrap();
316        conn1.subscribe("channel3").await.unwrap();
317
318        conn2.subscribe("channel1").await.unwrap();
319        conn2.subscribe("channel3").await.unwrap();
320        conn2.unsubscribe("channel3").await.unwrap();
321
322        interval.tick().await;
323
324        assert_eq!(
325            HashSet::from_iter([
326                (String::from("channel1"), 2),
327                (String::from("channel2"), 1),
328                (String::from("channel3"), 1)
329            ]),
330            broker.subscriptions().await
331        );
332
333        drop(conn1);
334        interval.tick().await;
335
336        assert_eq!(
337            HashSet::from_iter([(String::from("channel1"), 1)]),
338            broker.subscriptions().await
339        );
340    }
341}