sp1_core_machine/utils/
concurrency.rs1use std::{
2 collections::{hash_map::Entry, HashMap},
3 future::Future,
4 pin::Pin,
5 sync::{Arc, Condvar, Mutex},
6 task::{Context, Poll, Waker},
7};
8
9pub struct TurnBasedSync {
11 pub current_turn: Mutex<usize>,
12 pub cv: Condvar,
13}
14
15impl TurnBasedSync {
16 pub fn new() -> Self {
18 TurnBasedSync { current_turn: Mutex::new(0), cv: Condvar::new() }
19 }
20
21 pub fn wait_for_turn(&self, my_turn: usize) {
23 let mut turn = self.current_turn.lock().unwrap();
24 while *turn != my_turn {
25 turn = self.cv.wait(turn).unwrap();
26 }
27 }
28
29 pub fn current_turn(&self) -> usize {
34 *self.current_turn.lock().unwrap()
35 }
36
37 pub fn advance_turn(&self) {
39 let mut turn: std::sync::MutexGuard<'_, usize> = self.current_turn.lock().unwrap();
40 *turn += 1;
41 self.cv.notify_all();
42 }
43}
44
45pub struct AsyncTurn {
46 inner: Arc<Mutex<AsyncTurnInner>>,
47}
48
49impl AsyncTurn {
50 pub fn new() -> Self {
51 Self {
54 inner: Arc::new(Mutex::new(AsyncTurnInner { current_turn: 0, wakers: HashMap::new() })),
55 }
56 }
57
58 pub fn wait_for_turn(&self, my_turn: usize) -> AsyncTurnFuture {
59 AsyncTurnFuture { inner: self.inner.clone(), my_turn }
60 }
61}
62
63pub struct AsyncTurnInner {
65 current_turn: usize,
66 wakers: HashMap<usize, Waker>,
67}
68
69impl Clone for AsyncTurn {
70 fn clone(&self) -> Self {
71 Self { inner: Arc::clone(&self.inner) }
72 }
73}
74
75#[must_use = "Futures do nothing unless `await`ed"]
76pub struct AsyncTurnFuture {
77 inner: Arc<Mutex<AsyncTurnInner>>,
78 my_turn: usize,
79}
80
81impl Future for AsyncTurnFuture {
82 type Output = AsyncTurnGuard;
83
84 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
85 let this = self.get_mut();
86 let mut inner = this.inner.lock().expect("AsyncTurnFuture poisoned");
87
88 if inner.current_turn == this.my_turn {
90 return Poll::Ready(AsyncTurnGuard { inner: this.inner.clone() });
91 }
92
93 match inner.wakers.entry(this.my_turn) {
95 Entry::Vacant(v) => {
96 v.insert(cx.waker().clone());
97 }
98 Entry::Occupied(mut o) => {
99 let _ = o.insert(cx.waker().clone());
100 }
101 }
102
103 if inner.current_turn > this.my_turn {
105 #[cold]
106 #[inline(never)]
107 fn panic_turn_passed(turn: usize) -> ! {
108 panic!("AsyncTurnFuture: turn {turn} has already passed");
109 }
110
111 panic_turn_passed(this.my_turn);
112 } else {
113 Poll::Pending
114 }
115 }
116}
117
118pub struct AsyncTurnGuard {
119 inner: Arc<Mutex<AsyncTurnInner>>,
120}
121
122impl Drop for AsyncTurnGuard {
123 fn drop(&mut self) {
124 let mut lock = self.inner.lock().expect("AsyncTurnGuard poisoned");
125
126 lock.current_turn += 1;
128
129 if let Some(waker) = lock.wakers.get(&lock.current_turn) {
131 waker.wake_by_ref();
132 }
133 }
134}