1#![no_std]
15
16use core::fmt;
17
18use slab::Slab;
19
20#[cfg(feature = "serde")]
21use serde::ser::SerializeSeq;
22
23#[derive(Debug, Clone)]
30pub struct TimerQueue<T> {
31 timers: Slab<TimerState<T>>,
38
39 levels: [Level; LEVELS],
112
113 next_tick: u64,
118}
119
120impl<T> TimerQueue<T> {
121 pub const fn new() -> Self {
123 Self {
124 timers: Slab::new(),
125 levels: [Level::new(); LEVELS],
126 next_tick: 0,
127 }
128 }
129
130 pub fn with_capacity(n: usize) -> Self {
132 Self {
133 timers: Slab::with_capacity(n),
134 levels: [Level::new(); LEVELS],
135 next_tick: 0,
136 }
137 }
138
139 pub fn poll(&mut self, now: u64) -> Option<T> {
143 debug_assert!(now >= self.next_tick, "time advances monotonically");
144 loop {
145 self.advance_towards(now);
147 if let Some(value) = self.scan_bottom(now) {
149 return Some(value);
150 }
151 if self.next_tick >= now {
153 return None;
154 }
155 }
156 }
157
158 fn scan_bottom(&mut self, now: u64) -> Option<T> {
160 let index = self.levels[0].first_index()?;
161 if slot_start(self.next_tick, 0, index) > now {
162 return None;
163 }
164 let timer = self.levels[0].slots[index];
165 let state = self.timers.remove(timer.0);
166 debug_assert_eq!(state.prev, None, "head of list has no predecessor");
167 debug_assert!(state.expiry <= now);
168 if let Some(next) = state.next {
169 debug_assert_eq!(
170 self.timers[next.0].prev,
171 Some(timer),
172 "successor links to head"
173 );
174 self.timers[next.0].prev = None;
175 }
176 self.levels[0].set(index, state.next);
177 self.next_tick = state.expiry;
178 self.maybe_shrink();
179 Some(state.value)
180 }
181
182 fn advance_towards(&mut self, now: u64) {
184 for level in 0..LEVELS {
185 if let Some(slot) = self.levels[level].first_index() {
186 if slot_start(self.next_tick, level, slot) > now {
187 break;
188 }
189 self.advance_to(level, slot);
190 return;
191 }
192 }
193 self.next_tick = now;
194 }
195
196 fn advance_to(&mut self, level: usize, slot: usize) {
198 debug_assert!(
199 self.levels[..level].iter().all(|level| level.is_empty()),
200 "lower levels are empty"
201 );
202 debug_assert!(
203 self.levels[level].first_index().map_or(true, |x| x >= slot),
204 "lower slots in this level are empty"
205 );
206
207 self.next_tick = slot_start(self.next_tick, level, slot);
209
210 if level == 0 {
211 return;
213 }
214
215 while let Some(timer) = self.levels[level].take(slot) {
217 let next = self.timers[timer.0].next;
218 self.levels[level].set(slot, next);
219 if let Some(next) = next {
220 self.timers[next.0].prev = None;
221 }
222 self.list_unlink(timer);
223 self.schedule(timer);
224 }
225 }
226
227 fn schedule(&mut self, timer: Timer) {
229 debug_assert_eq!(
230 self.timers[timer.0].next, None,
231 "timer isn't already scheduled"
232 );
233 debug_assert_eq!(
234 self.timers[timer.0].prev, None,
235 "timer isn't already scheduled"
236 );
237 let (level, slot) = timer_index(self.next_tick, self.timers[timer.0].expiry);
238 let head = self.levels[level].get(slot);
240 self.timers[timer.0].next = head;
241 if let Some(head) = head {
242 self.timers[head.0].prev = Some(timer);
243 }
244 self.levels[level].set(slot, Some(timer));
245 }
246
247 pub fn next_timeout(&self) -> Option<u64> {
249 for level in 0..LEVELS {
250 let start = ((self.next_tick >> (level * LOG_2_SLOTS)) & (SLOTS - 1) as u64) as usize;
251 for slot in start..SLOTS {
252 if self.levels[level].get(slot).is_some() {
253 return Some(slot_start(self.next_tick, level, slot));
254 }
255 }
256 }
257 None
258 }
259
260 pub fn insert(&mut self, timeout: u64, value: T) -> Timer {
262 let timer = Timer(self.timers.insert(TimerState {
263 expiry: timeout.max(self.next_tick),
264 prev: None,
265 next: None,
266 value,
267 }));
268 self.schedule(timer);
269 timer
270 }
271
272 pub fn reset(&mut self, timer: Timer, timeout: u64) {
274 self.unlink(timer);
275 self.timers[timer.0].expiry = timeout.max(self.next_tick);
276 self.schedule(timer);
277 }
278
279 pub fn remove(&mut self, timer: Timer) -> T {
281 self.unlink(timer);
282 let state = self.timers.remove(timer.0);
283 self.maybe_shrink();
284 state.value
285 }
286
287 fn maybe_shrink(&mut self) {
289 if self.timers.capacity() / 16 > self.timers.len() {
290 self.timers.shrink_to_fit();
291 }
292 }
293
294 pub fn iter(&self) -> impl ExactSizeIterator<Item = (u64, &T)> {
296 self.timers.iter().map(|(_, x)| (x.expiry, &x.value))
297 }
298
299 pub fn iter_mut(&mut self) -> impl ExactSizeIterator<Item = (u64, &mut T)> {
301 self.timers
302 .iter_mut()
303 .map(|(_, x)| (x.expiry, &mut x.value))
304 }
305
306 pub fn get(&self, timer: Timer) -> &T {
308 &self.timers[timer.0].value
309 }
310
311 pub fn get_mut(&mut self, timer: Timer) -> &mut T {
313 &mut self.timers[timer.0].value
314 }
315
316 pub fn len(&self) -> usize {
318 self.timers.len()
319 }
320
321 pub fn is_empty(&self) -> bool {
323 self.timers.is_empty()
324 }
325
326 fn unlink(&mut self, timer: Timer) {
328 let (level, slot) = timer_index(self.next_tick, self.timers[timer.0].expiry);
329 let slot_head = self.levels[level].get(slot).unwrap();
332 if slot_head == timer {
333 self.levels[level].set(slot, self.timers[slot_head.0].next);
334 debug_assert_eq!(
335 self.timers[timer.0].prev, None,
336 "head of list has no predecessor"
337 );
338 }
339 self.list_unlink(timer);
341 }
342
343 fn list_unlink(&mut self, timer: Timer) {
345 let prev = self.timers[timer.0].prev.take();
346 let next = self.timers[timer.0].next.take();
347 if let Some(prev) = prev {
348 self.timers[prev.0].next = next;
350 }
351 if let Some(next) = next {
352 self.timers[next.0].prev = prev;
354 }
355 }
356}
357
358fn slot_start(base: u64, level: usize, slot: usize) -> u64 {
360 let shift = (level * LOG_2_SLOTS) as u64;
361 (base & ((!0 << shift) << LOG_2_SLOTS as u64)) | ((slot as u64) << shift)
363}
364
365fn timer_index(base: u64, expiry: u64) -> (usize, usize) {
367 let differing_bits = base ^ expiry;
370 let level = (63 - (differing_bits | 1).leading_zeros()) as usize / LOG_2_SLOTS;
371 debug_assert!(level < LEVELS, "every possible expiry is in range");
372
373 let slot_base = (base >> (level * LOG_2_SLOTS)) & (!0 << LOG_2_SLOTS);
377 let slot = (expiry >> (level * LOG_2_SLOTS)) - slot_base;
378 debug_assert!(slot < SLOTS as u64);
379
380 (level, slot as usize)
381}
382
383impl<T> Default for TimerQueue<T> {
384 fn default() -> Self {
385 Self::new()
386 }
387}
388
389#[derive(Debug, Clone)]
390struct TimerState<T> {
391 expiry: u64,
393 value: T,
395 prev: Option<Timer>,
397 next: Option<Timer>,
399}
400
401#[derive(Copy, Clone)]
406struct Level {
407 slots: [Timer; SLOTS],
408 occupied: u64,
410}
411
412impl Level {
413 const fn new() -> Self {
414 Self {
415 slots: [Timer(usize::MAX); SLOTS],
416 occupied: 0,
417 }
418 }
419
420 fn first_index(&self) -> Option<usize> {
421 let x = self.occupied.trailing_zeros() as usize;
422 if x == self.slots.len() {
423 return None;
424 }
425 Some(x)
426 }
427
428 fn get(&self, slot: usize) -> Option<Timer> {
429 if self.occupied & (1 << slot) == 0 {
430 return None;
431 }
432 Some(self.slots[slot])
433 }
434
435 fn take(&mut self, slot: usize) -> Option<Timer> {
436 let x = self.get(slot)?;
437 self.set(slot, None);
438 Some(x)
439 }
440
441 fn set(&mut self, slot: usize, timer: Option<Timer>) {
442 match timer {
443 None => {
444 self.slots[slot] = Timer(usize::MAX);
445 self.occupied &= !(1 << slot);
446 }
447 Some(x) => {
448 self.slots[slot] = x;
449 self.occupied |= 1 << slot;
450 }
451 }
452 }
453
454 fn is_empty(&self) -> bool {
455 self.occupied == 0
456 }
457}
458
459impl fmt::Debug for Level {
460 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
461 let mut m = f.debug_map();
462 let numbered_nonempty_slots = self
463 .slots
464 .iter()
465 .enumerate()
466 .filter(|(i, _)| self.occupied & (1 << i) != 0);
467 for (i, Timer(t)) in numbered_nonempty_slots {
468 m.entry(&i, &t);
469 }
470 m.finish()
471 }
472}
473
474const LOG_2_SLOTS: usize = 6;
475const LEVELS: usize = 1 + 64 / LOG_2_SLOTS;
476const SLOTS: usize = 1 << LOG_2_SLOTS;
477
478#[derive(Debug, Copy, Clone, Eq, PartialEq)]
481pub struct Timer(usize);
482
483#[cfg(feature = "serde")]
484impl<T: serde::Serialize> serde::Serialize for TimerQueue<T> {
485 fn serialize<S>(
486 &self,
487 serializer: S,
488 ) -> Result<<S as serde::Serializer>::Ok, <S as serde::Serializer>::Error>
489 where
490 S: serde::Serializer,
491 {
492 let mut seq = serializer.serialize_seq(Some(self.len()))?;
493 for v in self.iter() {
494 let t: (u64, &T) = v;
495 seq.serialize_element(&t)?;
496 }
497 seq.end()
498 }
499}
500
501#[cfg(feature = "serde")]
502impl<'de, T> serde::Deserialize<'de> for TimerQueue<T>
503where
504 T: serde::Deserialize<'de>,
505{
506 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
507 where
508 D: serde::Deserializer<'de>,
509 {
510 use core::fmt::Formatter;
511 use core::marker::PhantomData;
512
513 struct TimerQueueVisitor<T>(PhantomData<T>);
514
515 impl<'de, T> serde::de::Visitor<'de> for TimerQueueVisitor<T>
516 where
517 T: serde::Deserialize<'de>,
518 {
519 type Value = TimerQueue<T>;
520
521 fn expecting(&self, formatter: &mut Formatter) -> fmt::Result {
522 write!(
523 formatter,
524 "a sequence of (u64, {}) tuples",
525 core::any::type_name::<T>()
526 )
527 }
528
529 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
530 where
531 A: serde::de::SeqAccess<'de>,
532 {
533 let mut timer_queue = if let Some(size) = seq.size_hint() {
534 TimerQueue::<T>::with_capacity(size)
535 } else {
536 TimerQueue::<T>::new()
537 };
538 while let Some((time, value)) = seq.next_element::<(u64, T)>()? {
539 timer_queue.insert(time, value);
540 }
541 Ok(timer_queue)
542 }
543 }
544
545 deserializer.deserialize_seq(TimerQueueVisitor(PhantomData))
546 }
547}
548
549#[cfg(test)]
550mod tests {
551 extern crate alloc;
552 extern crate std;
553
554 use std::{collections::HashMap, vec::Vec};
555
556 use super::*;
557 use proptest::prelude::*;
558
559 #[test]
560 fn max_timeout() {
561 let mut queue = TimerQueue::new();
562 queue.insert(u64::MAX, ());
563 assert!(queue.poll(u64::MAX - 1).is_none());
564 assert!(queue.poll(u64::MAX).is_some());
565 }
566
567 #[test]
568 fn slot_starts() {
569 for i in 0..SLOTS {
570 assert_eq!(slot_start(0, 0, i), i as u64);
571 assert_eq!(slot_start(SLOTS as u64, 0, i), SLOTS as u64 + i as u64);
572 assert_eq!(slot_start(SLOTS as u64 + 1, 0, i), SLOTS as u64 + i as u64);
573 for j in 1..LEVELS {
574 assert_eq!(
575 slot_start(0, j, i),
576 (SLOTS as u64).pow(j as u32).wrapping_mul(i as u64)
577 );
578 }
579 }
580 }
581
582 #[test]
583 fn indexes() {
584 assert_eq!(timer_index(0, 0), (0, 0));
585 assert_eq!(timer_index(0, SLOTS as u64 - 1), (0, SLOTS - 1));
586 assert_eq!(
587 timer_index(SLOTS as u64 - 1, SLOTS as u64 - 1),
588 (0, SLOTS - 1)
589 );
590 assert_eq!(timer_index(0, SLOTS as u64), (1, 1));
591 for i in 0..LEVELS {
592 assert_eq!(timer_index(0, (SLOTS as u64).pow(i as u32)), (i, 1));
593 if i < LEVELS - 1 {
594 assert_eq!(
595 timer_index(0, (SLOTS as u64).pow(i as u32 + 1) - 1),
596 (i, SLOTS - 1)
597 );
598 assert_eq!(
599 timer_index(SLOTS as u64 - 1, (SLOTS as u64).pow(i as u32 + 1) - 1),
600 (i, SLOTS - 1)
601 );
602 }
603 }
604 }
605
606 #[test]
607 fn next_timeout() {
608 let mut queue = TimerQueue::new();
609 assert_eq!(queue.next_timeout(), None);
610 let k = queue.insert(0, ());
611 assert_eq!(queue.next_timeout(), Some(0));
612 queue.remove(k);
613 assert_eq!(queue.next_timeout(), None);
614 queue.insert(1234, ());
615 assert!(queue.next_timeout().unwrap() > 12);
616 queue.insert(12, ());
617 assert_eq!(queue.next_timeout(), Some(12));
618 }
619
620 #[test]
621 fn poll_boundary() {
622 let mut queue = TimerQueue::new();
623 queue.insert(SLOTS as u64 - 1, 'a');
624 queue.insert(SLOTS as u64, 'b');
625 assert_eq!(queue.poll(SLOTS as u64 - 2), None);
626 assert_eq!(queue.poll(SLOTS as u64 - 1), Some('a'));
627 assert_eq!(queue.poll(SLOTS as u64 - 1), None);
628 assert_eq!(queue.poll(SLOTS as u64), Some('b'));
629 }
630
631 #[test]
632 fn reset_list_middle() {
634 let mut queue = TimerQueue::new();
635 let slot = SLOTS as u64 / 2;
636 let a = queue.insert(slot, ());
637 let b = queue.insert(slot, ());
638 let c = queue.insert(slot, ());
639
640 queue.reset(b, slot + 1);
641
642 assert_eq!(queue.levels[0].get(slot as usize + 1), Some(b));
643 assert_eq!(queue.timers[b.0].prev, None);
644 assert_eq!(queue.timers[b.0].next, None);
645
646 assert_eq!(queue.levels[0].get(slot as usize), Some(c));
647 assert_eq!(queue.timers[c.0].prev, None);
648 assert_eq!(queue.timers[c.0].next, Some(a));
649 assert_eq!(queue.timers[a.0].prev, Some(c));
650 assert_eq!(queue.timers[a.0].next, None);
651 }
652
653 proptest! {
654 #[test]
655 fn poll(ts in times()) {
656 let mut queue = TimerQueue::new();
657 let mut time_values = HashMap::<u64, Vec<usize>>::new();
658 for (i, t) in ts.into_iter().enumerate() {
659 queue.insert(t, i);
660 time_values.entry(t).or_default().push(i);
661 }
662 let mut time_values = time_values.into_iter().collect::<Vec<(u64, Vec<usize>)>>();
663 time_values.sort_unstable_by_key(|&(t, _)| t);
664 for &(t, ref is) in &time_values {
665 assert!(queue.next_timeout().unwrap() <= t);
666 if t > 0 {
667 assert_eq!(queue.poll(t-1), None);
668 }
669 let mut values = Vec::new();
670 while let Some(i) = queue.poll(t) {
671 values.push(i);
672 }
673 assert_eq!(values.len(), is.len());
674 for i in is {
675 assert!(values.contains(i));
676 }
677 }
678 }
679
680 #[test]
681 fn reset(ts_a in times(), ts_b in times()) {
682 let mut queue = TimerQueue::new();
683 let timers = ts_a.map(|t| queue.insert(t, ()));
684 for (timer, t) in timers.into_iter().zip(ts_b) {
685 queue.reset(timer, t);
686 }
687 let mut n = 0;
688 while let Some(()) = queue.poll(u64::MAX) {
689 n += 1;
690 }
691 assert_eq!(n, timers.len());
692 }
693
694 #[test]
695 fn index_start_consistency(a in time(), b in time()) {
696 let base = a.min(b);
697 let t = a.max(b);
698 let (level, slot) = timer_index(base, t);
699 let start = slot_start(base, level, slot);
700 assert!(start <= t);
701 if let Some(end) = start.checked_add((SLOTS as u64).pow(level as u32)) {
702 assert!(end > t);
703 } else {
704 assert!(start >= slot_start(0, LEVELS - 1, 15));
706 if level == LEVELS - 1 {
707 assert_eq!(slot, 15);
708 } else {
709 assert_eq!(slot, SLOTS - 1);
710 }
711 }
712 }
713 }
714
715 #[test]
716 #[cfg(feature = "serde")]
717 fn serialization() {
718 const VALUES: [(u64, usize); 17] = [
719 (23, 5132),
720 (87, 6),
721 (45, 7839),
722 (122, 345),
723 (67, 12333),
724 (34, 8),
725 (90, 234),
726 (151, 82290),
727 (56, 32),
728 (78, 567),
729 (19, 345),
730 (22, 78),
731 (33, 890),
732 (44, 123),
733 (51235, 6),
734 (66, 89),
735 (727, 890),
736 ];
737
738 let mut queue = TimerQueue::<usize>::new();
739 for (t, v) in VALUES {
740 queue.insert(t, v);
741 }
742 let serialized: Vec<u8> = bincode::serialize(&queue).expect("Serialization failed");
743 let mut deserialized: TimerQueue<usize> =
744 bincode::deserialize(&serialized).expect("Deserialization failed");
745
746 loop {
747 let r1 = queue.poll(u64::MAX);
748 let r2 = deserialized.poll(u64::MAX);
749 assert!(r1 == r2);
750 if r1.is_none() {
751 break;
752 }
753 }
754 }
755
756 fn time() -> impl Strategy<Value = u64> {
758 ((0..LEVELS as u32), (0..SLOTS as u64)).prop_perturb(|(level, mut slot), mut rng| {
759 if level == LEVELS as u32 - 1 {
760 slot %= 16;
761 }
762 let slot_size = (SLOTS as u64).pow(level);
763 let slot_start = slot * slot_size;
764 let slot_end = (slot + 1).saturating_mul(slot_size);
765 rng.gen_range(slot_start..slot_end)
766 })
767 }
768
769 #[rustfmt::skip]
770 fn times() -> impl Strategy<Value = [u64; 16]> {
771 [time(), time(), time(), time(), time(), time(), time(), time(),
772 time(), time(), time(), time(), time(), time(), time(), time()]
773 }
774}