1use core::marker::PhantomData;
33use std::{
34 fmt,
35 future::Future,
36 pin::Pin,
37 sync::Arc,
38 task::{Context, Poll},
39};
40
41use tracing::Instrument;
42
43use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
44
45use thiserror::Error;
46
47use crate::queue::{self, AcquireWorkerError, WorkerQueue};
48
49pub trait TaskInput: 'static + Send + Sync {}
54
55impl<T: 'static + Send + Sync> TaskInput for T {}
56
57#[derive(Error, Debug)]
61#[error("Engine closed")]
62pub struct SubmitError;
63
64#[derive(Error)]
70#[error("failed to submit task")]
71pub enum TrySubmitError<T> {
72 #[error("engine closed")]
74 Closed,
75 #[error("no capacity available")]
77 NoCapacity(T),
78}
79
80impl<T> fmt::Debug for TrySubmitError<T> {
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82 write!(f, "TrySubmitError<{}>", std::any::type_name::<T>())
83 }
84}
85
86#[derive(Error, Debug)]
90pub enum RunError {
91 #[error("failed to submit task")]
92 SubmitError(#[from] SubmitError),
93 #[error("task execution failed")]
94 TaskFailed(#[from] TaskJoinError),
95}
96
97#[derive(Error, Debug)]
102pub enum TaskJoinError {
103 #[error("execution error")]
105 ExecutionError(#[from] tokio::task::JoinError),
106 #[error("failed to acquire worker")]
108 PopWorker(#[from] AcquireWorkerError),
109}
110
111pub struct TaskHandle<T> {
133 inner: tokio::task::JoinHandle<Result<T, TaskJoinError>>,
134}
135
136impl<T> TaskHandle<T> {
137 pub fn abort(&self) {
142 self.inner.abort();
143 }
144}
145
146impl<T> Drop for TaskHandle<T> {
147 fn drop(&mut self) {
148 self.abort();
149 }
150}
151
152impl<T> Future for TaskHandle<T> {
153 type Output = Result<T, TaskJoinError>;
154
155 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
156 let pin = Pin::new(&mut self.inner);
157 pin.poll(cx).map(|res| res.map_err(TaskJoinError::from)).map(|res| match res {
158 Ok(Ok(output)) => Ok(output),
159 Ok(Err(error)) => Err(error),
160 Err(error) => Err(error),
161 })
162 }
163}
164
165pub type SubmitHandle<P> = PipelineHandle<<P as Pipeline>::Resource, <P as Pipeline>::Output>;
166
167pub struct PipelineHandle<R, O> {
168 handle: TaskHandle<(R, O)>,
169}
170
171impl<R, O> PipelineHandle<R, O> {
172 pub fn new(handle: TaskHandle<(R, O)>) -> Self {
173 Self { handle }
174 }
175
176 pub fn abort(&self) {
177 self.handle.abort();
178 }
179
180 fn into_inner(self) -> TaskHandle<(R, O)> {
181 self.handle
182 }
183}
184
185impl<R, O> Future for PipelineHandle<R, O> {
186 type Output = Result<O, TaskJoinError>;
187
188 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
189 let pin = Pin::new(&mut self.handle);
190 pin.poll(cx).map(|res| res.map(|(_, output)| output))
191 }
192}
193
194pub trait Pipeline: 'static + Send + Sync {
209 type Input: 'static + Send + Sync;
211 type Output: 'static + Send + Sync;
213 type Resource: 'static + Send + Sync;
215
216 fn submit(
221 &self,
222 input: Self::Input,
223 ) -> impl Future<Output = Result<SubmitHandle<Self>, SubmitError>> + Send;
224
225 fn try_submit(
230 &self,
231 input: Self::Input,
232 ) -> Result<SubmitHandle<Self>, TrySubmitError<Self::Input>>;
233
234 fn run(
238 &self,
239 input: Self::Input,
240 ) -> impl Future<Output = Result<Self::Output, RunError>> + Send {
241 async move {
242 let handle = self.submit(input).await?;
243 let output = handle.await.map_err(RunError::from)?;
244 Ok(output)
245 }
246 }
247
248 fn blocking_submit(&self, input: Self::Input) -> Result<SubmitHandle<Self>, SubmitError> {
249 let mut last_input = input;
250 loop {
251 match self.try_submit(last_input) {
252 Ok(handle) => {
253 return Ok(handle);
254 }
255 Err(TrySubmitError::NoCapacity(input)) => {
256 last_input = input;
257 std::hint::spin_loop();
258 }
259 Err(TrySubmitError::Closed) => {
260 return Err(SubmitError);
261 }
262 }
263 }
264 }
265}
266
267pub trait AsyncWorker<Input, Output>: 'static + Send + Sync {
288 fn call(&self, input: Input) -> impl Future<Output = Output> + Send;
293}
294
295#[derive(Debug, Clone)]
324pub struct AsyncEngine<Input, Output, Worker> {
325 task_permits: Arc<Semaphore>,
326 workers: Arc<WorkerQueue<Worker>>,
327 _marker: PhantomData<(Input, Output)>,
328}
329
330impl<Input, Output, Worker> AsyncEngine<Input, Output, Worker>
331where
332 Input: TaskInput,
333 Worker: AsyncWorker<Input, Output>,
334 Output: 'static + Send + Sync,
335{
336 pub fn new(workers: Vec<Worker>, input_buffer_size: usize) -> Self {
350 Self {
351 workers: Arc::new(WorkerQueue::new(workers)),
352 task_permits: Arc::new(Semaphore::new(input_buffer_size)),
353 _marker: PhantomData,
354 }
355 }
356
357 pub fn single_permit_per_worker(workers: Vec<Worker>) -> Self {
370 let num_workers = workers.len();
371 Self::new(workers, num_workers)
372 }
373
374 fn spawn(
375 &self,
376 input: Input,
377 permit: OwnedSemaphorePermit,
378 ) -> TaskHandle<(queue::Worker<Worker>, Output)> {
379 let workers = self.workers.clone();
380 let handle = tokio::spawn(
381 async move {
382 let permit = permit;
383 let worker = workers
384 .pop()
385 .instrument(tracing::debug_span!("waiting for a worker"))
386 .await
387 .map_err(TaskJoinError::from)?;
388 drop(permit);
390 let output = worker.call(input).await;
392 Ok((worker, output))
394 }
395 .in_current_span(),
396 );
397 TaskHandle { inner: handle }
398 }
399}
400
401impl<Input, Output, Worker> Pipeline for AsyncEngine<Input, Output, Worker>
406where
407 Input: TaskInput,
408 Worker: AsyncWorker<Input, Output>,
409 Output: 'static + Send + Sync,
410{
411 type Input = Input;
412 type Output = Output;
413 type Resource = queue::Worker<Worker>;
414
415 async fn submit(&self, input: Self::Input) -> Result<SubmitHandle<Self>, SubmitError> {
416 let permit = self
417 .task_permits
418 .clone()
419 .acquire_owned()
420 .instrument(tracing::debug_span!("waiting to enter input queue"))
421 .await
422 .map_err(|_| SubmitError)?;
423 Ok(PipelineHandle::new(self.spawn(input, permit)))
424 }
425
426 fn try_submit(
427 &self,
428 input: Self::Input,
429 ) -> Result<SubmitHandle<Self>, TrySubmitError<Self::Input>> {
430 let permit_result = self.task_permits.clone().try_acquire_owned();
431 match permit_result {
432 Ok(permit) => Ok(PipelineHandle::new(self.spawn(input, permit))),
433 Err(TryAcquireError::NoPermits) => Err(TrySubmitError::NoCapacity(input)),
434 Err(TryAcquireError::Closed) => Err(TrySubmitError::Closed),
435 }
436 }
437}
438
439pub trait BlockingWorker<Input, Output>: 'static + Send + Sync {
457 fn call(&self, input: Input) -> Output;
459}
460
461pub trait RayonWorker<Input, Output>: 'static + Send + Sync {
479 fn call(&self, input: Input) -> Output;
481}
482
483#[derive(Debug, Clone)]
510pub struct BlockingEngine<Input, Output, Worker> {
511 task_permits: Arc<Semaphore>,
512 workers: Arc<WorkerQueue<Worker>>,
513 _marker: PhantomData<(Input, Output)>,
514}
515
516impl<Input, Output, Worker> BlockingEngine<Input, Output, Worker>
517where
518 Input: TaskInput,
519 Worker: BlockingWorker<Input, Output>,
520 Output: 'static + Send + Sync,
521{
522 pub fn new(workers: Vec<Worker>, input_buffer_size: usize) -> Self {
536 Self {
537 workers: Arc::new(WorkerQueue::new(workers)),
538 task_permits: Arc::new(Semaphore::new(input_buffer_size)),
539 _marker: PhantomData,
540 }
541 }
542
543 pub fn single_permit_per_worker(workers: Vec<Worker>) -> Self {
556 let num_workers = workers.len();
557 Self::new(workers, num_workers)
558 }
559
560 fn spawn(
561 &self,
562 input: Input,
563 permit: OwnedSemaphorePermit,
564 ) -> TaskHandle<(queue::Worker<Worker>, Output)> {
565 let workers = self.workers.clone();
566 let handle = tokio::spawn(
567 async move {
568 let permit = permit;
569 let worker = workers
571 .pop()
572 .instrument(tracing::debug_span!("waiting for a worker"))
573 .await
574 .map_err(TaskJoinError::from)?;
575 drop(permit);
577 let span = tracing::Span::current();
578 let (worker, output) = tokio::task::spawn_blocking(move || {
579 let _guard = span.enter();
580 let output = worker.call(input);
581 (worker, output)
582 })
583 .await
584 .unwrap();
585 Ok((worker, output))
586 }
587 .in_current_span(),
588 );
589 TaskHandle { inner: handle }
590 }
591
592 pub fn blocking_submit(
593 &self,
594 input: Input,
595 ) -> Result<TaskHandle<(queue::Worker<Worker>, Output)>, SubmitError> {
596 let permit = loop {
597 match self.task_permits.clone().try_acquire_owned() {
598 Ok(permit) => break permit,
599 Err(TryAcquireError::NoPermits) => {
600 std::hint::spin_loop();
601 }
602 Err(TryAcquireError::Closed) => {
603 return Err(SubmitError);
604 }
605 }
606 };
607 Ok(self.spawn(input, permit))
608 }
609}
610
611impl<Input, Output, Worker> Pipeline for BlockingEngine<Input, Output, Worker>
616where
617 Input: TaskInput,
618 Worker: BlockingWorker<Input, Output>,
619 Output: 'static + Send + Sync,
620{
621 type Input = Input;
622 type Output = Output;
623 type Resource = queue::Worker<Worker>;
624
625 async fn submit(&self, input: Self::Input) -> Result<SubmitHandle<Self>, SubmitError> {
626 let permit = self
627 .task_permits
628 .clone()
629 .acquire_owned()
630 .instrument(tracing::debug_span!("waiting to enter input queue"))
631 .await
632 .map_err(|_| SubmitError)?;
633 Ok(PipelineHandle::new(self.spawn(input, permit)))
634 }
635
636 fn try_submit(
637 &self,
638 input: Self::Input,
639 ) -> Result<SubmitHandle<Self>, TrySubmitError<Self::Input>> {
640 let permit_result = self.task_permits.clone().try_acquire_owned();
641 match permit_result {
642 Ok(permit) => Ok(PipelineHandle::new(self.spawn(input, permit))),
643 Err(TryAcquireError::NoPermits) => Err(TrySubmitError::NoCapacity(input)),
644 Err(TryAcquireError::Closed) => Err(TrySubmitError::Closed),
645 }
646 }
647}
648
649#[derive(Debug, Clone)]
677pub struct RayonEngine<Input, Output, Worker> {
678 task_permits: Arc<Semaphore>,
679 workers: Arc<WorkerQueue<Worker>>,
680 _marker: PhantomData<(Input, Output)>,
681}
682
683impl<Input, Output, Worker> RayonEngine<Input, Output, Worker>
684where
685 Input: TaskInput,
686 Worker: RayonWorker<Input, Output>,
687 Output: 'static + Send + Sync,
688{
689 pub fn new(workers: Vec<Worker>, permits: Arc<Semaphore>) -> Self {
704 Self {
705 workers: Arc::new(WorkerQueue::new(workers)),
706 task_permits: permits,
707 _marker: PhantomData,
708 }
709 }
710
711 pub fn single_permit_per_worker(workers: Vec<Worker>) -> Self {
724 let num_workers = workers.len();
725 Self::new(workers, Arc::new(Semaphore::new(num_workers)))
726 }
727
728 fn spawn(
729 &self,
730 input: Input,
731 permit: OwnedSemaphorePermit,
732 ) -> TaskHandle<(queue::Worker<Worker>, Output)> {
733 let workers = self.workers.clone();
734 let handle = tokio::spawn(
735 async move {
736 let permit = permit;
737 let worker = workers
739 .pop()
740 .instrument(tracing::debug_span!("waiting for a worker"))
741 .await
742 .map_err(TaskJoinError::from)?;
743 drop(permit);
745 let ret = crate::rayon::spawn(move || {
747 let output = worker.call(input);
748 (worker, output)
749 })
750 .await
751 .unwrap();
752 Ok(ret)
753 }
754 .in_current_span(),
755 );
756 TaskHandle { inner: handle }
757 }
758}
759
760impl<Input, Output, Worker> Pipeline for RayonEngine<Input, Output, Worker>
765where
766 Input: TaskInput,
767 Worker: RayonWorker<Input, Output>,
768 Output: 'static + Send + Sync,
769{
770 type Input = Input;
771 type Output = Output;
772 type Resource = queue::Worker<Worker>;
773 async fn submit(&self, input: Self::Input) -> Result<SubmitHandle<Self>, SubmitError> {
774 let permit = self
775 .task_permits
776 .clone()
777 .acquire_owned()
778 .instrument(tracing::debug_span!("waiting to enter input queue"))
779 .await
780 .map_err(|_| SubmitError)?;
781 Ok(PipelineHandle::new(self.spawn(input, permit)))
782 }
783
784 fn try_submit(
785 &self,
786 input: Self::Input,
787 ) -> Result<SubmitHandle<Self>, TrySubmitError<Self::Input>> {
788 let permit_result = self.task_permits.clone().try_acquire_owned();
789 match permit_result {
790 Ok(permit) => Ok(PipelineHandle::new(self.spawn(input, permit))),
791 Err(TryAcquireError::NoPermits) => Err(TrySubmitError::NoCapacity(input)),
792 Err(TryAcquireError::Closed) => Err(TrySubmitError::Closed),
793 }
794 }
795}
796
797#[derive(Clone, Debug, Copy)]
822pub struct Chain<First, Second> {
823 first: First,
824 second: Second,
825}
826
827impl<First, Second> Chain<First, Second>
828where
829 First: Pipeline + Clone,
830 Second: Pipeline + Clone,
831 First::Output: Into<Second::Input>,
832{
833 pub fn new(first: First, second: Second) -> Self {
846 Self { first, second }
847 }
848
849 pub fn first(&self) -> &First {
854 &self.first
855 }
856
857 pub fn second(&self) -> &Second {
862 &self.second
863 }
864
865 fn spawn(
866 &self,
867 first_handle: TaskHandle<(First::Resource, First::Output)>,
868 ) -> TaskHandle<(Second::Resource, Second::Output)> {
869 let second = self.second.clone();
870 let handle = tokio::spawn(
871 async move {
872 let first_handle = first_handle;
873 let (first_resource, first_output) = first_handle.await?;
874 let second_input: Second::Input = first_output.into();
875 let second_handle =
877 second.submit(second_input).await.expect("failed to submit second task");
878 drop(first_resource);
880 let second_handle = second_handle.into_inner();
882 second_handle.await
883 }
884 .in_current_span(),
885 );
886 TaskHandle { inner: handle }
887 }
888}
889
890impl<First, Second> Pipeline for Chain<First, Second>
895where
896 First: Pipeline + Clone,
897 Second: Pipeline + Clone,
898 First::Output: Into<Second::Input>,
899{
900 type Input = First::Input;
901 type Output = Second::Output;
902 type Resource = Second::Resource;
903 async fn submit(&self, input: Self::Input) -> Result<SubmitHandle<Self>, SubmitError> {
916 let first_handle = self.first.submit(input).await?;
917 Ok(PipelineHandle::new(self.spawn(first_handle.into_inner())))
918 }
919
920 fn try_submit(
931 &self,
932 input: Self::Input,
933 ) -> Result<SubmitHandle<Self>, TrySubmitError<Self::Input>> {
934 let first_handle = self.first.try_submit(input)?;
935 Ok(PipelineHandle::new(self.spawn(first_handle.into_inner())))
936 }
937}
938
939impl<P: Pipeline> Pipeline for Arc<P> {
958 type Input = P::Input;
959 type Output = P::Output;
960 type Resource = P::Resource;
961
962 #[inline]
963 async fn submit(&self, input: Self::Input) -> Result<SubmitHandle<Self>, SubmitError> {
964 self.as_ref().submit(input).await
965 }
966
967 #[inline]
968 fn try_submit(
969 &self,
970 input: Self::Input,
971 ) -> Result<SubmitHandle<Self>, TrySubmitError<Self::Input>> {
972 self.as_ref().try_submit(input)
973 }
974}
975
976#[derive(Debug, Clone)]
977pub struct PipelineBuilder<P = ()> {
978 pipeline: P,
979}
980
981impl PipelineBuilder {
982 pub fn new<P: Pipeline>(pipeline: P) -> PipelineBuilder<P> {
983 PipelineBuilder { pipeline }
984 }
985}
986
987impl<P: Pipeline> PipelineBuilder<P> {
988 pub fn build(self) -> P {
994 self.pipeline
995 }
996
997 pub fn through<Q>(self, pipeline: Q) -> PipelineBuilder<Chain<P, Q>>
1007 where
1008 P: Clone,
1009 Q: Pipeline + Clone,
1010 P::Output: Into<Q::Input>,
1011 {
1012 PipelineBuilder { pipeline: Chain::new(self.pipeline, pipeline) }
1013 }
1014}
1015
1016#[cfg(test)]
1017mod tests {
1018 use futures::{prelude::*, stream::FuturesOrdered};
1019 use rand::Rng;
1020 use std::time::Duration;
1021 use tokio::task::JoinSet;
1022
1023 use super::*;
1024
1025 #[derive(Debug, Clone)]
1026 struct TestWorker;
1027
1028 #[derive(Debug, Clone)]
1029 struct TestTask {
1030 time: Duration,
1031 hanging_probability: f64,
1032 }
1033
1034 impl AsyncWorker<TestTask, ()> for TestWorker {
1035 async fn call(&self, input: TestTask) {
1036 tokio::time::sleep(input.time).await;
1037
1038 let should_hang = rand::thread_rng().gen_bool(input.hanging_probability);
1039 if should_hang {
1040 loop {
1041 tokio::task::yield_now().await;
1042 }
1043 }
1044 }
1045 }
1046
1047 #[tokio::test]
1048 #[allow(clippy::print_stdout)]
1049 async fn test_async_engine() {
1050 let num_workers = 5;
1051 let task_queue_length = 5;
1052 let num_tasks_spawned = 10;
1053 let wait_duration = Duration::from_millis(10);
1054
1055 let workers = (0..num_workers).map(|_| TestWorker).collect();
1056 let engine = Arc::new(AsyncEngine::new(workers, task_queue_length));
1057
1058 let tasks = (0..num_tasks_spawned)
1059 .map(|_| TestTask { time: wait_duration, hanging_probability: 0.0 })
1060 .collect::<Vec<_>>();
1061
1062 let mut join_set = JoinSet::new();
1064 let time = tokio::time::Instant::now();
1065 for task in tasks {
1066 let e = engine.clone();
1067 join_set.spawn(async move { e.submit(task).await.unwrap().await.unwrap() });
1068 }
1069 join_set.join_all().await;
1070 let duration = time.elapsed();
1071 println!("Time taken for async engine: {:?}", duration);
1072
1073 let mut join_set = JoinSet::new();
1075 let tasks_per_worker = num_tasks_spawned / num_workers;
1076 let time = tokio::time::Instant::now();
1077 for _ in 0..num_workers {
1078 join_set.spawn(async move {
1079 for _ in 0..tasks_per_worker {
1080 tokio::time::sleep(wait_duration).await;
1081 }
1082 });
1083 }
1084 join_set.join_all().await;
1085 let duration = time.elapsed();
1086 println!("Time taken for complete parallelism: {:?}", duration);
1087 }
1088
1089 #[tokio::test]
1090 #[allow(clippy::print_stdout)]
1091 async fn test_hanging_task_async_engine() {
1092 let num_workers = 1;
1093 let task_queue_length = 2;
1094 let num_tasks_spawned = 100;
1095 let wait_duration = Duration::from_millis(1);
1096 let hanging_probability = 0.5;
1097 let timeout = Duration::from_millis(100);
1098
1099 let workers = (0..num_workers).map(|_| TestWorker).collect();
1100 let engine = Arc::new(AsyncEngine::new(workers, task_queue_length));
1101
1102 let tasks = (0..num_tasks_spawned)
1103 .map(|_| TestTask { time: wait_duration, hanging_probability })
1104 .collect::<Vec<_>>();
1105
1106 let mut join_set = JoinSet::new();
1108 let time = tokio::time::Instant::now();
1109 for task in tasks {
1110 let handle = engine.submit(task).await.unwrap();
1111 let future = async move { handle.await.unwrap() };
1112 join_set.spawn(async move { tokio::time::timeout(timeout, future).await });
1113 }
1114
1115 let mut success_count = 0;
1116 while let Some(result) = join_set.join_next().await {
1117 let result = result.unwrap();
1118 if result.is_ok() {
1119 success_count += 1;
1120 }
1121 }
1122 let duration = time.elapsed();
1123 println!("Time taken for async engine: {:?}, success count: {success_count}", duration);
1124 }
1125
1126 #[tokio::test]
1127 #[allow(clippy::print_stdout)]
1128 async fn test_blocking_engine() {
1129 #[derive(Debug, Clone)]
1130 struct SummingWorker;
1131
1132 #[derive(Debug, Clone)]
1133 struct SummingTask {
1134 summands: Vec<u32>,
1135 }
1136
1137 impl BlockingWorker<SummingTask, u32> for SummingWorker {
1138 fn call(&self, input: SummingTask) -> u32 {
1139 input.summands.iter().sum()
1140 }
1141 }
1142
1143 let num_workers = 10;
1144 let task_queue_length = 20;
1145 let num_tasks_spawned = 10;
1146 let max_summands = 20;
1147
1148 let workers = (0..num_workers).map(|_| SummingWorker).collect();
1149 let engine = Arc::new(BlockingEngine::new(workers, task_queue_length));
1150
1151 let mut rng = rand::thread_rng();
1152 let tasks = (0..num_tasks_spawned)
1153 .map(|_| SummingTask { summands: vec![1; rng.gen_range(1..=max_summands)] })
1154 .collect::<Vec<_>>();
1155
1156 let mut results = FuturesOrdered::new();
1158 for task in tasks.iter() {
1159 results.push_back(engine.submit(task.clone()).await.unwrap());
1160 }
1161 let results = results.collect::<Vec<_>>().await;
1162 for (task, result) in tasks.iter().zip(results) {
1163 let result = result.unwrap();
1164 let expected = task.summands.iter().sum();
1165 assert_eq!(result, expected);
1166 }
1167 }
1168
1169 #[tokio::test]
1170 #[allow(clippy::print_stdout)]
1171 #[should_panic]
1172 async fn test_async_failing_engine() {
1173 #[derive(Debug, Clone)]
1174 struct FailingWorker;
1175
1176 #[derive(Debug, Clone)]
1177 struct TestTask {
1178 time: Duration,
1179 }
1180
1181 impl AsyncWorker<TestTask, ()> for FailingWorker {
1182 async fn call(&self, input: TestTask) {
1183 if input.time > Duration::from_millis(50) {
1184 panic!("not interested to wait for this long");
1185 }
1186 tokio::time::sleep(input.time).await;
1187 }
1188 }
1189 let num_workers = 10;
1190 let task_queue_length = 20;
1191 let wait_duration = 100;
1192
1193 let workers = (0..num_workers).map(|_| FailingWorker).collect();
1194 let engine = Arc::new(AsyncEngine::new(workers, task_queue_length));
1195
1196 let tasks = (0..wait_duration)
1197 .map(|i| TestTask { time: Duration::from_millis(i) })
1198 .collect::<Vec<_>>();
1199
1200 let mut join_set = JoinSet::new();
1202 let time = tokio::time::Instant::now();
1203 for task in tasks {
1204 let e = engine.clone();
1205 join_set.spawn(async move { e.submit(task).await.unwrap().await.unwrap() });
1206 }
1207 join_set.join_all().await;
1208 let duration = time.elapsed();
1209 println!("Time taken for async engine: {:?}", duration);
1210 }
1211
1212 #[tokio::test]
1213 #[allow(clippy::print_stdout)]
1214 async fn test_chained_pipelines() {
1215 #[derive(Debug, Clone)]
1216 struct FirstTask;
1217
1218 #[derive(Debug, Clone)]
1219 struct FirstWorker;
1220
1221 impl BlockingWorker<FirstTask, SecondTask> for FirstWorker {
1222 fn call(&self, _input: FirstTask) -> SecondTask {
1223 let mut rng = rand::thread_rng();
1224 SecondTask { value: rng.gen_range(200..=1000) }
1225 }
1226 }
1227
1228 #[derive(Debug, Clone)]
1229 struct SecondWorker;
1230
1231 #[derive(Debug, Clone)]
1232 struct SecondTask {
1233 value: u64,
1234 }
1235
1236 impl AsyncWorker<SecondTask, u64> for SecondWorker {
1237 async fn call(&self, input: SecondTask) -> u64 {
1238 tokio::time::sleep(Duration::from_millis(input.value)).await;
1239 input.value
1240 }
1241 }
1242
1243 let first_workers = (0..10).map(|_| FirstWorker).collect();
1244 let first_pipeline = Arc::new(BlockingEngine::single_permit_per_worker(first_workers));
1245 let second_workers = (0..10).map(|_| SecondWorker).collect();
1246 let second_pipeline = Arc::new(AsyncEngine::single_permit_per_worker(second_workers));
1247 let chain = Chain::new(first_pipeline, second_pipeline);
1248
1249 let handles = (0..10)
1250 .map(|_| chain.submit(FirstTask))
1251 .collect::<FuturesOrdered<_>>()
1252 .try_collect::<Vec<_>>()
1253 .await
1254 .unwrap();
1255
1256 for handle in handles {
1257 let _result = handle.await.unwrap();
1258 }
1259 }
1260
1261 #[tokio::test]
1262 #[allow(clippy::print_stdout)]
1263 async fn test_timing_chained_pipelines() {
1264 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
1265 struct SleepTask {
1266 duration: Duration,
1267 }
1268
1269 #[derive(Debug, Clone)]
1270 struct SleepWorker;
1271
1272 impl AsyncWorker<SleepTask, SleepTask> for SleepWorker {
1273 async fn call(&self, input: SleepTask) -> SleepTask {
1274 let sleep_duration = input.duration;
1275 tokio::time::sleep(sleep_duration).await;
1276 input
1277 }
1278 }
1279
1280 let num_workers = 10;
1281
1282 let workers = (0..num_workers).map(|_| SleepWorker).collect::<Vec<_>>();
1283 let make_engine =
1284 |workers: Vec<SleepWorker>| Arc::new(AsyncEngine::single_permit_per_worker(workers));
1285
1286 let pipeline = PipelineBuilder::new(make_engine(workers.clone()))
1287 .through(make_engine(workers.clone()))
1288 .through(make_engine(workers.clone()))
1289 .through(make_engine(workers.clone()))
1290 .through(make_engine(workers.clone()))
1291 .build();
1292
1293 let chain_input_task = SleepTask { duration: Duration::from_millis(100) };
1294 let single_input_task = SleepTask { duration: Duration::from_millis(500) };
1295
1296 let time = tokio::time::Instant::now();
1297 let chain_result = pipeline.submit(chain_input_task).await.unwrap().await.unwrap();
1298 let chain_duration = time.elapsed();
1299 println!("Chain duration: {:?}", chain_duration);
1300 assert_eq!(chain_result, chain_input_task);
1301
1302 let single_engine = make_engine(workers.clone());
1303 let time = tokio::time::Instant::now();
1304 let single_result = single_engine.submit(single_input_task).await.unwrap().await.unwrap();
1305 let single_duration = time.elapsed();
1306 println!("Single duration: {:?}", single_duration);
1307 assert_eq!(single_result, single_input_task);
1308 }
1309}