1use core::panic;
2use std::task::{Poll, Waker};
3
4use futures::{channel::oneshot, pin_mut};
5
6use super::*;
7
8extern "Rust" {
9 fn sidevm_main_future() -> Pin<Box<dyn Future<Output = ()>>>;
10}
11
12type TaskFuture = Pin<Box<dyn Future<Output = ()>>>;
13type Tasks = Vec<Option<TaskFuture>>;
16
17thread_local! {
18 static CURRENT_TASK: std::cell::Cell<i32> = Default::default();
20 static TASKS: RefCell<Tasks> = {
22 std::panic::set_hook(Box::new(|info| {
23 let _ = log::error!("{}", info);
24 }));
25 RefCell::new(vec![Some(unsafe { sidevm_main_future() })])
26 };
27 static SPAWNING_TASKS: RefCell<Vec<TaskFuture>> = RefCell::new(vec![]);
30 static WAKERS: RefCell<Vec<Option<Waker>>> = RefCell::new(vec![]);
39}
40
41pub fn intern_waker(waker: task::Waker) -> i32 {
42 const MAX_N_WAKERS: usize = (i32::MAX / 2) as usize;
43 WAKERS.with(|wakers| {
44 let mut wakers = wakers.borrow_mut();
45 for (id, waker_ref) in wakers.iter_mut().enumerate() {
46 if waker_ref.is_none() {
47 *waker_ref = Some(waker);
48 return id as i32;
49 }
50 }
51 if wakers.len() < MAX_N_WAKERS {
52 wakers.push(Some(waker));
53 wakers.len() as i32 - 1
54 } else {
55 panic!("Too many wakers");
56 }
57 })
58}
59
60fn wake_waker(waker_id: i32) {
61 WAKERS.with(|wakers| {
62 let wakers = wakers.borrow();
63 if let Some(Some(waker)) = wakers.get(waker_id as usize) {
64 waker.wake_by_ref();
65 }
66 });
67}
68
69fn drop_waker(waker_id: i32) {
70 WAKERS.with(|wakers| {
71 let mut wakers = wakers.borrow_mut();
72 if let Some(waker) = wakers.get_mut(waker_id as usize) {
73 *waker = None;
74 }
75 });
76}
77
78pub struct JoinHandle<T>(oneshot::Receiver<T>);
79
80#[derive(Clone, Copy, PartialEq, Eq, Debug)]
82pub struct Canceled;
83
84impl<T> Future for JoinHandle<T> {
85 type Output = Result<T, Canceled>;
86
87 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
88 let this = self.get_mut();
89 let inner = &mut this.0;
90 pin_mut!(inner);
91 match inner.poll(cx) {
92 Poll::Ready(x) => Poll::Ready(x.map_err(|_: oneshot::Canceled| Canceled)),
93 Poll::Pending => Poll::Pending,
94 }
95 }
96}
97
98pub fn spawn<T: 'static>(fut: impl Future<Output = T> + 'static) -> JoinHandle<T> {
99 let (tx, rx) = oneshot::channel();
100 SPAWNING_TASKS.with(move |tasks| {
101 (*tasks).borrow_mut().push(Box::pin(async move {
102 let _ = tx.send(fut.await);
103 }))
104 });
105 JoinHandle(rx)
106}
107
108fn start_task(tasks: &mut Tasks, task: TaskFuture) {
109 const MAX_N_TASKS: usize = (i32::MAX / 2) as _;
110
111 for (task_id, task_ref) in tasks.iter_mut().enumerate().skip(1) {
112 if task_ref.is_none() {
113 *task_ref = Some(task);
114 ocall::mark_task_ready(task_id as _).expect("Mark task ready failed");
115 return;
116 }
117 }
118
119 if tasks.len() < MAX_N_TASKS {
120 let task_id = tasks.len();
121 tasks.push(Some(task));
122 ocall::mark_task_ready(task_id as _).expect("Mark task ready failed");
123 return;
124 }
125
126 panic!("Spawn task failed, Max number of tasks reached");
127}
128
129fn start_spawned_tasks(tasks: &mut Tasks) {
130 SPAWNING_TASKS.with(|spowned_tasks| {
131 for task in spowned_tasks.borrow_mut().drain(..) {
132 start_task(tasks, task);
133 }
134 })
135}
136
137pub(crate) fn current_task() -> i32 {
138 CURRENT_TASK.with(|id| id.get())
139}
140
141fn set_current_task(task_id: i32) {
142 CURRENT_TASK.with(|id| id.set(task_id))
143}
144
145fn poll_with_guest_context<F>(f: Pin<&mut F>) -> task::Poll<F::Output>
146where
147 F: Future + ?Sized,
148{
149 fn raw_waker(task_id: i32) -> task::RawWaker {
150 task::RawWaker::new(
151 task_id as _,
152 &task::RawWakerVTable::new(
153 |data| raw_waker(data as _),
154 |data| {
155 let task_id = data as _;
156 ocall::mark_task_ready(task_id).expect("Mark task ready failed");
157 },
158 |data| {
159 let task_id = data as _;
160 ocall::mark_task_ready(task_id).expect("Mark task ready failed");
161 },
162 |_| (),
163 ),
164 )
165 }
166 let waker = unsafe { task::Waker::from_raw(raw_waker(current_task())) };
167 let mut context = task::Context::from_waker(&waker);
168 f.poll(&mut context)
169}
170
171#[no_mangle]
172extern "C" fn sidevm_poll() -> i32 {
173 use task::Poll::*;
174
175 fn poll() -> task::Poll<()> {
176 {
177 for waker_id in ocall::awake_wakers().expect("Failed to get awake wakers") {
178 if waker_id >= 0 {
179 wake_waker(waker_id);
180 } else {
181 drop_waker(-1 - waker_id);
182 }
183 }
184
185 let task_id = match ocall::next_ready_task() {
186 Ok(id) => id as usize,
187 Err(OcallError::NotFound) => return task::Poll::Pending,
188 Err(err) => panic!("Error occured: {:?}", err),
189 };
190 let exited = TASKS.with(|tasks| -> Option<bool> {
191 let exited = {
192 let mut tasks = tasks.borrow_mut();
193 let task = tasks.get_mut(task_id)?.as_mut()?;
194 set_current_task(task_id as _);
195 match poll_with_guest_context(task.as_mut()) {
196 Pending => (),
197 Ready(()) => {
198 tasks[task_id] = None;
199 }
200 }
201 tasks[0].is_none()
202 };
203 if !exited {
204 start_spawned_tasks(&mut *tasks.borrow_mut());
205 }
206 Some(exited)
207 });
208 if let Some(true) = exited {
209 return task::Poll::Ready(());
210 }
211 }
212 task::Poll::Pending
213 }
214 match poll() {
215 Ready(()) => 1,
216 Pending => 0,
217 }
218}