use crate::{sink::FanoutMany, BoxSink};
use futures::{
channel::{
mpsc::{self, Receiver, Sender},
oneshot,
},
ready,
stream::BoxStream,
Future, Sink, Stream,
};
use log::error;
use pin_project_lite::pin_project;
use selium_protocol::traits::{ShutdownSink, ShutdownStream};
use selium_std::errors::Result;
use std::{
fmt::Debug,
pin::Pin,
task::{Context, Poll},
};
use tokio_stream::StreamMap;
const SOCK_CHANNEL_SIZE: usize = 100;
pub type TopicShutdown = oneshot::Receiver<()>;
pub enum Socket<T, E> {
Stream(BoxStream<'static, Result<T>>),
Sink(BoxSink<T, E>),
}
pin_project! {
#[project = TopicProj]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Topic<T, E> {
#[pin]
stream: StreamMap<usize, BoxStream<'static, Result<T>>>,
next_stream_id: usize,
#[pin]
sink: FanoutMany<usize, BoxSink<T, E>>,
next_sink_id: usize,
#[pin]
handle: Receiver<Socket<T, E>>,
buffered_item: Option<T>,
}
}
impl<T, E> Topic<T, E> {
pub fn pair() -> (Self, Sender<Socket<T, E>>) {
let (tx, rx) = mpsc::channel(SOCK_CHANNEL_SIZE);
(
Self {
stream: StreamMap::new(),
next_stream_id: 0,
sink: FanoutMany::new(),
next_sink_id: 0,
handle: rx,
buffered_item: None,
},
tx,
)
}
}
impl<T, E> Future for Topic<T, E>
where
E: Debug + Unpin,
T: Clone + Unpin,
{
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let TopicProj {
mut stream,
next_stream_id,
mut sink,
next_sink_id,
mut handle,
buffered_item,
} = self.project();
loop {
if buffered_item.is_some() {
ready!(sink.as_mut().poll_ready(cx)).unwrap();
sink.as_mut()
.start_send(buffered_item.take().unwrap())
.unwrap();
}
match handle.as_mut().poll_next(cx) {
Poll::Ready(Some(sock)) => match sock {
Socket::Stream(st) => {
stream.as_mut().insert(*next_stream_id, st);
*next_stream_id += 1;
}
Socket::Sink(si) => {
sink.as_mut().insert(*next_sink_id, si);
*next_sink_id += 1;
}
},
Poll::Ready(None) => {
ready!(sink.as_mut().poll_flush(cx)).unwrap();
stream.iter_mut().for_each(|(_, s)| s.shutdown_stream());
sink.iter_mut().for_each(|(_, s)| s.shutdown_sink());
return Poll::Ready(());
}
Poll::Pending if stream.is_empty() && buffered_item.is_none() => {
return Poll::Pending
}
Poll::Pending => (),
}
match stream.as_mut().poll_next(cx) {
Poll::Ready(Some((_, Ok(item)))) => *buffered_item = Some(item),
Poll::Ready(Some((_, Err(e)))) => {
error!("Received invalid message from stream: {e:?}")
}
Poll::Ready(None) => ready!(sink.as_mut().poll_flush(cx)).unwrap(),
Poll::Pending => {
ready!(sink.poll_flush(cx)).unwrap();
return Poll::Pending;
}
}
}
}
}