timer_queue/
lib.rs

1//! A pure, minimal, and scalable structure for tracking expiration of timers
2//!
3//! ```
4//! # use timer_queue::TimerQueue;
5//! let mut q = TimerQueue::new();
6//! q.insert(42, "second");
7//! q.insert(17, "first");
8//! assert!(q.next_timeout().unwrap() <= 17);
9//! assert_eq!(q.poll(16), None);
10//! assert_eq!(q.poll(17), Some("first"));
11//! assert_eq!(q.poll(100), Some("second"));
12//! ```
13
14#![no_std]
15
16use core::fmt;
17
18use slab::Slab;
19
20/// Stores values to be yielded at specific times in the future
21///
22/// Time is expressed as a bare u64 representing an absolute point in time. The caller may use any
23/// consistent unit, e.g. milliseconds, and any consistent definition of time zero. Larger units
24/// limit resolution but make `poll`ing over the same real-time interval proportionately faster,
25/// whereas smaller units improve resolution, limit total range, and reduce `poll` performance.
26#[derive(Debug, Clone)]
27pub struct TimerQueue<T> {
28    /// Definitions of each active timer
29    ///
30    /// Timers are defined here, and referenced indirectly by index from `levels` and in the public
31    /// API. This allows for safe construction of intrusive linked lists between timers, and helps
32    /// reduce the amount of data that needs to be routinely shuffled around in `levels` as time
33    /// passes.
34    timers: Slab<TimerState<T>>,
35
36    /// A hierarchical timer wheel
37    ///
38    /// This data structure breaks down points in time into digits. The base of those digits can be
39    /// chosen arbitrarily; this implementation uses base `2^LOG_2_SLOTS`. A power of two makes it
40    /// easy to manipulate individual digits using bit shifts and masking because each digit
41    /// corresponds directly to `LOG_2_SLOTS` bits in the binary representation. For familiarity, we
42    /// will illustrate a timer wheel built instead on base 10, but the behavior is identical.
43    ///
44    /// Consider this timer wheel where timers are set at times 32, 42, and 46, and `next_tick` is
45    /// between 30 and 32 inclusive. Note that the number of slots in each level is equal to the
46    /// base of the digits used, in this case 10.
47    ///
48    /// ```text
49    ///           +--+--+--+--+--
50    /// Level 0   |30|31|32|33| ...
51    ///           +--+--+--+--+--
52    ///            \      |       /
53    ///             \     V      /
54    ///              \  +--+    /
55    ///               \ |32|   /
56    ///                \+--+  /
57    ///                 \    /
58    ///         +--+--+--+--+--+--+--+--+--+--+
59    /// Level 1 |00|10|20|30|40|50|60|70|80|90|
60    ///         +--+--+--+--+--+--+--+--+--+--+
61    ///                       |
62    ///                       V
63    ///                     +--+
64    ///                     |46|
65    ///                     +--+
66    ///                      ^|
67    ///                      |V
68    ///                     +--+
69    ///                     |42|
70    ///                     +--+
71    /// ```
72    ///
73    /// Timers are organized into buckets (or slots) at a resolution that decreases exponentially
74    /// with distance from `next_tick`, the present. Higher-numbered levels cover larger intervals,
75    /// until the highest-numbered level covers the complete representable of timers, from 0 to
76    /// `u64::MAX`. Every lower level covers the slot in the next highest level which `next_tick`
77    /// lies within. Level 0 represents the maximum resolution, where each slot covers exactly one
78    /// unit of time.
79    ///
80    /// The slot that a timer should be stored in is easily computed based on `next_tick` and the
81    /// desired expiry time. For a base 10 structure, find the most significant digit in the base 10
82    /// representations of `next_tick` and the desired expiry time that differs between the two. The
83    /// position of that digit is the level, and the value of that digit is the position in the
84    /// level. For example, if `next_tick` is 7342, and a timer is scheduled for time 7361, the
85    /// timer would be stored at level 1, slot 6. Note that no subtraction is performed: the start
86    /// of each level is always the greatest integer multiple of the level's span which is less than
87    /// or equal to `next_tick`.
88    ///
89    /// Calls to `poll` move `next_tick` towards the passed-in time. When `next_tick` reaches a
90    /// timer in level 0, it stops there and the timer is removed and returned from `poll`. Reaching
91    /// the end of level 0 redefines level 0 to represent the next slot in level 1, at which point
92    /// all timers stored in that slot are unpacked into appropriate slots of level 0, and traversal
93    /// of level 0 begins again from the start. When level 1 is exhausted, the next slot in level 2
94    /// is unpacked into levels 1 and 0, and so on for higher levels. Slots preceding `next_tick`
95    /// are therefore empty at any level, and for levels above 0, the slot containing `next_tick` is
96    /// also empty, having necessarily been unpacked into lower levels.
97    ///
98    /// Assuming the number of timers scheduled within a period of time is on average proportional
99    /// to the size of that period, advancing the queue by a constant amount of time has amortized
100    /// constant time complexity, because the frequency with which slots at a particular level are
101    /// unpacked is inversely proportional to the expected number of timers stored in that
102    /// slot.
103    ///
104    /// Inserting, removing, and updating timers are constant-time operations thanks to the above
105    /// and the use of unordered doubly linked lists to represent the contents of a slot. We can
106    /// also compute a lower bound for the next timeout in constant time by scanning for the
107    /// earliest nonempty slot.
108    levels: [Level; LEVELS],
109
110    /// Earliest point at which a timer may be pending
111    ///
112    /// Each `LOG_2_SLOTS` bits of this are a cursor into the associated level, in order of
113    /// ascending significance.
114    next_tick: u64,
115}
116
117impl<T> TimerQueue<T> {
118    /// Create an empty queue starting at time `0`
119    pub const fn new() -> Self {
120        Self {
121            timers: Slab::new(),
122            levels: [Level::new(); LEVELS],
123            next_tick: 0,
124        }
125    }
126
127    /// Create a queue for which at least `n` calls to `insert` will not require a reallocation
128    pub fn with_capacity(n: usize) -> Self {
129        Self {
130            timers: Slab::with_capacity(n),
131            levels: [Level::new(); LEVELS],
132            next_tick: 0,
133        }
134    }
135
136    /// Returns a timer that has expired by `now`, if any
137    ///
138    /// `now` must be at least the largest previously passed value
139    pub fn poll(&mut self, now: u64) -> Option<T> {
140        debug_assert!(now >= self.next_tick, "time advances monotonically");
141        loop {
142            // Advance towards the next timeout
143            self.advance_towards(now);
144            // Check for timeouts in the immediate future
145            if let Some(value) = self.scan_bottom(now) {
146                return Some(value);
147            }
148            // If we can't advance any further, bail out
149            if self.next_tick >= now {
150                return None;
151            }
152        }
153    }
154
155    /// Find a timer expired by `now` in level 0
156    fn scan_bottom(&mut self, now: u64) -> Option<T> {
157        let index = self.levels[0].first_index()?;
158        if slot_start(self.next_tick, 0, index) > now {
159            return None;
160        }
161        let timer = self.levels[0].slots[index];
162        let state = self.timers.remove(timer.0);
163        debug_assert_eq!(state.prev, None, "head of list has no predecessor");
164        debug_assert!(state.expiry <= now);
165        if let Some(next) = state.next {
166            debug_assert_eq!(
167                self.timers[next.0].prev,
168                Some(timer),
169                "successor links to head"
170            );
171            self.timers[next.0].prev = None;
172        }
173        self.levels[0].set(index, state.next);
174        self.next_tick = state.expiry;
175        self.maybe_shrink();
176        Some(state.value)
177    }
178
179    /// Advance to the start of the first nonempty slot or `now`, whichever is sooner
180    fn advance_towards(&mut self, now: u64) {
181        for level in 0..LEVELS {
182            if let Some(slot) = self.levels[level].first_index() {
183                if slot_start(self.next_tick, level, slot) > now {
184                    break;
185                }
186                self.advance_to(level, slot);
187                return;
188            }
189        }
190        self.next_tick = now;
191    }
192
193    /// Advance to a specific slot, which must be the first nonempty slot
194    fn advance_to(&mut self, level: usize, slot: usize) {
195        debug_assert!(
196            self.levels[..level].iter().all(|level| level.is_empty()),
197            "lower levels are empty"
198        );
199        debug_assert!(
200            self.levels[level].first_index().map_or(true, |x| x >= slot),
201            "lower slots in this level are empty"
202        );
203
204        // Advance into the slot
205        self.next_tick = slot_start(self.next_tick, level, slot);
206
207        if level == 0 {
208            // No lower levels exist to unpack timers into
209            return;
210        }
211
212        // Unpack all timers in this slot into lower levels
213        while let Some(timer) = self.levels[level].take(slot) {
214            let next = self.timers[timer.0].next;
215            self.levels[level].set(slot, next);
216            if let Some(next) = next {
217                self.timers[next.0].prev = None;
218            }
219            self.list_unlink(timer);
220            self.schedule(timer);
221        }
222    }
223
224    /// Link `timer` from the slot associated with its expiry
225    fn schedule(&mut self, timer: Timer) {
226        debug_assert_eq!(
227            self.timers[timer.0].next, None,
228            "timer isn't already scheduled"
229        );
230        debug_assert_eq!(
231            self.timers[timer.0].prev, None,
232            "timer isn't already scheduled"
233        );
234        let (level, slot) = timer_index(self.next_tick, self.timers[timer.0].expiry);
235        // Insert `timer` at the head of the list in the target slot
236        let head = self.levels[level].get(slot);
237        self.timers[timer.0].next = head;
238        if let Some(head) = head {
239            self.timers[head.0].prev = Some(timer);
240        }
241        self.levels[level].set(slot, Some(timer));
242    }
243
244    /// Lower bound on when the next timer will expire, if any
245    pub fn next_timeout(&self) -> Option<u64> {
246        for level in 0..LEVELS {
247            let start = ((self.next_tick >> (level * LOG_2_SLOTS)) & (SLOTS - 1) as u64) as usize;
248            for slot in start..SLOTS {
249                if self.levels[level].get(slot).is_some() {
250                    return Some(slot_start(self.next_tick, level, slot));
251                }
252            }
253        }
254        None
255    }
256
257    /// Register a timer that will yield `value` at `timeout`
258    pub fn insert(&mut self, timeout: u64, value: T) -> Timer {
259        let timer = Timer(self.timers.insert(TimerState {
260            expiry: timeout.max(self.next_tick),
261            prev: None,
262            next: None,
263            value,
264        }));
265        self.schedule(timer);
266        timer
267    }
268
269    /// Adjust `timer` to expire at `timeout`
270    pub fn reset(&mut self, timer: Timer, timeout: u64) {
271        self.unlink(timer);
272        self.timers[timer.0].expiry = timeout.max(self.next_tick);
273        self.schedule(timer);
274    }
275
276    /// Cancel `timer`
277    pub fn remove(&mut self, timer: Timer) -> T {
278        self.unlink(timer);
279        let state = self.timers.remove(timer.0);
280        self.maybe_shrink();
281        state.value
282    }
283
284    /// Release timer state memory if it's mostly unused
285    fn maybe_shrink(&mut self) {
286        if self.timers.capacity() / 16 > self.timers.len() {
287            self.timers.shrink_to_fit();
288        }
289    }
290
291    /// Iterate over the expiration and value of all scheduled timers
292    pub fn iter(&self) -> impl ExactSizeIterator<Item = (u64, &T)> {
293        self.timers.iter().map(|(_, x)| (x.expiry, &x.value))
294    }
295
296    /// Iterate over the expiration and value of all scheduled timers
297    pub fn iter_mut(&mut self) -> impl ExactSizeIterator<Item = (u64, &mut T)> {
298        self.timers
299            .iter_mut()
300            .map(|(_, x)| (x.expiry, &mut x.value))
301    }
302
303    /// Borrow the value associated with `timer`
304    pub fn get(&self, timer: Timer) -> &T {
305        &self.timers[timer.0].value
306    }
307
308    /// Uniquely borrow the value associated with `timer`
309    pub fn get_mut(&mut self, timer: Timer) -> &mut T {
310        &mut self.timers[timer.0].value
311    }
312
313    /// Number of scheduled timers
314    pub fn len(&self) -> usize {
315        self.timers.len()
316    }
317
318    /// Whether no timers are scheduled
319    pub fn is_empty(&self) -> bool {
320        self.timers.is_empty()
321    }
322
323    /// Remove all references to `timer`
324    fn unlink(&mut self, timer: Timer) {
325        let (level, slot) = timer_index(self.next_tick, self.timers[timer.0].expiry);
326        // If necessary, remove a reference to `timer` from its slot by replacing it with its
327        // successor
328        let slot_head = self.levels[level].get(slot).unwrap();
329        if slot_head == timer {
330            self.levels[level].set(slot, self.timers[slot_head.0].next);
331            debug_assert_eq!(
332                self.timers[timer.0].prev, None,
333                "head of list has no predecessor"
334            );
335        }
336        // Remove references to `timer` from other timers
337        self.list_unlink(timer);
338    }
339
340    /// Remove `timer` from its list
341    fn list_unlink(&mut self, timer: Timer) {
342        let prev = self.timers[timer.0].prev.take();
343        let next = self.timers[timer.0].next.take();
344        if let Some(prev) = prev {
345            // Remove reference from predecessor
346            self.timers[prev.0].next = next;
347        }
348        if let Some(next) = next {
349            // Remove reference from successor
350            self.timers[next.0].prev = prev;
351        }
352    }
353}
354
355/// Compute the first tick that lies within a slot
356fn slot_start(base: u64, level: usize, slot: usize) -> u64 {
357    let shift = (level * LOG_2_SLOTS) as u64;
358    // Shifting twice avoids an overflow when level = 10.
359    (base & ((!0 << shift) << LOG_2_SLOTS as u64)) | ((slot as u64) << shift)
360}
361
362/// Compute the level and slot for a certain expiry
363fn timer_index(base: u64, expiry: u64) -> (usize, usize) {
364    // The level is the position of the first bit set in `expiry` but not in `base`, divided by the
365    // number of bits spanned by each level.
366    let differing_bits = base ^ expiry;
367    let level = (63 - (differing_bits | 1).leading_zeros()) as usize / LOG_2_SLOTS;
368    debug_assert!(level < LEVELS, "every possible expiry is in range");
369
370    // The slot in that level is the difference between the expiry time and the time at which the
371    // level's span begins, after both times are shifted down to the level's granularity. Each
372    // level's spans starts at `base`, rounded down to a multiple of the size of its span.
373    let slot_base = (base >> (level * LOG_2_SLOTS)) & (!0 << LOG_2_SLOTS);
374    let slot = (expiry >> (level * LOG_2_SLOTS)) - slot_base;
375    debug_assert!(slot < SLOTS as u64);
376
377    (level, slot as usize)
378}
379
380impl<T> Default for TimerQueue<T> {
381    fn default() -> Self {
382        Self::new()
383    }
384}
385
386#[derive(Debug, Clone)]
387struct TimerState<T> {
388    /// Lowest argument to `poll` for which this timer may be returned
389    expiry: u64,
390    /// Value returned to the caller on expiry
391    value: T,
392    /// Predecessor within a slot's list
393    prev: Option<Timer>,
394    /// Successor within a slot's list
395    next: Option<Timer>,
396}
397
398/// A set of contiguous timer lists, ordered by expiry
399///
400/// Level `n` spans `2^(LOG_2_SLOTS * (n+1))` ticks, and each of its slots corresponds to a span of
401/// `2^(LOG_2_SLOTS * n)`.
402#[derive(Copy, Clone)]
403struct Level {
404    slots: [Timer; SLOTS],
405    /// Bit n indicates whether slot n is occupied, counting from LSB up
406    occupied: u64,
407}
408
409impl Level {
410    const fn new() -> Self {
411        Self {
412            slots: [Timer(usize::MAX); SLOTS],
413            occupied: 0,
414        }
415    }
416
417    fn first_index(&self) -> Option<usize> {
418        let x = self.occupied.trailing_zeros() as usize;
419        if x == self.slots.len() {
420            return None;
421        }
422        Some(x)
423    }
424
425    fn get(&self, slot: usize) -> Option<Timer> {
426        if self.occupied & (1 << slot) == 0 {
427            return None;
428        }
429        Some(self.slots[slot])
430    }
431
432    fn take(&mut self, slot: usize) -> Option<Timer> {
433        let x = self.get(slot)?;
434        self.set(slot, None);
435        Some(x)
436    }
437
438    fn set(&mut self, slot: usize, timer: Option<Timer>) {
439        match timer {
440            None => {
441                self.slots[slot] = Timer(usize::MAX);
442                self.occupied &= !(1 << slot);
443            }
444            Some(x) => {
445                self.slots[slot] = x;
446                self.occupied |= 1 << slot;
447            }
448        }
449    }
450
451    fn is_empty(&self) -> bool {
452        self.occupied == 0
453    }
454}
455
456impl fmt::Debug for Level {
457    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
458        let mut m = f.debug_map();
459        let numbered_nonempty_slots = self
460            .slots
461            .iter()
462            .enumerate()
463            .filter(|(i, _)| self.occupied & (1 << i) != 0);
464        for (i, Timer(t)) in numbered_nonempty_slots {
465            m.entry(&i, &t);
466        }
467        m.finish()
468    }
469}
470
471const LOG_2_SLOTS: usize = 6;
472const LEVELS: usize = 1 + 64 / LOG_2_SLOTS;
473const SLOTS: usize = 1 << LOG_2_SLOTS;
474
475// Index in `TimerQueue::timers`. Future work: add a niche here.
476/// Handle to a specific timer, obtained from [`TimerQueue::insert`]
477#[derive(Debug, Copy, Clone, Eq, PartialEq)]
478pub struct Timer(usize);
479
480#[cfg(test)]
481mod tests {
482    extern crate std;
483
484    use std::{vec::Vec, collections::HashMap};
485
486    use super::*;
487    use proptest::prelude::*;
488
489    #[test]
490    fn max_timeout() {
491        let mut queue = TimerQueue::new();
492        queue.insert(u64::MAX, ());
493        assert!(queue.poll(u64::MAX - 1).is_none());
494        assert!(queue.poll(u64::MAX).is_some());
495    }
496
497    #[test]
498    fn slot_starts() {
499        for i in 0..SLOTS {
500            assert_eq!(slot_start(0, 0, i), i as u64);
501            assert_eq!(slot_start(SLOTS as u64, 0, i), SLOTS as u64 + i as u64);
502            assert_eq!(slot_start(SLOTS as u64 + 1, 0, i), SLOTS as u64 + i as u64);
503            for j in 1..LEVELS {
504                assert_eq!(
505                    slot_start(0, j, i),
506                    (SLOTS as u64).pow(j as u32).wrapping_mul(i as u64)
507                );
508            }
509        }
510    }
511
512    #[test]
513    fn indexes() {
514        assert_eq!(timer_index(0, 0), (0, 0));
515        assert_eq!(timer_index(0, SLOTS as u64 - 1), (0, SLOTS - 1));
516        assert_eq!(
517            timer_index(SLOTS as u64 - 1, SLOTS as u64 - 1),
518            (0, SLOTS - 1)
519        );
520        assert_eq!(timer_index(0, SLOTS as u64), (1, 1));
521        for i in 0..LEVELS {
522            assert_eq!(timer_index(0, (SLOTS as u64).pow(i as u32)), (i, 1));
523            if i < LEVELS - 1 {
524                assert_eq!(
525                    timer_index(0, (SLOTS as u64).pow(i as u32 + 1) - 1),
526                    (i, SLOTS - 1)
527                );
528                assert_eq!(
529                    timer_index(SLOTS as u64 - 1, (SLOTS as u64).pow(i as u32 + 1) - 1),
530                    (i, SLOTS - 1)
531                );
532            }
533        }
534    }
535
536    #[test]
537    fn next_timeout() {
538        let mut queue = TimerQueue::new();
539        assert_eq!(queue.next_timeout(), None);
540        let k = queue.insert(0, ());
541        assert_eq!(queue.next_timeout(), Some(0));
542        queue.remove(k);
543        assert_eq!(queue.next_timeout(), None);
544        queue.insert(1234, ());
545        assert!(queue.next_timeout().unwrap() > 12);
546        queue.insert(12, ());
547        assert_eq!(queue.next_timeout(), Some(12));
548    }
549
550    #[test]
551    fn poll_boundary() {
552        let mut queue = TimerQueue::new();
553        queue.insert(SLOTS as u64 - 1, 'a');
554        queue.insert(SLOTS as u64, 'b');
555        assert_eq!(queue.poll(SLOTS as u64 - 2), None);
556        assert_eq!(queue.poll(SLOTS as u64 - 1), Some('a'));
557        assert_eq!(queue.poll(SLOTS as u64 - 1), None);
558        assert_eq!(queue.poll(SLOTS as u64), Some('b'));
559    }
560
561    #[test]
562    /// Validate that `reset` properly updates intrusive list links
563    fn reset_list_middle() {
564        let mut queue = TimerQueue::new();
565        let slot = SLOTS as u64 / 2;
566        let a = queue.insert(slot, ());
567        let b = queue.insert(slot, ());
568        let c = queue.insert(slot, ());
569
570        queue.reset(b, slot + 1);
571
572        assert_eq!(queue.levels[0].get(slot as usize + 1), Some(b));
573        assert_eq!(queue.timers[b.0].prev, None);
574        assert_eq!(queue.timers[b.0].next, None);
575
576        assert_eq!(queue.levels[0].get(slot as usize), Some(c));
577        assert_eq!(queue.timers[c.0].prev, None);
578        assert_eq!(queue.timers[c.0].next, Some(a));
579        assert_eq!(queue.timers[a.0].prev, Some(c));
580        assert_eq!(queue.timers[a.0].next, None);
581    }
582
583    proptest! {
584        #[test]
585        fn poll(ts in times()) {
586            let mut queue = TimerQueue::new();
587            let mut time_values = HashMap::<u64, Vec<usize>>::new();
588            for (i, t) in ts.into_iter().enumerate() {
589                queue.insert(t, i);
590                time_values.entry(t).or_default().push(i);
591            }
592            let mut time_values = time_values.into_iter().collect::<Vec<(u64, Vec<usize>)>>();
593            time_values.sort_unstable_by_key(|&(t, _)| t);
594            for &(t, ref is) in &time_values {
595                assert!(queue.next_timeout().unwrap() <= t);
596                if t > 0 {
597                    assert_eq!(queue.poll(t-1), None);
598                }
599                let mut values = Vec::new();
600                while let Some(i) = queue.poll(t) {
601                    values.push(i);
602                }
603                assert_eq!(values.len(), is.len());
604                for i in is {
605                    assert!(values.contains(i));
606                }
607            }
608        }
609
610        #[test]
611        fn reset(ts_a in times(), ts_b in times()) {
612            let mut queue = TimerQueue::new();
613            let timers = ts_a.map(|t| queue.insert(t, ()));
614            for (timer, t) in timers.into_iter().zip(ts_b) {
615                queue.reset(timer, t);
616            }
617            let mut n = 0;
618            while let Some(()) = queue.poll(u64::MAX) {
619                n += 1;
620            }
621            assert_eq!(n, timers.len());
622        }
623
624        #[test]
625        fn index_start_consistency(a in time(), b in time()) {
626            let base = a.min(b);
627            let t = a.max(b);
628            let (level, slot) = timer_index(base, t);
629            let start = slot_start(base, level, slot);
630            assert!(start <= t);
631            if let Some(end) = start.checked_add((SLOTS as u64).pow(level as u32)) {
632                assert!(end > t);
633            } else {
634                // Slot contains u64::MAX
635                assert!(start >= slot_start(0, LEVELS - 1, 15));
636                if level == LEVELS - 1 {
637                    assert_eq!(slot, 15);
638                } else {
639                    assert_eq!(slot, SLOTS - 1);
640                }
641            }
642        }
643    }
644
645    /// Generates a time whose level/slot is more or less uniformly distributed
646    fn time() -> impl Strategy<Value = u64> {
647        ((0..LEVELS as u32), (0..SLOTS as u64)).prop_perturb(|(level, mut slot), mut rng| {
648            if level == LEVELS as u32 - 1 {
649                slot %= 16;
650            }
651            let slot_size = (SLOTS as u64).pow(level);
652            let slot_start = slot * slot_size;
653            let slot_end = (slot + 1).saturating_mul(slot_size);
654            rng.gen_range(slot_start..slot_end)
655        })
656    }
657
658    #[rustfmt::skip]
659    fn times() -> impl Strategy<Value = [u64; 16]> {
660        [time(), time(), time(), time(), time(), time(), time(), time(),
661         time(), time(), time(), time(), time(), time(), time(), time()]
662    }
663}