1use std::{collections::BTreeMap, pin::Pin, sync::Arc, time::Duration};
48
49use eyre::Report;
50use futures::{stream::FuturesUnordered, Future, StreamExt};
51
52struct Racer<T> {
54 name: String,
55 fut: Pin<Box<dyn Future<Output = Result<T, Report>> + Send>>,
56}
57
58#[derive(Debug, Clone)]
60pub struct RaceResult<T> {
61 pub name: String,
62 pub duration: Duration,
63 pub disqualified: bool,
64 pub error: Option<Arc<Report>>,
65 pub value: Option<T>,
66}
67
68pub struct RaceTrack<T> {
70 timeout: Duration,
71 racers: Vec<Racer<T>>,
72 rankings: BTreeMap<usize, RaceResult<T>>,
73}
74
75impl<T> Default for RaceTrack<T> {
76 fn default() -> Self {
77 Self {
78 timeout: Duration::from_secs(5),
79 rankings: BTreeMap::new(),
80 racers: Vec::new(),
81 }
82 }
83}
84
85impl<T> RaceTrack<T>
86where
87 T: std::fmt::Debug + Clone + Send + 'static,
88{
89 pub fn disqualify_after(timeout: Duration) -> Self {
91 Self {
92 timeout,
93 ..Default::default()
94 }
95 }
96
97 pub fn add_racer<F>(&mut self, name: impl Into<String>, fut: F)
99 where
100 F: Future<Output = Result<T, Report>> + Send + 'static,
101 {
102 self.racers.push(Racer {
103 name: name.into(),
104 fut: Box::pin(fut),
105 });
106 }
107
108 pub async fn run(&mut self) {
110 let racers = std::mem::take(&mut self.racers);
111
112 self.rankings.clear();
114
115 let mut tasks = FuturesUnordered::new();
117 for racer in racers {
118 let name = racer.name.clone();
119 let timeout = self.timeout;
120 tasks.push(tokio::spawn(async move {
121 let start = std::time::Instant::now();
123 let res = tokio::time::timeout(timeout, racer.fut).await;
124 let duration = start.elapsed();
125
126 let mut disqualified = res.is_err();
128
129 let result = res.unwrap_or_else(|_| Err(eyre::eyre!("Racer timed out")));
131 let (value, error) = match result {
132 Ok(value) => (Some(value), None),
133 Err(error) => {
134 disqualified = true;
136 (None, Some(error))
137 },
138 };
139
140 RaceResult {
141 name,
142 duration,
143 disqualified,
144 error: error.map(Arc::new),
145 value,
146 }
147 }));
148 }
149
150 let mut i = 0;
152 while let Some(result) = tasks.next().await {
153 self.rankings.insert(i, result.unwrap());
154 i += 1;
155 }
156 }
157
158 pub fn rankings(&self) -> Vec<RaceResult<T>> {
160 self.rankings.values().cloned().collect()
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use std::time::Duration;
167
168 use tokio::time::sleep;
169 use eyre::eyre;
170
171 use super::*;
172
173 #[tokio::test]
174 async fn on_your_mark_get_set_go() {
175 let mut race_track = RaceTrack::disqualify_after(Duration::from_millis(20));
176
177 race_track.add_racer("Racer #1", async move {
178 sleep(Duration::from_millis(5)).await;
179 Ok(1)
180 });
181 race_track.add_racer("Racer #2", async move {
182 sleep(Duration::from_millis(10)).await;
183 Ok(2)
184 });
185 race_track.add_racer("Racer #3", async move {
186 sleep(Duration::from_millis(15)).await;
187 Ok(3)
188 });
189 race_track.add_racer("Racer #4", async move {
190 sleep(Duration::from_millis(25)).await;
191 Ok(4)
192 });
193 race_track.add_racer("Racer #5", async move {
194 Err(eyre!("Racer #5 failed!"))
195 });
196
197 race_track.run().await;
198 let rankings = race_track.rankings();
199
200 println!("{:#?}", rankings);
201
202 assert_eq!(rankings[0].value, None);
203 assert_eq!(rankings[0].name, "Racer #5");
204 assert_eq!(rankings[0].disqualified, true);
205 assert_eq!(
206 rankings[0].error.as_ref().unwrap().to_string(),
207 "Racer #5 failed!"
208 );
209 assert_eq!(rankings[0].value, None);
210
211 assert_eq!(rankings[1].name, "Racer #1");
212 assert_eq!(rankings[1].disqualified, false);
213 assert_eq!(rankings[1].value, Some(1));
214
215 assert_eq!(rankings[2].name, "Racer #2");
216 assert_eq!(rankings[2].disqualified, false);
217 assert_eq!(rankings[2].value, Some(2));
218
219 assert_eq!(rankings[3].name, "Racer #3");
220 assert_eq!(rankings[3].disqualified, false);
221 assert_eq!(rankings[3].value, Some(3));
222
223 assert_eq!(rankings[4].name, "Racer #4");
224 assert_eq!(rankings[4].disqualified, true);
225 assert_eq!(
226 rankings[4].error.as_ref().unwrap().to_string(),
227 "Racer timed out"
228 );
229 assert_eq!(rankings[4].value, None);
230 }
231}