Skip to main content

slop_futures/
pipeline.rs

1//! # Async Pipeline
2//!
3//! A flexible and efficient asynchronous task execution pipeline for Rust.
4//!
5//! This module provides a framework for building composable asynchronous pipelines
6//! that can process tasks through multiple stages with worker pools and capacity management.
7//!
8//! ## Features
9//!
10//! - **Worker Pools**: Manage pools of workers that can execute tasks concurrently
11//! - **Capacity Control**: Use semaphore-based permits to limit concurrent task execution
12//! - **Pipeline Composition**: Chain multiple pipelines together to create complex workflows
13//! - **Task Weighting**: Support for weighted tasks that consume multiple permits
14//! - **Error Handling**: Comprehensive error types for different failure scenarios
15//!
16//! ## Example
17//!
18//! ```ignore
19//! use std::sync::Arc;
20//! use tokio::sync::Semaphore;
21//!
22//! // Create workers and engine
23//! let workers = vec![MyWorker::new(); 4];
24//! let permits = Arc::new(Semaphore::new(10));
25//! let engine = AsyncEngine::new(workers, permits);
26//!
27//! // Submit a task
28//! let handle = engine.submit(my_task).await?;
29//! let result = handle.await?;
30//! ```
31
32use core::marker::PhantomData;
33use std::{
34    fmt,
35    future::Future,
36    pin::Pin,
37    sync::Arc,
38    task::{Context, Poll},
39};
40
41use tracing::Instrument;
42
43use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
44
45use thiserror::Error;
46
47use crate::queue::{self, AcquireWorkerError, WorkerQueue};
48
49/// A trait for task inputs that can be processed by the pipeline.
50///
51/// Tasks must have a static lifetime and be thread-safe. This trait is automatically implemented
52/// for all static lifetime types that are `Send` and `Sync`.
53pub trait TaskInput: 'static + Send + Sync {}
54
55impl<T: 'static + Send + Sync> TaskInput for T {}
56
57/// Error returned when a task submission fails.
58///
59/// This error indicates that the engine has been closed and is no longer accepting new tasks.
60#[derive(Error, Debug)]
61#[error("Engine closed")]
62pub struct SubmitError;
63
64/// Error returned when a non-blocking task submission fails.
65///
66/// This error can occur for two reasons:
67/// - The engine has been closed
68/// - No capacity is currently available (all permits are in use)
69#[derive(Error)]
70#[error("failed to submit task")]
71pub enum TrySubmitError<T> {
72    /// The engine has been closed and is no longer accepting tasks
73    #[error("engine closed")]
74    Closed,
75    /// No capacity is currently available for new tasks
76    #[error("no capacity available")]
77    NoCapacity(T),
78}
79
80impl<T> fmt::Debug for TrySubmitError<T> {
81    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82        write!(f, "TrySubmitError<{}>", std::any::type_name::<T>())
83    }
84}
85
86/// Error returned when a task submission fails.
87///
88/// This error indicates that the engine has been closed and is no longer accepting new tasks.
89#[derive(Error, Debug)]
90pub enum RunError {
91    #[error("failed to submit task")]
92    SubmitError(#[from] SubmitError),
93    #[error("task execution failed")]
94    TaskFailed(#[from] TaskJoinError),
95}
96
97/// Error that can occur when waiting for a task to complete.
98///
99/// This error type represents various failure modes that can occur
100/// during task execution or when acquiring workers from the pool.
101#[derive(Error, Debug)]
102pub enum TaskJoinError {
103    /// The task failed during execution (e.g., panicked)
104    #[error("execution error")]
105    ExecutionError(#[from] tokio::task::JoinError),
106    /// Failed to acquire a worker from the pool
107    #[error("failed to acquire worker")]
108    PopWorker(#[from] AcquireWorkerError),
109}
110
111/// A handle to a running task that can be awaited for its result.
112///
113/// This handle is returned when a task is submitted to the pipeline
114/// and can be used to:
115/// - Wait for the task to complete and retrieve its result
116/// - Abort the task if it's no longer needed
117///
118/// # Example
119///
120/// ```ignore
121/// let handle = engine.submit(my_task).await?;
122///
123/// // Option 1: Wait for completion
124/// match handle.await {
125///     Ok(result) => println!("Task completed: {:?}", result),
126///     Err(e) => eprintln!("Task failed: {}", e),
127/// }
128///
129/// // Option 2: Abort the task
130/// handle.abort();
131/// ```
132pub struct TaskHandle<T> {
133    inner: tokio::task::JoinHandle<Result<T, TaskJoinError>>,
134}
135
136impl<T> TaskHandle<T> {
137    /// Aborts the task associated with this handle.
138    ///
139    /// This will cause the task to stop executing as soon as possible.
140    /// Any work already completed by the task will be lost.
141    pub fn abort(&self) {
142        self.inner.abort();
143    }
144}
145
146impl<T> Drop for TaskHandle<T> {
147    fn drop(&mut self) {
148        self.abort();
149    }
150}
151
152impl<T> Future for TaskHandle<T> {
153    type Output = Result<T, TaskJoinError>;
154
155    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
156        let pin = Pin::new(&mut self.inner);
157        pin.poll(cx).map(|res| res.map_err(TaskJoinError::from)).map(|res| match res {
158            Ok(Ok(output)) => Ok(output),
159            Ok(Err(error)) => Err(error),
160            Err(error) => Err(error),
161        })
162    }
163}
164
165pub type SubmitHandle<P> = PipelineHandle<<P as Pipeline>::Resource, <P as Pipeline>::Output>;
166
167pub struct PipelineHandle<R, O> {
168    handle: TaskHandle<(R, O)>,
169}
170
171impl<R, O> PipelineHandle<R, O> {
172    pub fn new(handle: TaskHandle<(R, O)>) -> Self {
173        Self { handle }
174    }
175
176    pub fn abort(&self) {
177        self.handle.abort();
178    }
179
180    fn into_inner(self) -> TaskHandle<(R, O)> {
181        self.handle
182    }
183}
184
185impl<R, O> Future for PipelineHandle<R, O> {
186    type Output = Result<O, TaskJoinError>;
187
188    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
189        let pin = Pin::new(&mut self.handle);
190        pin.poll(cx).map(|res| res.map(|(_, output)| output))
191    }
192}
193
194/// A trait representing an asynchronous processing pipeline.
195///
196/// Pipelines accept input tasks and produce output results asynchronously.
197/// They can be composed together to create complex processing workflows.
198///
199/// # Type Parameters
200///
201/// - `Input`: The type of input tasks the pipeline accepts
202/// - `Output`: The type of results the pipeline produces
203///
204/// # Required Methods
205///
206/// - `submit`: Asynchronously submit a task, waiting if necessary for capacity
207/// - `try_submit`: Try to submit a task without waiting
208pub trait Pipeline: 'static + Send + Sync {
209    /// The input type that this pipeline accepts
210    type Input: 'static + Send + Sync;
211    /// The output type that this pipeline produces
212    type Output: 'static + Send + Sync;
213    /// The resource type that this pipeline uses
214    type Resource: 'static + Send + Sync;
215
216    /// Submit a task to the pipeline, waiting if necessary for capacity.
217    ///
218    /// This method will wait until there is capacity available in the pipeline before submitting
219    /// the task.
220    fn submit(
221        &self,
222        input: Self::Input,
223    ) -> impl Future<Output = Result<SubmitHandle<Self>, SubmitError>> + Send;
224
225    /// Try to submit a task without waiting.
226    ///
227    /// This method returns immediately with an error if there is no capacity
228    /// available in the pipeline.
229    fn try_submit(
230        &self,
231        input: Self::Input,
232    ) -> Result<SubmitHandle<Self>, TrySubmitError<Self::Input>>;
233
234    /// Run the pipeline on an input task and wait for the output.
235    ///     
236    /// This method will submit the task to the pipeline and wait for the output.
237    fn run(
238        &self,
239        input: Self::Input,
240    ) -> impl Future<Output = Result<Self::Output, RunError>> + Send {
241        async move {
242            let handle = self.submit(input).await?;
243            let output = handle.await.map_err(RunError::from)?;
244            Ok(output)
245        }
246    }
247
248    fn blocking_submit(&self, input: Self::Input) -> Result<SubmitHandle<Self>, SubmitError> {
249        let mut last_input = input;
250        loop {
251            match self.try_submit(last_input) {
252                Ok(handle) => {
253                    return Ok(handle);
254                }
255                Err(TrySubmitError::NoCapacity(input)) => {
256                    last_input = input;
257                    std::hint::spin_loop();
258                }
259                Err(TrySubmitError::Closed) => {
260                    return Err(SubmitError);
261                }
262            }
263        }
264    }
265}
266
267/// A trait for workers that can process tasks asynchronously.
268///
269/// Workers are the units of execution in the pipeline. They receive
270/// input tasks and produce output results asynchronously.
271///
272/// # Example
273///
274/// ```ignore
275/// #[derive(Debug, Clone)]
276/// struct MyWorker {
277///     config: WorkerConfig,
278/// }
279///
280/// impl AsyncWorker<MyTask, MyResult> for MyWorker {
281///     async fn call(&self, input: MyTask) -> MyResult {
282///         // Process the task...
283///         MyResult::new()
284///     }
285/// }
286/// ```
287pub trait AsyncWorker<Input, Output>: 'static + Send + Sync {
288    /// Process an input task and produce an output result.
289    ///
290    /// This method is called by the engine when a worker is assigned
291    /// to process a task.
292    fn call(&self, input: Input) -> impl Future<Output = Output> + Send;
293}
294
295/// An asynchronous execution engine that manages a pool of workers.
296///
297/// The `AsyncEngine` orchestrates task execution using:
298/// - A pool of workers that process tasks
299/// - A semaphore-based permit system for capacity control
300/// - Task weighting support for resource management
301///
302/// # Type Parameters
303///
304/// - `Input`: The task input type (must implement `TaskInput`)
305/// - `Output`: The result type produced by workers
306/// - `Worker`: The worker type that processes tasks
307///
308/// # Example
309///
310/// ```ignore
311/// use std::sync::Arc;
312/// use tokio::sync::Semaphore;
313///
314/// // Create a pool of 4 workers with capacity for 10 concurrent tasks
315/// let workers = vec![MyWorker::new(); 4];
316/// let permits = Arc::new(Semaphore::new(10));
317/// let engine = AsyncEngine::new(workers, permits);
318///
319/// // Submit tasks to the engine
320/// let handle = engine.submit(my_task).await?;
321/// let result = handle.await?;
322/// ```
323#[derive(Debug, Clone)]
324pub struct AsyncEngine<Input, Output, Worker> {
325    task_permits: Arc<Semaphore>,
326    workers: Arc<WorkerQueue<Worker>>,
327    _marker: PhantomData<(Input, Output)>,
328}
329
330impl<Input, Output, Worker> AsyncEngine<Input, Output, Worker>
331where
332    Input: TaskInput,
333    Worker: AsyncWorker<Input, Output>,
334    Output: 'static + Send + Sync,
335{
336    /// Creates a new `AsyncEngine` with the specified workers and permit semaphore.
337    ///
338    /// # Arguments
339    ///
340    /// - `workers`: A vector of workers that will process tasks
341    /// - `input_buffer_size`: The size of the input buffer
342    ///
343    /// # Example
344    ///
345    /// ```ignore
346    /// let workers = vec![MyWorker::new(); 4];
347    /// let engine = AsyncEngine::new(workers, 10);
348    /// ```
349    pub fn new(workers: Vec<Worker>, input_buffer_size: usize) -> Self {
350        Self {
351            workers: Arc::new(WorkerQueue::new(workers)),
352            task_permits: Arc::new(Semaphore::new(input_buffer_size)),
353            _marker: PhantomData,
354        }
355    }
356
357    /// Create a new `AsyncEngine` with a single permit per worker.
358    ///
359    /// # Arguments
360    ///
361    /// - `workers`: A vector of workers that will process tasks
362    ///
363    /// # Example
364    ///
365    /// ```ignore
366    /// let workers = vec![MyWorker::new(); 4];
367    /// let engine = AsyncEngine::single_permit_per_worker(workers);
368    /// ```
369    pub fn single_permit_per_worker(workers: Vec<Worker>) -> Self {
370        let num_workers = workers.len();
371        Self::new(workers, num_workers)
372    }
373
374    fn spawn(
375        &self,
376        input: Input,
377        permit: OwnedSemaphorePermit,
378    ) -> TaskHandle<(queue::Worker<Worker>, Output)> {
379        let workers = self.workers.clone();
380        let handle = tokio::spawn(
381            async move {
382                let permit = permit;
383                let worker = workers
384                    .pop()
385                    .instrument(tracing::debug_span!("waiting for a worker"))
386                    .await
387                    .map_err(TaskJoinError::from)?;
388                // Drop the permit to release the input queue task slot.
389                drop(permit);
390                // Process the task.
391                let output = worker.call(input).await;
392                // Return the worker and output.
393                Ok((worker, output))
394            }
395            .in_current_span(),
396        );
397        TaskHandle { inner: handle }
398    }
399}
400
401/// Implementation of `Pipeline` for `AsyncEngine`.
402///
403/// This allows the async engine to be used as a pipeline component,
404/// enabling it to be composed with other pipelines.
405impl<Input, Output, Worker> Pipeline for AsyncEngine<Input, Output, Worker>
406where
407    Input: TaskInput,
408    Worker: AsyncWorker<Input, Output>,
409    Output: 'static + Send + Sync,
410{
411    type Input = Input;
412    type Output = Output;
413    type Resource = queue::Worker<Worker>;
414
415    async fn submit(&self, input: Self::Input) -> Result<SubmitHandle<Self>, SubmitError> {
416        let permit = self
417            .task_permits
418            .clone()
419            .acquire_owned()
420            .instrument(tracing::debug_span!("waiting to enter input queue"))
421            .await
422            .map_err(|_| SubmitError)?;
423        Ok(PipelineHandle::new(self.spawn(input, permit)))
424    }
425
426    fn try_submit(
427        &self,
428        input: Self::Input,
429    ) -> Result<SubmitHandle<Self>, TrySubmitError<Self::Input>> {
430        let permit_result = self.task_permits.clone().try_acquire_owned();
431        match permit_result {
432            Ok(permit) => Ok(PipelineHandle::new(self.spawn(input, permit))),
433            Err(TryAcquireError::NoPermits) => Err(TrySubmitError::NoCapacity(input)),
434            Err(TryAcquireError::Closed) => Err(TrySubmitError::Closed),
435        }
436    }
437}
438
439/// A trait for workers that process tasks synchronously.
440///
441/// This trait is similar to `AsyncWorker` but for synchronous blocking tasks. It can be useful
442/// when wrapping blocking operations or integrating with non-async code.
443///
444/// # Example
445///
446/// ```ignore
447/// struct BlockingWorker;
448///
449/// impl BlockingWorker<ComputeTask, ComputeResult> for BlockingWorker {
450///     fn call(&self, input: ComputeTask) -> ComputeResult {
451///         // Perform a potentially blocking calculation
452///         ComputeResult::wait_for_result(input)
453///     }
454/// }
455/// ```
456pub trait BlockingWorker<Input, Output>: 'static + Send + Sync {
457    /// Process an input task synchronously and produce an output result.
458    fn call(&self, input: Input) -> Output;
459}
460
461/// A trait for workers that process tasks synchronously.
462///
463/// This trait is similar to `AsyncWorker` but for synchronous cpu-intensive tasks. It can be useful
464/// when wrapping blocking operations or integrating with non-async code.
465///
466/// # Example
467///
468/// ```ignore
469/// struct CpuIntensiveWorker;
470///
471/// impl RayonWorker<ComputeTask, ComputeResult> for CpuIntensiveWorker {
472///     fn call(&self, input: ComputeTask) -> ComputeResult {
473///         // Perform CPU-intensive calculation
474///         ComputeResult::calculate(input)
475///     }
476/// }
477/// ```
478pub trait RayonWorker<Input, Output>: 'static + Send + Sync {
479    /// Process an input task synchronously and produce an output result.
480    fn call(&self, input: Input) -> Output;
481}
482
483/// A blocking execution engine that manages a pool of workers for blocking tasks.
484///
485/// The `BlockingEngine` is similar to `AsyncEngine` but designed for synchronous, blocking tasks.
486/// It executes blocking tasks on the tokio runtime to avoid blocking the async runtime.
487///
488/// # Type Parameters
489///
490/// - `Input`: The task input type (must implement `TaskInput`)
491/// - `Output`: The result type produced by workers
492/// - `Worker`: The worker type that processes tasks synchronously
493///
494/// # Example
495///
496/// ```ignore
497/// use std::sync::Arc;
498/// use tokio::sync::Semaphore;
499///
500/// // Create a pool of workers for CPU-intensive tasks
501/// let workers = vec![ComputeWorker::new(); 4];
502/// let permits = Arc::new(Semaphore::new(10));
503/// let engine = BlockingEngine::new(workers, permits);
504///
505/// // Submit CPU-intensive tasks
506/// let handle = engine.submit(compute_task).await?;
507/// let result = handle.await?;
508/// ```
509#[derive(Debug, Clone)]
510pub struct BlockingEngine<Input, Output, Worker> {
511    task_permits: Arc<Semaphore>,
512    workers: Arc<WorkerQueue<Worker>>,
513    _marker: PhantomData<(Input, Output)>,
514}
515
516impl<Input, Output, Worker> BlockingEngine<Input, Output, Worker>
517where
518    Input: TaskInput,
519    Worker: BlockingWorker<Input, Output>,
520    Output: 'static + Send + Sync,
521{
522    /// Creates a new `BlockingEngine` with the specified workers and permit semaphore.
523    ///
524    /// # Arguments
525    ///
526    /// - `workers`: A vector of workers that will process tasks
527    /// - `permits`: A semaphore controlling the maximum number of concurrent tasks
528    ///
529    /// # Example
530    ///
531    /// ```ignore
532    /// let workers = vec![MyWorker::new(); 4];
533    /// let engine = BlockingEngine::new(workers, 10);
534    /// ```
535    pub fn new(workers: Vec<Worker>, input_buffer_size: usize) -> Self {
536        Self {
537            workers: Arc::new(WorkerQueue::new(workers)),
538            task_permits: Arc::new(Semaphore::new(input_buffer_size)),
539            _marker: PhantomData,
540        }
541    }
542
543    /// Create a new `BlockingEngine` with a single permit per worker.
544    ///
545    /// # Arguments
546    ///
547    /// - `workers`: A vector of workers that will process tasks
548    ///
549    /// # Example
550    ///
551    /// ```ignore
552    /// let workers = vec![MyWorker::new(); 4];
553    /// let engine = BlockingEngine::single_permit_per_worker(workers);
554    /// ```
555    pub fn single_permit_per_worker(workers: Vec<Worker>) -> Self {
556        let num_workers = workers.len();
557        Self::new(workers, num_workers)
558    }
559
560    fn spawn(
561        &self,
562        input: Input,
563        permit: OwnedSemaphorePermit,
564    ) -> TaskHandle<(queue::Worker<Worker>, Output)> {
565        let workers = self.workers.clone();
566        let handle = tokio::spawn(
567            async move {
568                let permit = permit;
569                // Wait for a worker to become available.
570                let worker = workers
571                    .pop()
572                    .instrument(tracing::debug_span!("waiting for a worker"))
573                    .await
574                    .map_err(TaskJoinError::from)?;
575                // Drop the permit to release the input queue task slot.
576                drop(permit);
577                let span = tracing::Span::current();
578                let (worker, output) = tokio::task::spawn_blocking(move || {
579                    let _guard = span.enter();
580                    let output = worker.call(input);
581                    (worker, output)
582                })
583                .await
584                .unwrap();
585                Ok((worker, output))
586            }
587            .in_current_span(),
588        );
589        TaskHandle { inner: handle }
590    }
591
592    pub fn blocking_submit(
593        &self,
594        input: Input,
595    ) -> Result<TaskHandle<(queue::Worker<Worker>, Output)>, SubmitError> {
596        let permit = loop {
597            match self.task_permits.clone().try_acquire_owned() {
598                Ok(permit) => break permit,
599                Err(TryAcquireError::NoPermits) => {
600                    std::hint::spin_loop();
601                }
602                Err(TryAcquireError::Closed) => {
603                    return Err(SubmitError);
604                }
605            }
606        };
607        Ok(self.spawn(input, permit))
608    }
609}
610
611/// Implementation of `Pipeline` for `BlockingEngine`.
612///
613/// This allows the blocking engine to be used as a pipeline component,
614/// enabling it to be composed with other pipelines.
615impl<Input, Output, Worker> Pipeline for BlockingEngine<Input, Output, Worker>
616where
617    Input: TaskInput,
618    Worker: BlockingWorker<Input, Output>,
619    Output: 'static + Send + Sync,
620{
621    type Input = Input;
622    type Output = Output;
623    type Resource = queue::Worker<Worker>;
624
625    async fn submit(&self, input: Self::Input) -> Result<SubmitHandle<Self>, SubmitError> {
626        let permit = self
627            .task_permits
628            .clone()
629            .acquire_owned()
630            .instrument(tracing::debug_span!("waiting to enter input queue"))
631            .await
632            .map_err(|_| SubmitError)?;
633        Ok(PipelineHandle::new(self.spawn(input, permit)))
634    }
635
636    fn try_submit(
637        &self,
638        input: Self::Input,
639    ) -> Result<SubmitHandle<Self>, TrySubmitError<Self::Input>> {
640        let permit_result = self.task_permits.clone().try_acquire_owned();
641        match permit_result {
642            Ok(permit) => Ok(PipelineHandle::new(self.spawn(input, permit))),
643            Err(TryAcquireError::NoPermits) => Err(TrySubmitError::NoCapacity(input)),
644            Err(TryAcquireError::Closed) => Err(TrySubmitError::Closed),
645        }
646    }
647}
648
649/// An execution engine that manages a pool of workers for CPU-intensive tasks using `rayon`.
650///
651/// The `RayonEngine` is similar to `BlockinEngine` but designed for synchronous, CPU-intensive
652/// workloads. It executes blocking tasks on a Rayon thread pool to avoid blocking the async
653/// runtime.
654///
655/// # Type Parameters
656///
657/// - `Input`: The task input type (must implement `TaskInput`)
658/// - `Output`: The result type produced by workers
659/// - `Worker`: The worker type that processes tasks synchronously
660///
661/// # Example
662///
663/// ```ignore
664/// use std::sync::Arc;
665/// use tokio::sync::Semaphore;
666///
667/// // Create a pool of workers for CPU-intensive tasks
668/// let workers = vec![ComputeWorker::new(); 4];
669/// let permits = Arc::new(Semaphore::new(10));
670/// let engine = BlockingEngine::new(workers, permits);
671///
672/// // Submit CPU-intensive tasks
673/// let handle = engine.submit(compute_task).await?;
674/// let result = handle.await?;
675/// ```
676#[derive(Debug, Clone)]
677pub struct RayonEngine<Input, Output, Worker> {
678    task_permits: Arc<Semaphore>,
679    workers: Arc<WorkerQueue<Worker>>,
680    _marker: PhantomData<(Input, Output)>,
681}
682
683impl<Input, Output, Worker> RayonEngine<Input, Output, Worker>
684where
685    Input: TaskInput,
686    Worker: RayonWorker<Input, Output>,
687    Output: 'static + Send + Sync,
688{
689    /// Creates a new `RayonEngine` with the specified workers and permit semaphore.
690    ///
691    /// # Arguments
692    ///
693    /// - `workers`: A vector of workers that will process tasks
694    /// - `permits`: A semaphore controlling the maximum number of concurrent tasks
695    ///
696    /// # Example
697    ///
698    /// ```ignore
699    /// let workers = vec![MyWorker::new(); 4];
700    /// let permits = Arc::new(Semaphore::new(10));
701    /// let engine = RayonEngine::new(workers, permits);
702    /// ```
703    pub fn new(workers: Vec<Worker>, permits: Arc<Semaphore>) -> Self {
704        Self {
705            workers: Arc::new(WorkerQueue::new(workers)),
706            task_permits: permits,
707            _marker: PhantomData,
708        }
709    }
710
711    /// Create a new `RayonEngine` with a single permit per worker.
712    ///
713    /// # Arguments
714    ///
715    /// - `workers`: A vector of workers that will process tasks
716    ///
717    /// # Example
718    ///
719    /// ```ignore
720    /// let workers = vec![MyWorker::new(); 4];
721    /// let engine = RayonEngine::single_permit_per_worker(workers);
722    /// ```
723    pub fn single_permit_per_worker(workers: Vec<Worker>) -> Self {
724        let num_workers = workers.len();
725        Self::new(workers, Arc::new(Semaphore::new(num_workers)))
726    }
727
728    fn spawn(
729        &self,
730        input: Input,
731        permit: OwnedSemaphorePermit,
732    ) -> TaskHandle<(queue::Worker<Worker>, Output)> {
733        let workers = self.workers.clone();
734        let handle = tokio::spawn(
735            async move {
736                let permit = permit;
737                // Wait for a worker to become available.
738                let worker = workers
739                    .pop()
740                    .instrument(tracing::debug_span!("waiting for a worker"))
741                    .await
742                    .map_err(TaskJoinError::from)?;
743                // Drop the permit to release the input queue task slot.
744                drop(permit);
745                // Spawn the blocking task on the rayon thread pool
746                let ret = crate::rayon::spawn(move || {
747                    let output = worker.call(input);
748                    (worker, output)
749                })
750                .await
751                .unwrap();
752                Ok(ret)
753            }
754            .in_current_span(),
755        );
756        TaskHandle { inner: handle }
757    }
758}
759
760/// Implementation of `Pipeline` for `BlockingEngine`.
761///
762/// This allows the blocking engine to be used as a pipeline component,
763/// enabling it to be composed with other pipelines.
764impl<Input, Output, Worker> Pipeline for RayonEngine<Input, Output, Worker>
765where
766    Input: TaskInput,
767    Worker: RayonWorker<Input, Output>,
768    Output: 'static + Send + Sync,
769{
770    type Input = Input;
771    type Output = Output;
772    type Resource = queue::Worker<Worker>;
773    async fn submit(&self, input: Self::Input) -> Result<SubmitHandle<Self>, SubmitError> {
774        let permit = self
775            .task_permits
776            .clone()
777            .acquire_owned()
778            .instrument(tracing::debug_span!("waiting to enter input queue"))
779            .await
780            .map_err(|_| SubmitError)?;
781        Ok(PipelineHandle::new(self.spawn(input, permit)))
782    }
783
784    fn try_submit(
785        &self,
786        input: Self::Input,
787    ) -> Result<SubmitHandle<Self>, TrySubmitError<Self::Input>> {
788        let permit_result = self.task_permits.clone().try_acquire_owned();
789        match permit_result {
790            Ok(permit) => Ok(PipelineHandle::new(self.spawn(input, permit))),
791            Err(TryAcquireError::NoPermits) => Err(TrySubmitError::NoCapacity(input)),
792            Err(TryAcquireError::Closed) => Err(TrySubmitError::Closed),
793        }
794    }
795}
796
797/// A composite pipeline that chains two pipelines together.
798///
799/// `Chain` allows you to create complex processing workflows by connecting the output of one
800/// pipeline to the input of another. The output type of the first pipeline must be convertible
801/// to the input type of the second pipeline.
802///
803/// # Type Parameters
804///
805/// - `First`: The first pipeline in the chain
806/// - `Second`: The second pipeline in the chain
807///
808/// # Example
809///
810/// ```ignore
811/// // Create two pipelines
812/// let preprocessing = PreprocessingPipeline::new();
813/// let processing = ProcessingPipeline::new();
814///
815/// // Chain them together
816/// let chain = Chain::new(preprocessing, processing);
817///
818/// // Submit tasks to the chained pipeline
819/// let result = chain.submit(raw_data).await?;
820/// ```
821#[derive(Clone, Debug, Copy)]
822pub struct Chain<First, Second> {
823    first: First,
824    second: Second,
825}
826
827impl<First, Second> Chain<First, Second>
828where
829    First: Pipeline + Clone,
830    Second: Pipeline + Clone,
831    First::Output: Into<Second::Input>,
832{
833    /// Creates a new chain from two pipelines.
834    ///
835    /// # Arguments
836    ///
837    /// - `first`: The first pipeline that will process the input
838    /// - `second`: The second pipeline that will process the output from the first
839    ///
840    /// # Example
841    ///
842    /// ```ignore
843    /// let chain = Chain::new(first_pipeline, second_pipeline);
844    /// ```
845    pub fn new(first: First, second: Second) -> Self {
846        Self { first, second }
847    }
848
849    /// Get a reference to the first pipeline in the chain.
850    ///
851    /// This is useful for being able to submit tasks to the first pipeline directly, without having
852    /// to go through the second pipeline if there is no need to.
853    pub fn first(&self) -> &First {
854        &self.first
855    }
856
857    /// Get a reference to the second pipeline in the chain.
858    ///
859    /// This is useful for being able to submit tasks to the second pipeline directly, without
860    /// having to go through the first pipeline if there is no need to.
861    pub fn second(&self) -> &Second {
862        &self.second
863    }
864
865    fn spawn(
866        &self,
867        first_handle: TaskHandle<(First::Resource, First::Output)>,
868    ) -> TaskHandle<(Second::Resource, Second::Output)> {
869        let second = self.second.clone();
870        let handle = tokio::spawn(
871            async move {
872                let first_handle = first_handle;
873                let (first_resource, first_output) = first_handle.await?;
874                let second_input: Second::Input = first_output.into();
875                // Submit the second task to the second pipeline.
876                let second_handle =
877                    second.submit(second_input).await.expect("failed to submit second task");
878                // Once the task is in the second pipeline, we can release the first resource.
879                drop(first_resource);
880                // Wait for the second task to complete with it's resource.
881                let second_handle = second_handle.into_inner();
882                second_handle.await
883            }
884            .in_current_span(),
885        );
886        TaskHandle { inner: handle }
887    }
888}
889
890/// Implementation of `Pipeline` for `Chain<First, Second>`.
891///
892/// This implementation allows chains to be used as pipelines themselves,
893/// enabling further composition and nesting of processing workflows.
894impl<First, Second> Pipeline for Chain<First, Second>
895where
896    First: Pipeline + Clone,
897    Second: Pipeline + Clone,
898    First::Output: Into<Second::Input>,
899{
900    type Input = First::Input;
901    type Output = Second::Output;
902    type Resource = Second::Resource;
903    /// Submit a task to the chained pipeline.
904    ///
905    /// The task will be processed by the first pipeline, and its output
906    /// will automatically be fed as input to the second pipeline.
907    ///
908    /// # Arguments
909    ///
910    /// - `input`: The initial input for the first pipeline
911    ///
912    /// # Returns
913    ///
914    /// A handle to the final result from the second pipeline
915    async fn submit(&self, input: Self::Input) -> Result<SubmitHandle<Self>, SubmitError> {
916        let first_handle = self.first.submit(input).await?;
917        Ok(PipelineHandle::new(self.spawn(first_handle.into_inner())))
918    }
919
920    /// Try to submit a task to the chained pipeline without blocking.
921    ///
922    /// # Arguments
923    ///
924    /// - `input`: The initial input for the first pipeline
925    ///
926    /// # Returns
927    ///
928    /// - `Ok(TaskHandle)` if the task was successfully submitted to the first pipeline
929    /// - `Err(TrySubmitError)` if submission failed
930    fn try_submit(
931        &self,
932        input: Self::Input,
933    ) -> Result<SubmitHandle<Self>, TrySubmitError<Self::Input>> {
934        let first_handle = self.first.try_submit(input)?;
935        Ok(PipelineHandle::new(self.spawn(first_handle.into_inner())))
936    }
937}
938
939/// Implementation of `Pipeline` for `Arc<P>` where `P` implements `Pipeline`.
940///
941/// This allows pipelines to be shared across multiple threads efficiently.
942/// The `Arc` wrapper enables cheap cloning and thread-safe sharing of the
943/// underlying pipeline.
944///
945/// # Example
946///
947/// ```ignore
948/// let pipeline = MyPipeline::new();
949/// let shared_pipeline = Arc::new(pipeline);
950///
951/// // Can now clone and share across threads
952/// let pipeline_clone = shared_pipeline.clone();
953/// tokio::spawn(async move {
954///     let result = pipeline_clone.submit(task).await?;
955/// });
956/// ```
957impl<P: Pipeline> Pipeline for Arc<P> {
958    type Input = P::Input;
959    type Output = P::Output;
960    type Resource = P::Resource;
961
962    #[inline]
963    async fn submit(&self, input: Self::Input) -> Result<SubmitHandle<Self>, SubmitError> {
964        self.as_ref().submit(input).await
965    }
966
967    #[inline]
968    fn try_submit(
969        &self,
970        input: Self::Input,
971    ) -> Result<SubmitHandle<Self>, TrySubmitError<Self::Input>> {
972        self.as_ref().try_submit(input)
973    }
974}
975
976#[derive(Debug, Clone)]
977pub struct PipelineBuilder<P = ()> {
978    pipeline: P,
979}
980
981impl PipelineBuilder {
982    pub fn new<P: Pipeline>(pipeline: P) -> PipelineBuilder<P> {
983        PipelineBuilder { pipeline }
984    }
985}
986
987impl<P: Pipeline> PipelineBuilder<P> {
988    /// Build the pipeline.
989    ///
990    /// # Returns
991    ///
992    /// The built pipeline
993    pub fn build(self) -> P {
994        self.pipeline
995    }
996
997    /// Chain the pipeline with another pipeline.
998    ///
999    /// # Arguments
1000    ///
1001    /// - `pipeline`: The pipeline to chain with
1002    ///
1003    /// # Returns
1004    ///
1005    /// A new pipeline builder with the chained pipeline
1006    pub fn through<Q>(self, pipeline: Q) -> PipelineBuilder<Chain<P, Q>>
1007    where
1008        P: Clone,
1009        Q: Pipeline + Clone,
1010        P::Output: Into<Q::Input>,
1011    {
1012        PipelineBuilder { pipeline: Chain::new(self.pipeline, pipeline) }
1013    }
1014}
1015
1016#[cfg(test)]
1017mod tests {
1018    use futures::{prelude::*, stream::FuturesOrdered};
1019    use rand::Rng;
1020    use std::time::Duration;
1021    use tokio::task::JoinSet;
1022
1023    use super::*;
1024
1025    #[derive(Debug, Clone)]
1026    struct TestWorker;
1027
1028    #[derive(Debug, Clone)]
1029    struct TestTask {
1030        time: Duration,
1031        hanging_probability: f64,
1032    }
1033
1034    impl AsyncWorker<TestTask, ()> for TestWorker {
1035        async fn call(&self, input: TestTask) {
1036            tokio::time::sleep(input.time).await;
1037
1038            let should_hang = rand::thread_rng().gen_bool(input.hanging_probability);
1039            if should_hang {
1040                loop {
1041                    tokio::task::yield_now().await;
1042                }
1043            }
1044        }
1045    }
1046
1047    #[tokio::test]
1048    #[allow(clippy::print_stdout)]
1049    async fn test_async_engine() {
1050        let num_workers = 5;
1051        let task_queue_length = 5;
1052        let num_tasks_spawned = 10;
1053        let wait_duration = Duration::from_millis(10);
1054
1055        let workers = (0..num_workers).map(|_| TestWorker).collect();
1056        let engine = Arc::new(AsyncEngine::new(workers, task_queue_length));
1057
1058        let tasks = (0..num_tasks_spawned)
1059            .map(|_| TestTask { time: wait_duration, hanging_probability: 0.0 })
1060            .collect::<Vec<_>>();
1061
1062        // Submit all tasks concurrently and wait for them to complete
1063        let mut join_set = JoinSet::new();
1064        let time = tokio::time::Instant::now();
1065        for task in tasks {
1066            let e = engine.clone();
1067            join_set.spawn(async move { e.submit(task).await.unwrap().await.unwrap() });
1068        }
1069        join_set.join_all().await;
1070        let duration = time.elapsed();
1071        println!("Time taken for async engine: {:?}", duration);
1072
1073        // Compare this with the case of complete parallelism
1074        let mut join_set = JoinSet::new();
1075        let tasks_per_worker = num_tasks_spawned / num_workers;
1076        let time = tokio::time::Instant::now();
1077        for _ in 0..num_workers {
1078            join_set.spawn(async move {
1079                for _ in 0..tasks_per_worker {
1080                    tokio::time::sleep(wait_duration).await;
1081                }
1082            });
1083        }
1084        join_set.join_all().await;
1085        let duration = time.elapsed();
1086        println!("Time taken for complete parallelism: {:?}", duration);
1087    }
1088
1089    #[tokio::test]
1090    #[allow(clippy::print_stdout)]
1091    async fn test_hanging_task_async_engine() {
1092        let num_workers = 1;
1093        let task_queue_length = 2;
1094        let num_tasks_spawned = 100;
1095        let wait_duration = Duration::from_millis(1);
1096        let hanging_probability = 0.5;
1097        let timeout = Duration::from_millis(100);
1098
1099        let workers = (0..num_workers).map(|_| TestWorker).collect();
1100        let engine = Arc::new(AsyncEngine::new(workers, task_queue_length));
1101
1102        let tasks = (0..num_tasks_spawned)
1103            .map(|_| TestTask { time: wait_duration, hanging_probability })
1104            .collect::<Vec<_>>();
1105
1106        // Submit all tasks concurrently and wait for them to complete
1107        let mut join_set = JoinSet::new();
1108        let time = tokio::time::Instant::now();
1109        for task in tasks {
1110            let handle = engine.submit(task).await.unwrap();
1111            let future = async move { handle.await.unwrap() };
1112            join_set.spawn(async move { tokio::time::timeout(timeout, future).await });
1113        }
1114
1115        let mut success_count = 0;
1116        while let Some(result) = join_set.join_next().await {
1117            let result = result.unwrap();
1118            if result.is_ok() {
1119                success_count += 1;
1120            }
1121        }
1122        let duration = time.elapsed();
1123        println!("Time taken for async engine: {:?}, success count: {success_count}", duration);
1124    }
1125
1126    #[tokio::test]
1127    #[allow(clippy::print_stdout)]
1128    async fn test_blocking_engine() {
1129        #[derive(Debug, Clone)]
1130        struct SummingWorker;
1131
1132        #[derive(Debug, Clone)]
1133        struct SummingTask {
1134            summands: Vec<u32>,
1135        }
1136
1137        impl BlockingWorker<SummingTask, u32> for SummingWorker {
1138            fn call(&self, input: SummingTask) -> u32 {
1139                input.summands.iter().sum()
1140            }
1141        }
1142
1143        let num_workers = 10;
1144        let task_queue_length = 20;
1145        let num_tasks_spawned = 10;
1146        let max_summands = 20;
1147
1148        let workers = (0..num_workers).map(|_| SummingWorker).collect();
1149        let engine = Arc::new(BlockingEngine::new(workers, task_queue_length));
1150
1151        let mut rng = rand::thread_rng();
1152        let tasks = (0..num_tasks_spawned)
1153            .map(|_| SummingTask { summands: vec![1; rng.gen_range(1..=max_summands)] })
1154            .collect::<Vec<_>>();
1155
1156        // Submit all tasks concurrently and wait for them to complete
1157        let mut results = FuturesOrdered::new();
1158        for task in tasks.iter() {
1159            results.push_back(engine.submit(task.clone()).await.unwrap());
1160        }
1161        let results = results.collect::<Vec<_>>().await;
1162        for (task, result) in tasks.iter().zip(results) {
1163            let result = result.unwrap();
1164            let expected = task.summands.iter().sum();
1165            assert_eq!(result, expected);
1166        }
1167    }
1168
1169    #[tokio::test]
1170    #[allow(clippy::print_stdout)]
1171    #[should_panic]
1172    async fn test_async_failing_engine() {
1173        #[derive(Debug, Clone)]
1174        struct FailingWorker;
1175
1176        #[derive(Debug, Clone)]
1177        struct TestTask {
1178            time: Duration,
1179        }
1180
1181        impl AsyncWorker<TestTask, ()> for FailingWorker {
1182            async fn call(&self, input: TestTask) {
1183                if input.time > Duration::from_millis(50) {
1184                    panic!("not interested to wait for this long");
1185                }
1186                tokio::time::sleep(input.time).await;
1187            }
1188        }
1189        let num_workers = 10;
1190        let task_queue_length = 20;
1191        let wait_duration = 100;
1192
1193        let workers = (0..num_workers).map(|_| FailingWorker).collect();
1194        let engine = Arc::new(AsyncEngine::new(workers, task_queue_length));
1195
1196        let tasks = (0..wait_duration)
1197            .map(|i| TestTask { time: Duration::from_millis(i) })
1198            .collect::<Vec<_>>();
1199
1200        // Submit all tasks concurrently and wait for them to complete
1201        let mut join_set = JoinSet::new();
1202        let time = tokio::time::Instant::now();
1203        for task in tasks {
1204            let e = engine.clone();
1205            join_set.spawn(async move { e.submit(task).await.unwrap().await.unwrap() });
1206        }
1207        join_set.join_all().await;
1208        let duration = time.elapsed();
1209        println!("Time taken for async engine: {:?}", duration);
1210    }
1211
1212    #[tokio::test]
1213    #[allow(clippy::print_stdout)]
1214    async fn test_chained_pipelines() {
1215        #[derive(Debug, Clone)]
1216        struct FirstTask;
1217
1218        #[derive(Debug, Clone)]
1219        struct FirstWorker;
1220
1221        impl BlockingWorker<FirstTask, SecondTask> for FirstWorker {
1222            fn call(&self, _input: FirstTask) -> SecondTask {
1223                let mut rng = rand::thread_rng();
1224                SecondTask { value: rng.gen_range(200..=1000) }
1225            }
1226        }
1227
1228        #[derive(Debug, Clone)]
1229        struct SecondWorker;
1230
1231        #[derive(Debug, Clone)]
1232        struct SecondTask {
1233            value: u64,
1234        }
1235
1236        impl AsyncWorker<SecondTask, u64> for SecondWorker {
1237            async fn call(&self, input: SecondTask) -> u64 {
1238                tokio::time::sleep(Duration::from_millis(input.value)).await;
1239                input.value
1240            }
1241        }
1242
1243        let first_workers = (0..10).map(|_| FirstWorker).collect();
1244        let first_pipeline = Arc::new(BlockingEngine::single_permit_per_worker(first_workers));
1245        let second_workers = (0..10).map(|_| SecondWorker).collect();
1246        let second_pipeline = Arc::new(AsyncEngine::single_permit_per_worker(second_workers));
1247        let chain = Chain::new(first_pipeline, second_pipeline);
1248
1249        let handles = (0..10)
1250            .map(|_| chain.submit(FirstTask))
1251            .collect::<FuturesOrdered<_>>()
1252            .try_collect::<Vec<_>>()
1253            .await
1254            .unwrap();
1255
1256        for handle in handles {
1257            let _result = handle.await.unwrap();
1258        }
1259    }
1260
1261    #[tokio::test]
1262    #[allow(clippy::print_stdout)]
1263    async fn test_timing_chained_pipelines() {
1264        #[derive(Debug, Clone, Copy, PartialEq, Eq)]
1265        struct SleepTask {
1266            duration: Duration,
1267        }
1268
1269        #[derive(Debug, Clone)]
1270        struct SleepWorker;
1271
1272        impl AsyncWorker<SleepTask, SleepTask> for SleepWorker {
1273            async fn call(&self, input: SleepTask) -> SleepTask {
1274                let sleep_duration = input.duration;
1275                tokio::time::sleep(sleep_duration).await;
1276                input
1277            }
1278        }
1279
1280        let num_workers = 10;
1281
1282        let workers = (0..num_workers).map(|_| SleepWorker).collect::<Vec<_>>();
1283        let make_engine =
1284            |workers: Vec<SleepWorker>| Arc::new(AsyncEngine::single_permit_per_worker(workers));
1285
1286        let pipeline = PipelineBuilder::new(make_engine(workers.clone()))
1287            .through(make_engine(workers.clone()))
1288            .through(make_engine(workers.clone()))
1289            .through(make_engine(workers.clone()))
1290            .through(make_engine(workers.clone()))
1291            .build();
1292
1293        let chain_input_task = SleepTask { duration: Duration::from_millis(100) };
1294        let single_input_task = SleepTask { duration: Duration::from_millis(500) };
1295
1296        let time = tokio::time::Instant::now();
1297        let chain_result = pipeline.submit(chain_input_task).await.unwrap().await.unwrap();
1298        let chain_duration = time.elapsed();
1299        println!("Chain duration: {:?}", chain_duration);
1300        assert_eq!(chain_result, chain_input_task);
1301
1302        let single_engine = make_engine(workers.clone());
1303        let time = tokio::time::Instant::now();
1304        let single_result = single_engine.submit(single_input_task).await.unwrap().await.unwrap();
1305        let single_duration = time.elapsed();
1306        println!("Single duration: {:?}", single_duration);
1307        assert_eq!(single_result, single_input_task);
1308    }
1309}