1use crate::*;
2use async_trait::async_trait;
3use std::collections::BTreeMap;
4use std::marker::PhantomData;
5
6#[derive(Eq, PartialEq, PartialOrd, Ord, Copy, Clone, Hash)]
9struct SubscriptionId(usize);
10
11struct Subscribe<T>
12where
13 T: Message<Result = ()> + Clone,
14{
15 channel: Box<dyn MessageChannel<T>>,
16}
17
18struct Unsubscribe<T>
19where
20 T: Message<Result = ()> + Clone,
21{
22 id: SubscriptionId,
23 _marker: PhantomData<T>,
24}
25
26struct Publish<T: Message<Result = ()> + Clone>(T);
27
28pub struct Subscription<T>
29where
30 T: Message<Result = ()> + Clone,
31{
32 id: SubscriptionId,
33 broker_addr: Address<Broker<T>>,
34}
35
36impl<T> Subscription<T>
37where
38 T: Message<Result = ()> + Clone,
39{
40 pub async fn unsubscribe(self) -> Result<(), ()> {
41 let Self { id, broker_addr } = self;
42 broker_addr
43 .send(Unsubscribe {
44 id,
45 _marker: PhantomData,
46 })
47 .await
48 .unwrap()
49 }
50}
51
52impl<T> Message for Subscribe<T>
53where
54 T: Message<Result = ()> + Clone,
55{
56 type Result = Result<Subscription<T>, ()>;
57}
58
59impl<T> Message for Unsubscribe<T>
60where
61 T: Message<Result = ()> + Clone,
62{
63 type Result = Result<(), ()>;
64}
65
66impl<T> Message for Publish<T>
67where
68 T: Message<Result = ()> + Clone,
69{
70 type Result = ();
71}
72
73pub struct Broker<T>
74where
75 T: Message<Result = ()> + Clone,
76{
77 next_id: usize,
78 subscriptions: BTreeMap<SubscriptionId, Box<dyn MessageChannel<T>>>,
79}
80
81impl<T> Broker<T>
82where
83 T: Message<Result = ()> + Clone,
84{
85 pub fn new() -> Self {
86 Self {
87 next_id: 0,
88 subscriptions: BTreeMap::new(),
89 }
90 }
91
92 pub async fn subscribe<A: Actor + Handler<T>>(
93 subscriber: Address<A>,
94 ) -> Result<Subscription<T>, ()> {
95 let broker = Self::from_registry().await;
96 subscriber.subscribe(broker).await
97 }
98
99 pub async fn publish(message: T) -> Result<(), xtra::Disconnected> {
100 let broker = Self::from_registry().await;
101 broker.publish(message).await
102 }
103}
104
105impl<T> Default for Broker<T>
106where
107 T: Message<Result = ()> + Clone,
108{
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114#[async_trait]
115impl<T> Actor for Broker<T> where T: Message<Result = ()> + Clone {}
116
117#[async_trait]
118impl<T> Handler<Subscribe<T>> for Broker<T>
119where
120 T: Message<Result = ()> + Clone,
121{
122 async fn handle(
123 &mut self,
124 message: Subscribe<T>,
125 ctx: &mut Context<Self>,
126 ) -> <Subscribe<T> as Message>::Result {
127 let broker_addr = ctx.address().map_err(|_| ())?;
128
129 let id = SubscriptionId(self.next_id);
130 self.next_id += 1;
131
132 self.subscriptions.insert(id, message.channel);
133
134 Ok(Subscription { id, broker_addr })
135 }
136}
137
138#[async_trait]
139impl<T> Handler<Unsubscribe<T>> for Broker<T>
140where
141 T: Message<Result = ()> + Clone,
142{
143 async fn handle(
144 &mut self,
145 message: Unsubscribe<T>,
146 _ctx: &mut Context<Self>,
147 ) -> <Unsubscribe<T> as Message>::Result {
148 match self.subscriptions.remove(&message.id) {
149 Some(_) => Ok(()),
150 None => Err(()),
151 }
152 }
153}
154
155#[async_trait]
156impl<T> Handler<Publish<T>> for Broker<T>
157where
158 T: Message<Result = ()> + Clone,
159{
160 async fn handle(&mut self, Publish(message): Publish<T>, _ctx: &mut Context<Self>) {
161 let mut disconnected: Vec<SubscriptionId> = Vec::new();
162
163 for (&id, subscriber) in &self.subscriptions {
164 match subscriber.do_send(message.clone()) {
165 Ok(()) => {}
166 Err(xtra::Disconnected) => {
167 disconnected.push(id);
168 }
169 }
170 }
171
172 for id in disconnected {
173 self.subscriptions.remove(&id);
174 }
175 }
176}
177
178#[async_trait]
179pub trait SubscribeExt<M>
180where
181 M: Message<Result = ()> + Clone,
182{
183 async fn subscribe(&self, broker: Address<Broker<M>>) -> Result<Subscription<M>, ()>;
184}
185
186#[async_trait]
187impl<T, M> SubscribeExt<M> for Address<T>
188where
189 T: Handler<M>,
190 M: Message<Result = ()> + Clone,
191{
192 async fn subscribe(&self, broker: Address<Broker<M>>) -> Result<Subscription<M>, ()> {
193 broker
194 .send(Subscribe {
195 channel: Box::new(self.clone()),
196 })
197 .await
198 .map_err(|_| ())?
199 }
200}
201
202#[async_trait]
203pub trait PublishExt<M>
204where
205 M: Message<Result = ()> + Clone,
206{
207 async fn publish(&self, message: M) -> Result<(), xtra::Disconnected>;
208}
209
210#[async_trait]
211impl<M> PublishExt<M> for Address<Broker<M>>
212where
213 M: Message<Result = ()> + Clone,
214{
215 async fn publish(&self, message: M) -> Result<(), xtra::Disconnected> {
216 self.send(Publish(message)).await
217 }
218}
219
220#[cfg(test)]
221#[async_std::test]
222async fn test_broker() {
223 use xtra::spawn::AsyncStd;
224
225 #[derive(Clone)]
226 struct Msg {
227 msg: String,
228 }
229
230 impl Message for Msg {
231 type Result = ();
232 }
233
234 struct RetrieveMessages;
235
236 impl Message for RetrieveMessages {
237 type Result = Vec<String>;
238 }
239
240 struct Collector {
241 messages: Vec<String>,
242 }
243
244 impl Collector {
245 fn new() -> Self {
246 Self { messages: vec![] }
247 }
248 }
249
250 impl Actor for Collector {}
251
252 struct SubscriberA {
253 collector: Address<Collector>,
254 }
255
256 impl Actor for SubscriberA {}
257
258 struct SubscriberB {
259 collector: Address<Collector>,
260 }
261
262 impl Actor for SubscriberB {}
263
264 #[async_trait]
265 impl Handler<Msg> for Collector {
266 async fn handle(&mut self, Msg { msg }: Msg, _ctx: &mut Context<Self>) {
267 self.messages.push(msg);
268 }
269 }
270
271 #[async_trait]
272 impl Handler<RetrieveMessages> for Collector {
273 async fn handle(
274 &mut self,
275 _: RetrieveMessages,
276 _ctx: &mut Context<Self>,
277 ) -> <RetrieveMessages as Message>::Result {
278 let mut messages = vec![];
279
280 std::mem::swap(&mut self.messages, &mut messages);
281
282 messages
283 }
284 }
285
286 #[async_trait]
287 impl Handler<Msg> for SubscriberA {
288 async fn handle(&mut self, msg: Msg, _ctx: &mut Context<Self>) {
289 self.collector
290 .do_send(Msg {
291 msg: format!("{} from SubscriberA", msg.msg),
292 })
293 .unwrap();
294 }
295 }
296
297 #[async_trait]
298 impl Handler<Msg> for SubscriberB {
299 async fn handle(&mut self, msg: Msg, _ctx: &mut Context<Self>) {
300 self.collector
301 .do_send(Msg {
302 msg: format!("{} from SubscriberB", msg.msg),
303 })
304 .unwrap();
305 }
306 }
307
308 let broker = Broker::<Msg>::new().create(None).spawn(&mut AsyncStd);
309 let collector = Collector::new().create(None).spawn(&mut AsyncStd);
310
311 let subscriber_a = SubscriberA {
312 collector: collector.clone(),
313 }
314 .create(None)
315 .spawn(&mut AsyncStd);
316 let subscriber_b = SubscriberB {
317 collector: collector.clone(),
318 }
319 .create(None)
320 .spawn(&mut AsyncStd);
321
322 assert!(collector.send(RetrieveMessages).await.unwrap().len() == 0);
324
325 let subscription_a = subscriber_a.subscribe(broker.clone()).await.unwrap();
326 let subscription_b = subscriber_b.subscribe(broker.clone()).await.unwrap();
327
328 broker
331 .publish(Msg {
332 msg: "1".to_string(),
333 })
334 .await
335 .unwrap();
336
337 async_std::task::sleep(std::time::Duration::from_millis(100)).await;
339
340 let msgs = collector.send(RetrieveMessages).await.unwrap();
341 if msgs[0].as_str() == "1 from SubscriberA" {
342 assert_eq!(
343 msgs,
344 vec![
345 "1 from SubscriberA".to_string(),
346 "1 from SubscriberB".to_string(),
347 ]
348 );
349 } else {
350 assert_eq!(
351 msgs,
352 vec![
353 "1 from SubscriberB".to_string(),
354 "1 from SubscriberA".to_string(),
355 ]
356 );
357 }
358
359 subscription_b.unsubscribe().await.unwrap();
362
363 broker
364 .publish(Msg {
365 msg: "2".to_string(),
366 })
367 .await
368 .unwrap();
369
370 async_std::task::sleep(std::time::Duration::from_millis(100)).await;
372
373 assert_eq!(
374 collector.send(RetrieveMessages).await.unwrap(),
375 vec!["2 from SubscriberA".to_string(),]
376 );
377
378 subscription_a.unsubscribe().await.unwrap();
381
382 broker
383 .publish(Msg {
384 msg: "3".to_string(),
385 })
386 .await
387 .unwrap();
388
389 async_std::task::sleep(std::time::Duration::from_millis(100)).await;
391
392 assert_eq!(
393 collector.send(RetrieveMessages).await.unwrap(),
394 Vec::<String>::new(),
395 );
396}
397
398#[cfg(test)]
399#[async_std::test]
400async fn test_broker_using_registry() {
401 #[derive(Clone)]
402 struct Msg {
403 msg: String,
404 }
405
406 impl Message for Msg {
407 type Result = ();
408 }
409
410 #[derive(Default)]
411 struct MyActor;
412
413 impl Actor for MyActor {}
414
415 #[async_trait]
416 impl Handler<Msg> for MyActor {
417 async fn handle(&mut self, _: Msg, _ctx: &mut Context<Self>) {}
418 }
419
420 let myactor = MyActor::from_registry().await;
421
422 let subscription = Broker::<Msg>::subscribe(myactor).await.unwrap();
423
424 Broker::<Msg>::publish(Msg { msg: "123".into() })
425 .await
426 .unwrap();
427
428 subscription.unsubscribe().await.unwrap();
429}