1use crate::channel::{from_receiver, Builder, Receiver, Sender};
11use crate::loc::WakeMsg;
12use crate::msg::Message;
13use crate::runtime::execution::ExecutionState;
14use crate::runtime::task::TaskId;
15use crate::runtime::thread::{self, switch};
16use crate::thread::Thread;
17use crate::CommunicationModel::LocalOrder;
18use crate::TJoin;
19use std::error::Error;
20use std::fmt::{Display, Formatter};
21use std::future::Future;
22use std::pin::Pin;
23use std::result::Result;
24use std::task::{Context, Poll, Waker};
25
26impl std::task::Wake for Sender<WakeMsg> {
32 fn wake(self: std::sync::Arc<Self>) {
33 self.send_msg(WakeMsg);
34 }
35}
36
37fn get_bidir_handles() -> (TwoWayCom, TwoWayCom) {
38 let (sender1, receiver1) = Builder::new().with_comm(LocalOrder).build();
39 let (sender2, receiver2) = Builder::new().with_comm(LocalOrder).build();
40 (
42 TwoWayCom {
43 sender: sender1,
44 receiver: receiver2,
45 },
46 TwoWayCom {
47 sender: sender2,
48 receiver: receiver1,
49 },
50 )
51}
52
53pub fn spawn<T, F>(fut: F) -> JoinHandle<T>
55where
56 F: Future<Output = T> + Send + 'static,
57 T: Message + 'static,
58{
59 spawn_with_attributes::<T, F>(false, None, fut)
60}
61
62pub fn spawn_with_attributes<T, F>(is_daemon: bool, name: Option<String>, fut: F) -> JoinHandle<T>
64where
65 F: Future<Output = T> + Send + 'static,
66 T: Message + 'static,
67{
68 thread::switch();
69
70 let stack_size = ExecutionState::with(|s| s.must.borrow().config.stack_size);
71 let (fut_handles, join_handles) = get_bidir_handles();
72
73 let task_id = ExecutionState::spawn_thread(
74 move || {
75 let (sender, fut_recv) = Builder::<WakeMsg>::new().build();
76 let fut_waker = Waker::from(std::sync::Arc::new(sender.clone()));
77
78 let mut fut = Box::pin(fut);
81 let mut res = fut.as_mut().poll(&mut Context::from_waker(&fut_waker));
82 let mut join_waker: Option<Waker> = None;
83 let res = loop {
84 match res {
85 Poll::Ready(res) => {
87 if let Some(waker) = join_waker {
88 waker.wake();
89 }
90 break Some(res);
91 }
92 Poll::Pending => { }
93 }
94
95 let (msg, ind) = crate::select_val_block(&fut_handles.receiver, &fut_recv);
97
98 if ind == 0 {
100 match msg.as_any().downcast::<PollerMsg>() {
101 Ok(waker) => match *waker {
102 PollerMsg::Waker(waker) => {
103 assert!(ind == 0);
104 join_waker = Some(waker.clone());
105 fut_handles.sender.send_msg(PollerMsg::Pending);
106 }
107 PollerMsg::Cancel => break None,
108 _ => unreachable!(),
109 },
110 _ => unreachable!(),
111 }
112 } else {
113 assert!(ind == 1);
115 assert!(msg.as_any().downcast::<WakeMsg>().is_ok());
116 res = fut.as_mut().poll(&mut Context::from_waker(&fut_waker));
117 }
118 };
119
120 let val = match res {
121 Some(result) => {
123 match fut_handles.receiver.recv_msg_block() {
125 PollerMsg::Waker(_) => {
126 fut_handles.sender.send_msg(PollerMsg::Ready);
128 crate::Val::new(result)
129 }
130 PollerMsg::Cancel => {
131 drop(fut);
133 crate::Val::new(())
134 }
135 _ => unreachable!(),
136 }
137 }
138 None => {
140 drop(fut);
143 crate::Val::new(())
144 }
145 };
146
147 fut_handles.sender.send_msg(PollerMsg::Done);
149
150 ExecutionState::with(|state| {
152 let pos = state.next_pos();
153 state
154 .must
155 .borrow_mut()
156 .handle_tend(crate::End::new(pos, val));
157 crate::must::Must::unstuck_joiners(state, pos.thread);
158 });
159 },
160 stack_size,
161 None,
162 );
163
164 let (thread_id, name) = ExecutionState::with(|state| {
165 let pos = state.next_pos();
166 let tid = state.must.borrow().next_thread_id(&pos);
167 let name = match name {
168 None => format!("<future-{}>", tid.to_number()),
169 Some(x) => x,
170 };
171 state.must.borrow_mut().handle_tcreate(
173 tid,
174 task_id,
175 None, pos,
177 Some(name.clone()),
178 is_daemon,
179 );
180 (tid, Some(name))
181 });
182
183 let thread = Thread {
184 id: thread_id,
185 name,
186 };
187
188 thread::switch();
189
190 JoinHandle {
191 task_id,
192 thread,
193 com: join_handles,
194 _p: std::marker::PhantomData,
195 }
196}
197
198pub(crate) fn spawn_receive<T>(recv: &Receiver<T>) -> JoinHandle<T>
199where
200 T: Message + Clone + 'static,
201{
202 thread::switch();
203
204 let stack_size = ExecutionState::with(|s| s.must.borrow().config.stack_size);
205 let (fut_handles, join_handles) = get_bidir_handles();
206
207 let recv = recv.clone();
208 let task_id = ExecutionState::spawn_thread(
209 move || {
210 let mut join_waker: Option<Waker> = None;
211 let res = loop {
212 let (msg, ind) = crate::select_val_block(&fut_handles.receiver, &recv);
214
215 if ind == 0 {
218 match msg.as_any().downcast::<PollerMsg>() {
219 Ok(msg) => {
220 match *msg {
221 PollerMsg::Waker(waker) => {
222 join_waker = Some(waker.clone());
224 fut_handles.sender.send_msg(PollerMsg::Pending);
225 }
226 PollerMsg::Cancel => break None,
228 _ => unreachable!(),
229 }
230 }
231 _ => unreachable!(),
232 }
233 } else {
234 assert!(ind == 1);
236 match msg.as_any().downcast::<T>() {
237 Ok(result) => {
238 if let Some(waker) = join_waker {
239 waker.wake();
240 }
241 break Some(*result);
243 }
244 _ => unreachable!(),
245 }
246 }
247 };
248
249 let val = match res {
251 Some(result) => {
253 match fut_handles.receiver.recv_msg_block() {
255 PollerMsg::Waker(_) => {
257 fut_handles.sender.send_msg(PollerMsg::Ready);
258 crate::Val::new(result)
259 }
260 PollerMsg::Cancel => {
262 from_receiver(recv).send_msg(result);
263 crate::Val::new(())
264 }
265 _ => unreachable!(),
266 }
267 }
268 None => crate::Val::new(()),
270 };
271
272 fut_handles.sender.send_msg(PollerMsg::Done);
274
275 ExecutionState::with(|state| {
277 let pos = state.next_pos();
278 state
279 .must
280 .borrow_mut()
281 .handle_tend(crate::End::new(pos, val));
282 crate::must::Must::unstuck_joiners(state, pos.thread);
283 });
284 },
285 stack_size,
286 None,
287 );
288
289 let (thread_id, name) = ExecutionState::with(|state| {
290 let pos = state.next_pos();
291 let tid = state.must.borrow().next_thread_id(&pos);
292 let name = format!("<async_recv-{}>", tid.to_number());
293 state.must.borrow_mut().handle_tcreate(
294 tid,
295 task_id,
296 None, pos,
298 Some(name.clone()),
299 false, );
301 (tid, Some(name))
302 });
303
304 let thread = Thread {
305 id: thread_id,
306 name,
307 };
308
309 thread::switch();
310
311 JoinHandle {
312 task_id,
313 thread,
314 com: join_handles,
315 _p: std::marker::PhantomData,
316 }
317}
318
319#[derive(Debug)]
321pub struct JoinHandle<T> {
322 task_id: TaskId,
323 thread: Thread,
324 com: TwoWayCom,
325 _p: std::marker::PhantomData<T>,
326}
327
328#[derive(Clone, Debug)]
329pub enum PollerMsg {
330 Waker(Waker),
331 Pending,
332 Cancel,
333 Done,
334 Ready,
335}
336
337impl PartialEq for PollerMsg {
343 fn eq(&self, other: &Self) -> bool {
344 match (self, other) {
345 (PollerMsg::Waker(_), PollerMsg::Waker(_)) => true,
347 (PollerMsg::Pending, PollerMsg::Pending) => true,
348 (PollerMsg::Cancel, PollerMsg::Cancel) => true,
349 (PollerMsg::Ready, PollerMsg::Ready) => true,
350 (PollerMsg::Done, PollerMsg::Done) => true,
351 _ => false,
352 }
353 }
354}
355
356#[derive(Clone, Debug)]
358pub struct TwoWayCom {
359 pub sender: Sender<PollerMsg>,
360 pub receiver: Receiver<PollerMsg>,
361}
362
363impl<T> JoinHandle<T> {
364 pub fn is_finished(&self) -> bool {
369 ExecutionState::with(|state| {
370 let task = state.get(self.task_id);
371 task.finished()
372 })
373 }
374
375 pub fn thread(&self) -> &Thread {
377 &self.thread
378 }
379
380 pub fn abort(&self) {
382 if ExecutionState::with(|state| state.is_running()) {
400 self.com.sender.send_msg(PollerMsg::Cancel);
401 let ack = self.com.receiver.recv_msg_block();
406 assert!(matches!(ack, PollerMsg::Done));
407 }
408 }
409}
410
411#[derive(Debug)]
414pub enum JoinError {
415 Cancelled,
417}
418
419impl Display for JoinError {
420 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
421 match self {
422 JoinError::Cancelled => write!(f, "task was cancelled"),
423 }
424 }
425}
426
427impl Error for JoinError {}
428
429impl<T> Drop for JoinHandle<T> {
430 fn drop(&mut self) {
431 if std::thread::panicking() {
436 return;
437 }
438 self.abort();
439 }
440}
441
442impl<T: Message + 'static> Future for JoinHandle<T> {
443 type Output = Result<T, JoinError>;
444
445 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
446 self.com
448 .sender
449 .send_msg(PollerMsg::Waker(cx.waker().clone()));
450 match self.com.receiver.recv_msg_block() {
451 PollerMsg::Ready => {
452 loop {
453 switch();
454 let val = ExecutionState::with(|s| {
455 let target_task_id = s.get(self.task_id).id();
456 let target_id = s.must.borrow().to_thread_id(target_task_id);
457 let pos = s.next_pos();
458 s.must.borrow_mut().handle_tjoin(TJoin::new(pos, target_id))
459 });
460
461 if let Some(val) = val {
463 if val.is_pending() {
464 ExecutionState::with(|s| s.current_mut().stuck());
465 } else {
466 return Poll::Ready(Ok(*val.as_any().downcast().unwrap()));
467 }
468 }
469
470 ExecutionState::with(|s| s.prev_pos());
471 }
472 }
473 PollerMsg::Pending => Poll::Pending,
474 _ => unreachable!(),
475 }
476 }
477}
478
479pub fn block_on<F: Future>(future: F) -> F::Output {
481 let mut future = Box::pin(future);
482 let (sender, receiver) = Builder::<WakeMsg>::new().build();
483 let waker = Waker::from(std::sync::Arc::new(sender.clone()));
484 let cx = &mut Context::from_waker(&waker);
485
486 thread::switch();
487
488 loop {
489 match future.as_mut().poll(cx) {
490 Poll::Ready(result) => {
491 break result;
492 }
493 Poll::Pending => {
494 receiver.recv_msg_block();
495 }
496 }
497
498 thread::switch();
499 }
500}
501
502#[cfg(test)]
503mod test {
504 use crate::{recv_msg_block, send_msg, thread, verify, Config};
505
506 use super::block_on;
507
508 #[test]
509 fn test_thread() {
510 verify(Config::builder().build(), || {
511 let parent_id = thread::current().id();
512
513 let fut = crate::future::spawn(async move {
514 let i: i32 = recv_msg_block();
515 send_msg(parent_id, i); 3 });
518
519 let fut_tid = fut.thread().id();
520 println!("Future's thread id is {}", fut.thread().id());
521
522 send_msg(fut_tid, 4);
523 let echoed: i32 = recv_msg_block();
524 assert_eq!(echoed, 4);
525
526 let res = block_on(fut);
527 println!("Retrieved {:?} from future", &res);
528 assert_eq!(res.unwrap(), 3);
529 });
530 }
531}