use crate::*;
use async_trait::async_trait;
use std::collections::BTreeMap;
use std::marker::PhantomData;
#[derive(Eq, PartialEq, PartialOrd, Ord, Copy, Clone, Hash)]
struct SubscriptionId(usize);
struct Subscribe<T>
where
T: Message<Result = ()> + Clone,
{
channel: Box<dyn MessageChannel<T>>,
}
struct Unsubscribe<T>
where
T: Message<Result = ()> + Clone,
{
id: SubscriptionId,
_marker: PhantomData<T>,
}
struct Publish<T: Message<Result = ()> + Clone>(T);
pub struct Subscription<T>
where
T: Message<Result = ()> + Clone,
{
id: SubscriptionId,
broker_addr: Address<Broker<T>>,
}
impl<T> Subscription<T>
where
T: Message<Result = ()> + Clone,
{
pub async fn unsubscribe(self) -> Result<(), ()> {
let Self { id, broker_addr } = self;
broker_addr
.send(Unsubscribe {
id,
_marker: PhantomData,
})
.await
.unwrap()
}
}
impl<T> Message for Subscribe<T>
where
T: Message<Result = ()> + Clone,
{
type Result = Result<Subscription<T>, ()>;
}
impl<T> Message for Unsubscribe<T>
where
T: Message<Result = ()> + Clone,
{
type Result = Result<(), ()>;
}
impl<T> Message for Publish<T>
where
T: Message<Result = ()> + Clone,
{
type Result = ();
}
pub struct Broker<T>
where
T: Message<Result = ()> + Clone,
{
next_id: usize,
subscriptions: BTreeMap<SubscriptionId, Box<dyn MessageChannel<T>>>,
}
impl<T> Broker<T>
where
T: Message<Result = ()> + Clone,
{
pub fn new() -> Self {
Self {
next_id: 0,
subscriptions: BTreeMap::new(),
}
}
pub async fn subscribe<A: Actor + Handler<T>>(
subscriber: Address<A>,
) -> Result<Subscription<T>, ()> {
let broker = Self::from_registry().await;
subscriber.subscribe(broker).await
}
pub async fn publish(message: T) -> Result<(), xtra::Disconnected> {
let broker = Self::from_registry().await;
broker.publish(message).await
}
}
impl<T> Default for Broker<T>
where
T: Message<Result = ()> + Clone,
{
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl<T> Actor for Broker<T> where T: Message<Result = ()> + Clone {}
#[async_trait]
impl<T> Handler<Subscribe<T>> for Broker<T>
where
T: Message<Result = ()> + Clone,
{
async fn handle(
&mut self,
message: Subscribe<T>,
ctx: &mut Context<Self>,
) -> <Subscribe<T> as Message>::Result {
let broker_addr = ctx.address().map_err(|_| ())?;
let id = SubscriptionId(self.next_id);
self.next_id += 1;
self.subscriptions.insert(id, message.channel);
Ok(Subscription { id, broker_addr })
}
}
#[async_trait]
impl<T> Handler<Unsubscribe<T>> for Broker<T>
where
T: Message<Result = ()> + Clone,
{
async fn handle(
&mut self,
message: Unsubscribe<T>,
_ctx: &mut Context<Self>,
) -> <Unsubscribe<T> as Message>::Result {
match self.subscriptions.remove(&message.id) {
Some(_) => Ok(()),
None => Err(()),
}
}
}
#[async_trait]
impl<T> Handler<Publish<T>> for Broker<T>
where
T: Message<Result = ()> + Clone,
{
async fn handle(&mut self, Publish(message): Publish<T>, _ctx: &mut Context<Self>) {
let mut disconnected: Vec<SubscriptionId> = Vec::new();
for (&id, subscriber) in &self.subscriptions {
match subscriber.do_send(message.clone()) {
Ok(()) => {}
Err(xtra::Disconnected) => {
disconnected.push(id);
}
}
}
for id in disconnected {
self.subscriptions.remove(&id);
}
}
}
#[async_trait]
pub trait SubscribeExt<M>
where
M: Message<Result = ()> + Clone,
{
async fn subscribe(&self, broker: Address<Broker<M>>) -> Result<Subscription<M>, ()>;
}
#[async_trait]
impl<T, M> SubscribeExt<M> for Address<T>
where
T: Handler<M>,
M: Message<Result = ()> + Clone,
{
async fn subscribe(&self, broker: Address<Broker<M>>) -> Result<Subscription<M>, ()> {
broker
.send(Subscribe {
channel: Box::new(self.clone()),
})
.await
.map_err(|_| ())?
}
}
#[async_trait]
pub trait PublishExt<M>
where
M: Message<Result = ()> + Clone,
{
async fn publish(&self, message: M) -> Result<(), xtra::Disconnected>;
}
#[async_trait]
impl<M> PublishExt<M> for Address<Broker<M>>
where
M: Message<Result = ()> + Clone,
{
async fn publish(&self, message: M) -> Result<(), xtra::Disconnected> {
self.send(Publish(message)).await
}
}
#[cfg(test)]
#[async_std::test]
async fn test_broker() {
use xtra::spawn::AsyncStd;
#[derive(Clone)]
struct Msg {
msg: String,
}
impl Message for Msg {
type Result = ();
}
struct RetrieveMessages;
impl Message for RetrieveMessages {
type Result = Vec<String>;
}
struct Collector {
messages: Vec<String>,
}
impl Collector {
fn new() -> Self {
Self { messages: vec![] }
}
}
impl Actor for Collector {}
struct SubscriberA {
collector: Address<Collector>,
}
impl Actor for SubscriberA {}
struct SubscriberB {
collector: Address<Collector>,
}
impl Actor for SubscriberB {}
#[async_trait]
impl Handler<Msg> for Collector {
async fn handle(&mut self, Msg { msg }: Msg, _ctx: &mut Context<Self>) {
self.messages.push(msg);
}
}
#[async_trait]
impl Handler<RetrieveMessages> for Collector {
async fn handle(
&mut self,
_: RetrieveMessages,
_ctx: &mut Context<Self>,
) -> <RetrieveMessages as Message>::Result {
let mut messages = vec![];
std::mem::swap(&mut self.messages, &mut messages);
messages
}
}
#[async_trait]
impl Handler<Msg> for SubscriberA {
async fn handle(&mut self, msg: Msg, _ctx: &mut Context<Self>) {
self.collector
.do_send(Msg {
msg: format!("{} from SubscriberA", msg.msg),
})
.unwrap();
}
}
#[async_trait]
impl Handler<Msg> for SubscriberB {
async fn handle(&mut self, msg: Msg, _ctx: &mut Context<Self>) {
self.collector
.do_send(Msg {
msg: format!("{} from SubscriberB", msg.msg),
})
.unwrap();
}
}
let broker = Broker::<Msg>::new().create(None).spawn(&mut AsyncStd);
let collector = Collector::new().create(None).spawn(&mut AsyncStd);
let subscriber_a = SubscriberA {
collector: collector.clone(),
}
.create(None)
.spawn(&mut AsyncStd);
let subscriber_b = SubscriberB {
collector: collector.clone(),
}
.create(None)
.spawn(&mut AsyncStd);
assert!(collector.send(RetrieveMessages).await.unwrap().len() == 0);
let subscription_a = subscriber_a.subscribe(broker.clone()).await.unwrap();
let subscription_b = subscriber_b.subscribe(broker.clone()).await.unwrap();
broker
.publish(Msg {
msg: "1".to_string(),
})
.await
.unwrap();
async_std::task::sleep(std::time::Duration::from_millis(100)).await;
let msgs = collector.send(RetrieveMessages).await.unwrap();
if msgs[0].as_str() == "1 from SubscriberA" {
assert_eq!(
msgs,
vec![
"1 from SubscriberA".to_string(),
"1 from SubscriberB".to_string(),
]
);
} else {
assert_eq!(
msgs,
vec![
"1 from SubscriberB".to_string(),
"1 from SubscriberA".to_string(),
]
);
}
subscription_b.unsubscribe().await.unwrap();
broker
.publish(Msg {
msg: "2".to_string(),
})
.await
.unwrap();
async_std::task::sleep(std::time::Duration::from_millis(100)).await;
assert_eq!(
collector.send(RetrieveMessages).await.unwrap(),
vec!["2 from SubscriberA".to_string(),]
);
subscription_a.unsubscribe().await.unwrap();
broker
.publish(Msg {
msg: "3".to_string(),
})
.await
.unwrap();
async_std::task::sleep(std::time::Duration::from_millis(100)).await;
assert_eq!(
collector.send(RetrieveMessages).await.unwrap(),
Vec::<String>::new(),
);
}
#[cfg(test)]
#[async_std::test]
async fn test_broker_using_registry() {
#[derive(Clone)]
struct Msg {
msg: String,
}
impl Message for Msg {
type Result = ();
}
#[derive(Default)]
struct MyActor;
impl Actor for MyActor {}
#[async_trait]
impl Handler<Msg> for MyActor {
async fn handle(&mut self, _: Msg, _ctx: &mut Context<Self>) {}
}
let myactor = MyActor::from_registry().await;
let subscription = Broker::<Msg>::subscribe(myactor).await.unwrap();
Broker::<Msg>::publish(Msg { msg: "123".into() })
.await
.unwrap();
subscription.unsubscribe().await.unwrap();
}