1#![deny(missing_docs)]
2#![doc = include_str!("../README.md")]
3
4use std::{
5 env,
6 future::Future,
7 panic::{catch_unwind, resume_unwind},
8 pin::Pin,
9 sync::atomic::{AtomicBool, Ordering},
10 task,
11};
12
13use async_task::Task;
14use context::Context;
15
16mod context;
17mod schedule;
18
19const SCHEDULE_ENV: &str = "SCHEDULE";
20
21pub struct JoinHandle<T> {
28 task: Option<Task<T>>,
29 abort: AtomicBool,
30}
31
32pub struct JoinError();
37
38impl JoinError {
39 pub fn is_cancelled(&self) -> bool {
41 true
42 }
43}
44
45impl<T> JoinHandle<T> {
46 fn new(task: Task<T>) -> Self {
47 JoinHandle {
48 task: Some(task),
49 abort: AtomicBool::new(false),
50 }
51 }
52}
53
54impl<T> JoinHandle<T> {
55 pub fn abort(&self) {
60 self.abort.store(true, Ordering::Relaxed)
61 }
62}
63
64impl<T> Drop for JoinHandle<T> {
65 fn drop(&mut self) {
66 if let Some(task) = self.task.take() {
67 task.detach()
68 }
69 }
70}
71
72impl<T> Future for JoinHandle<T> {
73 type Output = Result<T, JoinError>;
74
75 #[inline]
76 fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context) -> task::Poll<Self::Output> {
77 let JoinHandle { task, abort } = &mut *self;
78
79 match task {
80 Some(task) if task.is_finished() || !*abort.get_mut() => {
81 Pin::new(task).poll(cx).map(Ok)
82 }
83 _ => {
84 task.take();
85 task::Poll::Ready(Err(JoinError()))
86 }
87 }
88 }
89}
90
91pub fn spawn<T>(future: T) -> JoinHandle<T::Output>
96where
97 T: Future + Send + 'static,
98 T::Output: Send + 'static,
99{
100 JoinHandle::new(spawn_task(future))
101}
102
103fn spawn_task<T>(future: T) -> Task<T::Output>
104where
105 T: Future + Send + 'static,
106 T::Output: Send + 'static,
107{
108 let (runnable, task) = async_task::spawn(future, Context::schedule);
109 runnable.schedule();
110 task
111}
112
113#[inline]
127pub fn for_all_schedules<T>(mut f: impl FnMut() -> T)
128where
129 T: Future<Output = ()> + 'static + Send,
130{
131 fn walk(spawn: &mut dyn FnMut() -> Task<()>) {
132 match env::var(SCHEDULE_ENV) {
133 Ok(schedule) => walk_schedule(&schedule, spawn),
134 Err(env::VarError::NotPresent) => walk_exhaustive(&mut Vec::new(), spawn),
135 Err(env::VarError::NotUnicode(_)) => {
136 panic!(
137 "found a schedule in {}, but it was not valid unicode",
138 SCHEDULE_ENV
139 )
140 }
141 }
142 }
143
144 walk(&mut || spawn_task(f()))
146}
147
148fn walk_schedule(schedule: &str, spawn: &mut dyn FnMut() -> Task<()>) {
149 let mut schedule = schedule::Decoder::new(schedule);
150 Context::init(|context| {
151 let task = spawn();
152 loop {
153 let runnable = {
154 let mut runnables = context.runnables();
155 let choices = runnables.len();
156
157 if choices == 0 {
158 assert!(task.is_finished(), "deadlock");
159 break;
160 } else {
161 runnables.swap_remove(schedule.read(choices))
162 }
163 };
164
165 runnable.run();
166 }
167 })
168}
169
170fn walk_exhaustive(schedule: &mut Vec<(usize, usize)>, spawn: &mut dyn FnMut() -> Task<()>) {
171 fn advance(schedule: &mut Vec<(usize, usize)>) -> bool {
172 loop {
173 if let Some((choice, len)) = schedule.pop() {
174 let new_choice = choice + 1;
175 if new_choice < len {
176 schedule.push((new_choice, len));
177 return true;
178 }
179 } else {
180 return false;
181 }
182 }
183 }
184
185 Context::init(|context| 'schedules: loop {
186 let mut step = 0;
187 let task = spawn();
188
189 loop {
190 let runnable = {
191 let mut runnables = context.runnables();
192 let choices = runnables.len();
193
194 let choice = if step < schedule.len() {
195 let (choice, existing_choices) = schedule[step];
196
197 assert_eq!(
198 choices,
199 existing_choices,
200 "nondeterminism: number of pollable futures ({}) did not equal number in previous executions ({})",
201 choices,
202 existing_choices,
203 );
204
205 choice
206 } else if choices == 0 {
207 if task.is_finished() {
208 if advance(schedule) {
209 continue 'schedules;
210 } else {
211 break 'schedules;
212 }
213 } else {
214 panic!(
215 "deadlock in {}={}",
216 SCHEDULE_ENV,
217 schedule::encode(&schedule)
218 );
219 }
220 } else {
221 schedule.push((0, choices));
222 0
223 };
224
225 runnables.swap_remove(choice)
226 };
227
228 step += 1;
229 let result = catch_unwind(|| runnable.run());
230
231 if let Err(panic) = result {
232 eprintln!("panic in {}={}", SCHEDULE_ENV, schedule::encode(&schedule));
233 resume_unwind(panic)
234 }
235 }
236 })
237}
238
239#[cfg(test)]
240mod tests {
241 use std::{
242 any::Any,
243 fmt::Debug,
244 panic::{panic_any, AssertUnwindSafe},
245 };
246
247 use futures::{
248 channel::{mpsc, oneshot},
249 future::{pending, select, Either},
250 };
251
252 use super::*;
253
254 fn assert_panics<T>(f: impl FnOnce() -> T) -> Box<dyn Any + Send>
255 where
256 T: Debug,
257 {
258 catch_unwind(AssertUnwindSafe(f)).expect_err("expected panic")
259 }
260
261 fn assert_finds_panicking_schedule<T>(mut f: impl FnMut() -> T) -> String
262 where
263 T: Future<Output = ()> + 'static + Send,
264 {
265 let mut schedule = Vec::new();
266
267 assert_panics(|| walk_exhaustive(&mut schedule, &mut || spawn_task(f())))
268 .downcast::<PanicMarker>()
269 .expect("expected test panic");
270
271 let encoded_schedule = schedule::encode(&schedule);
272
273 assert_panics(|| walk_schedule(&encoded_schedule, &mut || spawn_task(f())))
274 .downcast::<PanicMarker>()
275 .expect("expected test panic");
276
277 encoded_schedule
278 }
279
280 struct PanicMarker;
281
282 fn panic_target() {
283 panic_any(PanicMarker);
284 }
285
286 #[test]
287 fn basic() {
288 assert_finds_panicking_schedule(|| async { panic_target() });
289 }
290
291 #[test]
292 fn spawn_panic() {
293 assert_finds_panicking_schedule(|| async {
294 spawn(async { panic_target() });
295 });
296 }
297
298 #[test]
299 fn example() {
300 let f = || async {
301 let (sender, mut receiver) = mpsc::unbounded::<usize>();
302
303 spawn(async move {
304 sender.unbounded_send(1).unwrap();
305 sender.unbounded_send(3).unwrap();
306 sender.unbounded_send(2).unwrap();
307 });
308
309 spawn(async move {
310 let mut sum = 0;
311 let mut count = 0;
312 while let Some(num) = receiver.try_next().unwrap() {
313 sum += num;
314 count += 1;
315 }
316
317 println!("average is {}", sum / count)
318 });
319 };
320
321 let mut schedule = Vec::new();
322 assert_panics(|| walk_exhaustive(&mut schedule, &mut || spawn_task(f())));
323 assert_eq!(schedule::encode(&schedule), "01")
324 }
325
326 #[test]
327 fn channels() {
328 assert_finds_panicking_schedule(|| async {
329 let (sender_a, receiver_a) = oneshot::channel();
330 let (sender_b, receiver_b) = oneshot::channel();
331
332 spawn(async {
333 drop(sender_a.send(()));
334 });
335
336 spawn(async {
337 drop(sender_b.send(()));
338 });
339
340 match select(receiver_a, receiver_b).await {
341 Either::Left(_) => (),
342 Either::Right(_) => panic_target(),
343 }
344 });
345 }
346
347 #[test]
348 fn walk_basic() {
349 for_all_schedules(|| async { () });
350 }
351
352 #[test]
353 fn walk_channels() {
354 for_all_schedules(|| async {
355 let (sender_a, receiver_a) = oneshot::channel();
356 let (sender_b, receiver_b) = oneshot::channel();
357
358 spawn(async {
359 sender_a.send(()).unwrap();
360 });
361
362 spawn(async {
363 sender_b.send(()).unwrap();
364 });
365
366 receiver_a.await.unwrap();
367 receiver_b.await.unwrap();
368 });
369 }
370
371 #[test]
372 #[should_panic]
373 fn walk_deadlock() {
374 for_all_schedules(|| pending::<()>())
375 }
376
377 #[test]
378 #[should_panic]
379 fn channel_deadlock() {
380 for_all_schedules(|| async {
381 let (sender, receiver) = oneshot::channel::<()>();
382
383 receiver.await.unwrap();
384 drop(sender)
385 });
386 }
387}