Skip to main content

workflow_core/
channel.rs

1//! [`async-channel`](async_channel) re-exports and shims
2use crate::id::Id;
3pub use async_channel::{
4    Receiver, RecvError, SendError, Sender, TryRecvError, TrySendError, bounded, unbounded,
5};
6use std::{
7    collections::HashMap,
8    marker::PhantomData,
9    sync::{Arc, Mutex},
10};
11use thiserror::Error;
12
13/// Errors produced by channel operations in this module.
14#[derive(Error, Debug)]
15pub enum ChannelError<T> {
16    /// The underlying channel send operation failed (the channel is closed).
17    #[error(transparent)]
18    SendError(#[from] SendError<T>),
19    /// The underlying channel receive operation failed (the channel is closed and empty).
20    #[error(transparent)]
21    RecvError(#[from] RecvError),
22    /// Serialization or deserialization to/from a JavaScript value failed.
23    #[error(transparent)]
24    SerdeWasmBindgen(#[from] serde_wasm_bindgen::Error),
25    /// A non-blocking `try_send()` failed while broadcasting through a [`Multiplexer`].
26    #[error("try_send() error during multiplexer broadcast")]
27    BroadcastTrySendError,
28}
29
30/// Creates a oneshot channel (bounded channel with a limit of 1 message)
31pub fn oneshot<T>() -> (Sender<T>, Receiver<T>) {
32    bounded(1)
33}
34
35/// [`DuplexChannel`] contains 2 channels `request` and `response`
36/// meant to provide for a request/response pattern. This is useful
37/// for any type of signaling, but especially during task termination,
38/// where you can request a task to terminate and wait for a response
39/// confirming its termination.
40#[derive(Debug, Clone)]
41pub struct DuplexChannel<T = (), R = ()> {
42    /// Channel carrying request messages of type `T`.
43    pub request: Channel<T>,
44    /// Channel carrying response messages of type `R`.
45    pub response: Channel<R>,
46}
47
48impl<T, R> DuplexChannel<T, R> {
49    /// Creates a duplex channel whose `request` and `response` channels are both unbounded.
50    pub fn unbounded() -> Self {
51        Self {
52            request: Channel::unbounded(),
53            response: Channel::unbounded(),
54        }
55    }
56
57    /// Creates a duplex channel whose `request` and `response` channels are both oneshot.
58    pub fn oneshot() -> Self {
59        Self {
60            request: Channel::oneshot(),
61            response: Channel::oneshot(),
62        }
63    }
64
65    /// Sends `msg` on the request channel and waits for the matching response,
66    /// returning the received response value.
67    pub async fn signal(&self, msg: T) -> std::result::Result<R, ChannelError<T>> {
68        self.request.sender.send(msg).await?;
69        self.response
70            .receiver
71            .recv()
72            .await
73            .map_err(|err| err.into())
74    }
75}
76
77/// [`Channel`] struct that combines [[`async_channel::Sender`]] and
78/// [[`async_channel::Receiver`]] into a single struct with `sender`
79/// and `receiver` members representing a single channel.
80#[derive(Debug, Clone)]
81pub struct Channel<T = ()> {
82    /// Sender endpoint of the channel.
83    pub sender: Sender<T>,
84    /// Receiver endpoint of the channel.
85    pub receiver: Receiver<T>,
86}
87
88impl<T> Channel<T> {
89    /// Creates a channel with an unbounded message buffer.
90    pub fn unbounded() -> Self {
91        let (sender, receiver) = unbounded();
92        Self { sender, receiver }
93    }
94
95    /// Creates a channel with a buffer bounded to `cap` messages.
96    pub fn bounded(cap: usize) -> Self {
97        let (sender, receiver) = bounded(cap);
98        Self { sender, receiver }
99    }
100
101    /// Creates a oneshot channel (a bounded channel with a capacity of one message).
102    pub fn oneshot() -> Self {
103        let (sender, receiver) = bounded(1);
104        Self { sender, receiver }
105    }
106
107    /// Discards all currently-buffered messages from the channel.
108    pub fn drain(&self) -> std::result::Result<(), TryRecvError> {
109        while !self.receiver.is_empty() {
110            self.receiver.try_recv()?;
111        }
112        Ok(())
113    }
114
115    /// Receives a message, waiting asynchronously until one is available.
116    pub async fn recv(&self) -> Result<T, RecvError> {
117        self.receiver.recv().await
118    }
119
120    /// Attempts to receive a message without blocking, returning an error if the channel is empty.
121    pub fn try_recv(&self) -> Result<T, TryRecvError> {
122        self.receiver.try_recv()
123    }
124
125    /// Sends a message, waiting asynchronously if the channel is bounded and full.
126    pub async fn send(&self, msg: T) -> Result<(), SendError<T>> {
127        self.sender.send(msg).await
128    }
129
130    /// Attempts to send a message without blocking, returning an error if the channel is full or closed.
131    pub fn try_send(&self, msg: T) -> Result<(), TrySendError<T>> {
132        self.sender.try_send(msg)
133    }
134
135    /// Returns the number of messages currently buffered in the channel.
136    pub fn len(&self) -> usize {
137        self.receiver.len()
138    }
139
140    /// Returns `true` if there are no messages currently buffered in the channel.
141    pub fn is_empty(&self) -> bool {
142        self.receiver.is_empty()
143    }
144
145    /// Returns the number of [`Receiver`] endpoints currently connected to the channel.
146    pub fn receiver_count(&self) -> usize {
147        self.sender.receiver_count()
148    }
149
150    /// Returns the number of [`Sender`] endpoints currently connected to the channel.
151    pub fn sender_count(&self) -> usize {
152        self.sender.sender_count()
153    }
154
155    /// Returns an iterator that drains all currently-buffered messages from the channel.
156    pub fn iter(&self) -> ChannelIterator<T> {
157        ChannelIterator::new(self.receiver.clone())
158    }
159}
160
161/// Iterator that drains all currently-buffered messages from a [`Channel`]'s
162/// receiver, yielding `None` once the channel is momentarily empty.
163pub struct ChannelIterator<T> {
164    receiver: Receiver<T>,
165}
166
167impl<T> ChannelIterator<T> {
168    /// Create a new iterator that drains messages from the given receiver.
169    pub fn new(receiver: Receiver<T>) -> Self {
170        ChannelIterator { receiver }
171    }
172}
173
174impl<T> Iterator for ChannelIterator<T> {
175    type Item = T;
176    fn next(&mut self) -> Option<T> {
177        if self.receiver.is_empty() {
178            None
179        } else {
180            self.receiver.try_recv().ok()
181        }
182    }
183}
184
185/// A simple MPMC (one to many) channel Multiplexer that broadcasts to
186/// multiple registered receivers.  [`Multiplexer<T>`] itself can be
187/// cloned and used to broadcast using [`Multiplexer::broadcast()`]
188/// or [`Multiplexer::try_broadcast()`].  To create a receiving channel,
189/// you can call [`MultiplexerChannel<T>::from()`] and supply the
190/// desired Multiplexer instance, or  simply call [`Multiplexer::channel()`]
191/// to create a new [`MultiplexerChannel`] instance.  The receiving channel
192/// gets unregistered when [`MultiplexerChannel`] is dropped or the
193/// underlying [`Receiver`] is closed.
194#[derive(Clone)]
195pub struct Multiplexer<T>
196where
197    T: Clone + Send + Sync + 'static,
198{
199    /// Map of currently registered receiver channels keyed by their unique id.
200    pub channels: Arc<Mutex<HashMap<Id, Arc<Sender<T>>>>>,
201    t: PhantomData<T>,
202}
203
204impl<T> Default for Multiplexer<T>
205where
206    T: Clone + Send + Sync + 'static,
207{
208    fn default() -> Self {
209        Self::new()
210    }
211}
212
213impl<T> Multiplexer<T>
214where
215    T: Clone + Send + Sync + 'static,
216{
217    /// Create a new Multiplexer instance
218    pub fn new() -> Multiplexer<T> {
219        Multiplexer {
220            channels: Arc::new(Mutex::new(HashMap::default())),
221            t: PhantomData,
222        }
223    }
224
225    /// Create a new multiplexer receiving channel
226    pub fn channel(&self) -> MultiplexerChannel<T> {
227        MultiplexerChannel::from(self)
228    }
229
230    fn register_event_channel(&self) -> (Id, Sender<T>, Receiver<T>) {
231        let (sender, receiver) = unbounded();
232        let id = Id::new();
233        self.channels
234            .lock()
235            .unwrap()
236            .insert(id, Arc::new(sender.clone()));
237        (id, sender, receiver)
238    }
239
240    fn unregister_event_channel(&self, id: Id) {
241        self.channels.lock().unwrap().remove(&id);
242    }
243
244    /// Async [`Multiplexer::broadcast`] function that calls [`Sender::send()`] on all registered [`MultiplexerChannel`] instances.
245    pub async fn broadcast(&self, event: T) -> Result<(), ChannelError<T>> {
246        let mut removed = vec![];
247        let channels = self
248            .channels
249            .lock()
250            .unwrap()
251            .iter()
252            .map(|(k, v)| (*k, v.clone()))
253            .collect::<Vec<_>>();
254        for (id, sender) in channels.iter() {
255            match sender.send(event.clone()).await {
256                Ok(_) => {}
257                Err(_err) => {
258                    removed.push(*id);
259                }
260            }
261        }
262        if !removed.is_empty() {
263            let mut channels = self.channels.lock().unwrap();
264            for id in removed.iter() {
265                channels.remove(id);
266            }
267        }
268
269        Ok(())
270    }
271
272    /// A synchronous [`Multiplexer::try_broadcast`] function that calls [`Sender::try_send()`] on all registered [`MultiplexerChannel`] instances.
273    /// This function holds a mutex for the duration of the broadcast.
274    pub fn try_broadcast(&self, event: T) -> Result<(), ChannelError<T>> {
275        let mut removed = vec![];
276        let mut channels = self.channels.lock().unwrap();
277        for (id, sender) in channels.iter() {
278            match sender.try_send(event.clone()) {
279                Ok(_) => {}
280                Err(_err) => {
281                    removed.push(*id);
282                }
283            }
284        }
285        if !removed.is_empty() {
286            for id in removed.iter() {
287                channels.remove(id);
288            }
289        }
290
291        Ok(())
292    }
293}
294
295/// Receiving channel endpoint for the [`Multiplexer`].  [`MultiplexerChannel<T>`] holds a [`Sender`] and the [`Receiver`] channel endpoints.
296/// The [`Sender`] is provided for convenience, allowing internal relay within this channel instance.
297/// To process events, simply iterate over [`MultiplexerChannel::recv()`] by calling `channel.recv().await`.
298#[derive(Clone)]
299pub struct MultiplexerChannel<T>
300where
301    T: Clone + Send + Sync + 'static,
302{
303    multiplexer: Multiplexer<T>,
304    /// Unique id identifying this channel within the parent [`Multiplexer`].
305    pub id: Id,
306    /// Sender endpoint, provided for convenient internal relay within this channel.
307    pub sender: Sender<T>,
308    /// Receiver endpoint from which broadcast events are consumed.
309    pub receiver: Receiver<T>,
310}
311
312impl<T> MultiplexerChannel<T>
313where
314    T: Clone + Send + Sync + 'static,
315{
316    /// Close the receiving channel.  This will unregister the channel from the [`Multiplexer`].
317    pub fn close(&self) {
318        self.multiplexer.unregister_event_channel(self.id);
319    }
320
321    /// Receive an event from the channel.  This is a blocking async call.
322    pub async fn recv(&self) -> Result<T, RecvError> {
323        self.receiver.recv().await
324    }
325
326    /// Receive an event from the channel.  This is a non-blocking sync call that
327    /// follows [`Receiver::try_recv`] semantics.
328    pub fn try_recv(&self) -> Result<T, TryRecvError> {
329        self.receiver.try_recv()
330    }
331}
332
333/// Create a [`MultiplexerChannel`] from [`Multiplexer`] by reference.
334impl<T> From<&Multiplexer<T>> for MultiplexerChannel<T>
335where
336    T: Clone + Send + Sync + 'static,
337{
338    fn from(multiplexer: &Multiplexer<T>) -> Self {
339        let (id, sender, receiver) = multiplexer.register_event_channel();
340        MultiplexerChannel {
341            multiplexer: multiplexer.clone(),
342            id,
343            sender,
344            receiver,
345        }
346    }
347}
348
349impl<T> Drop for MultiplexerChannel<T>
350where
351    T: Clone + Send + Sync + 'static,
352{
353    fn drop(&mut self) {
354        self.multiplexer.unregister_event_channel(self.id);
355    }
356}