simple_agents_workflow/
scheduler.rs1use std::future::Future;
2
3use futures::stream::{FuturesUnordered, StreamExt};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub struct DagScheduler {
8 max_in_flight: usize,
9}
10
11impl DagScheduler {
12 pub fn new(max_in_flight: usize) -> Self {
14 Self {
15 max_in_flight: max_in_flight.max(1),
16 }
17 }
18
19 pub fn max_in_flight(self) -> usize {
21 self.max_in_flight
22 }
23
24 pub async fn run_bounded<I, F, Fut, T, E>(
26 &self,
27 inputs: I,
28 mut task_builder: F,
29 ) -> Result<Vec<T>, E>
30 where
31 I: IntoIterator,
32 F: FnMut(I::Item) -> Fut,
33 Fut: Future<Output = Result<T, E>>,
34 {
35 let mut indexed_inputs = inputs.into_iter().enumerate();
36 let mut in_flight = FuturesUnordered::new();
37 let mut results: Vec<Option<T>> = Vec::new();
38
39 loop {
40 while in_flight.len() < self.max_in_flight {
41 let Some((index, item)) = indexed_inputs.next() else {
42 break;
43 };
44 if results.len() <= index {
45 results.resize_with(index + 1, || None);
46 }
47 let task = task_builder(item);
48 in_flight.push(async move { (index, task.await) });
49 }
50
51 let Some((index, output)) = in_flight.next().await else {
52 break;
53 };
54
55 match output {
56 Ok(value) => {
57 results[index] = Some(value);
58 }
59 Err(error) => return Err(error),
60 }
61 }
62
63 Ok(results
64 .into_iter()
65 .map(|entry| entry.expect("scheduler result slot must be filled"))
66 .collect())
67 }
68}
69
70#[cfg(test)]
71mod tests {
72 use std::sync::Arc;
73 use std::time::{Duration, Instant};
74
75 use tokio::sync::Mutex;
76 use tokio::time::sleep;
77
78 use super::DagScheduler;
79
80 #[tokio::test]
81 async fn respects_max_in_flight_limit() {
82 let scheduler = DagScheduler::new(2);
83 let in_flight = Arc::new(Mutex::new(0usize));
84 let peak = Arc::new(Mutex::new(0usize));
85
86 let outputs = scheduler
87 .run_bounded(0..8usize, {
88 let in_flight = Arc::clone(&in_flight);
89 let peak = Arc::clone(&peak);
90 move |item| {
91 let in_flight = Arc::clone(&in_flight);
92 let peak = Arc::clone(&peak);
93 async move {
94 {
95 let mut active = in_flight.lock().await;
96 *active += 1;
97 let mut peak_guard = peak.lock().await;
98 *peak_guard = (*peak_guard).max(*active);
99 }
100
101 sleep(Duration::from_millis(10)).await;
102
103 {
104 let mut active = in_flight.lock().await;
105 *active = active.saturating_sub(1);
106 }
107
108 Ok::<usize, ()>(item * 2)
109 }
110 }
111 })
112 .await
113 .expect("bounded scheduling should succeed");
114
115 assert_eq!(outputs, vec![0, 2, 4, 6, 8, 10, 12, 14]);
116 assert!(*peak.lock().await <= 2);
117 }
118
119 #[tokio::test]
120 async fn runs_concurrently_when_limit_above_one() {
121 let serial_scheduler = DagScheduler::new(1);
122 let serial_started = Instant::now();
123
124 let _ = serial_scheduler
125 .run_bounded(0..4usize, |_| async {
126 sleep(Duration::from_millis(20)).await;
127 Ok::<(), ()>(())
128 })
129 .await
130 .expect("scheduler should run all tasks");
131
132 let serial_elapsed = serial_started.elapsed();
133
134 let parallel_scheduler = DagScheduler::new(4);
135 let parallel_started = Instant::now();
136
137 let _ = parallel_scheduler
138 .run_bounded(0..4usize, |_| async {
139 sleep(Duration::from_millis(20)).await;
140 Ok::<(), ()>(())
141 })
142 .await
143 .expect("scheduler should run all tasks");
144
145 let parallel_elapsed = parallel_started.elapsed();
146
147 assert!(
148 parallel_elapsed < serial_elapsed,
149 "expected parallel scheduler to finish faster (parallel={parallel_elapsed:?}, serial={serial_elapsed:?})"
150 );
151 }
152}