sisyphus_tasks/
pipe.rs

1use tokio::sync::mpsc;
2
3/// PipeError indicates that either the upstream or downstream channels have
4/// closed
5pub enum PipeError {
6    /// Upstream have all closed
7    InboundGone,
8    /// Downstream have all closed
9    OutboundGone,
10}
11
12/// A pipe that enforces process-once semantics.
13///
14/// This type ensures that the pipe's owner tasks sees a piece of data exactly
15/// once, and that errors or panics in an owning task do not cause data loss
16/// from the pipe.
17///
18/// Using a pipe has a number of advantages
19/// - Data is not lost if errors occur during processing
20/// - Data is not process twice if errors occur during processing
21/// - Flow control via backpressure is preserved
22#[derive(Debug)]
23pub struct Pipe<T>
24where
25    T: std::fmt::Debug + Send + Sync + 'static,
26{
27    rx: mpsc::Receiver<T>,
28    tx: mpsc::Sender<T>,
29    contents: Option<T>,
30}
31
32impl<T> Pipe<T>
33where
34    T: std::fmt::Debug + Send + Sync + 'static,
35{
36    /// Instantiate a new pipe from an inbound receiver and outbound sender
37    pub fn new(rx: mpsc::Receiver<T>, tx: mpsc::Sender<T>, contents: Option<T>) -> Self {
38        Self { rx, tx, contents }
39    }
40
41    /// Creates a series of linked pipes, with an inbound channel, and an
42    /// outbound channel. This allows you to quickly instantiate an ordered
43    /// data-processing pipeline.
44    ///
45    /// Pipes should be pulled from the pipeline in order, and passed to
46    /// processing tasks.
47    ///
48    /// ## Note
49    ///
50    /// All pipes must polled in a loop. Otherwise data will not reach
51    /// downstream pipes.
52    ///
53    /// In order to avoid saturating the pipes, the `rx` must be read from.
54    /// Otherwise, processing will stop once the total capcity is reached
55    pub fn unterminated_pipeline(
56        length: usize,
57        total_capacity: Option<usize>,
58    ) -> (mpsc::Sender<T>, Vec<Pipe<T>>, mpsc::Receiver<T>) {
59        let total_capacity = total_capacity.unwrap_or(length * 20);
60        let buffer = std::cmp::max(total_capacity / length, 1);
61        let (tx, mut rx) = mpsc::channel::<T>(buffer);
62
63        let mut pipeline = Vec::with_capacity(length);
64
65        (0..length).for_each(|_| {
66            let (next_tx, mut next_rx) = mpsc::channel::<T>(buffer);
67            std::mem::swap(&mut rx, &mut next_rx);
68            pipeline.push(Pipe::new(next_rx, next_tx, None));
69        });
70
71        (tx, pipeline, rx)
72    }
73
74    /// Creates a series of linked pipes, with an inbound channel, and an
75    /// outbound channel. This allows you to quickly instantiate an ordered
76    /// data-processing pipeline.
77    ///
78    /// The outbound channel from this pipeline terminates in a simple task
79    /// that drops messages as soon as they reach it. This ensures that the
80    /// pipeline does not saturate. This task will end as soon as its upstream
81    /// pipes close.
82    ///
83    /// Pipes should be pulled from the pipeline in order, and passed to
84    /// processing tasks.
85    ///
86    /// ## Note
87    ///
88    /// All pipes must polled in a loop. Otherwise data will not reach
89    /// downstream pipes.
90    pub fn pipeline(
91        length: usize,
92        total_capacity: Option<usize>,
93    ) -> (mpsc::Sender<T>, Vec<Pipe<T>>) {
94        let (tx, pipeline, mut rx) = Self::unterminated_pipeline(length, total_capacity);
95        tokio::spawn(async move {
96            loop {
97                if rx.recv().await.is_none() {
98                    break;
99                }
100            }
101        });
102        (tx, pipeline)
103    }
104
105    /// Take the contents of the pipe, if any. This prevents them from being
106    /// sent out of the pipe, and can be used to filter a value
107    pub fn take(&mut self) -> Option<T> {
108        self.contents.take()
109    }
110
111    /// Read the value in the pipe, if any, without advancing to the next
112    /// value. Typically tasks should use `next()` to wait for the next value.
113    /// Read may be used to inspect contents without a mutable ref
114    pub fn read(&self) -> Option<&T> {
115        self.contents.as_ref()
116    }
117
118    /// Get an owned copy of the contents, without advancing to the next value
119    pub fn to_owned(&self) -> Option<<T as ToOwned>::Owned>
120    where
121        T: ToOwned,
122    {
123        self.read().map(|c| c.to_owned())
124    }
125
126    /// Reserve channel capacity, and then move the contents into the channel
127    async fn send(&mut self) -> Result<(), PipeError> {
128        if self.contents.is_some() {
129            let permit = self
130                .tx
131                .reserve()
132                .await
133                .map_err(|_| PipeError::OutboundGone)?;
134            permit.send(self.contents.take().unwrap());
135        }
136        Ok(())
137    }
138
139    /// Release the current contents of the pipe and wait for the next value to
140    /// become available.
141    ///
142    /// # Cancel Safety
143    ///
144    /// Because the pipe owns the data, and the owner of the pipe only borrows
145    /// it, data is preserved through cancellation. At any await point in the
146    /// `next()` future, one of the following is true:
147    ///
148    /// - The pipe still owns the data.
149    /// - The pipe has sent the data onwards, and is currently empty.
150    ///
151    /// If another task takes ownership of the pipe and resumes work, the
152    /// message will not be seen again.
153    pub async fn next(&mut self) -> Result<&T, PipeError> {
154        self.send().await?;
155        let next = self.rx.recv().await.ok_or(PipeError::InboundGone)?;
156        self.contents = Some(next);
157        Ok(self.read().expect("checked"))
158    }
159
160    /// Converts a pipeline to a task that simply polls `next()`
161    pub fn nop(self) {
162        self.for_each(|_| {});
163    }
164
165    /// Run a synchronous function on each element
166    ///
167    /// This will run indefinitely, until an upstream or downstream channel
168    /// closes.
169    pub fn for_each<Func>(mut self, f: Func)
170    where
171        Func: Fn(&T) + Send + 'static,
172    {
173        tokio::spawn(async move {
174            while let Ok(contents) = self.next().await {
175                f(contents);
176            }
177        });
178    }
179
180    /// Run an async function on each element. Does not distinguish between
181    /// success and failure of that function.
182    ///
183    /// This will run indefinitely, until an upstream or downstream channel
184    /// closes.
185    ///
186    /// # Note:
187    ///
188    /// Be careful not to bottleneck your pipeline :) Consider adding a timeout
189    /// wrapper to your async function
190    pub fn for_each_async<Func, Fut, Out>(mut self, f: Func)
191    where
192        Func: Fn(&T) -> Fut + Send + Sync + 'static,
193        Fut: std::future::Future<Output = Out> + Send + 'static,
194    {
195        tokio::spawn(async move {
196            while let Ok(contents) = self.next().await {
197                f(contents).await;
198            }
199        });
200    }
201}
202
203impl<T> Drop for Pipe<T>
204where
205    T: std::fmt::Debug + Send + Sync + 'static,
206{
207    fn drop(&mut self) {
208        // we attempt to empty the contents on drop, knowing that some
209        // downstream task may still be running.
210        if let Some(contents) = self.contents.take() {
211            let _ = self.tx.try_send(contents);
212        }
213    }
214}