schedwalk/
lib.rs

1#![deny(missing_docs)]
2#![doc = include_str!("../README.md")]
3
4use std::{
5    env,
6    future::Future,
7    panic::{catch_unwind, resume_unwind},
8    pin::Pin,
9    sync::atomic::{AtomicBool, Ordering},
10    task,
11};
12
13use async_task::Task;
14use context::Context;
15
16mod context;
17mod schedule;
18
19const SCHEDULE_ENV: &str = "SCHEDULE";
20
21/// A spawned future that can be awaited.
22///
23/// This is the equivalent of Tokio's `tokio::task::JoinHandle`.
24///
25/// A `JoinHandle` detaches when the handle is dropped. The underlying task will continue to run
26/// unless [`JoinHandle::abort`] was called.
27pub struct JoinHandle<T> {
28    task: Option<Task<T>>,
29    abort: AtomicBool,
30}
31
32/// An error when joining a future via a [`JoinHandle`].
33///
34/// Currently, as panics are not handled by schedwalk, an error can only occur if
35/// [`JoinHandle::abort`] is called, but this may change in the future.
36pub struct JoinError();
37
38impl JoinError {
39    /// Whether this error is due to cancellation.
40    pub fn is_cancelled(&self) -> bool {
41        true
42    }
43}
44
45impl<T> JoinHandle<T> {
46    fn new(task: Task<T>) -> Self {
47        JoinHandle {
48            task: Some(task),
49            abort: AtomicBool::new(false),
50        }
51    }
52}
53
54impl<T> JoinHandle<T> {
55    /// Aborts the underlying task.
56    ///
57    /// If the task is not complete, this will cause it to complete with a [`JoinError`].
58    /// Otherwise, it will not have an effect.
59    pub fn abort(&self) {
60        self.abort.store(true, Ordering::Relaxed)
61    }
62}
63
64impl<T> Drop for JoinHandle<T> {
65    fn drop(&mut self) {
66        if let Some(task) = self.task.take() {
67            task.detach()
68        }
69    }
70}
71
72impl<T> Future for JoinHandle<T> {
73    type Output = Result<T, JoinError>;
74
75    #[inline]
76    fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context) -> task::Poll<Self::Output> {
77        let JoinHandle { task, abort } = &mut *self;
78
79        match task {
80            Some(task) if task.is_finished() || !*abort.get_mut() => {
81                Pin::new(task).poll(cx).map(Ok)
82            }
83            _ => {
84                task.take();
85                task::Poll::Ready(Err(JoinError()))
86            }
87        }
88    }
89}
90
91/// Spawns a new asynchronous task and returns a [`JoinHandle`] to it.
92///
93/// This must be called within a context created by [`for_all_schedules`]. Failure to do so will
94/// throw an exception.
95pub fn spawn<T>(future: T) -> JoinHandle<T::Output>
96where
97    T: Future + Send + 'static,
98    T::Output: Send + 'static,
99{
100    JoinHandle::new(spawn_task(future))
101}
102
103fn spawn_task<T>(future: T) -> Task<T::Output>
104where
105    T: Future + Send + 'static,
106    T::Output: Send + 'static,
107{
108    let (runnable, task) = async_task::spawn(future, Context::schedule);
109    runnable.schedule();
110    task
111}
112
113/// Executes the given future multiple times, each time under a new polling schedule, eventually
114/// executing it under all possible polling schedules.
115///
116/// This can be used to deterministically test for what would otherwise be asynchronous race
117/// conditions.
118///
119/// If a panic occurs when executing a schedule, it will be written to standard error. For ease of
120/// debugging, rerunning the test with `SCHEDULE` set to this string will execute that particular
121/// failing schedule only.
122///
123/// This assumes *determinism*; the spawned futures and the order they are polled in must not depend
124/// on anything external to the function such as network or thread locals. This function will panic
125/// in case non-determinism is detected, but it cannot do so reliably in all cases.
126#[inline]
127pub fn for_all_schedules<T>(mut f: impl FnMut() -> T)
128where
129    T: Future<Output = ()> + 'static + Send,
130{
131    fn walk(spawn: &mut dyn FnMut() -> Task<()>) {
132        match env::var(SCHEDULE_ENV) {
133            Ok(schedule) => walk_schedule(&schedule, spawn),
134            Err(env::VarError::NotPresent) => walk_exhaustive(&mut Vec::new(), spawn),
135            Err(env::VarError::NotUnicode(_)) => {
136                panic!(
137                    "found a schedule in {}, but it was not valid unicode",
138                    SCHEDULE_ENV
139                )
140            }
141        }
142    }
143
144    // Defer to `dyn` as quickly as possible to minimize per-test compilation overhead
145    walk(&mut || spawn_task(f()))
146}
147
148fn walk_schedule(schedule: &str, spawn: &mut dyn FnMut() -> Task<()>) {
149    let mut schedule = schedule::Decoder::new(schedule);
150    Context::init(|context| {
151        let task = spawn();
152        loop {
153            let runnable = {
154                let mut runnables = context.runnables();
155                let choices = runnables.len();
156
157                if choices == 0 {
158                    assert!(task.is_finished(), "deadlock");
159                    break;
160                } else {
161                    runnables.swap_remove(schedule.read(choices))
162                }
163            };
164
165            runnable.run();
166        }
167    })
168}
169
170fn walk_exhaustive(schedule: &mut Vec<(usize, usize)>, spawn: &mut dyn FnMut() -> Task<()>) {
171    fn advance(schedule: &mut Vec<(usize, usize)>) -> bool {
172        loop {
173            if let Some((choice, len)) = schedule.pop() {
174                let new_choice = choice + 1;
175                if new_choice < len {
176                    schedule.push((new_choice, len));
177                    return true;
178                }
179            } else {
180                return false;
181            }
182        }
183    }
184
185    Context::init(|context| 'schedules: loop {
186        let mut step = 0;
187        let task = spawn();
188
189        loop {
190            let runnable = {
191                let mut runnables = context.runnables();
192                let choices = runnables.len();
193
194                let choice = if step < schedule.len() {
195                    let (choice, existing_choices) = schedule[step];
196
197                    assert_eq!(
198                        choices,
199                        existing_choices,
200                        "nondeterminism: number of pollable futures ({}) did not equal number in previous executions ({})",
201                        choices,
202                        existing_choices,
203                    );
204
205                    choice
206                } else if choices == 0 {
207                    if task.is_finished() {
208                        if advance(schedule) {
209                            continue 'schedules;
210                        } else {
211                            break 'schedules;
212                        }
213                    } else {
214                        panic!(
215                            "deadlock in {}={}",
216                            SCHEDULE_ENV,
217                            schedule::encode(&schedule)
218                        );
219                    }
220                } else {
221                    schedule.push((0, choices));
222                    0
223                };
224
225                runnables.swap_remove(choice)
226            };
227
228            step += 1;
229            let result = catch_unwind(|| runnable.run());
230
231            if let Err(panic) = result {
232                eprintln!("panic in {}={}", SCHEDULE_ENV, schedule::encode(&schedule));
233                resume_unwind(panic)
234            }
235        }
236    })
237}
238
239#[cfg(test)]
240mod tests {
241    use std::{
242        any::Any,
243        fmt::Debug,
244        panic::{panic_any, AssertUnwindSafe},
245    };
246
247    use futures::{
248        channel::{mpsc, oneshot},
249        future::{pending, select, Either},
250    };
251
252    use super::*;
253
254    fn assert_panics<T>(f: impl FnOnce() -> T) -> Box<dyn Any + Send>
255    where
256        T: Debug,
257    {
258        catch_unwind(AssertUnwindSafe(f)).expect_err("expected panic")
259    }
260
261    fn assert_finds_panicking_schedule<T>(mut f: impl FnMut() -> T) -> String
262    where
263        T: Future<Output = ()> + 'static + Send,
264    {
265        let mut schedule = Vec::new();
266
267        assert_panics(|| walk_exhaustive(&mut schedule, &mut || spawn_task(f())))
268            .downcast::<PanicMarker>()
269            .expect("expected test panic");
270
271        let encoded_schedule = schedule::encode(&schedule);
272
273        assert_panics(|| walk_schedule(&encoded_schedule, &mut || spawn_task(f())))
274            .downcast::<PanicMarker>()
275            .expect("expected test panic");
276
277        encoded_schedule
278    }
279
280    struct PanicMarker;
281
282    fn panic_target() {
283        panic_any(PanicMarker);
284    }
285
286    #[test]
287    fn basic() {
288        assert_finds_panicking_schedule(|| async { panic_target() });
289    }
290
291    #[test]
292    fn spawn_panic() {
293        assert_finds_panicking_schedule(|| async {
294            spawn(async { panic_target() });
295        });
296    }
297
298    #[test]
299    fn example() {
300        let f = || async {
301            let (sender, mut receiver) = mpsc::unbounded::<usize>();
302
303            spawn(async move {
304                sender.unbounded_send(1).unwrap();
305                sender.unbounded_send(3).unwrap();
306                sender.unbounded_send(2).unwrap();
307            });
308
309            spawn(async move {
310                let mut sum = 0;
311                let mut count = 0;
312                while let Some(num) = receiver.try_next().unwrap() {
313                    sum += num;
314                    count += 1;
315                }
316
317                println!("average is {}", sum / count)
318            });
319        };
320
321        let mut schedule = Vec::new();
322        assert_panics(|| walk_exhaustive(&mut schedule, &mut || spawn_task(f())));
323        assert_eq!(schedule::encode(&schedule), "01")
324    }
325
326    #[test]
327    fn channels() {
328        assert_finds_panicking_schedule(|| async {
329            let (sender_a, receiver_a) = oneshot::channel();
330            let (sender_b, receiver_b) = oneshot::channel();
331
332            spawn(async {
333                drop(sender_a.send(()));
334            });
335
336            spawn(async {
337                drop(sender_b.send(()));
338            });
339
340            match select(receiver_a, receiver_b).await {
341                Either::Left(_) => (),
342                Either::Right(_) => panic_target(),
343            }
344        });
345    }
346
347    #[test]
348    fn walk_basic() {
349        for_all_schedules(|| async { () });
350    }
351
352    #[test]
353    fn walk_channels() {
354        for_all_schedules(|| async {
355            let (sender_a, receiver_a) = oneshot::channel();
356            let (sender_b, receiver_b) = oneshot::channel();
357
358            spawn(async {
359                sender_a.send(()).unwrap();
360            });
361
362            spawn(async {
363                sender_b.send(()).unwrap();
364            });
365
366            receiver_a.await.unwrap();
367            receiver_b.await.unwrap();
368        });
369    }
370
371    #[test]
372    #[should_panic]
373    fn walk_deadlock() {
374        for_all_schedules(|| pending::<()>())
375    }
376
377    #[test]
378    #[should_panic]
379    fn channel_deadlock() {
380        for_all_schedules(|| async {
381            let (sender, receiver) = oneshot::channel::<()>();
382
383            receiver.await.unwrap();
384            drop(sender)
385        });
386    }
387}