witty_actors/
scheduler.rs1use std::cmp::Reverse;
21use std::collections::binary_heap::PeekMut;
22use std::collections::BinaryHeap;
23use std::future::Future;
24use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
25use std::sync::{Arc, Weak};
26use std::time::{Duration, Instant};
27
28use tokio::sync::oneshot;
29use tokio::task::JoinHandle;
30
31type Callback = Box<dyn FnOnce() + Sync + Send + 'static>;
32
33struct TimeoutEvent {
34 deadline: Instant,
35 event_id: u64, callback: Callback,
37}
38
39impl PartialEq for TimeoutEvent {
40 fn eq(&self, other: &Self) -> bool {
41 self.event_id == other.event_id
42 }
43}
44
45impl Eq for TimeoutEvent {}
46
47impl PartialOrd for TimeoutEvent {
48 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
49 Some(self.cmp(other))
50 }
51}
52
53impl Ord for TimeoutEvent {
54 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
55 self.deadline
56 .cmp(&other.deadline)
57 .then_with(|| self.event_id.cmp(&other.event_id))
58 }
59}
60
61enum SchedulerMessage {
62 ProcessTime,
63 Schedule {
64 callback: Callback,
65 timeout: Duration,
66 },
67}
68
69#[derive(Clone)]
70pub struct SchedulerClient {
71 inner: Arc<SchedulerClientInner>,
72}
73
74struct SchedulerClientInner {
75 no_advance_time_guard_count: AtomicUsize,
76 accelerate_time: AtomicBool,
77 tx: flume::Sender<SchedulerMessage>,
78}
79
80impl SchedulerClient {
81 fn time_is_accelerated(&self) -> bool {
83 self.inner.accelerate_time.load(Ordering::Relaxed)
84 }
85
86 fn is_advance_time_forbidden(&self) -> bool {
88 self.inner
89 .no_advance_time_guard_count
90 .load(Ordering::SeqCst)
91 > 0
92 }
93
94 pub fn schedule_event<F: FnOnce() + Send + Sync + 'static>(
101 &self,
102 callback: F,
103 timeout: Duration,
104 ) {
105 let _ = self.inner.tx.send(SchedulerMessage::Schedule {
106 callback: Box::new(callback),
107 timeout,
108 });
109 }
110
111 pub(crate) fn inc_no_advance_time(&self) {
113 self.inner
114 .no_advance_time_guard_count
115 .fetch_add(1, Ordering::SeqCst);
116 }
117
118 pub(crate) fn dec_no_advance_time(&self) {
122 let previous_count = self
123 .inner
124 .no_advance_time_guard_count
125 .fetch_sub(1, Ordering::SeqCst);
126 if previous_count == 1 {
127 self.process_time();
128 }
129 }
130
131 pub fn accelerate_time(&self) {
135 self.inner.accelerate_time.store(true, Ordering::Relaxed);
136 self.process_time();
137 }
138
139 pub async fn sleep(&self, duration: Duration) {
140 let (oneshot_tx, oneshot_rx) = oneshot::channel();
141 self.schedule_event(
142 move || {
143 let _ = oneshot_tx.send(());
144 },
145 duration,
146 );
147 let _ = oneshot_rx.await;
148 }
149
150 pub async fn timeout<O>(
151 &self,
152 duration: Duration,
153 fut: impl Future<Output = O>,
154 ) -> Result<O, ()> {
155 tokio::select! {
156 _ = self.sleep(duration) => {
157 Err(())
158 },
159 future_output = fut => {
160 Ok(future_output)
161 }
162 }
163 }
164
165 pub(crate) fn process_time(&self) {
169 let _ = self.inner.tx.send(SchedulerMessage::ProcessTime);
170 }
171
172 pub fn no_advance_time_guard(&self) -> NoAdvanceTimeGuard {
175 NoAdvanceTimeGuard::new(self.clone())
176 }
177}
178
179pub struct NoAdvanceTimeGuard {
180 scheduler_client: SchedulerClient,
181}
182
183impl NoAdvanceTimeGuard {
184 fn new(scheduler_client: SchedulerClient) -> Self {
185 scheduler_client.inc_no_advance_time();
186 NoAdvanceTimeGuard { scheduler_client }
187 }
188}
189
190impl Drop for NoAdvanceTimeGuard {
191 fn drop(&mut self) {
192 self.scheduler_client.dec_no_advance_time();
193 }
194}
195
196pub fn start_scheduler() -> SchedulerClient {
197 let (tx, rx) = flume::unbounded::<SchedulerMessage>();
198 let scheduler_client = SchedulerClient {
199 inner: Arc::new(SchedulerClientInner {
200 no_advance_time_guard_count: AtomicUsize::default(),
201 accelerate_time: Default::default(),
202 tx,
203 }),
204 };
205 let mut scheduler = Scheduler::new(&scheduler_client);
206 tokio::spawn(async move {
207 while let Ok(scheduler_message) = rx.recv_async().await {
208 match scheduler_message {
209 SchedulerMessage::ProcessTime => scheduler.process_time(),
210 SchedulerMessage::Schedule { callback, timeout } => {
211 scheduler.process_schedule(callback, timeout);
212 }
213 }
214 }
215 });
216 scheduler_client
217}
218
219struct Scheduler {
220 event_id_generator: u64,
223 simulated_time_shift: Duration,
227 future_events: BinaryHeap<Reverse<TimeoutEvent>>,
228 next_timeout: Option<JoinHandle<()>>,
229 weak_scheduler_client: Weak<SchedulerClientInner>,
230}
231
232impl Scheduler {
233 fn process_time(&mut self) {
240 let now = self.simulated_now();
241 while let Some(next_event_peek) = self.future_events.peek_mut() {
243 if next_event_peek.0.deadline > now {
244 break;
246 }
247 let next_event = PeekMut::pop(next_event_peek);
248 (next_event.0.callback)();
249 }
250
251 self.advance_time_if_necessary();
254 self.schedule_next_timeout();
255 }
256
257 fn process_schedule(&mut self, callback: Callback, timeout: Duration) {
259 let new_evt_deadline = self.simulated_now() + timeout;
260 let timeout_event = self.timeout_event(new_evt_deadline, callback);
261 self.future_events.push(Reverse(timeout_event));
262 self.process_time();
263 }
264
265 fn scheduler_client(&self) -> Option<SchedulerClient> {
266 let scheduler_client = SchedulerClient {
267 inner: self.weak_scheduler_client.upgrade()?,
268 };
269 Some(scheduler_client)
270 }
271
272 fn schedule_next_timeout(&mut self) {
274 let Some(scheduler_client) = self.scheduler_client() else { return };
275 let simulated_now = self.simulated_now();
276 let Some(next_deadline) = self.next_event_deadline() else { return; };
277 let timeout: Duration = if next_deadline <= simulated_now {
278 Duration::default()
284 } else {
285 next_deadline - simulated_now
286 };
287 if let Some(previous_join_handle) = self.next_timeout.take() {
288 previous_join_handle.abort();
291 }
292 let new_join_handle: JoinHandle<()> = tokio::task::spawn(async move {
293 if timeout.is_zero() {
294 tokio::task::yield_now().await;
295 } else {
296 tokio::time::sleep(timeout).await;
297 }
298 scheduler_client.process_time();
299 });
300 self.next_timeout = Some(new_join_handle);
301 }
302}
303
304impl Scheduler {
305 pub fn new(scheduler_client: &SchedulerClient) -> Self {
306 Scheduler {
307 event_id_generator: 0u64,
308 simulated_time_shift: Duration::default(),
309 future_events: Default::default(),
310 next_timeout: None,
311 weak_scheduler_client: Arc::downgrade(&scheduler_client.inner),
312 }
313 }
314
315 fn advance_time_if_necessary(&mut self) {
323 let Some(scheduler_client) = self.scheduler_client() else { return; };
324 if !scheduler_client.time_is_accelerated() {
325 return;
326 }
327 if scheduler_client.is_advance_time_forbidden() {
328 return;
329 }
330 let Some(advance_to_instant) = self.next_event_deadline() else { return; };
331 let now = self.simulated_now();
332 if let Some(time_shift) = advance_to_instant.checked_duration_since(now) {
333 self.simulated_time_shift += time_shift;
334 }
335 }
336
337 fn next_event_deadline(&self) -> Option<Instant> {
338 self.future_events.peek().map(|rev| rev.0.deadline)
339 }
340
341 fn simulated_now(&self) -> Instant {
342 Instant::now() + self.simulated_time_shift
343 }
344
345 fn timeout_event(&mut self, deadline: Instant, callback: Callback) -> TimeoutEvent {
346 let event_id = self.event_id_generator;
347 self.event_id_generator += 1;
348 TimeoutEvent {
349 deadline,
350 event_id,
351 callback,
352 }
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use std::sync::atomic::{AtomicUsize, Ordering};
359 use std::sync::Arc;
360 use std::time::{Duration, Instant};
361
362 use async_trait::async_trait;
363
364 use crate::{Actor, ActorContext, ActorExitStatus, Handler, Universe};
365
366 struct ClockActor {
367 count: Arc<AtomicUsize>,
368 }
369
370 #[derive(Debug)]
371 struct Tick;
372
373 #[async_trait]
374 impl Actor for ClockActor {
375 type ObservableState = ();
376 fn observable_state(&self) -> Self::ObservableState {}
377
378 async fn initialize(&mut self, ctx: &ActorContext<Self>) -> Result<(), ActorExitStatus> {
379 self.handle(Tick, ctx).await
380 }
381 }
382
383 #[async_trait]
384 impl Handler<Tick> for ClockActor {
385 type Reply = ();
386
387 async fn handle(
388 &mut self,
389 _tick: Tick,
390 ctx: &ActorContext<Self>,
391 ) -> Result<(), ActorExitStatus> {
392 self.count.fetch_add(1, Ordering::SeqCst);
393 ctx.schedule_self_msg(Duration::from_secs(1), Tick).await;
394 Ok(())
395 }
396 }
397
398 #[tokio::test]
399 async fn test_scheduler_advance_time_fast_forward_initialize() {
400 let count: Arc<AtomicUsize> = Default::default();
402 let simple_actor = ClockActor {
403 count: count.clone(),
404 };
405 let universe = Universe::with_accelerated_time();
406 universe.spawn_builder().spawn(simple_actor);
407 assert_eq!(count.load(Ordering::SeqCst), 0);
408 universe.sleep(Duration::from_millis(15)).await;
409 assert_eq!(count.load(Ordering::SeqCst), 1);
410 universe.assert_quit().await;
411 }
412
413 #[tokio::test]
414 async fn test_scheduler_advance_time_fast_forward_scheduled_message() {
415 let start = Instant::now();
416 let count: Arc<AtomicUsize> = Default::default();
418 let simple_actor = ClockActor {
419 count: count.clone(),
420 };
421 let universe = Universe::with_accelerated_time();
422 universe.spawn_builder().spawn(simple_actor);
423 assert_eq!(count.load(Ordering::SeqCst), 0);
424 universe.sleep(Duration::from_secs(10)).await;
425 assert_eq!(count.load(Ordering::SeqCst), 10);
426 let elapsed = start.elapsed();
427 assert!(elapsed.as_millis() < 50);
429 universe.assert_quit().await;
430 }
431}