Skip to main content

xet_runtime/core/
par_utils.rs

1use std::future::Future;
2use std::sync::Arc;
3
4use thiserror::Error;
5use tokio::sync::{AcquireError, Semaphore};
6use tokio::task::{JoinError, JoinSet};
7
8#[derive(Debug, Error)]
9pub enum ParutilsError<E: Send + Sync + 'static> {
10    #[error(transparent)]
11    Join(#[from] JoinError),
12    #[error(transparent)]
13    Acquire(#[from] AcquireError),
14    #[error(transparent)]
15    Task(E),
16    #[error("Infallible, this should not be possible: {0}")]
17    Infallible(String),
18}
19
20/// Runs all the futures in the provided iterator with a maximum concurrency limit.
21///
22/// This function ensures that a permit is acquired from the provided semaphore before any work
23/// is done for a future, thus limiting concurrency based on the number of permits in the semaphore.
24///
25/// Each future in the iterator must return a `Result<T, E>`. If any future returns an error,
26/// or if there is a `JoinError` or failure to acquire a semaphore permit, the function will
27/// return an error as soon as possible.
28///
29/// If all tasks complete successfully, the function returns a `Vec<T>` containing the results
30/// of the successful futures, in the same order as they were produced by the iterator.
31///
32/// # Arguments
33///
34/// * `futures_it` - An iterator of futures, where each future resolves to a `Result<T, E>`.
35/// * `max_concurrent` - An `Arc<Semaphore>` that limits the number of concurrent tasks.
36///
37/// # Type Parameters
38///
39/// * `Fut` - The type of the futures in the iterator. Each future must output a `Result<T, E>`.
40/// * `T` - The type of the successful result produced by each future.
41/// * `E` - The type of the error produced by each future.
42///
43/// # Returns
44///
45/// A `Result` containing:
46/// * `Ok(Vec<T>)` - A vector of successful results if all tasks complete successfully.
47/// * `Err(ParutilsError<E>)` - An error if any task fails, a semaphore permit cannot be acquired, or a `JoinError`
48///   occurs.
49///
50/// # Errors
51///
52/// This function returns a `ParutilsError<E>` in the following cases:
53/// * A task returns an error of type `E`.
54/// * A semaphore permit cannot be acquired.
55/// * A `JoinError` occurs while waiting for a task to complete.
56///
57/// # Example
58///
59/// ```rust
60/// use std::sync::Arc;
61///
62/// use tokio::sync::Semaphore;
63/// use xet_runtime::core::par_utils::run_constrained_with_semaphore;
64///
65/// #[tokio::main]
66/// async fn main() {
67///     let semaphore = Arc::new(Semaphore::new(2)); // Limit concurrency to 2 tasks.
68///     let futures = (1..=3).map(|n| async move { Ok::<_, ()>(n) });
69///
70///     let results = run_constrained_with_semaphore(futures.into_iter(), semaphore).await;
71///     assert_eq!(results.unwrap(), vec![1, 2, 3]);
72/// }
73/// ```
74pub async fn run_constrained_with_semaphore<Fut, T, E>(
75    futures_it: impl Iterator<Item = Fut>,
76    max_concurrent: Arc<Semaphore>,
77) -> Result<Vec<T>, ParutilsError<E>>
78where
79    Fut: Future<Output = Result<T, E>> + Send + 'static,
80    T: Send + Sync + 'static,
81    E: Send + Sync + 'static,
82{
83    let handle = tokio::runtime::Handle::current();
84    let mut js: JoinSet<Result<(usize, T), ParutilsError<E>>> = JoinSet::new();
85    for (i, fut) in futures_it.enumerate() {
86        let semaphore = max_concurrent.clone();
87        js.spawn_on(
88            async move {
89                let _permit = semaphore.acquire().await?;
90                let res = fut.await.map_err(ParutilsError::Task)?;
91                Ok((i, res))
92            },
93            &handle,
94        );
95    }
96
97    let mut results: Vec<Option<T>> = Vec::with_capacity(js.len());
98    (0..js.len()).for_each(|_| results.push(None));
99    while let Some(result) = js.join_next().await {
100        let (i, res) = result??;
101        debug_assert!(results[i].is_none());
102        results[i] = Some(res);
103    }
104    debug_assert!(js.is_empty());
105    debug_assert!(results.iter().all(|r| r.is_some()));
106
107    // Convert from Vec<Option<T>> to Option<Vec<T>>
108    // Should be impossible to get back a None, that would indicate the js.join_next() should have
109    // more tasks
110    let Some(result) = results.into_iter().collect() else {
111        return Err(ParutilsError::Infallible("A task was unaccounted for when collecting result".to_string()));
112    };
113    Ok(result)
114}
115
116/// Like tokio_run_max_concurrency_fold_result_with_semaphore but callers can pass in the number
117/// of concurrent tasks they wish to allow and the semaphore is created inside this function scope
118pub async fn run_constrained<Fut, T, E>(
119    futures_it: impl Iterator<Item = Fut>,
120    max_concurrent: usize,
121) -> Result<Vec<T>, ParutilsError<E>>
122where
123    Fut: Future<Output = Result<T, E>> + Send + 'static,
124    T: Send + Sync + 'static,
125    E: Send + Sync + 'static,
126{
127    let semaphore = Arc::new(Semaphore::new(max_concurrent));
128    run_constrained_with_semaphore(futures_it, semaphore).await
129}
130
131#[cfg(test)]
132mod parallel_tests {
133    use std::sync::atomic::{AtomicU32, Ordering};
134
135    use super::*;
136
137    #[tokio::test(flavor = "multi_thread")]
138    async fn test_simple_parallel() {
139        let data: Vec<String> = (0..400).map(|i| format!("Number = {}", &i)).collect();
140
141        let data_ref: Vec<String> = data.iter().enumerate().map(|(i, s)| format!("{}{}{}", &s, ":", &i)).collect();
142
143        let r = run_constrained(
144            data.into_iter()
145                .enumerate()
146                .map(|(i, s)| async move { Result::<_, ()>::Ok(format!("{}{}{}", &s, ":", &i)) }),
147            4,
148        )
149        .await
150        .unwrap();
151
152        assert_eq!(data_ref.len(), r.len());
153        for i in 0..data_ref.len() {
154            assert_eq!(data_ref[i], r[i]);
155        }
156    }
157
158    #[tokio::test(flavor = "multi_thread")]
159    async fn test_parallel_with_sleeps() {
160        let data: Vec<String> = (0..400).map(|i| format!("Number = {}", &i)).collect();
161
162        let data_ref: Vec<String> = data.iter().enumerate().map(|(i, s)| format!("{}{}{}", &s, ":", &i)).collect();
163
164        let r = run_constrained(
165            data.into_iter().enumerate().map(|(i, s)| async move {
166                tokio::time::sleep(std::time::Duration::from_millis(401 - i as u64)).await;
167                Result::<_, ()>::Ok(format!("{}{}{}", &s, ":", &i))
168            }),
169            100,
170        )
171        .await
172        .unwrap();
173
174        assert_eq!(data_ref.len(), r.len());
175        for i in 0..data_ref.len() {
176            assert_eq!(data_ref[i], r[i]);
177        }
178    }
179
180    #[tokio::test(flavor = "multi_thread")]
181    async fn test_max_concurrent_constraint() {
182        const NUM_TASKS: u64 = 100;
183        const TASK_DURATION_BASE_MS: u64 = 100;
184        const MAX_CONCURRENT: usize = 5;
185
186        // Counters to track concurrent task execution
187        let current_running = Arc::new(AtomicU32::new(0));
188        let max_concurrent_observed = Arc::new(AtomicU32::new(0));
189
190        let futures = (0..NUM_TASKS).map(|i| {
191            let current_running = current_running.clone();
192            let max_concurrent_observed = max_concurrent_observed.clone();
193
194            async move {
195                // Increment running counter
196                let running = current_running.fetch_add(1, Ordering::SeqCst) + 1;
197
198                // Update max observed if necessary
199                max_concurrent_observed.fetch_max(running, Ordering::SeqCst);
200
201                // Simulate work
202                tokio::time::sleep(std::time::Duration::from_millis(TASK_DURATION_BASE_MS - i)).await;
203
204                // Decrement running counter
205                current_running.fetch_sub(1, Ordering::SeqCst);
206
207                Result::<_, ()>::Ok(i)
208            }
209        });
210
211        let results = run_constrained(futures, MAX_CONCURRENT).await.unwrap();
212
213        // Verify all tasks completed successfully
214        assert_eq!(results.len(), NUM_TASKS as usize);
215        for i in 0..NUM_TASKS {
216            assert_eq!(results[i as usize], i);
217        }
218
219        // Verify that we never exceeded the concurrency limit
220        let max_observed = max_concurrent_observed.load(Ordering::SeqCst);
221        assert!(
222            max_observed <= MAX_CONCURRENT as u32,
223            "Max concurrent tasks observed: {}, but limit was: {}",
224            max_observed,
225            MAX_CONCURRENT
226        );
227
228        assert_eq!(
229            max_observed, MAX_CONCURRENT as u32,
230            "Expected to see exactly {} concurrent tasks, but saw {}",
231            MAX_CONCURRENT, max_observed
232        );
233
234        // Ensure no tasks are still running
235        assert_eq!(current_running.load(Ordering::SeqCst), 0);
236    }
237
238    #[tokio::test(flavor = "multi_thread")]
239    async fn test_returns_error() {
240        let futures = (0..10).map(|i| async move {
241            if i == 5 {
242                Result::<_, i32>::Err(5)
243            } else {
244                Result::<_, i32>::Ok(i)
245            }
246        });
247
248        let result = run_constrained(futures, 2).await;
249        assert!(matches!(result, Err(ParutilsError::Task(5))));
250    }
251
252    #[tokio::test(flavor = "multi_thread")]
253    async fn test_returns_join_error_on_panic() {
254        let futures = (0..10).map(|i| async move { if i == 5 { panic!("5") } else { Result::<_, i32>::Ok(i) } });
255
256        let result = run_constrained(futures, 2).await;
257        if let Err(ParutilsError::Join(e)) = result {
258            assert!(e.is_panic());
259        } else {
260            assert!(false, "Expected to panic, but got {:?}", result);
261        }
262    }
263}