Skip to main content

simple_agents_workflow/
scheduler.rs

1use std::future::Future;
2
3use futures::stream::{FuturesUnordered, StreamExt};
4
5/// Bounded async scheduler for DAG-adjacent fan-out workloads.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub struct DagScheduler {
8    max_in_flight: usize,
9}
10
11impl DagScheduler {
12    /// Creates a scheduler with a bounded number of concurrent tasks.
13    pub fn new(max_in_flight: usize) -> Self {
14        Self {
15            max_in_flight: max_in_flight.max(1),
16        }
17    }
18
19    /// Returns the configured maximum number of in-flight tasks.
20    pub fn max_in_flight(self) -> usize {
21        self.max_in_flight
22    }
23
24    /// Executes tasks with bounded concurrency and deterministic result ordering.
25    pub async fn run_bounded<I, F, Fut, T, E>(
26        &self,
27        inputs: I,
28        mut task_builder: F,
29    ) -> Result<Vec<T>, E>
30    where
31        I: IntoIterator,
32        F: FnMut(I::Item) -> Fut,
33        Fut: Future<Output = Result<T, E>>,
34    {
35        let mut indexed_inputs = inputs.into_iter().enumerate();
36        let mut in_flight = FuturesUnordered::new();
37        let mut results: Vec<Option<T>> = Vec::new();
38
39        loop {
40            while in_flight.len() < self.max_in_flight {
41                let Some((index, item)) = indexed_inputs.next() else {
42                    break;
43                };
44                if results.len() <= index {
45                    results.resize_with(index + 1, || None);
46                }
47                let task = task_builder(item);
48                in_flight.push(async move { (index, task.await) });
49            }
50
51            let Some((index, output)) = in_flight.next().await else {
52                break;
53            };
54
55            match output {
56                Ok(value) => {
57                    results[index] = Some(value);
58                }
59                Err(error) => return Err(error),
60            }
61        }
62
63        Ok(results
64            .into_iter()
65            .map(|entry| entry.expect("scheduler result slot must be filled"))
66            .collect())
67    }
68}
69
70#[cfg(test)]
71mod tests {
72    use std::sync::Arc;
73    use std::time::{Duration, Instant};
74
75    use tokio::sync::Mutex;
76    use tokio::time::sleep;
77
78    use super::DagScheduler;
79
80    #[tokio::test]
81    async fn respects_max_in_flight_limit() {
82        let scheduler = DagScheduler::new(2);
83        let in_flight = Arc::new(Mutex::new(0usize));
84        let peak = Arc::new(Mutex::new(0usize));
85
86        let outputs = scheduler
87            .run_bounded(0..8usize, {
88                let in_flight = Arc::clone(&in_flight);
89                let peak = Arc::clone(&peak);
90                move |item| {
91                    let in_flight = Arc::clone(&in_flight);
92                    let peak = Arc::clone(&peak);
93                    async move {
94                        {
95                            let mut active = in_flight.lock().await;
96                            *active += 1;
97                            let mut peak_guard = peak.lock().await;
98                            *peak_guard = (*peak_guard).max(*active);
99                        }
100
101                        sleep(Duration::from_millis(10)).await;
102
103                        {
104                            let mut active = in_flight.lock().await;
105                            *active = active.saturating_sub(1);
106                        }
107
108                        Ok::<usize, ()>(item * 2)
109                    }
110                }
111            })
112            .await
113            .expect("bounded scheduling should succeed");
114
115        assert_eq!(outputs, vec![0, 2, 4, 6, 8, 10, 12, 14]);
116        assert!(*peak.lock().await <= 2);
117    }
118
119    #[tokio::test]
120    async fn runs_concurrently_when_limit_above_one() {
121        let serial_scheduler = DagScheduler::new(1);
122        let serial_started = Instant::now();
123
124        let _ = serial_scheduler
125            .run_bounded(0..4usize, |_| async {
126                sleep(Duration::from_millis(20)).await;
127                Ok::<(), ()>(())
128            })
129            .await
130            .expect("scheduler should run all tasks");
131
132        let serial_elapsed = serial_started.elapsed();
133
134        let parallel_scheduler = DagScheduler::new(4);
135        let parallel_started = Instant::now();
136
137        let _ = parallel_scheduler
138            .run_bounded(0..4usize, |_| async {
139                sleep(Duration::from_millis(20)).await;
140                Ok::<(), ()>(())
141            })
142            .await
143            .expect("scheduler should run all tasks");
144
145        let parallel_elapsed = parallel_started.elapsed();
146
147        assert!(
148            parallel_elapsed < serial_elapsed,
149            "expected parallel scheduler to finish faster (parallel={parallel_elapsed:?}, serial={serial_elapsed:?})"
150        );
151    }
152}