streambed_patterns/
ask.rs

1//! This is an implementation of the ask pattern that sends a request to a [tokio::sync::mpsc::Sender<_>]
2//! and uses a [tokio::sync::oneshot] channel internally to convey the reply.
3
4use async_trait::async_trait;
5use std::error::Error;
6use std::fmt::{Debug, Display, Formatter};
7use tokio::sync::{mpsc, oneshot};
8
9#[derive(Debug)]
10pub enum AskError {
11    SendError,
12    ReceiveError,
13}
14
15impl Display for AskError {
16    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
17        match self {
18            AskError::SendError => write!(f, "SendError"),
19            AskError::ReceiveError => write!(f, "ReceiveError"),
20        }
21    }
22}
23
24impl Error for AskError {}
25
26#[async_trait]
27pub trait Ask<A> {
28    /// The ask pattern is a way to send a request and get a reply back.
29    async fn ask<F, R>(&self, f: F) -> Result<R, AskError>
30    where
31        F: FnOnce(Box<dyn FnOnce(R) + Send>) -> A + Send,
32        R: Send + 'static;
33}
34
35#[async_trait]
36impl<A> Ask<A> for mpsc::Sender<A>
37where
38    A: Send,
39{
40    async fn ask<F, R>(&self, f: F) -> Result<R, AskError>
41    where
42        F: FnOnce(Box<dyn FnOnce(R) + Send>) -> A + Send,
43        R: Send + 'static,
44    {
45        let (tx, rx) = oneshot::channel();
46        let reply_to = Box::new(|r| {
47            let _ = tx.send(r);
48        });
49
50        let request = f(reply_to);
51        self.send(request).await.map_err(|_| AskError::SendError)?;
52
53        rx.await.map_err(|_| AskError::ReceiveError)
54    }
55}