1use core::panic;
2use std::collections::VecDeque;
3use std::task::{Poll, Waker};
4
5use futures::{channel::oneshot, pin_mut};
6
7use super::*;
8
9extern "Rust" {
10 fn wapo_main_future() -> Pin<Box<dyn Future<Output = ()>>>;
11}
12
13type TaskFuture = Pin<Box<dyn Future<Output = ()>>>;
14type Tasks = Vec<Option<TaskFuture>>;
17
18thread_local! {
19 static CURRENT_TASK: std::cell::Cell<i32> = Default::default();
21 static TASKS: RefCell<Tasks> = {
23 std::panic::set_hook(Box::new(|info| {
24 log::error!("{}", info);
25 }));
26 RefCell::new(vec![Some(unsafe { wapo_main_future() })])
27 };
28 static SPAWNING_TASKS: RefCell<Vec<TaskFuture>> = RefCell::new(vec![]);
31 static WAKERS: RefCell<Vec<Option<Waker>>> = RefCell::new(vec![]);
40
41 static CACHED_WAKER_IDS: RefCell<VecDeque<i32>> = RefCell::new(VecDeque::new());
43}
44
45const MAX_CACHE_WAKER_IDS: usize = 1024;
46
47fn maybe_cache_waker_id(waker_id: i32) {
48 CACHED_WAKER_IDS.with(|ids| {
49 let mut ids = ids.borrow_mut();
50 if ids.len() < MAX_CACHE_WAKER_IDS {
51 ids.push_back(waker_id);
52 }
53 });
54}
55
56fn get_free_waker_id() -> Option<i32> {
57 CACHED_WAKER_IDS.with(|ids| {
58 let mut ids = ids.borrow_mut();
59 if ids.len() == 0 {
60 WAKERS.with(|wakers| {
61 for (id, waker_ref) in wakers.borrow().iter().enumerate() {
62 if waker_ref.is_none() {
63 ids.push_back(id as i32);
64 if ids.len() >= MAX_CACHE_WAKER_IDS {
65 break;
66 }
67 }
68 }
69 });
70 }
71 ids.pop_front()
72 })
73}
74
75pub fn intern_waker(waker: task::Waker) -> i32 {
76 const MAX_N_WAKERS: usize = (i32::MAX / 2) as usize;
77 let free_slot = get_free_waker_id();
78 WAKERS.with(|wakers| {
79 let mut wakers = wakers.borrow_mut();
80 if let Some(id) = free_slot {
81 wakers[id as usize] = Some(waker);
82 return id;
83 }
84 if wakers.len() < MAX_N_WAKERS {
85 wakers.push(Some(waker));
86 wakers.len() as i32 - 1
87 } else {
88 panic!("Too many wakers");
89 }
90 })
91}
92
93fn wake_waker(waker_id: i32) {
94 WAKERS.with(|wakers| {
95 let wakers = wakers.borrow();
96 if let Some(Some(waker)) = wakers.get(waker_id as usize) {
97 waker.wake_by_ref();
98 }
99 });
100}
101
102fn drop_waker(waker_id: i32) {
103 WAKERS.with(|wakers| {
104 let mut wakers = wakers.borrow_mut();
105 if let Some(waker) = wakers.get_mut(waker_id as usize) {
106 *waker = None;
107 maybe_cache_waker_id(waker_id);
108 }
109 });
110}
111
112pub struct JoinHandle<T>(oneshot::Receiver<T>);
113
114#[derive(Clone, Copy, PartialEq, Eq, Debug)]
116pub struct Canceled;
117
118impl<T> Future for JoinHandle<T> {
119 type Output = Result<T, Canceled>;
120
121 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
122 let this = self.get_mut();
123 let inner = &mut this.0;
124 pin_mut!(inner);
125 match inner.poll(cx) {
126 Poll::Ready(x) => Poll::Ready(x.map_err(|_: oneshot::Canceled| Canceled)),
127 Poll::Pending => Poll::Pending,
128 }
129 }
130}
131
132pub fn spawn<T: 'static>(fut: impl Future<Output = T> + 'static) -> JoinHandle<T> {
133 let (tx, rx) = oneshot::channel();
134 SPAWNING_TASKS.with(move |tasks| {
135 (*tasks).borrow_mut().push(Box::pin(async move {
136 let _ = tx.send(fut.await);
137 }))
138 });
139 JoinHandle(rx)
140}
141
142fn start_task(tasks: &mut Tasks, task: TaskFuture) {
143 const MAX_N_TASKS: usize = (i32::MAX / 2) as _;
144
145 for (task_id, task_ref) in tasks.iter_mut().enumerate().skip(1) {
146 if task_ref.is_none() {
147 *task_ref = Some(task);
148 ocall::mark_task_ready(task_id as _).expect("Mark task ready failed");
149 return;
150 }
151 }
152
153 if tasks.len() < MAX_N_TASKS {
154 let task_id = tasks.len();
155 tasks.push(Some(task));
156 ocall::mark_task_ready(task_id as _).expect("Mark task ready failed");
157 return;
158 }
159
160 panic!("Spawn task failed, Max number of tasks reached");
161}
162
163fn start_spawned_tasks(tasks: &mut Tasks) {
164 SPAWNING_TASKS.with(|spowned_tasks| {
165 for task in spowned_tasks.borrow_mut().drain(..) {
166 start_task(tasks, task);
167 }
168 })
169}
170
171pub(crate) fn current_task() -> i32 {
172 CURRENT_TASK.with(|id| id.get())
173}
174
175fn set_current_task(task_id: i32) {
176 CURRENT_TASK.with(|id| id.set(task_id))
177}
178
179fn poll_with_guest_context<F>(f: Pin<&mut F>) -> task::Poll<F::Output>
180where
181 F: Future + ?Sized,
182{
183 fn raw_waker(task_id: i32) -> task::RawWaker {
184 task::RawWaker::new(
185 task_id as _,
186 &task::RawWakerVTable::new(
187 |data| raw_waker(data as _),
188 |data| {
189 let task_id = data as _;
190 ocall::mark_task_ready(task_id).expect("Mark task ready failed");
191 },
192 |data| {
193 let task_id = data as _;
194 ocall::mark_task_ready(task_id).expect("Mark task ready failed");
195 },
196 |_| (),
197 ),
198 )
199 }
200 let waker = unsafe { task::Waker::from_raw(raw_waker(current_task())) };
201 let mut context = task::Context::from_waker(&waker);
202 f.poll(&mut context)
203}
204
205pub fn wapo_poll() -> i32 {
206 use task::Poll::*;
207
208 fn poll() -> task::Poll<()> {
209 {
210 for waker_id in ocall::awake_wakers().expect("Failed to get awake wakers") {
211 if waker_id >= 0 {
212 wake_waker(waker_id);
213 } else {
214 drop_waker(-1 - waker_id);
215 }
216 }
217
218 let task_id = match ocall::next_ready_task() {
219 Ok(id) => id as usize,
220 Err(OcallError::NotFound) => return task::Poll::Pending,
221 Err(err) => panic!("Error occured: {:?}", err),
222 };
223 let exited = TASKS.with(|tasks| -> Option<bool> {
224 let exited = {
225 let mut tasks = tasks.borrow_mut();
226 let task = tasks.get_mut(task_id)?.as_mut()?;
227 set_current_task(task_id as _);
228 match poll_with_guest_context(task.as_mut()) {
229 Pending => (),
230 Ready(()) => {
231 tasks[task_id] = None;
232 }
233 }
234 tasks[0].is_none()
235 };
236 if !exited {
237 start_spawned_tasks(&mut tasks.borrow_mut());
238 }
239 Some(exited)
240 });
241 if let Some(true) = exited {
242 return task::Poll::Ready(());
243 }
244 }
245 task::Poll::Pending
246 }
247 match poll() {
248 Ready(()) => 1,
249 Pending => 0,
250 }
251}