speedracer/
lib.rs

1//! A crate for racing `Future`s and getting ranked results back.
2//!
3//! # Example
4//!
5//! ```ignore
6//! use tokio::time::sleep;
7//! use std::time::Duration;
8//!
9//! use speedracer::RaceTrack;
10//!
11//! let mut race_track = RaceTrack::disqualify_after(Duration::from_millis(500));
12//!
13//! race_track.add_racer("Racer #1", async move {
14//!     println!("Racer #1 is starting");
15//!     sleep(std::time::Duration::from_millis(100)).await;
16//!     println!("Racer #1 is ending");
17//!
18//!     Ok(())
19//! });
20//! race_track.add_racer("Racer #2", async move {
21//!     println!("Racer #2 is starting");
22//!     sleep(std::time::Duration::from_secs(200)).await;
23//!     println!("Racer #2 is ending");
24//!
25//!     Ok(())
26//! });
27//! race_track.add_racer("Racer #3", async move {
28//!     println!("Racer #3 is starting");
29//!     sleep(std::time::Duration::from_secs(700)).await;
30//!     println!("Racer #3 is ending");
31//!
32//!     Ok(())
33//! });
34//!
35//! race_track.run().await;
36//! let rankings = race_track.rankings();
37//!
38//! println!("Rankings: {:?}", rankings);
39//!     
40//! assert_eq!(rankings[0].name, "Racer #1");
41//! assert_eq!(rankings[1].name, "Racer #2");
42//! assert_eq!(rankings[2].name, "Racer #3");
43//! assert_eq!(rankings[2].disqualified, true);
44//!
45//! ```
46
47use std::{collections::BTreeMap, pin::Pin, sync::Arc, time::Duration};
48
49use eyre::Report;
50use futures::{stream::FuturesUnordered, Future, StreamExt};
51
52/// A wrapper around a `Future`.
53struct Racer<T> {
54    name: String,
55    fut: Pin<Box<dyn Future<Output = Result<T, Report>> + Send>>,
56}
57
58/// The rank and disqualification status of an executed Racer.
59#[derive(Debug, Clone)]
60pub struct RaceResult<T> {
61    pub name: String,
62    pub duration: Duration,
63    pub disqualified: bool,
64    pub error: Option<Arc<Report>>,
65    pub value: Option<T>,
66}
67
68/// Race a set of `Future`s and rank them.
69pub struct RaceTrack<T> {
70    timeout: Duration,
71    racers: Vec<Racer<T>>,
72    rankings: BTreeMap<usize, RaceResult<T>>,
73}
74
75impl<T> Default for RaceTrack<T> {
76    fn default() -> Self {
77        Self {
78            timeout: Duration::from_secs(5),
79            rankings: BTreeMap::new(),
80            racers: Vec::new(),
81        }
82    }
83}
84
85impl<T> RaceTrack<T>
86where
87    T: std::fmt::Debug + Clone + Send + 'static,
88{
89    /// Create a new `RaceTrack` with specified timeout.
90    pub fn disqualify_after(timeout: Duration) -> Self {
91        Self {
92            timeout,
93            ..Default::default()
94        }
95    }
96
97    /// Add a `Future` to the `RaceTrack`.
98    pub fn add_racer<F>(&mut self, name: impl Into<String>, fut: F)
99    where
100        F: Future<Output = Result<T, Report>> + Send + 'static,
101    {
102        self.racers.push(Racer {
103            name: name.into(),
104            fut: Box::pin(fut),
105        });
106    }
107
108    /// Run the `RaceTrack` and collect the rankings.
109    pub async fn run(&mut self) {
110        let racers = std::mem::take(&mut self.racers);
111
112        // Clear the rankings from the previous run.
113        self.rankings.clear();
114
115        // Run the racers.
116        let mut tasks = FuturesUnordered::new();
117        for racer in racers {
118            let name = racer.name.clone();
119            let timeout = self.timeout;
120            tasks.push(tokio::spawn(async move {
121                // Start the racer and time it.
122                let start = std::time::Instant::now();
123                let res = tokio::time::timeout(timeout, racer.fut).await;
124                let duration = start.elapsed();
125
126                // Disqualify the racer if it timed out.
127                let mut disqualified = res.is_err();
128
129                // Do some magic on the timeout error and then split the result!
130                let result = res.unwrap_or_else(|_| Err(eyre::eyre!("Racer timed out")));
131                let (value, error) = match result {
132                    Ok(value) => (Some(value), None),
133                    Err(error) => {
134                        // Disqualify the racer if it errored.
135                        disqualified = true;
136                        (None, Some(error))
137                    },
138                };
139
140                RaceResult {
141                    name,
142                    duration,
143                    disqualified,
144                    error: error.map(Arc::new),
145                    value,
146                }
147            }));
148        }
149
150        // RaceResult em up!
151        let mut i = 0;
152        while let Some(result) = tasks.next().await {
153            self.rankings.insert(i, result.unwrap());
154            i += 1;
155        }
156    }
157
158    /// Get the rankings for the previous `RaceTrack` run.
159    pub fn rankings(&self) -> Vec<RaceResult<T>> {
160        self.rankings.values().cloned().collect()
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use std::time::Duration;
167
168    use tokio::time::sleep;
169    use eyre::eyre;
170
171    use super::*;
172
173    #[tokio::test]
174    async fn on_your_mark_get_set_go() {
175        let mut race_track = RaceTrack::disqualify_after(Duration::from_millis(20));
176
177        race_track.add_racer("Racer #1", async move {
178            sleep(Duration::from_millis(5)).await;
179            Ok(1)
180        });
181        race_track.add_racer("Racer #2", async move {
182            sleep(Duration::from_millis(10)).await;
183            Ok(2)
184        });
185        race_track.add_racer("Racer #3", async move {
186            sleep(Duration::from_millis(15)).await;
187            Ok(3)
188        });
189        race_track.add_racer("Racer #4", async move {
190            sleep(Duration::from_millis(25)).await;
191            Ok(4)
192        });
193        race_track.add_racer("Racer #5", async move {
194            Err(eyre!("Racer #5 failed!"))
195        });
196
197        race_track.run().await;
198        let rankings = race_track.rankings();
199
200        println!("{:#?}", rankings);
201
202        assert_eq!(rankings[0].value, None);
203        assert_eq!(rankings[0].name, "Racer #5");
204        assert_eq!(rankings[0].disqualified, true);
205        assert_eq!(
206            rankings[0].error.as_ref().unwrap().to_string(),
207            "Racer #5 failed!"
208        );
209        assert_eq!(rankings[0].value, None);
210
211        assert_eq!(rankings[1].name, "Racer #1");
212        assert_eq!(rankings[1].disqualified, false);
213        assert_eq!(rankings[1].value, Some(1));
214
215        assert_eq!(rankings[2].name, "Racer #2");
216        assert_eq!(rankings[2].disqualified, false);
217        assert_eq!(rankings[2].value, Some(2));
218
219        assert_eq!(rankings[3].name, "Racer #3");
220        assert_eq!(rankings[3].disqualified, false);
221        assert_eq!(rankings[3].value, Some(3));
222
223        assert_eq!(rankings[4].name, "Racer #4");
224        assert_eq!(rankings[4].disqualified, true);
225        assert_eq!(
226            rankings[4].error.as_ref().unwrap().to_string(),
227            "Racer timed out"
228        );
229        assert_eq!(rankings[4].value, None);
230    }
231}