1use std::{
2 collections::{HashMap, HashSet},
3 mem,
4 sync::Arc,
5};
6
7use async_stream::stream;
8use futures::{stream::BoxStream, StreamExt};
9use serde::{de::DeserializeOwned, Serialize};
10use tokio::sync::{
11 broadcast::{
12 self,
13 error::{RecvError, TryRecvError},
14 },
15 mpsc, Mutex,
16};
17
18use crate::{BoxError, Broker, Connection};
19
20#[derive(Debug, Clone)]
21pub struct MemoryBroker {
22 broadcast: broadcast::Sender<(String, Vec<u8>)>,
23 subscribers: Arc<Mutex<HashMap<String, usize>>>,
24}
25
26impl Default for MemoryBroker {
27 fn default() -> Self {
28 Self::with_capacity(1000)
29 }
30}
31
32impl MemoryBroker {
33 pub fn with_capacity(capacity: usize) -> Self {
34 let (sender, _) = broadcast::channel(capacity);
35 Self {
36 broadcast: sender,
37 subscribers: Default::default(),
38 }
39 }
40}
41
42impl Broker for MemoryBroker {
43 type Conn = MemoryConnection;
44
45 async fn connect(&self) -> Result<MemoryConnection, BoxError> {
46 let sender = self.broadcast.clone();
47 let receiver = sender.subscribe();
48 let (events_tx, mut events_rx) = mpsc::unbounded_channel::<ConnectionEvent>();
49
50 tokio::spawn({
51 let subscribers = self.subscribers.clone();
52 async move {
53 while let Some(event) = events_rx.recv().await {
54 match event {
55 ConnectionEvent::Subscribe(channel) => {
56 subscribers
57 .lock()
58 .await
59 .entry(channel)
60 .and_modify(|count| *count += 1)
61 .or_insert(1);
62 }
63 ConnectionEvent::Unsubscribe(channel) => {
64 subscribers
65 .lock()
66 .await
67 .entry(channel)
68 .and_modify(|count| *count -= 1)
69 .or_default();
70 }
71 }
72 }
73 }
74 });
75
76 Ok(MemoryConnection {
77 sender,
78 receiver,
79 events: events_tx,
80 subs: HashSet::new(),
81 user_id: None,
82 })
83 }
84
85 async fn subscribers_count(&self, channel: &str) -> usize {
86 self.subscribers
87 .lock()
88 .await
89 .get(channel)
90 .copied()
91 .unwrap_or(0)
92 }
93
94 async fn subscriptions(&self) -> HashSet<(String, usize)> {
95 self.subscribers
96 .lock()
97 .await
98 .iter()
99 .map(|(channel, count)| (channel.clone(), *count))
100 .filter(|(_, count)| *count > 0)
101 .collect()
102 }
103
104 async fn publish(&self, channel: &str, msg: impl Serialize) -> Result<(), BoxError> {
105 self.broadcast
106 .send((channel.to_owned(), serde_json::to_vec(&msg)?))?;
107 Ok(())
108 }
109
110 fn all_messages<T: DeserializeOwned + Send + 'static>(&self) -> BoxStream<'static, T> {
111 let mut msgs = self.broadcast.clone().subscribe();
112 stream! {
113 loop {
114 match msgs.try_recv() {
115 Ok((_, msg)) => {
116 if let Ok(msg) = serde_json::from_slice(&msg) {
117 yield msg
118 }
119 }
120 Err(TryRecvError::Lagged(_)) => continue,
121 Err(_) => break,
122 }
123 }
124 }
125 .boxed()
126 }
127}
128
129#[derive(Debug, Clone)]
130enum ConnectionEvent {
131 Subscribe(String),
132 Unsubscribe(String),
133}
134
135#[derive(Debug)]
136pub struct MemoryConnection {
137 sender: broadcast::Sender<(String, Vec<u8>)>,
138 receiver: broadcast::Receiver<(String, Vec<u8>)>,
139 events: mpsc::UnboundedSender<ConnectionEvent>,
140 subs: HashSet<String>,
141 user_id: Option<String>,
142}
143
144impl Drop for MemoryConnection {
145 fn drop(&mut self) {
146 for channel in mem::take(&mut self.subs).into_iter() {
147 self.events
148 .send(ConnectionEvent::Unsubscribe(channel.to_owned()))
149 .ok();
150 }
151 }
152}
153
154impl Connection for MemoryConnection {
155 async fn authenticate(&mut self, user_id: &str, _data: impl Serialize) -> Result<(), BoxError> {
156 match self.user_id.as_mut() {
157 Some(current_user_id) if current_user_id != user_id => {
158 Err("Connection already authenticated".into())
159 }
160 Some(current_user_id) => {
161 *current_user_id = user_id.to_owned();
162 Ok(())
163 }
164 None => {
165 self.user_id = Some(user_id.to_string());
166 Ok(())
167 }
168 }
169 }
170
171 async fn publish(&mut self, channel: &str, msg: impl Serialize) -> Result<(), BoxError> {
172 self.sender
173 .send((channel.to_owned(), serde_json::to_vec(&msg)?))?;
174 Ok(())
175 }
176
177 async fn subscribe(&mut self, channel: &str) -> Result<(), BoxError> {
178 if self.subs.insert(channel.to_owned()) {
179 self.events
180 .send(ConnectionEvent::Subscribe(channel.to_owned()))?;
181 }
182 Ok(())
183 }
184
185 async fn unsubscribe(&mut self, channel: &str) -> Result<(), BoxError> {
186 if self.subs.remove(channel) {
187 self.events
188 .send(ConnectionEvent::Unsubscribe(channel.to_owned()))?;
189 }
190 Ok(())
191 }
192
193 async fn recv<T: DeserializeOwned>(&mut self) -> Result<T, BoxError> {
194 loop {
195 match self.receiver.recv().await {
196 Ok((channel, msg)) => match serde_json::from_slice(&msg) {
197 Ok(msg) if self.subs.contains(&channel) => return Ok(msg),
198 _ => continue,
199 },
200 Err(RecvError::Lagged(_)) => continue,
201 Err(err) => return Err(err.into()),
202 }
203 }
204 }
205
206 async fn try_recv<T: DeserializeOwned>(&mut self) -> Result<Option<T>, BoxError> {
207 loop {
208 match self.receiver.try_recv() {
209 Ok((channel, msg)) => match serde_json::from_slice(&msg) {
210 Ok(msg) if self.subs.contains(&channel) => return Ok(Some(msg)),
211 _ => return Ok(None),
212 },
213 Err(TryRecvError::Empty) => return Ok(None),
214 Err(TryRecvError::Lagged(_)) => continue,
215 Err(err) => return Err(err.into()),
216 }
217 }
218 }
219}
220
221#[cfg(test)]
222mod test {
223 use tokio::time;
224
225 use super::*;
226
227 #[tokio::test]
228 async fn test_pubsub() {
229 let broker = MemoryBroker::default();
230 let mut conn1 = broker.connect().await.unwrap();
231 let mut conn2 = broker.connect().await.unwrap();
232 let mut conn3 = broker.connect().await.unwrap();
233
234 conn1.subscribe("channel_all").await.unwrap();
235 conn2.subscribe("channel_all").await.unwrap();
236 conn3.subscribe("channel_all").await.unwrap();
237
238 conn2.subscribe("channel2").await.unwrap();
239
240 conn3.subscribe("channel3").await.unwrap();
241
242 conn1.publish("channel_all", "1").await.unwrap();
243 conn2.publish("channel_all", "2").await.unwrap();
244 conn3.publish("channel_all", "3").await.unwrap();
245
246 conn1.publish("channel2", "only 2").await.unwrap();
247 conn1.publish("channel3", "only 3").await.unwrap();
248
249 assert_eq!("1", conn1.recv::<String>().await.unwrap());
250 assert_eq!("2", conn1.recv::<String>().await.unwrap());
251 assert_eq!("3", conn1.recv::<String>().await.unwrap());
252
253 assert_eq!("1", conn2.recv::<String>().await.unwrap());
254 assert_eq!("2", conn2.recv::<String>().await.unwrap());
255 assert_eq!("3", conn2.recv::<String>().await.unwrap());
256 assert_eq!("only 2", conn2.recv::<String>().await.unwrap());
257
258 assert_eq!("1", conn3.recv::<String>().await.unwrap());
259 assert_eq!("2", conn3.recv::<String>().await.unwrap());
260 assert_eq!("3", conn3.recv::<String>().await.unwrap());
261 assert_eq!("only 3", conn3.recv::<String>().await.unwrap());
262 }
263
264 #[tokio::test]
265 async fn test_unsubsribe() {
266 let broker = MemoryBroker::default();
267 let mut conn1 = broker.connect().await.unwrap();
268 let mut conn2 = broker.connect().await.unwrap();
269
270 conn1.subscribe("channel").await.unwrap();
271 conn2.subscribe("channel").await.unwrap();
272
273 conn1.publish("channel", "1").await.unwrap();
274 assert_eq!("1", conn1.recv::<String>().await.unwrap());
275 assert_eq!("1", conn2.recv::<String>().await.unwrap());
276
277 conn1.unsubscribe("channel").await.unwrap();
278
279 conn2.publish("channel", "3").await.unwrap();
280
281 assert_eq!("3", conn2.recv::<String>().await.unwrap());
282 assert_eq!(None, conn1.try_recv::<String>().await.unwrap());
283 }
284
285 #[tokio::test]
286 async fn test_broker_subscribers_count() {
287 let mut interval = time::interval(time::Duration::from_millis(1));
288 let broker = MemoryBroker::default();
289 let mut conn1 = broker.connect().await.unwrap();
290 let mut conn2 = broker.connect().await.unwrap();
291
292 conn1.subscribe("channel1").await.unwrap();
293 conn1.subscribe("channel2").await.unwrap();
294 conn2.subscribe("channel1").await.unwrap();
295 interval.tick().await;
296
297 assert_eq!(0, broker.subscribers_count("channel0").await);
298 assert_eq!(2, broker.subscribers_count("channel1").await);
299 assert_eq!(1, broker.subscribers_count("channel2").await);
300
301 conn1.unsubscribe("channel1").await.unwrap();
302 interval.tick().await;
303
304 assert_eq!(1, broker.subscribers_count("channel1").await);
305 }
306
307 #[tokio::test]
308 async fn test_subscriptions() {
309 let mut interval = time::interval(time::Duration::from_millis(1));
310 let broker = MemoryBroker::default();
311 let mut conn1 = broker.connect().await.unwrap();
312 let mut conn2 = broker.connect().await.unwrap();
313
314 conn1.subscribe("channel1").await.unwrap();
315 conn1.subscribe("channel2").await.unwrap();
316 conn1.subscribe("channel3").await.unwrap();
317
318 conn2.subscribe("channel1").await.unwrap();
319 conn2.subscribe("channel3").await.unwrap();
320 conn2.unsubscribe("channel3").await.unwrap();
321
322 interval.tick().await;
323
324 assert_eq!(
325 HashSet::from_iter([
326 (String::from("channel1"), 2),
327 (String::from("channel2"), 1),
328 (String::from("channel3"), 1)
329 ]),
330 broker.subscriptions().await
331 );
332
333 drop(conn1);
334 interval.tick().await;
335
336 assert_eq!(
337 HashSet::from_iter([(String::from("channel1"), 1)]),
338 broker.subscriptions().await
339 );
340 }
341}