wapo_env/
tasks.rs

1use core::panic;
2use std::collections::VecDeque;
3use std::task::{Poll, Waker};
4
5use futures::{channel::oneshot, pin_mut};
6
7use super::*;
8
9extern "Rust" {
10    fn wapo_main_future() -> Pin<Box<dyn Future<Output = ()>>>;
11}
12
13type TaskFuture = Pin<Box<dyn Future<Output = ()>>>;
14/// When a task exited, the task future will be dropped, and it's space in this vector would be
15/// set to None in order to reuse it's id in the future.
16type Tasks = Vec<Option<TaskFuture>>;
17
18thread_local! {
19    /// The id of the current polling task. Would be passed to each ocall.
20    static CURRENT_TASK: std::cell::Cell<i32>  = Default::default();
21    /// All async tasks in the wapo guest.
22    static TASKS: RefCell<Tasks> = {
23        std::panic::set_hook(Box::new(|info| {
24            log::error!("{}", info);
25        }));
26        RefCell::new(vec![Some(unsafe { wapo_main_future() })])
27    };
28    /// New spawned tasks are pushed to this queue. Since tasks are always spawned from inside a
29    /// running task which borrowing the TASKS, it can not be immediately pushed to the TASKS.
30    static SPAWNING_TASKS: RefCell<Vec<TaskFuture>> = RefCell::new(vec![]);
31    /// Wakers might being referenced by the wapo host runtime.
32    ///
33    /// When a ocall polling some resource, we can not pass the waker to the host runtime,
34    /// because they are in different memory space and in different rust code compilation space.
35    /// So when we poll into the host runtime, we cache the waker in WAKERS, and pass the index,
36    /// which called waker_id, into the host runtime. And then before each guest polling, the
37    /// guest runtime ask the host runtime to see which waker is awaken or dropped in the host
38    /// runtime to deside to awake or drop the waker from this Vec.
39    static WAKERS: RefCell<Vec<Option<Waker>>> = RefCell::new(vec![]);
40
41    /// The released waker ids that are cached and can be reused.
42    static CACHED_WAKER_IDS: RefCell<VecDeque<i32>> = RefCell::new(VecDeque::new());
43}
44
45const MAX_CACHE_WAKER_IDS: usize = 1024;
46
47fn maybe_cache_waker_id(waker_id: i32) {
48    CACHED_WAKER_IDS.with(|ids| {
49        let mut ids = ids.borrow_mut();
50        if ids.len() < MAX_CACHE_WAKER_IDS {
51            ids.push_back(waker_id);
52        }
53    });
54}
55
56fn get_free_waker_id() -> Option<i32> {
57    CACHED_WAKER_IDS.with(|ids| {
58        let mut ids = ids.borrow_mut();
59        if ids.len() == 0 {
60            WAKERS.with(|wakers| {
61                for (id, waker_ref) in wakers.borrow().iter().enumerate() {
62                    if waker_ref.is_none() {
63                        ids.push_back(id as i32);
64                        if ids.len() >= MAX_CACHE_WAKER_IDS {
65                            break;
66                        }
67                    }
68                }
69            });
70        }
71        ids.pop_front()
72    })
73}
74
75pub fn intern_waker(waker: task::Waker) -> i32 {
76    const MAX_N_WAKERS: usize = (i32::MAX / 2) as usize;
77    let free_slot = get_free_waker_id();
78    WAKERS.with(|wakers| {
79        let mut wakers = wakers.borrow_mut();
80        if let Some(id) = free_slot {
81            wakers[id as usize] = Some(waker);
82            return id;
83        }
84        if wakers.len() < MAX_N_WAKERS {
85            wakers.push(Some(waker));
86            wakers.len() as i32 - 1
87        } else {
88            panic!("Too many wakers");
89        }
90    })
91}
92
93fn wake_waker(waker_id: i32) {
94    WAKERS.with(|wakers| {
95        let wakers = wakers.borrow();
96        if let Some(Some(waker)) = wakers.get(waker_id as usize) {
97            waker.wake_by_ref();
98        }
99    });
100}
101
102fn drop_waker(waker_id: i32) {
103    WAKERS.with(|wakers| {
104        let mut wakers = wakers.borrow_mut();
105        if let Some(waker) = wakers.get_mut(waker_id as usize) {
106            *waker = None;
107            maybe_cache_waker_id(waker_id);
108        }
109    });
110}
111
112pub struct JoinHandle<T>(oneshot::Receiver<T>);
113
114/// The task is dropped.
115#[derive(Clone, Copy, PartialEq, Eq, Debug)]
116pub struct Canceled;
117
118impl<T> Future for JoinHandle<T> {
119    type Output = Result<T, Canceled>;
120
121    fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
122        let this = self.get_mut();
123        let inner = &mut this.0;
124        pin_mut!(inner);
125        match inner.poll(cx) {
126            Poll::Ready(x) => Poll::Ready(x.map_err(|_: oneshot::Canceled| Canceled)),
127            Poll::Pending => Poll::Pending,
128        }
129    }
130}
131
132pub fn spawn<T: 'static>(fut: impl Future<Output = T> + 'static) -> JoinHandle<T> {
133    let (tx, rx) = oneshot::channel();
134    SPAWNING_TASKS.with(move |tasks| {
135        (*tasks).borrow_mut().push(Box::pin(async move {
136            let _ = tx.send(fut.await);
137        }))
138    });
139    JoinHandle(rx)
140}
141
142fn start_task(tasks: &mut Tasks, task: TaskFuture) {
143    const MAX_N_TASKS: usize = (i32::MAX / 2) as _;
144
145    for (task_id, task_ref) in tasks.iter_mut().enumerate().skip(1) {
146        if task_ref.is_none() {
147            *task_ref = Some(task);
148            ocall::mark_task_ready(task_id as _).expect("Mark task ready failed");
149            return;
150        }
151    }
152
153    if tasks.len() < MAX_N_TASKS {
154        let task_id = tasks.len();
155        tasks.push(Some(task));
156        ocall::mark_task_ready(task_id as _).expect("Mark task ready failed");
157        return;
158    }
159
160    panic!("Spawn task failed, Max number of tasks reached");
161}
162
163fn start_spawned_tasks(tasks: &mut Tasks) {
164    SPAWNING_TASKS.with(|spowned_tasks| {
165        for task in spowned_tasks.borrow_mut().drain(..) {
166            start_task(tasks, task);
167        }
168    })
169}
170
171pub(crate) fn current_task() -> i32 {
172    CURRENT_TASK.with(|id| id.get())
173}
174
175fn set_current_task(task_id: i32) {
176    CURRENT_TASK.with(|id| id.set(task_id))
177}
178
179fn poll_with_guest_context<F>(f: Pin<&mut F>) -> task::Poll<F::Output>
180where
181    F: Future + ?Sized,
182{
183    fn raw_waker(task_id: i32) -> task::RawWaker {
184        task::RawWaker::new(
185            task_id as _,
186            &task::RawWakerVTable::new(
187                |data| raw_waker(data as _),
188                |data| {
189                    let task_id = data as _;
190                    ocall::mark_task_ready(task_id).expect("Mark task ready failed");
191                },
192                |data| {
193                    let task_id = data as _;
194                    ocall::mark_task_ready(task_id).expect("Mark task ready failed");
195                },
196                |_| (),
197            ),
198        )
199    }
200    let waker = unsafe { task::Waker::from_raw(raw_waker(current_task())) };
201    let mut context = task::Context::from_waker(&waker);
202    f.poll(&mut context)
203}
204
205pub fn wapo_poll() -> i32 {
206    use task::Poll::*;
207
208    fn poll() -> task::Poll<()> {
209        {
210            for waker_id in ocall::awake_wakers().expect("Failed to get awake wakers") {
211                if waker_id >= 0 {
212                    wake_waker(waker_id);
213                } else {
214                    drop_waker(-1 - waker_id);
215                }
216            }
217
218            let task_id = match ocall::next_ready_task() {
219                Ok(id) => id as usize,
220                Err(OcallError::NotFound) => return task::Poll::Pending,
221                Err(err) => panic!("Error occured: {:?}", err),
222            };
223            let exited = TASKS.with(|tasks| -> Option<bool> {
224                let exited = {
225                    let mut tasks = tasks.borrow_mut();
226                    let task = tasks.get_mut(task_id)?.as_mut()?;
227                    set_current_task(task_id as _);
228                    match poll_with_guest_context(task.as_mut()) {
229                        Pending => (),
230                        Ready(()) => {
231                            tasks[task_id] = None;
232                        }
233                    }
234                    tasks[0].is_none()
235                };
236                if !exited {
237                    start_spawned_tasks(&mut tasks.borrow_mut());
238                }
239                Some(exited)
240            });
241            if let Some(true) = exited {
242                return task::Poll::Ready(());
243            }
244        }
245        task::Poll::Pending
246    }
247    match poll() {
248        Ready(()) => 1,
249        Pending => 0,
250    }
251}