workflow_core/
channel.rs

1//! [`async_std::channel`] re-exports and shims
2use crate::id::Id;
3pub use async_channel::{
4    bounded, unbounded, Receiver, RecvError, SendError, Sender, TryRecvError, TrySendError,
5};
6use std::{
7    collections::HashMap,
8    marker::PhantomData,
9    sync::{Arc, Mutex},
10};
11use thiserror::Error;
12
13#[derive(Error, Debug)]
14pub enum ChannelError<T> {
15    #[error(transparent)]
16    SendError(#[from] SendError<T>),
17    #[error(transparent)]
18    RecvError(#[from] RecvError),
19    #[error(transparent)]
20    SerdeWasmBindgen(#[from] serde_wasm_bindgen::Error),
21    #[error("try_send() error during multiplexer broadcast")]
22    BroadcastTrySendError,
23}
24
25/// Creates a oneshot channel (bounded channel with a limit of 1 message)
26pub fn oneshot<T>() -> (Sender<T>, Receiver<T>) {
27    bounded(1)
28}
29
30/// [`DuplexChannel`] contains 2 channels `request` and `response`
31/// meant to provide for a request/response pattern. This is useful
32/// for any type of signaling, but especially during task termination,
33/// where you can request a task to terminate and wait for a response
34/// confirming its termination.
35#[derive(Debug, Clone)]
36pub struct DuplexChannel<T = (), R = ()> {
37    pub request: Channel<T>,
38    pub response: Channel<R>,
39}
40
41impl<T, R> DuplexChannel<T, R> {
42    pub fn unbounded() -> Self {
43        Self {
44            request: Channel::unbounded(),
45            response: Channel::unbounded(),
46        }
47    }
48
49    pub fn oneshot() -> Self {
50        Self {
51            request: Channel::oneshot(),
52            response: Channel::oneshot(),
53        }
54    }
55
56    pub async fn signal(&self, msg: T) -> std::result::Result<R, ChannelError<T>> {
57        self.request.sender.send(msg).await?;
58        self.response
59            .receiver
60            .recv()
61            .await
62            .map_err(|err| err.into())
63    }
64}
65
66/// [`Channel`] struct that combines [`async_std::channel::Sender`] and
67/// [`async_std::channel::Receiver`] into a single struct with `sender`
68/// and `receiver` members representing a single channel.
69#[derive(Debug, Clone)]
70pub struct Channel<T = ()> {
71    pub sender: Sender<T>,
72    pub receiver: Receiver<T>,
73}
74
75impl<T> Channel<T> {
76    pub fn unbounded() -> Self {
77        let (sender, receiver) = unbounded();
78        Self { sender, receiver }
79    }
80
81    pub fn bounded(cap: usize) -> Self {
82        let (sender, receiver) = bounded(cap);
83        Self { sender, receiver }
84    }
85
86    pub fn oneshot() -> Self {
87        let (sender, receiver) = bounded(1);
88        Self { sender, receiver }
89    }
90
91    pub fn drain(&self) -> std::result::Result<(), TryRecvError> {
92        while !self.receiver.is_empty() {
93            self.receiver.try_recv()?;
94        }
95        Ok(())
96    }
97
98    pub async fn recv(&self) -> Result<T, RecvError> {
99        self.receiver.recv().await
100    }
101
102    pub fn try_recv(&self) -> Result<T, TryRecvError> {
103        self.receiver.try_recv()
104    }
105
106    pub async fn send(&self, msg: T) -> Result<(), SendError<T>> {
107        self.sender.send(msg).await
108    }
109
110    pub fn try_send(&self, msg: T) -> Result<(), TrySendError<T>> {
111        self.sender.try_send(msg)
112    }
113
114    pub fn len(&self) -> usize {
115        self.receiver.len()
116    }
117
118    pub fn is_empty(&self) -> bool {
119        self.receiver.is_empty()
120    }
121
122    pub fn receiver_count(&self) -> usize {
123        self.sender.receiver_count()
124    }
125
126    pub fn sender_count(&self) -> usize {
127        self.sender.sender_count()
128    }
129
130    pub fn iter(&self) -> ChannelIterator<T> {
131        ChannelIterator::new(self.receiver.clone())
132    }
133}
134
135pub struct ChannelIterator<T> {
136    receiver: Receiver<T>,
137}
138
139impl<T> ChannelIterator<T> {
140    pub fn new(receiver: Receiver<T>) -> Self {
141        ChannelIterator { receiver }
142    }
143}
144
145impl<T> Iterator for ChannelIterator<T> {
146    type Item = T;
147    fn next(&mut self) -> Option<T> {
148        if self.receiver.is_empty() {
149            None
150        } else {
151            self.receiver.try_recv().ok()
152        }
153    }
154}
155
156/// A simple MPMC (one to many) channel Multiplexer that broadcasts to
157/// multiple registered receivers.  [`Multiplexer<T>`] itself can be
158/// cloned and used to broadcast using [`Multiplexer::broadcast()`]
159/// or [`Multiplexer::try_broadcast()`].  To create a receiving channel,
160/// you can call [`MultiplexerChannel<T>::from()`] and supply the
161/// desired Multiplexer instance, or  simply call [`Multiplexer::channel()`]
162/// to create a new [`MultiplexerChannel`] instance.  The receiving channel
163/// gets unregistered when [`MultiplexerChannel`] is dropped or the
164/// underlying [`Receiver`] is closed.
165#[derive(Clone)]
166pub struct Multiplexer<T>
167where
168    T: Clone + Send + Sync + 'static,
169{
170    pub channels: Arc<Mutex<HashMap<Id, Arc<Sender<T>>>>>,
171    t: PhantomData<T>,
172}
173
174impl<T> Default for Multiplexer<T>
175where
176    T: Clone + Send + Sync + 'static,
177{
178    fn default() -> Self {
179        Self::new()
180    }
181}
182
183impl<T> Multiplexer<T>
184where
185    T: Clone + Send + Sync + 'static,
186{
187    /// Create a new Multiplexer instance
188    pub fn new() -> Multiplexer<T> {
189        Multiplexer {
190            channels: Arc::new(Mutex::new(HashMap::default())),
191            t: PhantomData,
192        }
193    }
194
195    /// Create a new multiplexer receiving channel
196    pub fn channel(&self) -> MultiplexerChannel<T> {
197        MultiplexerChannel::from(self)
198    }
199
200    fn register_event_channel(&self) -> (Id, Sender<T>, Receiver<T>) {
201        let (sender, receiver) = unbounded();
202        let id = Id::new();
203        self.channels
204            .lock()
205            .unwrap()
206            .insert(id, Arc::new(sender.clone()));
207        (id, sender, receiver)
208    }
209
210    fn unregister_event_channel(&self, id: Id) {
211        self.channels.lock().unwrap().remove(&id);
212    }
213
214    /// Async [`Multiplexer::broadcast`] function that calls [`Sender::send()`] on all registered [`MultiplexerChannel`] instances.
215    pub async fn broadcast(&self, event: T) -> Result<(), ChannelError<T>> {
216        let mut removed = vec![];
217        let channels = self
218            .channels
219            .lock()
220            .unwrap()
221            .iter()
222            .map(|(k, v)| (*k, v.clone()))
223            .collect::<Vec<_>>();
224        for (id, sender) in channels.iter() {
225            match sender.send(event.clone()).await {
226                Ok(_) => {}
227                Err(_err) => {
228                    removed.push(*id);
229                }
230            }
231        }
232        if !removed.is_empty() {
233            let mut channels = self.channels.lock().unwrap();
234            for id in removed.iter() {
235                channels.remove(id);
236            }
237        }
238
239        Ok(())
240    }
241
242    /// A synchronous [`Multiplexer::try_broadcast`] function that calls [`Sender::try_send()`] on all registered [`MultiplexerChannel`] instances.
243    /// This function holds a mutex for the duration of the broadcast.
244    pub fn try_broadcast(&self, event: T) -> Result<(), ChannelError<T>> {
245        let mut removed = vec![];
246        let mut channels = self.channels.lock().unwrap();
247        for (id, sender) in channels.iter() {
248            match sender.try_send(event.clone()) {
249                Ok(_) => {}
250                Err(_err) => {
251                    removed.push(*id);
252                }
253            }
254        }
255        if !removed.is_empty() {
256            for id in removed.iter() {
257                channels.remove(id);
258            }
259        }
260
261        Ok(())
262    }
263}
264
265/// Receiving channel endpoint for the [`Multiplexer`].  [`MultiplexerChannel<T>`] holds a [`Sender`] and the [`Receiver`] channel endpoints.
266/// The [`Sender`] is provided for convenience, allowing internal relay within this channel instance.
267/// To process events, simply iterate over [`MultiplexerChannel::recv()`] by calling `channel.recv().await`.
268#[derive(Clone)]
269pub struct MultiplexerChannel<T>
270where
271    T: Clone + Send + Sync + 'static,
272{
273    multiplexer: Multiplexer<T>,
274    pub id: Id,
275    pub sender: Sender<T>,
276    pub receiver: Receiver<T>,
277}
278
279impl<T> MultiplexerChannel<T>
280where
281    T: Clone + Send + Sync + 'static,
282{
283    /// Close the receiving channel.  This will unregister the channel from the [`Multiplexer`].
284    pub fn close(&self) {
285        self.multiplexer.unregister_event_channel(self.id);
286    }
287
288    /// Receive an event from the channel.  This is a blocking async call.
289    pub async fn recv(&self) -> Result<T, RecvError> {
290        self.receiver.recv().await
291    }
292
293    /// Receive an event from the channel.  This is a non-blocking sync call that
294    /// follows [`Receiver::try_recv`] semantics.
295    pub fn try_recv(&self) -> Result<T, TryRecvError> {
296        self.receiver.try_recv()
297    }
298}
299
300/// Create a [`MultiplexerChannel`] from [`Multiplexer`] by reference.
301impl<T> From<&Multiplexer<T>> for MultiplexerChannel<T>
302where
303    T: Clone + Send + Sync + 'static,
304{
305    fn from(multiplexer: &Multiplexer<T>) -> Self {
306        let (id, sender, receiver) = multiplexer.register_event_channel();
307        MultiplexerChannel {
308            multiplexer: multiplexer.clone(),
309            id,
310            sender,
311            receiver,
312        }
313    }
314}
315
316impl<T> Drop for MultiplexerChannel<T>
317where
318    T: Clone + Send + Sync + 'static,
319{
320    fn drop(&mut self) {
321        self.multiplexer.unregister_event_channel(self.id);
322    }
323}