Skip to main content

tempest_rt/
task.rs

1//! Async task primitives: spawning concurrent tasks and yielding to the scheduler.
2
3use std::{future::poll_fn, pin::Pin, task::Poll};
4
5use nonmax::NonMaxU32;
6use slab::Slab;
7
8use crate::{
9    context::{MAX_TASK_ID, TaskId, current_tasks, current_wake_sets},
10    sync::oneshot,
11};
12
13pub(crate) type Tasks = Slab<Pin<Box<dyn Future<Output = ()>>>>;
14
15/// Error returned by [`JoinHandle`] when the task's result was not collected before being dropped.
16#[derive(Debug, PartialEq, Eq)]
17pub struct Cancelled;
18
19/// Handle to a spawned task. Awaiting it returns the task's output, or [`Cancelled`] if the
20/// handle was dropped before the task completed.
21pub struct JoinHandle<T> {
22    rx: Option<oneshot::Receiver<T>>,
23}
24
25impl<T> Future for JoinHandle<T> {
26    type Output = Result<T, Cancelled>;
27
28    fn poll(
29        mut self: std::pin::Pin<&mut Self>,
30        cx: &mut std::task::Context<'_>,
31    ) -> std::task::Poll<Self::Output> {
32        let rx = self
33            .rx
34            .as_mut()
35            .expect("JoinHandle polled after completion");
36        match rx.poll_recv(cx) {
37            Poll::Ready(result) => {
38                self.rx = None;
39                Poll::Ready(result.map_err(|_| Cancelled))
40            }
41            Poll::Pending => Poll::Pending,
42        }
43    }
44}
45
46/// Spawns `fut` as a concurrent task, returning a [`JoinHandle`] to collect its result.
47pub fn spawn<T: 'static>(fut: impl Future<Output = T> + 'static) -> JoinHandle<T> {
48    let (tx, rx) = oneshot::channel();
49    let handle = JoinHandle { rx: Some(rx) };
50
51    let wrapper = async move {
52        // we can ignore this error, since tasks do not have to be joined
53        _ = tx.send(fut.await);
54    };
55
56    // SAFETY: we do not hold on to the references outside of this function
57    let (tasks, wake_sets) = unsafe { (current_tasks(), current_wake_sets()) };
58    let index = tasks.insert(Box::pin(wrapper));
59    assert!(index <= MAX_TASK_ID as usize);
60    // SAFETY: index < MAX_TASK_ID < u32::Max, so NonMaxU32 invariant holds
61    let task_id = TaskId::Task(unsafe { NonMaxU32::new_unchecked(index as u32) });
62
63    wake_sets.staging.insert(task_id);
64
65    handle
66}
67
68/// Yields control back to the runtime for one tick, allowing other tasks and I/O completions
69/// to be processed before this task resumes.
70pub async fn yield_now() {
71    let mut yielded = false;
72    poll_fn(|cx| {
73        if yielded {
74            Poll::Ready(())
75        } else {
76            yielded = true;
77            cx.waker().wake_by_ref();
78            Poll::Pending
79        }
80    })
81    .await
82}
83
84#[cfg(test)]
85mod tests {
86    use tempest_io::VirtualIo;
87
88    use crate::block_on;
89
90    use super::*;
91
92    #[test]
93    fn test_spawn_completes() {
94        block_on(VirtualIo::default(), async {
95            let handle = spawn(async { 42 });
96            assert_eq!(handle.await, Ok(42));
97        });
98    }
99
100    #[test]
101    fn test_spawn_cancelled() {
102        block_on(VirtualIo::default(), async {
103            let handle = spawn(async { 42 });
104            drop(handle);
105            // task still runs to completion, just result is discarded
106        });
107    }
108
109    #[test]
110    fn test_spawn_runs_concurrently() {
111        block_on(VirtualIo::default(), async {
112            let handle_a = spawn(async { 1 });
113            let handle_b = spawn(async { 2 });
114            assert_eq!(handle_a.await, Ok(1));
115            assert_eq!(handle_b.await, Ok(2));
116        });
117    }
118
119    #[test]
120    fn test_yield_now() {
121        block_on(VirtualIo::default(), yield_now());
122    }
123}