1use std::future::Future;
2use std::sync::Arc;
3
4use thiserror::Error;
5use tokio::sync::{AcquireError, Semaphore};
6use tokio::task::{JoinError, JoinSet};
7
8#[derive(Debug, Error)]
9pub enum ParutilsError<E: Send + Sync + 'static> {
10 #[error(transparent)]
11 Join(#[from] JoinError),
12 #[error(transparent)]
13 Acquire(#[from] AcquireError),
14 #[error(transparent)]
15 Task(E),
16 #[error("Infallible, this should not be possible: {0}")]
17 Infallible(String),
18}
19
20pub async fn run_constrained_with_semaphore<Fut, T, E>(
75 futures_it: impl Iterator<Item = Fut>,
76 max_concurrent: Arc<Semaphore>,
77) -> Result<Vec<T>, ParutilsError<E>>
78where
79 Fut: Future<Output = Result<T, E>> + Send + 'static,
80 T: Send + Sync + 'static,
81 E: Send + Sync + 'static,
82{
83 let handle = tokio::runtime::Handle::current();
84 let mut js: JoinSet<Result<(usize, T), ParutilsError<E>>> = JoinSet::new();
85 for (i, fut) in futures_it.enumerate() {
86 let semaphore = max_concurrent.clone();
87 js.spawn_on(
88 async move {
89 let _permit = semaphore.acquire().await?;
90 let res = fut.await.map_err(ParutilsError::Task)?;
91 Ok((i, res))
92 },
93 &handle,
94 );
95 }
96
97 let mut results: Vec<Option<T>> = Vec::with_capacity(js.len());
98 (0..js.len()).for_each(|_| results.push(None));
99 while let Some(result) = js.join_next().await {
100 let (i, res) = result??;
101 debug_assert!(results[i].is_none());
102 results[i] = Some(res);
103 }
104 debug_assert!(js.is_empty());
105 debug_assert!(results.iter().all(|r| r.is_some()));
106
107 let Some(result) = results.into_iter().collect() else {
111 return Err(ParutilsError::Infallible("A task was unaccounted for when collecting result".to_string()));
112 };
113 Ok(result)
114}
115
116pub async fn run_constrained<Fut, T, E>(
119 futures_it: impl Iterator<Item = Fut>,
120 max_concurrent: usize,
121) -> Result<Vec<T>, ParutilsError<E>>
122where
123 Fut: Future<Output = Result<T, E>> + Send + 'static,
124 T: Send + Sync + 'static,
125 E: Send + Sync + 'static,
126{
127 let semaphore = Arc::new(Semaphore::new(max_concurrent));
128 run_constrained_with_semaphore(futures_it, semaphore).await
129}
130
131#[cfg(test)]
132mod parallel_tests {
133 use std::sync::atomic::{AtomicU32, Ordering};
134
135 use super::*;
136
137 #[tokio::test(flavor = "multi_thread")]
138 async fn test_simple_parallel() {
139 let data: Vec<String> = (0..400).map(|i| format!("Number = {}", &i)).collect();
140
141 let data_ref: Vec<String> = data.iter().enumerate().map(|(i, s)| format!("{}{}{}", &s, ":", &i)).collect();
142
143 let r = run_constrained(
144 data.into_iter()
145 .enumerate()
146 .map(|(i, s)| async move { Result::<_, ()>::Ok(format!("{}{}{}", &s, ":", &i)) }),
147 4,
148 )
149 .await
150 .unwrap();
151
152 assert_eq!(data_ref.len(), r.len());
153 for i in 0..data_ref.len() {
154 assert_eq!(data_ref[i], r[i]);
155 }
156 }
157
158 #[tokio::test(flavor = "multi_thread")]
159 async fn test_parallel_with_sleeps() {
160 let data: Vec<String> = (0..400).map(|i| format!("Number = {}", &i)).collect();
161
162 let data_ref: Vec<String> = data.iter().enumerate().map(|(i, s)| format!("{}{}{}", &s, ":", &i)).collect();
163
164 let r = run_constrained(
165 data.into_iter().enumerate().map(|(i, s)| async move {
166 tokio::time::sleep(std::time::Duration::from_millis(401 - i as u64)).await;
167 Result::<_, ()>::Ok(format!("{}{}{}", &s, ":", &i))
168 }),
169 100,
170 )
171 .await
172 .unwrap();
173
174 assert_eq!(data_ref.len(), r.len());
175 for i in 0..data_ref.len() {
176 assert_eq!(data_ref[i], r[i]);
177 }
178 }
179
180 #[tokio::test(flavor = "multi_thread")]
181 async fn test_max_concurrent_constraint() {
182 const NUM_TASKS: u64 = 100;
183 const TASK_DURATION_BASE_MS: u64 = 100;
184 const MAX_CONCURRENT: usize = 5;
185
186 let current_running = Arc::new(AtomicU32::new(0));
188 let max_concurrent_observed = Arc::new(AtomicU32::new(0));
189
190 let futures = (0..NUM_TASKS).map(|i| {
191 let current_running = current_running.clone();
192 let max_concurrent_observed = max_concurrent_observed.clone();
193
194 async move {
195 let running = current_running.fetch_add(1, Ordering::SeqCst) + 1;
197
198 max_concurrent_observed.fetch_max(running, Ordering::SeqCst);
200
201 tokio::time::sleep(std::time::Duration::from_millis(TASK_DURATION_BASE_MS - i)).await;
203
204 current_running.fetch_sub(1, Ordering::SeqCst);
206
207 Result::<_, ()>::Ok(i)
208 }
209 });
210
211 let results = run_constrained(futures, MAX_CONCURRENT).await.unwrap();
212
213 assert_eq!(results.len(), NUM_TASKS as usize);
215 for i in 0..NUM_TASKS {
216 assert_eq!(results[i as usize], i);
217 }
218
219 let max_observed = max_concurrent_observed.load(Ordering::SeqCst);
221 assert!(
222 max_observed <= MAX_CONCURRENT as u32,
223 "Max concurrent tasks observed: {}, but limit was: {}",
224 max_observed,
225 MAX_CONCURRENT
226 );
227
228 assert_eq!(
229 max_observed, MAX_CONCURRENT as u32,
230 "Expected to see exactly {} concurrent tasks, but saw {}",
231 MAX_CONCURRENT, max_observed
232 );
233
234 assert_eq!(current_running.load(Ordering::SeqCst), 0);
236 }
237
238 #[tokio::test(flavor = "multi_thread")]
239 async fn test_returns_error() {
240 let futures = (0..10).map(|i| async move {
241 if i == 5 {
242 Result::<_, i32>::Err(5)
243 } else {
244 Result::<_, i32>::Ok(i)
245 }
246 });
247
248 let result = run_constrained(futures, 2).await;
249 assert!(matches!(result, Err(ParutilsError::Task(5))));
250 }
251
252 #[tokio::test(flavor = "multi_thread")]
253 async fn test_returns_join_error_on_panic() {
254 let futures = (0..10).map(|i| async move { if i == 5 { panic!("5") } else { Result::<_, i32>::Ok(i) } });
255
256 let result = run_constrained(futures, 2).await;
257 if let Err(ParutilsError::Join(e)) = result {
258 assert!(e.is_panic());
259 } else {
260 assert!(false, "Expected to panic, but got {:?}", result);
261 }
262 }
263}