sidevm_env/
tasks.rs

1use core::panic;
2use std::task::{Poll, Waker};
3
4use futures::{channel::oneshot, pin_mut};
5
6use super::*;
7
8extern "Rust" {
9    fn sidevm_main_future() -> Pin<Box<dyn Future<Output = ()>>>;
10}
11
12type TaskFuture = Pin<Box<dyn Future<Output = ()>>>;
13/// When a task exited, the task future will be dropped, and it's space in this vector would be
14/// set to None in order to reuse it's id in the future.
15type Tasks = Vec<Option<TaskFuture>>;
16
17thread_local! {
18    /// The id of the current polling task. Would be passed to each ocall.
19    static CURRENT_TASK: std::cell::Cell<i32>  = Default::default();
20    /// All async tasks in the sidevm guest.
21    static TASKS: RefCell<Tasks> = {
22        std::panic::set_hook(Box::new(|info| {
23            let _ = log::error!("{}", info);
24        }));
25        RefCell::new(vec![Some(unsafe { sidevm_main_future() })])
26    };
27    /// New spawned tasks are pushed to this queue. Since tasks are always spawned from inside a
28    /// running task which borrowing the TASKS, it can not be immediately pushed to the TASKS.
29    static SPAWNING_TASKS: RefCell<Vec<TaskFuture>> = RefCell::new(vec![]);
30    /// Wakers might being referenced by the sidevm host runtime.
31    ///
32    /// When a ocall polling some resource, we can not pass the waker to the host runtime,
33    /// because they are in different memory space and in different rust code compilation space.
34    /// So when we poll into the host runtime, we cache the waker in WAKERS, and pass the index,
35    /// which called waker_id, into the host runtime. And then before each guest polling, the
36    /// guest runtime ask the host runtime to see which waker is awaken or dropped in the host
37    /// runtime to deside to awake or drop the waker from this Vec.
38    static WAKERS: RefCell<Vec<Option<Waker>>> = RefCell::new(vec![]);
39}
40
41pub fn intern_waker(waker: task::Waker) -> i32 {
42    const MAX_N_WAKERS: usize = (i32::MAX / 2) as usize;
43    WAKERS.with(|wakers| {
44        let mut wakers = wakers.borrow_mut();
45        for (id, waker_ref) in wakers.iter_mut().enumerate() {
46            if waker_ref.is_none() {
47                *waker_ref = Some(waker);
48                return id as i32;
49            }
50        }
51        if wakers.len() < MAX_N_WAKERS {
52            wakers.push(Some(waker));
53            wakers.len() as i32 - 1
54        } else {
55            panic!("Too many wakers");
56        }
57    })
58}
59
60fn wake_waker(waker_id: i32) {
61    WAKERS.with(|wakers| {
62        let wakers = wakers.borrow();
63        if let Some(Some(waker)) = wakers.get(waker_id as usize) {
64            waker.wake_by_ref();
65        }
66    });
67}
68
69fn drop_waker(waker_id: i32) {
70    WAKERS.with(|wakers| {
71        let mut wakers = wakers.borrow_mut();
72        if let Some(waker) = wakers.get_mut(waker_id as usize) {
73            *waker = None;
74        }
75    });
76}
77
78pub struct JoinHandle<T>(oneshot::Receiver<T>);
79
80/// The task is dropped.
81#[derive(Clone, Copy, PartialEq, Eq, Debug)]
82pub struct Canceled;
83
84impl<T> Future for JoinHandle<T> {
85    type Output = Result<T, Canceled>;
86
87    fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
88        let this = self.get_mut();
89        let inner = &mut this.0;
90        pin_mut!(inner);
91        match inner.poll(cx) {
92            Poll::Ready(x) => Poll::Ready(x.map_err(|_: oneshot::Canceled| Canceled)),
93            Poll::Pending => Poll::Pending,
94        }
95    }
96}
97
98pub fn spawn<T: 'static>(fut: impl Future<Output = T> + 'static) -> JoinHandle<T> {
99    let (tx, rx) = oneshot::channel();
100    SPAWNING_TASKS.with(move |tasks| {
101        (*tasks).borrow_mut().push(Box::pin(async move {
102            let _ = tx.send(fut.await);
103        }))
104    });
105    JoinHandle(rx)
106}
107
108fn start_task(tasks: &mut Tasks, task: TaskFuture) {
109    const MAX_N_TASKS: usize = (i32::MAX / 2) as _;
110
111    for (task_id, task_ref) in tasks.iter_mut().enumerate().skip(1) {
112        if task_ref.is_none() {
113            *task_ref = Some(task);
114            ocall::mark_task_ready(task_id as _).expect("Mark task ready failed");
115            return;
116        }
117    }
118
119    if tasks.len() < MAX_N_TASKS {
120        let task_id = tasks.len();
121        tasks.push(Some(task));
122        ocall::mark_task_ready(task_id as _).expect("Mark task ready failed");
123        return;
124    }
125
126    panic!("Spawn task failed, Max number of tasks reached");
127}
128
129fn start_spawned_tasks(tasks: &mut Tasks) {
130    SPAWNING_TASKS.with(|spowned_tasks| {
131        for task in spowned_tasks.borrow_mut().drain(..) {
132            start_task(tasks, task);
133        }
134    })
135}
136
137pub(crate) fn current_task() -> i32 {
138    CURRENT_TASK.with(|id| id.get())
139}
140
141fn set_current_task(task_id: i32) {
142    CURRENT_TASK.with(|id| id.set(task_id))
143}
144
145fn poll_with_guest_context<F>(f: Pin<&mut F>) -> task::Poll<F::Output>
146where
147    F: Future + ?Sized,
148{
149    fn raw_waker(task_id: i32) -> task::RawWaker {
150        task::RawWaker::new(
151            task_id as _,
152            &task::RawWakerVTable::new(
153                |data| raw_waker(data as _),
154                |data| {
155                    let task_id = data as _;
156                    ocall::mark_task_ready(task_id).expect("Mark task ready failed");
157                },
158                |data| {
159                    let task_id = data as _;
160                    ocall::mark_task_ready(task_id).expect("Mark task ready failed");
161                },
162                |_| (),
163            ),
164        )
165    }
166    let waker = unsafe { task::Waker::from_raw(raw_waker(current_task())) };
167    let mut context = task::Context::from_waker(&waker);
168    f.poll(&mut context)
169}
170
171#[no_mangle]
172extern "C" fn sidevm_poll() -> i32 {
173    use task::Poll::*;
174
175    fn poll() -> task::Poll<()> {
176        {
177            for waker_id in ocall::awake_wakers().expect("Failed to get awake wakers") {
178                if waker_id >= 0 {
179                    wake_waker(waker_id);
180                } else {
181                    drop_waker(-1 - waker_id);
182                }
183            }
184
185            let task_id = match ocall::next_ready_task() {
186                Ok(id) => id as usize,
187                Err(OcallError::NotFound) => return task::Poll::Pending,
188                Err(err) => panic!("Error occured: {:?}", err),
189            };
190            let exited = TASKS.with(|tasks| -> Option<bool> {
191                let exited = {
192                    let mut tasks = tasks.borrow_mut();
193                    let task = tasks.get_mut(task_id)?.as_mut()?;
194                    set_current_task(task_id as _);
195                    match poll_with_guest_context(task.as_mut()) {
196                        Pending => (),
197                        Ready(()) => {
198                            tasks[task_id] = None;
199                        }
200                    }
201                    tasks[0].is_none()
202                };
203                if !exited {
204                    start_spawned_tasks(&mut *tasks.borrow_mut());
205                }
206                Some(exited)
207            });
208            if let Some(true) = exited {
209                return task::Poll::Ready(());
210            }
211        }
212        task::Poll::Pending
213    }
214    match poll() {
215        Ready(()) => 1,
216        Pending => 0,
217    }
218}