streambed_patterns/
ask.rs1use async_trait::async_trait;
5use std::error::Error;
6use std::fmt::{Debug, Display, Formatter};
7use tokio::sync::{mpsc, oneshot};
8
9#[derive(Debug)]
10pub enum AskError {
11 SendError,
12 ReceiveError,
13}
14
15impl Display for AskError {
16 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
17 match self {
18 AskError::SendError => write!(f, "SendError"),
19 AskError::ReceiveError => write!(f, "ReceiveError"),
20 }
21 }
22}
23
24impl Error for AskError {}
25
26#[async_trait]
27pub trait Ask<A> {
28 async fn ask<F, R>(&self, f: F) -> Result<R, AskError>
30 where
31 F: FnOnce(Box<dyn FnOnce(R) + Send>) -> A + Send,
32 R: Send + 'static;
33}
34
35#[async_trait]
36impl<A> Ask<A> for mpsc::Sender<A>
37where
38 A: Send,
39{
40 async fn ask<F, R>(&self, f: F) -> Result<R, AskError>
41 where
42 F: FnOnce(Box<dyn FnOnce(R) + Send>) -> A + Send,
43 R: Send + 'static,
44 {
45 let (tx, rx) = oneshot::channel();
46 let reply_to = Box::new(|r| {
47 let _ = tx.send(r);
48 });
49
50 let request = f(reply_to);
51 self.send(request).await.map_err(|_| AskError::SendError)?;
52
53 rx.await.map_err(|_| AskError::ReceiveError)
54 }
55}