Skip to main content

sp1_core_machine/utils/
concurrency.rs

1use std::{
2    collections::{hash_map::Entry, HashMap},
3    future::Future,
4    pin::Pin,
5    sync::{Arc, Condvar, Mutex},
6    task::{Context, Poll, Waker},
7};
8
9/// A turn-based synchronization primitive.
10pub struct TurnBasedSync {
11    pub current_turn: Mutex<usize>,
12    pub cv: Condvar,
13}
14
15impl TurnBasedSync {
16    /// Creates a new [TurnBasedSync].
17    pub fn new() -> Self {
18        TurnBasedSync { current_turn: Mutex::new(0), cv: Condvar::new() }
19    }
20
21    /// Waits for the current turn to be equal to the given turn.
22    pub fn wait_for_turn(&self, my_turn: usize) {
23        let mut turn = self.current_turn.lock().unwrap();
24        while *turn != my_turn {
25            turn = self.cv.wait(turn).unwrap();
26        }
27    }
28
29    /// Gets the current turn
30    ///
31    /// # WARNING
32    /// Note that relying on this value can cause race conditions.
33    pub fn current_turn(&self) -> usize {
34        *self.current_turn.lock().unwrap()
35    }
36
37    /// Advances the current turn.
38    pub fn advance_turn(&self) {
39        let mut turn: std::sync::MutexGuard<'_, usize> = self.current_turn.lock().unwrap();
40        *turn += 1;
41        self.cv.notify_all();
42    }
43}
44
45pub struct AsyncTurn {
46    inner: Arc<Mutex<AsyncTurnInner>>,
47}
48
49impl AsyncTurn {
50    pub fn new() -> Self {
51        // Note: We could define some preconditions here and use unsafe + atomic counter here, but
52        // this works fine for now...
53        Self {
54            inner: Arc::new(Mutex::new(AsyncTurnInner { current_turn: 0, wakers: HashMap::new() })),
55        }
56    }
57
58    pub fn wait_for_turn(&self, my_turn: usize) -> AsyncTurnFuture {
59        AsyncTurnFuture { inner: self.inner.clone(), my_turn }
60    }
61}
62
63/// The inner state of the [AsyncTurn] primitive.
64pub struct AsyncTurnInner {
65    current_turn: usize,
66    wakers: HashMap<usize, Waker>,
67}
68
69impl Clone for AsyncTurn {
70    fn clone(&self) -> Self {
71        Self { inner: Arc::clone(&self.inner) }
72    }
73}
74
75#[must_use = "Futures do nothing unless `await`ed"]
76pub struct AsyncTurnFuture {
77    inner: Arc<Mutex<AsyncTurnInner>>,
78    my_turn: usize,
79}
80
81impl Future for AsyncTurnFuture {
82    type Output = AsyncTurnGuard;
83
84    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
85        let this = self.get_mut();
86        let mut inner = this.inner.lock().expect("AsyncTurnFuture poisoned");
87
88        // Fast path: if the current turn is equal to the given turn, we can return immediately.
89        if inner.current_turn == this.my_turn {
90            return Poll::Ready(AsyncTurnGuard { inner: this.inner.clone() });
91        }
92
93        // Normal path: We need to wait for `this.my_turn` to be reached.
94        match inner.wakers.entry(this.my_turn) {
95            Entry::Vacant(v) => {
96                v.insert(cx.waker().clone());
97            }
98            Entry::Occupied(mut o) => {
99                let _ = o.insert(cx.waker().clone());
100            }
101        }
102
103        // Ensure our turn has not passed.
104        if inner.current_turn > this.my_turn {
105            #[cold]
106            #[inline(never)]
107            fn panic_turn_passed(turn: usize) -> ! {
108                panic!("AsyncTurnFuture: turn {turn} has already passed");
109            }
110
111            panic_turn_passed(this.my_turn);
112        } else {
113            Poll::Pending
114        }
115    }
116}
117
118pub struct AsyncTurnGuard {
119    inner: Arc<Mutex<AsyncTurnInner>>,
120}
121
122impl Drop for AsyncTurnGuard {
123    fn drop(&mut self) {
124        let mut lock = self.inner.lock().expect("AsyncTurnGuard poisoned");
125
126        // Advance the turn.
127        lock.current_turn += 1;
128
129        // Notify the waker.
130        if let Some(waker) = lock.wakers.get(&lock.current_turn) {
131            waker.wake_by_ref();
132        }
133    }
134}