use derive_error::Error;
use futures::{
channel::{
mpsc::{self, SendError},
oneshot,
},
ready,
stream::FusedStream,
task::Context,
Future,
FutureExt,
Stream,
StreamExt,
};
use std::{pin::Pin, task::Poll};
use tower_service::Service;
pub fn unbounded<TReq, TResp>() -> (SenderService<TReq, TResp>, Receiver<TReq, TResp>) {
let (tx, rx) = mpsc::unbounded();
(SenderService::new(tx), Receiver::new(rx))
}
pub type Rx<TReq, TRes> = mpsc::UnboundedReceiver<(TReq, oneshot::Sender<TRes>)>;
pub type Tx<TReq, TRes> = mpsc::UnboundedSender<(TReq, oneshot::Sender<TRes>)>;
pub struct SenderService<TReq, TRes> {
tx: Tx<TReq, TRes>,
}
impl<TReq, TRes> SenderService<TReq, TRes> {
pub fn new(tx: Tx<TReq, TRes>) -> Self {
Self { tx }
}
}
impl<TReq, TRes> Clone for SenderService<TReq, TRes> {
fn clone(&self) -> Self {
Self { tx: self.tx.clone() }
}
}
impl<TReq, TRes> Service<TReq> for SenderService<TReq, TRes> {
type Error = TransportChannelError;
type Future = TransportResponseFuture<TRes>;
type Response = TRes;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.tx.poll_ready(cx).map_err(|err| {
if err.is_disconnected() {
return TransportChannelError::ChannelClosed;
}
unreachable!("unbounded channels can never be full");
})
}
fn call(&mut self, request: TReq) -> Self::Future {
let (tx, rx) = oneshot::channel();
if self.tx.unbounded_send((request, tx)).is_ok() {
TransportResponseFuture::new(rx)
} else {
TransportResponseFuture::closed()
}
}
}
#[derive(Debug, Error, Eq, PartialEq, Clone)]
pub enum TransportChannelError {
SendError(SendError),
Canceled,
ChannelClosed,
}
pub struct TransportResponseFuture<T> {
rx: Option<oneshot::Receiver<T>>,
}
impl<T> TransportResponseFuture<T> {
pub fn new(rx: oneshot::Receiver<T>) -> Self {
Self { rx: Some(rx) }
}
pub fn closed() -> Self {
Self { rx: None }
}
}
impl<T> Future for TransportResponseFuture<T> {
type Output = Result<T, TransportChannelError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.rx {
Some(ref mut rx) => rx.poll_unpin(cx).map_err(|_| TransportChannelError::Canceled),
None => Poll::Ready(Err(TransportChannelError::ChannelClosed)),
}
}
}
pub struct RequestContext<TReq, TResp> {
reply_tx: oneshot::Sender<TResp>,
request: Option<TReq>,
}
impl<TReq, TResp> RequestContext<TReq, TResp> {
pub fn new(request: TReq, reply_tx: oneshot::Sender<TResp>) -> Self {
Self {
request: Some(request),
reply_tx,
}
}
pub fn request(&self) -> Option<&TReq> {
self.request.as_ref()
}
pub fn take_request(&mut self) -> Option<TReq> {
self.request.take()
}
pub fn split(self) -> (TReq, oneshot::Sender<TResp>) {
(
self.request.expect("RequestContext must be initialized with a request"),
self.reply_tx,
)
}
pub fn reply(self, resp: TResp) -> Result<(), TResp> {
self.reply_tx.send(resp)
}
}
pub struct Receiver<TReq, TResp> {
rx: Rx<TReq, TResp>,
}
impl<TReq, TResp> FusedStream for Receiver<TReq, TResp> {
fn is_terminated(&self) -> bool {
self.rx.is_terminated()
}
}
impl<TReq, TResp> Receiver<TReq, TResp> {
pub fn new(rx: Rx<TReq, TResp>) -> Self {
Self { rx }
}
pub fn close(&mut self) {
self.rx.close();
}
}
impl<TReq, TResp> Stream for Receiver<TReq, TResp> {
type Item = RequestContext<TReq, TResp>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match ready!(self.rx.poll_next_unpin(cx)) {
Some((req, tx)) => Poll::Ready(Some(RequestContext::new(req, tx))),
None => Poll::Ready(None),
}
}
}
#[cfg(test)]
mod test {
use super::*;
use futures::{executor::block_on, future};
use std::fmt::Debug;
use tari_test_utils::unpack_enum;
use tower::ServiceExt;
#[test]
fn await_response_future_new() {
let (tx, rx) = oneshot::channel::<Result<(), ()>>();
tx.send(Ok(())).unwrap();
block_on(TransportResponseFuture::new(rx)).unwrap().unwrap();
}
#[test]
fn await_response_future_closed() {
let err = block_on(TransportResponseFuture::<()>::closed()).unwrap_err();
unpack_enum!(TransportChannelError::ChannelClosed = err);
}
async fn reply<TReq, TResp>(mut rx: Rx<TReq, TResp>, msg: TResp)
where TResp: Debug {
match rx.next().await {
Some((_, tx)) => {
tx.send(msg).unwrap();
},
_ => panic!("Expected receiver to have something to receive"),
}
}
#[test]
fn requestor_call() {
let (tx, rx) = mpsc::unbounded();
let mut requestor = SenderService::<_, _>::new(tx);
block_on(requestor.ready()).unwrap();
let fut = future::join(requestor.call("PING"), reply(rx, "PONG"));
let msg = block_on(fut.map(|(r, _)| r.unwrap()));
assert_eq!(msg, "PONG");
}
#[test]
fn requestor_channel_closed() {
let (requestor, mut request_stream) = super::unbounded::<_, ()>();
request_stream.close();
let err = block_on(requestor.oneshot(())).unwrap_err();
unpack_enum!(TransportChannelError::ChannelClosed = err);
}
#[test]
fn request_response_request_abort() {
let (mut requestor, mut request_stream) = super::unbounded::<_, &str>();
block_on(future::join(
async move {
requestor.ready().await.unwrap();
let _ = requestor.call("PING");
},
async move {
let a = request_stream.next().await.unwrap();
let req = a.reply_tx.send("PONG").unwrap_err();
assert_eq!(req, "PONG");
},
));
}
#[test]
fn request_response_response_canceled() {
let (mut requestor, mut request_stream) = super::unbounded::<_, &str>();
block_on(future::join(
async move {
requestor.ready().await.unwrap();
let err = requestor.call("PING").await.unwrap_err();
assert_eq!(err, TransportChannelError::Canceled);
},
async move {
let req = request_stream.next().await.unwrap();
drop(req);
},
));
}
#[test]
fn request_response_success() {
let (mut requestor, mut request_stream) = super::unbounded::<_, &str>();
block_on(requestor.ready()).unwrap();
let (result, _) = block_on(future::join(requestor.call("PING"), async move {
let req = request_stream.next().await.unwrap();
req.reply("PONG").unwrap();
}));
assert_eq!(result.unwrap(), "PONG");
}
}