Skip to main content

strontium_core/
executor.rs

1use crate::clock::Clock;
2use crate::futures::{Interval, SimSleep, SimYield};
3use crate::rng::Rng;
4use crate::scheduler::{Scheduler, SimWaker, TaskFuture, WakeQueue};
5use crate::trace::{TraceBuffer, TraceConfig, TraceEvent};
6use std::collections::{HashSet, VecDeque};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::{Arc, Mutex};
10use std::task::{Poll, Waker};
11use std::time::{Duration, Instant};
12
13pub use crate::scheduler::TaskHandle;
14
15pub struct Reactor {
16    scheduler: Arc<Mutex<Scheduler>>,
17    wake_queue: Arc<WakeQueue>,
18    aborted: Arc<Mutex<HashSet<usize>>>,
19    clock: Arc<Mutex<Clock>>,
20    rng: Arc<Mutex<Rng>>,
21    trace: Arc<Mutex<TraceBuffer>>,
22}
23
24impl Reactor {
25    pub fn new(seed: u64) -> Self {
26        Self::new_with_trace(seed, TraceConfig::default())
27    }
28
29    pub fn new_with_trace(seed: u64, trace_config: TraceConfig) -> Self {
30        Self {
31            scheduler: Arc::new(Mutex::new(Scheduler::new())),
32            wake_queue: WakeQueue::new(),
33            aborted: Arc::new(Mutex::new(HashSet::new())),
34            clock: Arc::new(Mutex::new(Clock::new())),
35            rng: Arc::new(Mutex::new(Rng::new(seed))),
36            trace: Arc::new(Mutex::new(TraceBuffer::new(trace_config))),
37        }
38    }
39
40    pub fn new_for_engine(&self) -> Self {
41        Self {
42            scheduler: Arc::clone(&self.scheduler),
43            wake_queue: Arc::clone(&self.wake_queue),
44            aborted: Arc::clone(&self.aborted),
45            clock: Arc::new(Mutex::new(Clock::new())),
46            rng: Arc::clone(&self.rng),
47            trace: Arc::clone(&self.trace),
48        }
49    }
50
51    pub fn clock_ref(&self) -> &Arc<Mutex<Clock>> {
52        &self.clock
53    }
54
55    pub fn spawn_local_task(
56        &self,
57        fut: Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
58    ) -> TaskHandle {
59        let mut sched = self.scheduler.lock().expect("scheduler");
60        let id = sched.next_id;
61        sched.next_id += 1;
62        sched.pending_new.push((id, fut));
63        TaskHandle {
64            task_id: id,
65            aborted: Arc::clone(&self.aborted),
66        }
67    }
68
69    pub fn yield_now(&self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + 'static>> {
70        Box::pin(SimYield { yielded: false })
71    }
72
73    pub fn sleep(
74        &self,
75        duration: Duration,
76    ) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + 'static>> {
77        Box::pin(SimSleep {
78            duration,
79            deadline: None,
80            clock: Arc::clone(&self.clock),
81        })
82    }
83
84    pub fn create_interval(&self, duration: Duration) -> Interval {
85        let next_deadline = self.clock.lock().expect("sim clock").now + duration;
86        Interval::new(duration, next_deadline, Arc::clone(&self.clock))
87    }
88
89    pub fn run_until_idle(&self) {
90        loop {
91            let next = {
92                let mut sched = self.scheduler.lock().expect("scheduler");
93                self.drain_pending_and_wakes(&mut sched);
94                self.shuffle_ready_queue(&mut sched);
95                let found = self.select_next_task(&mut sched);
96                if let Some((task_id, _)) = &found {
97                    sched.decision_log.push(*task_id);
98                }
99                found
100            };
101
102            let (task_id, mut future) = match next {
103                Some(pair) => pair,
104                None => break,
105            };
106
107            let waker: Waker = Waker::from(Arc::new(SimWaker {
108                task_id,
109                queue: Arc::clone(&self.wake_queue),
110            }));
111            let mut cx = std::task::Context::from_waker(&waker);
112
113            match future.as_mut().poll(&mut cx) {
114                Poll::Ready(()) => {}
115                Poll::Pending => {
116                    self.scheduler
117                        .lock()
118                        .expect("scheduler")
119                        .tasks
120                        .insert(task_id, future);
121                }
122            }
123
124            self.scheduler.lock().expect("scheduler").step_count += 1;
125            self.record_task_polled(task_id);
126        }
127    }
128
129    pub fn advance_time(&self, d: Duration) {
130        let wakers = self.clock.lock().expect("sim clock").advance(d);
131        let waker_count = wakers.len();
132        for w in wakers {
133            w.wake();
134        }
135        if waker_count > 0 {
136            let deadline_ms = self.clock.lock().expect("sim clock").now.as_millis() as u64;
137            self.trace
138                .lock()
139                .expect("trace")
140                .record(TraceEvent::TimerFired {
141                    deadline_ms,
142                    count: waker_count,
143                });
144        }
145        self.run_until_idle();
146    }
147
148    pub fn run_to_completion(&self) {
149        const MAX_ADVANCES: usize = 10_000;
150        for _ in 0..MAX_ADVANCES {
151            self.run_until_idle();
152            let wakers = self
153                .clock
154                .lock()
155                .expect("sim clock")
156                .advance_to_next_timer();
157            match wakers {
158                None => break,
159                Some(wakers) if wakers.is_empty() => break,
160                Some(wakers) => {
161                    let waker_count = wakers.len();
162                    let deadline_ms = self.clock.lock().expect("sim clock").now.as_millis() as u64;
163                    self.trace
164                        .lock()
165                        .expect("trace")
166                        .record(TraceEvent::TimerFired {
167                            deadline_ms,
168                            count: waker_count,
169                        });
170                    for w in wakers {
171                        w.wake();
172                    }
173                }
174            }
175        }
176        self.run_until_idle();
177    }
178
179    pub fn step_count(&self) -> u64 {
180        self.scheduler.lock().expect("scheduler").step_count
181    }
182
183    pub fn seed(&self) -> u64 {
184        self.rng.lock().expect("rng").seed()
185    }
186
187    pub fn virtual_elapsed(&self) -> Duration {
188        self.clock.lock().expect("sim clock").now
189    }
190
191    pub fn now(&self) -> Instant {
192        self.clock.lock().expect("sim clock").now_as_instant()
193    }
194
195    pub fn elapsed_since(&self, _start: Instant) -> u64 {
196        const MOCK_COMPUTE_NS: u64 = 1000;
197        let mut clock = self.clock.lock().expect("sim clock");
198        let wakers = clock.advance(Duration::from_nanos(MOCK_COMPUTE_NS));
199        drop(clock);
200        for w in wakers {
201            w.wake();
202        }
203        MOCK_COMPUTE_NS
204    }
205
206    pub fn next_u64(&self) -> u64 {
207        self.rng.lock().expect("rng").next_u64()
208    }
209
210    pub fn decision_log(&self) -> Vec<usize> {
211        self.scheduler
212            .lock()
213            .expect("scheduler")
214            .decision_log
215            .clone()
216    }
217
218    pub fn set_replay_decisions(&self, decisions: Vec<usize>) {
219        self.scheduler.lock().expect("scheduler").replay_decisions = VecDeque::from(decisions);
220    }
221
222    pub fn clear_decision_log(&self) {
223        self.scheduler
224            .lock()
225            .expect("scheduler")
226            .decision_log
227            .clear();
228    }
229
230    pub fn trace_snapshot(&self) -> TraceBuffer {
231        self.trace.lock().expect("trace").clone()
232    }
233
234    fn drain_pending_and_wakes(&self, sched: &mut Scheduler) {
235        let pending = std::mem::take(&mut sched.pending_new);
236        for (id, fut) in pending {
237            sched.tasks.insert(id, fut);
238            sched.ready.push_back(id);
239        }
240
241        let mut wq = self.wake_queue.0.lock().expect("wake queue");
242        sched.ready.extend(wq.drain(..));
243    }
244
245    fn shuffle_ready_queue(&self, sched: &mut Scheduler) {
246        if sched.ready.len() > 1 {
247            let mut rng = self.rng.lock().expect("rng");
248            let mut ready_vec: Vec<usize> = sched.ready.drain(..).collect();
249            rng.shuffle(&mut ready_vec);
250            sched.ready.extend(ready_vec);
251        }
252    }
253
254    fn select_next_task(&self, sched: &mut Scheduler) -> Option<(usize, TaskFuture)> {
255        let preferred_id = if !sched.replay_decisions.is_empty() {
256            sched.replay_decisions.pop_front()
257        } else {
258            None
259        };
260
261        let aborted = self.aborted.lock().expect("aborted set");
262        let mut found = None;
263
264        if let Some(pid) = preferred_id {
265            if let Some(pos) = sched.ready.iter().position(|&id| id == pid) {
266                let id = sched.ready.remove(pos).expect("ready task");
267                if !aborted.contains(&id) {
268                    if let Some(fut) = sched.tasks.remove(&id) {
269                        found = Some((id, fut));
270                    }
271                }
272            }
273        }
274
275        if found.is_none() {
276            while let Some(task_id) = sched.ready.pop_front() {
277                if aborted.contains(&task_id) {
278                    continue;
279                }
280                if let Some(fut) = sched.tasks.remove(&task_id) {
281                    found = Some((task_id, fut));
282                    break;
283                }
284            }
285        }
286
287        found
288    }
289
290    fn record_task_polled(&self, task_id: usize) {
291        let vt_ms = self.clock.lock().expect("sim clock").now.as_millis() as u64;
292        self.trace
293            .lock()
294            .expect("trace")
295            .record(TraceEvent::TaskPolled {
296                task_id,
297                virtual_time_ms: vt_ms,
298            });
299    }
300}
301
302impl Clone for Reactor {
303    fn clone(&self) -> Self {
304        Self {
305            scheduler: Arc::clone(&self.scheduler),
306            wake_queue: Arc::clone(&self.wake_queue),
307            aborted: Arc::clone(&self.aborted),
308            clock: Arc::clone(&self.clock),
309            rng: Arc::clone(&self.rng),
310            trace: Arc::clone(&self.trace),
311        }
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::Reactor;
318    use std::sync::{
319        Arc,
320        atomic::{AtomicUsize, Ordering},
321    };
322    use std::time::Duration;
323
324    #[test]
325    fn spawned_tasks_run_to_idle() {
326        let reactor = Reactor::new(7);
327        let count = Arc::new(AtomicUsize::new(0));
328        let count2 = Arc::clone(&count);
329
330        reactor.spawn_local_task(Box::pin(async move {
331            count2.fetch_add(1, Ordering::Relaxed);
332        }));
333        reactor.run_until_idle();
334
335        assert_eq!(count.load(Ordering::Relaxed), 1);
336        assert_eq!(reactor.step_count(), 1);
337    }
338
339    #[test]
340    fn timers_advance_and_wake_sleepers() {
341        let reactor = Reactor::new(9);
342        let count = Arc::new(AtomicUsize::new(0));
343        let count2 = Arc::clone(&count);
344        let reactor2 = reactor.clone();
345
346        reactor.spawn_local_task(Box::pin(async move {
347            reactor2.sleep(Duration::from_millis(5)).await;
348            count2.fetch_add(1, Ordering::Relaxed);
349        }));
350
351        reactor.run_until_idle();
352        assert_eq!(count.load(Ordering::Relaxed), 0);
353        reactor.advance_time(Duration::from_millis(5));
354        assert_eq!(count.load(Ordering::Relaxed), 1);
355    }
356}