Skip to main content

traceforge/future/
mod.rs

1//! Shuttle's implementation of an async executor, roughly equivalent to [`futures::executor`].
2//!
3//! The [spawn] method spawns a new asynchronous task that the executor will run to completion. The
4//! [block_on] method blocks the current thread on the completion of a future.
5//!
6//! Copied over to the Must runtime to allow handling of async calls
7//!
8//! [`futures::executor`]: https://docs.rs/futures/0.3.30/futures/executor/index.html
9
10use crate::channel::{from_receiver, Builder, Receiver, Sender};
11use crate::loc::WakeMsg;
12use crate::msg::Message;
13use crate::runtime::execution::ExecutionState;
14use crate::runtime::task::TaskId;
15use crate::runtime::thread::{self, switch};
16use crate::thread::Thread;
17use crate::CommunicationModel::LocalOrder;
18use crate::TJoin;
19use std::error::Error;
20use std::fmt::{Display, Formatter};
21use std::future::Future;
22use std::pin::Pin;
23use std::result::Result;
24use std::task::{Context, Poll, Waker};
25
26// Not really unsafe, we're not doing any concurrency.
27// This is needed for `Waker::from`
28// unsafe impl Sync for Sender<()> {}
29
30// The value is irrelevant, we're using a Channel<()> as a waker.
31impl std::task::Wake for Sender<WakeMsg> {
32    fn wake(self: std::sync::Arc<Self>) {
33        self.send_msg(WakeMsg);
34    }
35}
36
37fn get_bidir_handles() -> (TwoWayCom, TwoWayCom) {
38    let (sender1, receiver1) = Builder::new().with_comm(LocalOrder).build();
39    let (sender2, receiver2) = Builder::new().with_comm(LocalOrder).build();
40    // *flip* them
41    (
42        TwoWayCom {
43            sender: sender1,
44            receiver: receiver2,
45        },
46        TwoWayCom {
47            sender: sender2,
48            receiver: receiver1,
49        },
50    )
51}
52
53/// Spawn a new async task that the executor will run to completion.
54pub fn spawn<T, F>(fut: F) -> JoinHandle<T>
55where
56    F: Future<Output = T> + Send + 'static,
57    T: Message + 'static,
58{
59    spawn_with_attributes::<T, F>(false, None, fut)
60}
61
62/// Spawn a new async task that the executor will run to completion.
63pub fn spawn_with_attributes<T, F>(is_daemon: bool, name: Option<String>, fut: F) -> JoinHandle<T>
64where
65    F: Future<Output = T> + Send + 'static,
66    T: Message + 'static,
67{
68    thread::switch();
69
70    let stack_size = ExecutionState::with(|s| s.must.borrow().config.stack_size);
71    let (fut_handles, join_handles) = get_bidir_handles();
72
73    let task_id = ExecutionState::spawn_thread(
74        move || {
75            let (sender, fut_recv) = Builder::<WakeMsg>::new().build();
76            let fut_waker = Waker::from(std::sync::Arc::new(sender.clone()));
77
78            // Poll once in advance:
79            // tokio's spawn semantics: the future will start running immediately.
80            let mut fut = Box::pin(fut);
81            let mut res = fut.as_mut().poll(&mut Context::from_waker(&fut_waker));
82            let mut join_waker: Option<Waker> = None;
83            let res = loop {
84                match res {
85                    // We're done, call the waker
86                    Poll::Ready(res) => {
87                        if let Some(waker) = join_waker {
88                            waker.wake();
89                        }
90                        break Some(res);
91                    }
92                    Poll::Pending => { /* keep going */ }
93                }
94
95                // Wait for either the joiner or the future, to poll or inform us, respectively
96                let (msg, ind) = crate::select_val_block(&fut_handles.receiver, &fut_recv);
97
98                // Joiner polled us, inform them it's pending
99                if ind == 0 {
100                    match msg.as_any().downcast::<PollerMsg>() {
101                        Ok(waker) => match *waker {
102                            PollerMsg::Waker(waker) => {
103                                assert!(ind == 0);
104                                join_waker = Some(waker.clone());
105                                fut_handles.sender.send_msg(PollerMsg::Pending);
106                            }
107                            PollerMsg::Cancel => break None,
108                            _ => unreachable!(),
109                        },
110                        _ => unreachable!(),
111                    }
112                } else {
113                    // Futured informed us to poll again
114                    assert!(ind == 1);
115                    assert!(msg.as_any().downcast::<WakeMsg>().is_ok());
116                    res = fut.as_mut().poll(&mut Context::from_waker(&fut_waker));
117                }
118            };
119
120            let val = match res {
121                // We consumed the future to completion
122                Some(result) => {
123                    // Wait for a final request
124                    match fut_handles.receiver.recv_msg_block() {
125                        PollerMsg::Waker(_) => {
126                            // Inform them it's ready, they can try to Join
127                            fut_handles.sender.send_msg(PollerMsg::Ready);
128                            crate::Val::new(result)
129                        }
130                        PollerMsg::Cancel => {
131                            // Explicitly drop the future here to cancel it
132                            drop(fut);
133                            crate::Val::new(())
134                        }
135                        _ => unreachable!(),
136                    }
137                }
138                // We were cancelled
139                None => {
140                    // Explicitly drop the future here as well
141                    // (it may have cancellation code to run?)
142                    drop(fut);
143                    crate::Val::new(())
144                }
145            };
146
147            // Final Message, useful for impl of Drop on JoinHandle
148            fut_handles.sender.send_msg(PollerMsg::Done);
149
150            // Properly End the thread
151            ExecutionState::with(|state| {
152                let pos = state.next_pos();
153                state
154                    .must
155                    .borrow_mut()
156                    .handle_tend(crate::End::new(pos, val));
157                crate::must::Must::unstuck_joiners(state, pos.thread);
158            });
159        },
160        stack_size,
161        None,
162    );
163
164    let (thread_id, name) = ExecutionState::with(|state| {
165        let pos = state.next_pos();
166        let tid = state.must.borrow().next_thread_id(&pos);
167        let name = match name {
168            None => format!("<future-{}>", tid.to_number()),
169            Some(x) => x,
170        };
171        //let name = format!("<future-{}>", tid.to_number());
172        state.must.borrow_mut().handle_tcreate(
173            tid,
174            task_id,
175            None, /* asyncs do not have symmetric versions for symm reduction */
176            pos,
177            Some(name.clone()),
178            is_daemon,
179        );
180        (tid, Some(name))
181    });
182
183    let thread = Thread {
184        id: thread_id,
185        name,
186    };
187
188    thread::switch();
189
190    JoinHandle {
191        task_id,
192        thread,
193        com: join_handles,
194        _p: std::marker::PhantomData,
195    }
196}
197
198pub(crate) fn spawn_receive<T>(recv: &Receiver<T>) -> JoinHandle<T>
199where
200    T: Message + Clone + 'static,
201{
202    thread::switch();
203
204    let stack_size = ExecutionState::with(|s| s.must.borrow().config.stack_size);
205    let (fut_handles, join_handles) = get_bidir_handles();
206
207    let recv = recv.clone();
208    let task_id = ExecutionState::spawn_thread(
209        move || {
210            let mut join_waker: Option<Waker> = None;
211            let res = loop {
212                // Wait for either the joiner to poll us, or the receive to succeed.
213                let (msg, ind) = crate::select_val_block(&fut_handles.receiver, &recv);
214
215                // TODO: Use `cast!` to avoid all the `unreachable!` mess.
216                // Joiner polled us
217                if ind == 0 {
218                    match msg.as_any().downcast::<PollerMsg>() {
219                        Ok(msg) => {
220                            match *msg {
221                                PollerMsg::Waker(waker) => {
222                                    // Save the waker and inform them it's Pending
223                                    join_waker = Some(waker.clone());
224                                    fut_handles.sender.send_msg(PollerMsg::Pending);
225                                }
226                                // We're cancelled, without having consumed anything
227                                PollerMsg::Cancel => break None,
228                                _ => unreachable!(),
229                            }
230                        }
231                        _ => unreachable!(),
232                    }
233                } else {
234                    // We did the receive, call the waker.
235                    assert!(ind == 1);
236                    match msg.as_any().downcast::<T>() {
237                        Ok(result) => {
238                            if let Some(waker) = join_waker {
239                                waker.wake();
240                            }
241                            // We consumed the message
242                            break Some(*result);
243                        }
244                        _ => unreachable!(),
245                    }
246                }
247            };
248
249            // Select is done, either wait for the request or cancel the receive
250            let val = match res {
251                // We consumed the message
252                Some(result) => {
253                    // Wait once more for the poller
254                    match fut_handles.receiver.recv_msg_block() {
255                        // Inform them it's ready, they can try to Join
256                        PollerMsg::Waker(_) => {
257                            fut_handles.sender.send_msg(PollerMsg::Ready);
258                            crate::Val::new(result)
259                        }
260                        // Cancelled, let's put the message back
261                        PollerMsg::Cancel => {
262                            from_receiver(recv).send_msg(result);
263                            crate::Val::new(())
264                        }
265                        _ => unreachable!(),
266                    }
267                }
268                // We got cancelled without consuming the message: nothing to do
269                None => crate::Val::new(()),
270            };
271
272            // Final Message, useful for impl of Drop on JoinHandle
273            fut_handles.sender.send_msg(PollerMsg::Done);
274
275            // Properly End the thread
276            ExecutionState::with(|state| {
277                let pos = state.next_pos();
278                state
279                    .must
280                    .borrow_mut()
281                    .handle_tend(crate::End::new(pos, val));
282                crate::must::Must::unstuck_joiners(state, pos.thread);
283            });
284        },
285        stack_size,
286        None,
287    );
288
289    let (thread_id, name) = ExecutionState::with(|state| {
290        let pos = state.next_pos();
291        let tid = state.must.borrow().next_thread_id(&pos);
292        let name = format!("<async_recv-{}>", tid.to_number());
293        state.must.borrow_mut().handle_tcreate(
294            tid,
295            task_id,
296            None, /* asyncs do not have symmetric versions for symm reduction */
297            pos,
298            Some(name.clone()),
299            false, /* asyncs are not daemon threads */
300        );
301        (tid, Some(name))
302    });
303
304    let thread = Thread {
305        id: thread_id,
306        name,
307    };
308
309    thread::switch();
310
311    JoinHandle {
312        task_id,
313        thread,
314        com: join_handles,
315        _p: std::marker::PhantomData,
316    }
317}
318
319/// An owned permission to join on an async task (await its termination).
320#[derive(Debug)]
321pub struct JoinHandle<T> {
322    task_id: TaskId,
323    thread: Thread,
324    com: TwoWayCom,
325    _p: std::marker::PhantomData<T>,
326}
327
328#[derive(Clone, Debug)]
329pub enum PollerMsg {
330    Waker(Waker),
331    Pending,
332    Cancel,
333    Done,
334    Ready,
335}
336
337// PollerMsg must satisfy Message in order to be sent around with channels.
338// Message includes PartialEq, which the foreign type Waker does not implement.
339// This implementation acts as if all Wakers are equal.
340// The PartialEq on Message is (only) used for validating during replay,
341// which is needed for the nondeterminism detector.
342impl PartialEq for PollerMsg {
343    fn eq(&self, other: &Self) -> bool {
344        match (self, other) {
345            // Wakers are equal
346            (PollerMsg::Waker(_), PollerMsg::Waker(_)) => true,
347            (PollerMsg::Pending, PollerMsg::Pending) => true,
348            (PollerMsg::Cancel, PollerMsg::Cancel) => true,
349            (PollerMsg::Ready, PollerMsg::Ready) => true,
350            (PollerMsg::Done, PollerMsg::Done) => true,
351            _ => false,
352        }
353    }
354}
355
356// Helper for two-way communication between JoinHandle and Poller
357#[derive(Clone, Debug)]
358pub struct TwoWayCom {
359    pub sender: Sender<PollerMsg>,
360    pub receiver: Receiver<PollerMsg>,
361}
362
363impl<T> JoinHandle<T> {
364    /// Returns `true` if this task is finished, otherwise returns `false`.
365    ///
366    /// ## Panics
367    /// Panics if called outside of shuttle context, i.e. if there is no execution context.
368    pub fn is_finished(&self) -> bool {
369        ExecutionState::with(|state| {
370            let task = state.get(self.task_id);
371            task.finished()
372        })
373    }
374
375    /// Extracts a handle to the underlying thread.
376    pub fn thread(&self) -> &Thread {
377        &self.thread
378    }
379
380    // Useful for the Drop implementation
381    pub fn abort(&self) {
382        // If a Join Handle for a spawned task is never awaited, one could abort the task by calling `self.abort()`
383        // But this may mean that certain side effects (message sends or receives) of the task
384        // do not get to run.
385        // Must currently does not perform a backtrack on aborted tasks. So it is conservative to
386        // not abort tasks that are not awaited.
387
388        // FIXME: Is the above comment relevant?
389
390        // Tricky: it could be that the handle was dropped because there was nothing else to do,
391        // i.e. there is no ScheduledTask.
392        // In that case, we cannot send a message (it would panic).
393        // We only lose the case where:
394        // there is someone that was waiting to read from a receive that was
395        // consumed through the underlying future but was not used,
396        // and had we actually cancelled the future this would now be able to run.
397        // If there are no concurrent receives, this shouldn't happen.
398        // TODO: Detect and handle this scenario?
399        if ExecutionState::with(|state| state.is_running()) {
400            self.com.sender.send_msg(PollerMsg::Cancel);
401            // We wait for the Future to actually finish,
402            // whether it was actually cancelled or not.
403            // This is necessary so that we "synchronize" in a porf-sense
404            // and the Future's receives are no longer concurrent.
405            let ack = self.com.receiver.recv_msg_block();
406            assert!(matches!(ack, PollerMsg::Done));
407        }
408    }
409}
410
411// TODO: need to work out all the error cases here
412/// Task failed to execute to completion.
413#[derive(Debug)]
414pub enum JoinError {
415    /// Task was aborted
416    Cancelled,
417}
418
419impl Display for JoinError {
420    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
421        match self {
422            JoinError::Cancelled => write!(f, "task was cancelled"),
423        }
424    }
425}
426
427impl Error for JoinError {}
428
429impl<T> Drop for JoinHandle<T> {
430    fn drop(&mut self) {
431        // Skip during panic unwinding. raw_cancel() triggers a Cancel-panic
432        // inside generators to unwind their stacks; calling Traceforge API
433        // (send_msg, recv_msg_block) during that unwinding would cause a
434        // nested panic and abort.
435        if std::thread::panicking() {
436            return;
437        }
438        self.abort();
439    }
440}
441
442impl<T: Message + 'static> Future for JoinHandle<T> {
443    type Output = Result<T, JoinError>;
444
445    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
446        // Ask poller
447        self.com
448            .sender
449            .send_msg(PollerMsg::Waker(cx.waker().clone()));
450        match self.com.receiver.recv_msg_block() {
451            PollerMsg::Ready => {
452                loop {
453                    switch();
454                    let val = ExecutionState::with(|s| {
455                        let target_task_id = s.get(self.task_id).id();
456                        let target_id = s.must.borrow().to_thread_id(target_task_id);
457                        let pos = s.next_pos();
458                        s.must.borrow_mut().handle_tjoin(TJoin::new(pos, target_id))
459                    });
460
461                    // Wait for the thread to *actually* finish
462                    if let Some(val) = val {
463                        if val.is_pending() {
464                            ExecutionState::with(|s| s.current_mut().stuck());
465                        } else {
466                            return Poll::Ready(Ok(*val.as_any().downcast().unwrap()));
467                        }
468                    }
469
470                    ExecutionState::with(|s| s.prev_pos());
471                }
472            }
473            PollerMsg::Pending => Poll::Pending,
474            _ => unreachable!(),
475        }
476    }
477}
478
479/// Run a future to completion on the current thread.
480pub fn block_on<F: Future>(future: F) -> F::Output {
481    let mut future = Box::pin(future);
482    let (sender, receiver) = Builder::<WakeMsg>::new().build();
483    let waker = Waker::from(std::sync::Arc::new(sender.clone()));
484    let cx = &mut Context::from_waker(&waker);
485
486    thread::switch();
487
488    loop {
489        match future.as_mut().poll(cx) {
490            Poll::Ready(result) => {
491                break result;
492            }
493            Poll::Pending => {
494                receiver.recv_msg_block();
495            }
496        }
497
498        thread::switch();
499    }
500}
501
502#[cfg(test)]
503mod test {
504    use crate::{recv_msg_block, send_msg, thread, verify, Config};
505
506    use super::block_on;
507
508    #[test]
509    fn test_thread() {
510        verify(Config::builder().build(), || {
511            let parent_id = thread::current().id();
512
513            let fut = crate::future::spawn(async move {
514                let i: i32 = recv_msg_block();
515                send_msg(parent_id, i); // Echo back the same value.
516                3 // return 3.
517            });
518
519            let fut_tid = fut.thread().id();
520            println!("Future's thread id is {}", fut.thread().id());
521
522            send_msg(fut_tid, 4);
523            let echoed: i32 = recv_msg_block();
524            assert_eq!(echoed, 4);
525
526            let res = block_on(fut);
527            println!("Retrieved {:?} from future", &res);
528            assert_eq!(res.unwrap(), 3);
529        });
530    }
531}