1use crate::clock::Clock;
2use crate::futures::{Interval, SimSleep, SimYield};
3use crate::rng::Rng;
4use crate::scheduler::{Scheduler, SimWaker, TaskFuture, WakeQueue};
5use crate::trace::{TraceBuffer, TraceConfig, TraceEvent};
6use std::collections::{HashSet, VecDeque};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::{Arc, Mutex};
10use std::task::{Poll, Waker};
11use std::time::{Duration, Instant};
12
13pub use crate::scheduler::TaskHandle;
14
15pub struct Reactor {
16 scheduler: Arc<Mutex<Scheduler>>,
17 wake_queue: Arc<WakeQueue>,
18 aborted: Arc<Mutex<HashSet<usize>>>,
19 clock: Arc<Mutex<Clock>>,
20 rng: Arc<Mutex<Rng>>,
21 trace: Arc<Mutex<TraceBuffer>>,
22}
23
24impl Reactor {
25 pub fn new(seed: u64) -> Self {
26 Self::new_with_trace(seed, TraceConfig::default())
27 }
28
29 pub fn new_with_trace(seed: u64, trace_config: TraceConfig) -> Self {
30 Self {
31 scheduler: Arc::new(Mutex::new(Scheduler::new())),
32 wake_queue: WakeQueue::new(),
33 aborted: Arc::new(Mutex::new(HashSet::new())),
34 clock: Arc::new(Mutex::new(Clock::new())),
35 rng: Arc::new(Mutex::new(Rng::new(seed))),
36 trace: Arc::new(Mutex::new(TraceBuffer::new(trace_config))),
37 }
38 }
39
40 pub fn new_for_engine(&self) -> Self {
41 Self {
42 scheduler: Arc::clone(&self.scheduler),
43 wake_queue: Arc::clone(&self.wake_queue),
44 aborted: Arc::clone(&self.aborted),
45 clock: Arc::new(Mutex::new(Clock::new())),
46 rng: Arc::clone(&self.rng),
47 trace: Arc::clone(&self.trace),
48 }
49 }
50
51 pub fn clock_ref(&self) -> &Arc<Mutex<Clock>> {
52 &self.clock
53 }
54
55 pub fn spawn_local_task(
56 &self,
57 fut: Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
58 ) -> TaskHandle {
59 let mut sched = self.scheduler.lock().expect("scheduler");
60 let id = sched.next_id;
61 sched.next_id += 1;
62 sched.pending_new.push((id, fut));
63 TaskHandle {
64 task_id: id,
65 aborted: Arc::clone(&self.aborted),
66 }
67 }
68
69 pub fn yield_now(&self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + 'static>> {
70 Box::pin(SimYield { yielded: false })
71 }
72
73 pub fn sleep(
74 &self,
75 duration: Duration,
76 ) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + 'static>> {
77 Box::pin(SimSleep {
78 duration,
79 deadline: None,
80 clock: Arc::clone(&self.clock),
81 })
82 }
83
84 pub fn create_interval(&self, duration: Duration) -> Interval {
85 let next_deadline = self.clock.lock().expect("sim clock").now + duration;
86 Interval::new(duration, next_deadline, Arc::clone(&self.clock))
87 }
88
89 pub fn run_until_idle(&self) {
90 loop {
91 let next = {
92 let mut sched = self.scheduler.lock().expect("scheduler");
93 self.drain_pending_and_wakes(&mut sched);
94 self.shuffle_ready_queue(&mut sched);
95 let found = self.select_next_task(&mut sched);
96 if let Some((task_id, _)) = &found {
97 sched.decision_log.push(*task_id);
98 }
99 found
100 };
101
102 let (task_id, mut future) = match next {
103 Some(pair) => pair,
104 None => break,
105 };
106
107 let waker: Waker = Waker::from(Arc::new(SimWaker {
108 task_id,
109 queue: Arc::clone(&self.wake_queue),
110 }));
111 let mut cx = std::task::Context::from_waker(&waker);
112
113 match future.as_mut().poll(&mut cx) {
114 Poll::Ready(()) => {}
115 Poll::Pending => {
116 self.scheduler
117 .lock()
118 .expect("scheduler")
119 .tasks
120 .insert(task_id, future);
121 }
122 }
123
124 self.scheduler.lock().expect("scheduler").step_count += 1;
125 self.record_task_polled(task_id);
126 }
127 }
128
129 pub fn advance_time(&self, d: Duration) {
130 let wakers = self.clock.lock().expect("sim clock").advance(d);
131 let waker_count = wakers.len();
132 for w in wakers {
133 w.wake();
134 }
135 if waker_count > 0 {
136 let deadline_ms = self.clock.lock().expect("sim clock").now.as_millis() as u64;
137 self.trace
138 .lock()
139 .expect("trace")
140 .record(TraceEvent::TimerFired {
141 deadline_ms,
142 count: waker_count,
143 });
144 }
145 self.run_until_idle();
146 }
147
148 pub fn run_to_completion(&self) {
149 const MAX_ADVANCES: usize = 10_000;
150 for _ in 0..MAX_ADVANCES {
151 self.run_until_idle();
152 let wakers = self
153 .clock
154 .lock()
155 .expect("sim clock")
156 .advance_to_next_timer();
157 match wakers {
158 None => break,
159 Some(wakers) if wakers.is_empty() => break,
160 Some(wakers) => {
161 let waker_count = wakers.len();
162 let deadline_ms = self.clock.lock().expect("sim clock").now.as_millis() as u64;
163 self.trace
164 .lock()
165 .expect("trace")
166 .record(TraceEvent::TimerFired {
167 deadline_ms,
168 count: waker_count,
169 });
170 for w in wakers {
171 w.wake();
172 }
173 }
174 }
175 }
176 self.run_until_idle();
177 }
178
179 pub fn step_count(&self) -> u64 {
180 self.scheduler.lock().expect("scheduler").step_count
181 }
182
183 pub fn seed(&self) -> u64 {
184 self.rng.lock().expect("rng").seed()
185 }
186
187 pub fn virtual_elapsed(&self) -> Duration {
188 self.clock.lock().expect("sim clock").now
189 }
190
191 pub fn now(&self) -> Instant {
192 self.clock.lock().expect("sim clock").now_as_instant()
193 }
194
195 pub fn elapsed_since(&self, _start: Instant) -> u64 {
196 const MOCK_COMPUTE_NS: u64 = 1000;
197 let mut clock = self.clock.lock().expect("sim clock");
198 let wakers = clock.advance(Duration::from_nanos(MOCK_COMPUTE_NS));
199 drop(clock);
200 for w in wakers {
201 w.wake();
202 }
203 MOCK_COMPUTE_NS
204 }
205
206 pub fn next_u64(&self) -> u64 {
207 self.rng.lock().expect("rng").next_u64()
208 }
209
210 pub fn decision_log(&self) -> Vec<usize> {
211 self.scheduler
212 .lock()
213 .expect("scheduler")
214 .decision_log
215 .clone()
216 }
217
218 pub fn set_replay_decisions(&self, decisions: Vec<usize>) {
219 self.scheduler.lock().expect("scheduler").replay_decisions = VecDeque::from(decisions);
220 }
221
222 pub fn clear_decision_log(&self) {
223 self.scheduler
224 .lock()
225 .expect("scheduler")
226 .decision_log
227 .clear();
228 }
229
230 pub fn trace_snapshot(&self) -> TraceBuffer {
231 self.trace.lock().expect("trace").clone()
232 }
233
234 fn drain_pending_and_wakes(&self, sched: &mut Scheduler) {
235 let pending = std::mem::take(&mut sched.pending_new);
236 for (id, fut) in pending {
237 sched.tasks.insert(id, fut);
238 sched.ready.push_back(id);
239 }
240
241 let mut wq = self.wake_queue.0.lock().expect("wake queue");
242 sched.ready.extend(wq.drain(..));
243 }
244
245 fn shuffle_ready_queue(&self, sched: &mut Scheduler) {
246 if sched.ready.len() > 1 {
247 let mut rng = self.rng.lock().expect("rng");
248 let mut ready_vec: Vec<usize> = sched.ready.drain(..).collect();
249 rng.shuffle(&mut ready_vec);
250 sched.ready.extend(ready_vec);
251 }
252 }
253
254 fn select_next_task(&self, sched: &mut Scheduler) -> Option<(usize, TaskFuture)> {
255 let preferred_id = if !sched.replay_decisions.is_empty() {
256 sched.replay_decisions.pop_front()
257 } else {
258 None
259 };
260
261 let aborted = self.aborted.lock().expect("aborted set");
262 let mut found = None;
263
264 if let Some(pid) = preferred_id {
265 if let Some(pos) = sched.ready.iter().position(|&id| id == pid) {
266 let id = sched.ready.remove(pos).expect("ready task");
267 if !aborted.contains(&id) {
268 if let Some(fut) = sched.tasks.remove(&id) {
269 found = Some((id, fut));
270 }
271 }
272 }
273 }
274
275 if found.is_none() {
276 while let Some(task_id) = sched.ready.pop_front() {
277 if aborted.contains(&task_id) {
278 continue;
279 }
280 if let Some(fut) = sched.tasks.remove(&task_id) {
281 found = Some((task_id, fut));
282 break;
283 }
284 }
285 }
286
287 found
288 }
289
290 fn record_task_polled(&self, task_id: usize) {
291 let vt_ms = self.clock.lock().expect("sim clock").now.as_millis() as u64;
292 self.trace
293 .lock()
294 .expect("trace")
295 .record(TraceEvent::TaskPolled {
296 task_id,
297 virtual_time_ms: vt_ms,
298 });
299 }
300}
301
302impl Clone for Reactor {
303 fn clone(&self) -> Self {
304 Self {
305 scheduler: Arc::clone(&self.scheduler),
306 wake_queue: Arc::clone(&self.wake_queue),
307 aborted: Arc::clone(&self.aborted),
308 clock: Arc::clone(&self.clock),
309 rng: Arc::clone(&self.rng),
310 trace: Arc::clone(&self.trace),
311 }
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use super::Reactor;
318 use std::sync::{
319 Arc,
320 atomic::{AtomicUsize, Ordering},
321 };
322 use std::time::Duration;
323
324 #[test]
325 fn spawned_tasks_run_to_idle() {
326 let reactor = Reactor::new(7);
327 let count = Arc::new(AtomicUsize::new(0));
328 let count2 = Arc::clone(&count);
329
330 reactor.spawn_local_task(Box::pin(async move {
331 count2.fetch_add(1, Ordering::Relaxed);
332 }));
333 reactor.run_until_idle();
334
335 assert_eq!(count.load(Ordering::Relaxed), 1);
336 assert_eq!(reactor.step_count(), 1);
337 }
338
339 #[test]
340 fn timers_advance_and_wake_sleepers() {
341 let reactor = Reactor::new(9);
342 let count = Arc::new(AtomicUsize::new(0));
343 let count2 = Arc::clone(&count);
344 let reactor2 = reactor.clone();
345
346 reactor.spawn_local_task(Box::pin(async move {
347 reactor2.sleep(Duration::from_millis(5)).await;
348 count2.fetch_add(1, Ordering::Relaxed);
349 }));
350
351 reactor.run_until_idle();
352 assert_eq!(count.load(Ordering::Relaxed), 0);
353 reactor.advance_time(Duration::from_millis(5));
354 assert_eq!(count.load(Ordering::Relaxed), 1);
355 }
356}