1#![no_std]
15
16use core::fmt;
17
18use slab::Slab;
19
20#[derive(Debug, Clone)]
27pub struct TimerQueue<T> {
28 timers: Slab<TimerState<T>>,
35
36 levels: [Level; LEVELS],
109
110 next_tick: u64,
115}
116
117impl<T> TimerQueue<T> {
118 pub const fn new() -> Self {
120 Self {
121 timers: Slab::new(),
122 levels: [Level::new(); LEVELS],
123 next_tick: 0,
124 }
125 }
126
127 pub fn with_capacity(n: usize) -> Self {
129 Self {
130 timers: Slab::with_capacity(n),
131 levels: [Level::new(); LEVELS],
132 next_tick: 0,
133 }
134 }
135
136 pub fn poll(&mut self, now: u64) -> Option<T> {
140 debug_assert!(now >= self.next_tick, "time advances monotonically");
141 loop {
142 self.advance_towards(now);
144 if let Some(value) = self.scan_bottom(now) {
146 return Some(value);
147 }
148 if self.next_tick >= now {
150 return None;
151 }
152 }
153 }
154
155 fn scan_bottom(&mut self, now: u64) -> Option<T> {
157 let index = self.levels[0].first_index()?;
158 if slot_start(self.next_tick, 0, index) > now {
159 return None;
160 }
161 let timer = self.levels[0].slots[index];
162 let state = self.timers.remove(timer.0);
163 debug_assert_eq!(state.prev, None, "head of list has no predecessor");
164 debug_assert!(state.expiry <= now);
165 if let Some(next) = state.next {
166 debug_assert_eq!(
167 self.timers[next.0].prev,
168 Some(timer),
169 "successor links to head"
170 );
171 self.timers[next.0].prev = None;
172 }
173 self.levels[0].set(index, state.next);
174 self.next_tick = state.expiry;
175 self.maybe_shrink();
176 Some(state.value)
177 }
178
179 fn advance_towards(&mut self, now: u64) {
181 for level in 0..LEVELS {
182 if let Some(slot) = self.levels[level].first_index() {
183 if slot_start(self.next_tick, level, slot) > now {
184 break;
185 }
186 self.advance_to(level, slot);
187 return;
188 }
189 }
190 self.next_tick = now;
191 }
192
193 fn advance_to(&mut self, level: usize, slot: usize) {
195 debug_assert!(
196 self.levels[..level].iter().all(|level| level.is_empty()),
197 "lower levels are empty"
198 );
199 debug_assert!(
200 self.levels[level].first_index().map_or(true, |x| x >= slot),
201 "lower slots in this level are empty"
202 );
203
204 self.next_tick = slot_start(self.next_tick, level, slot);
206
207 if level == 0 {
208 return;
210 }
211
212 while let Some(timer) = self.levels[level].take(slot) {
214 let next = self.timers[timer.0].next;
215 self.levels[level].set(slot, next);
216 if let Some(next) = next {
217 self.timers[next.0].prev = None;
218 }
219 self.list_unlink(timer);
220 self.schedule(timer);
221 }
222 }
223
224 fn schedule(&mut self, timer: Timer) {
226 debug_assert_eq!(
227 self.timers[timer.0].next, None,
228 "timer isn't already scheduled"
229 );
230 debug_assert_eq!(
231 self.timers[timer.0].prev, None,
232 "timer isn't already scheduled"
233 );
234 let (level, slot) = timer_index(self.next_tick, self.timers[timer.0].expiry);
235 let head = self.levels[level].get(slot);
237 self.timers[timer.0].next = head;
238 if let Some(head) = head {
239 self.timers[head.0].prev = Some(timer);
240 }
241 self.levels[level].set(slot, Some(timer));
242 }
243
244 pub fn next_timeout(&self) -> Option<u64> {
246 for level in 0..LEVELS {
247 let start = ((self.next_tick >> (level * LOG_2_SLOTS)) & (SLOTS - 1) as u64) as usize;
248 for slot in start..SLOTS {
249 if self.levels[level].get(slot).is_some() {
250 return Some(slot_start(self.next_tick, level, slot));
251 }
252 }
253 }
254 None
255 }
256
257 pub fn insert(&mut self, timeout: u64, value: T) -> Timer {
259 let timer = Timer(self.timers.insert(TimerState {
260 expiry: timeout.max(self.next_tick),
261 prev: None,
262 next: None,
263 value,
264 }));
265 self.schedule(timer);
266 timer
267 }
268
269 pub fn reset(&mut self, timer: Timer, timeout: u64) {
271 self.unlink(timer);
272 self.timers[timer.0].expiry = timeout.max(self.next_tick);
273 self.schedule(timer);
274 }
275
276 pub fn remove(&mut self, timer: Timer) -> T {
278 self.unlink(timer);
279 let state = self.timers.remove(timer.0);
280 self.maybe_shrink();
281 state.value
282 }
283
284 fn maybe_shrink(&mut self) {
286 if self.timers.capacity() / 16 > self.timers.len() {
287 self.timers.shrink_to_fit();
288 }
289 }
290
291 pub fn iter(&self) -> impl ExactSizeIterator<Item = (u64, &T)> {
293 self.timers.iter().map(|(_, x)| (x.expiry, &x.value))
294 }
295
296 pub fn iter_mut(&mut self) -> impl ExactSizeIterator<Item = (u64, &mut T)> {
298 self.timers
299 .iter_mut()
300 .map(|(_, x)| (x.expiry, &mut x.value))
301 }
302
303 pub fn get(&self, timer: Timer) -> &T {
305 &self.timers[timer.0].value
306 }
307
308 pub fn get_mut(&mut self, timer: Timer) -> &mut T {
310 &mut self.timers[timer.0].value
311 }
312
313 pub fn len(&self) -> usize {
315 self.timers.len()
316 }
317
318 pub fn is_empty(&self) -> bool {
320 self.timers.is_empty()
321 }
322
323 fn unlink(&mut self, timer: Timer) {
325 let (level, slot) = timer_index(self.next_tick, self.timers[timer.0].expiry);
326 let slot_head = self.levels[level].get(slot).unwrap();
329 if slot_head == timer {
330 self.levels[level].set(slot, self.timers[slot_head.0].next);
331 debug_assert_eq!(
332 self.timers[timer.0].prev, None,
333 "head of list has no predecessor"
334 );
335 }
336 self.list_unlink(timer);
338 }
339
340 fn list_unlink(&mut self, timer: Timer) {
342 let prev = self.timers[timer.0].prev.take();
343 let next = self.timers[timer.0].next.take();
344 if let Some(prev) = prev {
345 self.timers[prev.0].next = next;
347 }
348 if let Some(next) = next {
349 self.timers[next.0].prev = prev;
351 }
352 }
353}
354
355fn slot_start(base: u64, level: usize, slot: usize) -> u64 {
357 let shift = (level * LOG_2_SLOTS) as u64;
358 (base & ((!0 << shift) << LOG_2_SLOTS as u64)) | ((slot as u64) << shift)
360}
361
362fn timer_index(base: u64, expiry: u64) -> (usize, usize) {
364 let differing_bits = base ^ expiry;
367 let level = (63 - (differing_bits | 1).leading_zeros()) as usize / LOG_2_SLOTS;
368 debug_assert!(level < LEVELS, "every possible expiry is in range");
369
370 let slot_base = (base >> (level * LOG_2_SLOTS)) & (!0 << LOG_2_SLOTS);
374 let slot = (expiry >> (level * LOG_2_SLOTS)) - slot_base;
375 debug_assert!(slot < SLOTS as u64);
376
377 (level, slot as usize)
378}
379
380impl<T> Default for TimerQueue<T> {
381 fn default() -> Self {
382 Self::new()
383 }
384}
385
386#[derive(Debug, Clone)]
387struct TimerState<T> {
388 expiry: u64,
390 value: T,
392 prev: Option<Timer>,
394 next: Option<Timer>,
396}
397
398#[derive(Copy, Clone)]
403struct Level {
404 slots: [Timer; SLOTS],
405 occupied: u64,
407}
408
409impl Level {
410 const fn new() -> Self {
411 Self {
412 slots: [Timer(usize::MAX); SLOTS],
413 occupied: 0,
414 }
415 }
416
417 fn first_index(&self) -> Option<usize> {
418 let x = self.occupied.trailing_zeros() as usize;
419 if x == self.slots.len() {
420 return None;
421 }
422 Some(x)
423 }
424
425 fn get(&self, slot: usize) -> Option<Timer> {
426 if self.occupied & (1 << slot) == 0 {
427 return None;
428 }
429 Some(self.slots[slot])
430 }
431
432 fn take(&mut self, slot: usize) -> Option<Timer> {
433 let x = self.get(slot)?;
434 self.set(slot, None);
435 Some(x)
436 }
437
438 fn set(&mut self, slot: usize, timer: Option<Timer>) {
439 match timer {
440 None => {
441 self.slots[slot] = Timer(usize::MAX);
442 self.occupied &= !(1 << slot);
443 }
444 Some(x) => {
445 self.slots[slot] = x;
446 self.occupied |= 1 << slot;
447 }
448 }
449 }
450
451 fn is_empty(&self) -> bool {
452 self.occupied == 0
453 }
454}
455
456impl fmt::Debug for Level {
457 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
458 let mut m = f.debug_map();
459 let numbered_nonempty_slots = self
460 .slots
461 .iter()
462 .enumerate()
463 .filter(|(i, _)| self.occupied & (1 << i) != 0);
464 for (i, Timer(t)) in numbered_nonempty_slots {
465 m.entry(&i, &t);
466 }
467 m.finish()
468 }
469}
470
471const LOG_2_SLOTS: usize = 6;
472const LEVELS: usize = 1 + 64 / LOG_2_SLOTS;
473const SLOTS: usize = 1 << LOG_2_SLOTS;
474
475#[derive(Debug, Copy, Clone, Eq, PartialEq)]
478pub struct Timer(usize);
479
480#[cfg(test)]
481mod tests {
482 extern crate std;
483
484 use std::{vec::Vec, collections::HashMap};
485
486 use super::*;
487 use proptest::prelude::*;
488
489 #[test]
490 fn max_timeout() {
491 let mut queue = TimerQueue::new();
492 queue.insert(u64::MAX, ());
493 assert!(queue.poll(u64::MAX - 1).is_none());
494 assert!(queue.poll(u64::MAX).is_some());
495 }
496
497 #[test]
498 fn slot_starts() {
499 for i in 0..SLOTS {
500 assert_eq!(slot_start(0, 0, i), i as u64);
501 assert_eq!(slot_start(SLOTS as u64, 0, i), SLOTS as u64 + i as u64);
502 assert_eq!(slot_start(SLOTS as u64 + 1, 0, i), SLOTS as u64 + i as u64);
503 for j in 1..LEVELS {
504 assert_eq!(
505 slot_start(0, j, i),
506 (SLOTS as u64).pow(j as u32).wrapping_mul(i as u64)
507 );
508 }
509 }
510 }
511
512 #[test]
513 fn indexes() {
514 assert_eq!(timer_index(0, 0), (0, 0));
515 assert_eq!(timer_index(0, SLOTS as u64 - 1), (0, SLOTS - 1));
516 assert_eq!(
517 timer_index(SLOTS as u64 - 1, SLOTS as u64 - 1),
518 (0, SLOTS - 1)
519 );
520 assert_eq!(timer_index(0, SLOTS as u64), (1, 1));
521 for i in 0..LEVELS {
522 assert_eq!(timer_index(0, (SLOTS as u64).pow(i as u32)), (i, 1));
523 if i < LEVELS - 1 {
524 assert_eq!(
525 timer_index(0, (SLOTS as u64).pow(i as u32 + 1) - 1),
526 (i, SLOTS - 1)
527 );
528 assert_eq!(
529 timer_index(SLOTS as u64 - 1, (SLOTS as u64).pow(i as u32 + 1) - 1),
530 (i, SLOTS - 1)
531 );
532 }
533 }
534 }
535
536 #[test]
537 fn next_timeout() {
538 let mut queue = TimerQueue::new();
539 assert_eq!(queue.next_timeout(), None);
540 let k = queue.insert(0, ());
541 assert_eq!(queue.next_timeout(), Some(0));
542 queue.remove(k);
543 assert_eq!(queue.next_timeout(), None);
544 queue.insert(1234, ());
545 assert!(queue.next_timeout().unwrap() > 12);
546 queue.insert(12, ());
547 assert_eq!(queue.next_timeout(), Some(12));
548 }
549
550 #[test]
551 fn poll_boundary() {
552 let mut queue = TimerQueue::new();
553 queue.insert(SLOTS as u64 - 1, 'a');
554 queue.insert(SLOTS as u64, 'b');
555 assert_eq!(queue.poll(SLOTS as u64 - 2), None);
556 assert_eq!(queue.poll(SLOTS as u64 - 1), Some('a'));
557 assert_eq!(queue.poll(SLOTS as u64 - 1), None);
558 assert_eq!(queue.poll(SLOTS as u64), Some('b'));
559 }
560
561 #[test]
562 fn reset_list_middle() {
564 let mut queue = TimerQueue::new();
565 let slot = SLOTS as u64 / 2;
566 let a = queue.insert(slot, ());
567 let b = queue.insert(slot, ());
568 let c = queue.insert(slot, ());
569
570 queue.reset(b, slot + 1);
571
572 assert_eq!(queue.levels[0].get(slot as usize + 1), Some(b));
573 assert_eq!(queue.timers[b.0].prev, None);
574 assert_eq!(queue.timers[b.0].next, None);
575
576 assert_eq!(queue.levels[0].get(slot as usize), Some(c));
577 assert_eq!(queue.timers[c.0].prev, None);
578 assert_eq!(queue.timers[c.0].next, Some(a));
579 assert_eq!(queue.timers[a.0].prev, Some(c));
580 assert_eq!(queue.timers[a.0].next, None);
581 }
582
583 proptest! {
584 #[test]
585 fn poll(ts in times()) {
586 let mut queue = TimerQueue::new();
587 let mut time_values = HashMap::<u64, Vec<usize>>::new();
588 for (i, t) in ts.into_iter().enumerate() {
589 queue.insert(t, i);
590 time_values.entry(t).or_default().push(i);
591 }
592 let mut time_values = time_values.into_iter().collect::<Vec<(u64, Vec<usize>)>>();
593 time_values.sort_unstable_by_key(|&(t, _)| t);
594 for &(t, ref is) in &time_values {
595 assert!(queue.next_timeout().unwrap() <= t);
596 if t > 0 {
597 assert_eq!(queue.poll(t-1), None);
598 }
599 let mut values = Vec::new();
600 while let Some(i) = queue.poll(t) {
601 values.push(i);
602 }
603 assert_eq!(values.len(), is.len());
604 for i in is {
605 assert!(values.contains(i));
606 }
607 }
608 }
609
610 #[test]
611 fn reset(ts_a in times(), ts_b in times()) {
612 let mut queue = TimerQueue::new();
613 let timers = ts_a.map(|t| queue.insert(t, ()));
614 for (timer, t) in timers.into_iter().zip(ts_b) {
615 queue.reset(timer, t);
616 }
617 let mut n = 0;
618 while let Some(()) = queue.poll(u64::MAX) {
619 n += 1;
620 }
621 assert_eq!(n, timers.len());
622 }
623
624 #[test]
625 fn index_start_consistency(a in time(), b in time()) {
626 let base = a.min(b);
627 let t = a.max(b);
628 let (level, slot) = timer_index(base, t);
629 let start = slot_start(base, level, slot);
630 assert!(start <= t);
631 if let Some(end) = start.checked_add((SLOTS as u64).pow(level as u32)) {
632 assert!(end > t);
633 } else {
634 assert!(start >= slot_start(0, LEVELS - 1, 15));
636 if level == LEVELS - 1 {
637 assert_eq!(slot, 15);
638 } else {
639 assert_eq!(slot, SLOTS - 1);
640 }
641 }
642 }
643 }
644
645 fn time() -> impl Strategy<Value = u64> {
647 ((0..LEVELS as u32), (0..SLOTS as u64)).prop_perturb(|(level, mut slot), mut rng| {
648 if level == LEVELS as u32 - 1 {
649 slot %= 16;
650 }
651 let slot_size = (SLOTS as u64).pow(level);
652 let slot_start = slot * slot_size;
653 let slot_end = (slot + 1).saturating_mul(slot_size);
654 rng.gen_range(slot_start..slot_end)
655 })
656 }
657
658 #[rustfmt::skip]
659 fn times() -> impl Strategy<Value = [u64; 16]> {
660 [time(), time(), time(), time(), time(), time(), time(), time(),
661 time(), time(), time(), time(), time(), time(), time(), time()]
662 }
663}