Skip to main content

strontium_core/
clock.rs

1use std::cmp::Reverse;
2use std::collections::BinaryHeap;
3use std::task::Waker;
4use std::time::{Duration, Instant};
5
6pub(crate) struct TimerEntry {
7    pub deadline: Duration,
8    pub waker: Waker,
9}
10
11impl PartialEq for TimerEntry {
12    fn eq(&self, other: &Self) -> bool {
13        self.deadline == other.deadline
14    }
15}
16
17impl Eq for TimerEntry {}
18
19impl PartialOrd for TimerEntry {
20    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
21        Some(self.cmp(other))
22    }
23}
24
25impl Ord for TimerEntry {
26    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
27        self.deadline.cmp(&other.deadline)
28    }
29}
30
31pub struct Clock {
32    pub now: Duration,
33    base: Instant,
34    timers: BinaryHeap<Reverse<TimerEntry>>,
35}
36
37impl Default for Clock {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43impl Clock {
44    pub fn new() -> Self {
45        Self {
46            now: Duration::ZERO,
47            base: Instant::now(),
48            timers: BinaryHeap::new(),
49        }
50    }
51
52    pub fn register_timer(&mut self, deadline: Duration, waker: Waker) {
53        self.timers.push(Reverse(TimerEntry { deadline, waker }));
54    }
55
56    pub fn advance(&mut self, d: Duration) -> Vec<Waker> {
57        self.now += d;
58        let mut wakers = Vec::new();
59        while let Some(Reverse(entry)) = self.timers.peek() {
60            if entry.deadline <= self.now {
61                let Reverse(entry) = self.timers.pop().expect("timer entry");
62                wakers.push(entry.waker);
63            } else {
64                break;
65            }
66        }
67        wakers
68    }
69
70    pub fn advance_to_next_timer(&mut self) -> Option<Vec<Waker>> {
71        let next_deadline = self.timers.peek().map(|Reverse(e)| e.deadline)?;
72        let d = next_deadline.saturating_sub(self.now);
73        Some(self.advance(d))
74    }
75
76    pub fn now_as_instant(&self) -> Instant {
77        self.base + self.now
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use super::Clock;
84    use std::sync::Arc;
85    use std::task::{Wake, Waker};
86    use std::time::Duration;
87
88    #[derive(Default)]
89    struct CounterWake(std::sync::atomic::AtomicUsize);
90
91    impl Wake for CounterWake {
92        fn wake(self: Arc<Self>) {
93            self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
94        }
95    }
96
97    #[test]
98    fn advance_to_next_timer_wakes_registered_task() {
99        let mut clock = Clock::new();
100        let wake = Arc::new(CounterWake::default());
101        let waker = Waker::from(Arc::clone(&wake));
102
103        clock.register_timer(Duration::from_millis(5), waker);
104        let fired = clock.advance_to_next_timer().expect("timer");
105        assert_eq!(clock.now, Duration::from_millis(5));
106        assert_eq!(fired.len(), 1);
107    }
108}