Skip to main content

sozu_lib/
timer.rs

1//! Timer based on timing wheels
2//!
3//! code imported from mio-extras
4//! License: MIT or Apache 2.0
5use std::{
6    cmp,
7    fmt::Display,
8    time::{Duration, Instant},
9};
10
11use mio::Token;
12use slab::Slab;
13
14use crate::server::TIMER;
15
16// Conversion utilities
17mod convert {
18    use std::time::Duration;
19
20    /// Convert a `Duration` to milliseconds, rounding up and saturating at
21    /// `u64::MAX`.
22    ///
23    /// The saturating is fine because `u64::MAX` milliseconds are still many
24    /// million years.
25    pub fn millis(duration: Duration) -> u64 {
26        u64::try_from(duration.as_millis()).unwrap_or(u64::MAX)
27    }
28}
29
30/// A timer.
31///
32/// Typical usage goes like this:
33///
34/// * register the timer with a `mio::Poll`.
35/// * set a timeout, by calling `Timer::set_timeout`.  Here you provide some
36///   state to be associated with this timeout.
37/// * poll the `Poll`, to learn when a timeout has occurred.
38/// * retrieve state associated with the timeout by calling `Timer::poll`.
39///
40/// You can omit use of the `Poll` altogether, if you like, and just poll the
41/// `Timer` directly.
42pub struct Timer<T> {
43    // Size of each tick in milliseconds
44    tick_ms: u64,
45    // Slab of timeout entries
46    entries: Slab<Entry<T>>,
47    // Timeout wheel. Each tick, the timer will look at the next slot for
48    // timeouts that match the current tick.
49    wheel: Vec<WheelEntry>,
50    // Tick 0's time instant
51    start: Instant,
52    // The current tick
53    tick: Tick,
54    // The next entry to possibly timeout
55    next: Token,
56    // Masks the target tick to get the slot
57    mask: u64,
58}
59
60/// Used to create a `Timer`.
61pub struct Builder {
62    // Approximate duration of each tick
63    tick: Duration,
64    // Number of slots in the timer wheel
65    num_slots: usize,
66    // Max number of timeouts that can be in flight at a given time.
67    capacity: usize,
68}
69
70/// A timeout, as returned by `Timer::set_timeout`.
71///
72/// Use this as the argument to `Timer::cancel_timeout`, to cancel this timeout.
73#[derive(Clone, Debug)]
74pub struct Timeout {
75    // Reference into the timer entry slab
76    token: Token,
77    // Tick that it should match up with
78    tick: u64,
79}
80
81#[derive(Clone, Debug)]
82pub struct TimeoutContainer {
83    // mark it as an option, so we do not try to cancel a timeout multiple times
84    timeout: Option<Timeout>,
85    duration: Duration,
86    token: Option<Token>,
87}
88
89impl TimeoutContainer {
90    pub fn new(duration: Duration, token: Token) -> TimeoutContainer {
91        let timeout = TIMER.with(|timer| timer.borrow_mut().set_timeout(duration, token));
92        TimeoutContainer {
93            timeout: Some(timeout),
94            duration,
95            token: Some(token),
96        }
97    }
98
99    pub fn new_empty(duration: Duration) -> TimeoutContainer {
100        TimeoutContainer {
101            timeout: None,
102            duration,
103            token: None,
104        }
105    }
106
107    pub fn take(&mut self) -> TimeoutContainer {
108        TimeoutContainer {
109            timeout: self.timeout.take(),
110            duration: self.duration,
111            token: self.token.take(),
112        }
113    }
114
115    /// must be called when a timeout was triggered, to prevent errors when canceling
116    pub fn triggered(&mut self) {
117        let _ = self.timeout.take();
118    }
119
120    pub fn set(&mut self, token: Token) {
121        if let Some(timeout) = self.timeout.take() {
122            TIMER.with(|timer| timer.borrow_mut().cancel_timeout(&timeout));
123        }
124
125        let timeout = TIMER.with(|timer| timer.borrow_mut().set_timeout(self.duration, token));
126
127        self.timeout = Some(timeout);
128        self.token = Some(token);
129    }
130
131    /// warning: this does not reset the timer
132    pub fn set_duration(&mut self, duration: Duration) {
133        self.duration = duration;
134
135        if let Some(timeout) = self.timeout.take() {
136            TIMER.with(|timer| timer.borrow_mut().cancel_timeout(&timeout));
137        }
138
139        if let Some(token) = self.token {
140            self.timeout =
141                Some(TIMER.with(|timer| timer.borrow_mut().set_timeout(self.duration, token)));
142        }
143    }
144
145    pub fn duration(&self) -> Duration {
146        self.duration
147    }
148
149    pub fn cancel(&mut self) -> bool {
150        match self.timeout.take() {
151            None => {
152                //error!("cannot cancel non existing timeout");
153                //error!("self.duration was {:?}", self.duration);
154                false
155            }
156            Some(timeout) => {
157                TIMER.with(|timer| timer.borrow_mut().cancel_timeout(&timeout));
158                true
159            }
160        }
161    }
162
163    // Reset the timeout to its optional timeout, or to its defined duration
164    pub fn reset(&mut self) -> bool {
165        match self.timeout.take() {
166            None => {
167                if let Some(token) = self.token {
168                    self.timeout = Some(
169                        TIMER.with(|timer| timer.borrow_mut().set_timeout(self.duration, token)),
170                    );
171                } else {
172                    //error!("cannot reset non existing timeout");
173                    return false;
174                }
175            }
176            Some(timeout) => {
177                self.timeout =
178                    TIMER.with(|timer| timer.borrow_mut().reset_timeout(&timeout, self.duration));
179            }
180        };
181        self.timeout.is_some()
182    }
183}
184
185impl Display for TimeoutContainer {
186    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187        write!(f, "{:?}", self.duration)
188    }
189}
190
191impl std::ops::Drop for TimeoutContainer {
192    fn drop(&mut self) {
193        if self.cancel() {
194            debug!(
195                "Cancel a dangling timeout that haven't be handled in session lifecycle, token ({:?}), duration {}",
196                self.token, self
197            );
198        }
199    }
200}
201
202#[derive(Copy, Clone, Debug)]
203struct WheelEntry {
204    next_tick: Tick,
205    head: Token,
206}
207
208// Doubly linked list of timer entries. Allows for efficient insertion /
209// removal of timeouts.
210struct Entry<T> {
211    state: T,
212    links: EntryLinks,
213}
214
215#[derive(Copy, Clone)]
216struct EntryLinks {
217    tick: Tick,
218    prev: Token,
219    next: Token,
220}
221
222type Tick = u64;
223const TICK_MAX: Tick = u64::MAX;
224const EMPTY: Token = Token(usize::MAX);
225
226impl Builder {
227    /// Set the tick duration.  Default is 100ms.
228    pub fn tick_duration(mut self, duration: Duration) -> Builder {
229        self.tick = duration;
230        self
231    }
232
233    /// Set the number of slots.  Default is 256.
234    pub fn num_slots(mut self, num_slots: usize) -> Builder {
235        self.num_slots = num_slots;
236        self
237    }
238
239    /// Set the capacity.  Default is 65536.
240    pub fn capacity(mut self, capacity: usize) -> Builder {
241        self.capacity = capacity;
242        self
243    }
244
245    /// Build a `Timer` with the parameters set on this `Builder`.
246    pub fn build<T>(self) -> Timer<T> {
247        Timer::new(
248            convert::millis(self.tick),
249            self.num_slots,
250            self.capacity,
251            Instant::now(),
252        )
253    }
254}
255
256impl Default for Builder {
257    fn default() -> Builder {
258        Builder {
259            tick: Duration::from_millis(100),
260            num_slots: 1 << 8,
261            capacity: 1 << 16,
262        }
263    }
264}
265
266impl<T> Timer<T> {
267    fn new(tick_ms: u64, num_slots: usize, capacity: usize, start: Instant) -> Timer<T> {
268        let num_slots = num_slots.next_power_of_two();
269        let capacity = capacity.next_power_of_two();
270        let mask = (num_slots as u64) - 1;
271        let wheel = std::iter::repeat_n(
272            WheelEntry {
273                next_tick: TICK_MAX,
274                head: EMPTY,
275            },
276            num_slots,
277        )
278        .collect();
279
280        Timer {
281            tick_ms,
282            entries: Slab::with_capacity(capacity),
283            wheel,
284            start,
285            tick: 0,
286            next: EMPTY,
287            mask,
288        }
289    }
290
291    /// Set a timeout.
292    ///
293    /// When the timeout occurs, the given state becomes available via `poll`.
294    pub fn set_timeout(&mut self, delay_from_now: Duration, state: T) -> Timeout {
295        let delay_from_start = self.start.elapsed() + delay_from_now;
296        self.set_timeout_at(delay_from_start, state)
297    }
298
299    fn set_timeout_at(&mut self, delay_from_start: Duration, state: T) -> Timeout {
300        let mut tick = duration_to_tick(delay_from_start, self.tick_ms);
301        trace!(
302            "setting timeout; delay={:?}; tick={:?}; current-tick={:?}",
303            delay_from_start, tick, self.tick
304        );
305
306        // Always target at least 1 tick in the future
307        if tick <= self.tick {
308            tick = self.tick + 1;
309        }
310
311        self.insert(tick, state)
312    }
313
314    fn insert(&mut self, tick: Tick, state: T) -> Timeout {
315        // Get the slot for the requested tick
316        let slot = (tick & self.mask) as usize;
317        let curr = self.wheel[slot];
318
319        // Insert the new entry
320        let entry = Entry::new(state, tick, curr.head);
321        let token = Token(self.entries.insert(entry));
322
323        if curr.head != EMPTY {
324            // If there was a previous entry, set its prev pointer to the new
325            // entry
326            self.entries[curr.head.into()].links.prev = token;
327        }
328
329        // Update the head slot
330        self.wheel[slot] = WheelEntry {
331            next_tick: cmp::min(tick, curr.next_tick),
332            head: token,
333        };
334
335        trace!("inserted timeout; slot={}; token={:?}", slot, token);
336
337        // Return the new timeout
338        Timeout { token, tick }
339    }
340
341    /// Resets a timeout.
342    ///
343    pub fn reset_timeout(
344        &mut self,
345        timeout: &Timeout,
346        delay_from_now: Duration,
347    ) -> Option<Timeout> {
348        self.cancel_timeout(timeout)
349            .map(|state| self.set_timeout(delay_from_now, state))
350    }
351
352    // TODO: return Result with context
353    /// Cancel a timeout.
354    ///
355    /// If the timeout has not yet occurred, the return value holds the
356    /// associated state.
357    pub fn cancel_timeout(&mut self, timeout: &Timeout) -> Option<T> {
358        let links = match self.entries.get(timeout.token.into()) {
359            Some(e) => e.links,
360            None => {
361                debug!("timeout token {:?} not found", timeout.token);
362                return None;
363            }
364        };
365
366        // Sanity check
367        if links.tick != timeout.tick {
368            return None;
369        }
370
371        self.unlink(&links, timeout.token);
372        Some(self.entries.remove(timeout.token.into()).state)
373    }
374
375    /// Poll for an expired timer.
376    ///
377    /// The return value holds the state associated with the first expired
378    /// timer, if any.
379    pub fn poll(&mut self) -> Option<T> {
380        let target_tick = current_tick(self.start, self.tick_ms);
381        self.poll_to(target_tick)
382    }
383
384    fn poll_to(&mut self, mut target_tick: Tick) -> Option<T> {
385        trace!(
386            "tick_to; target_tick={}; current_tick={}",
387            target_tick, self.tick
388        );
389
390        if target_tick < self.tick {
391            target_tick = self.tick;
392        }
393
394        while self.tick <= target_tick {
395            let curr = self.next;
396
397            //info!("ticking; curr={:?}", curr);
398
399            if curr == EMPTY {
400                self.tick += 1;
401
402                let slot = self.slot_for(self.tick);
403                self.next = self.wheel[slot].head;
404
405                // Handle the case when a slot has a single timeout which gets
406                // canceled before the timeout expires. In this case, the
407                // slot's head is EMPTY but there is a value for next_tick. Not
408                // resetting next_tick here causes the timer to get stuck in a
409                // loop.
410                if self.next == EMPTY {
411                    self.wheel[slot].next_tick = TICK_MAX;
412                }
413            } else {
414                let slot = self.slot_for(self.tick);
415
416                if curr == self.wheel[slot].head {
417                    self.wheel[slot].next_tick = TICK_MAX;
418                }
419
420                let links = self.entries[curr.into()].links;
421
422                if links.tick <= self.tick {
423                    trace!("triggering; token={:?}", curr);
424
425                    // Unlink will also advance self.next
426                    self.unlink(&links, curr);
427
428                    // Remove and return the token
429                    return Some(self.entries.remove(curr.into()).state);
430                } else {
431                    let next_tick = self.wheel[slot].next_tick;
432                    self.wheel[slot].next_tick = cmp::min(next_tick, links.tick);
433                    self.next = links.next;
434                }
435            }
436        }
437
438        None
439    }
440
441    fn unlink(&mut self, links: &EntryLinks, token: Token) {
442        trace!(
443            "unlinking timeout; slot={}; token={:?}",
444            self.slot_for(links.tick),
445            token
446        );
447
448        if links.prev == EMPTY {
449            let slot = self.slot_for(links.tick);
450            self.wheel[slot].head = links.next;
451        } else {
452            self.entries[links.prev.into()].links.next = links.next;
453        }
454
455        if links.next != EMPTY {
456            self.entries[links.next.into()].links.prev = links.prev;
457
458            if token == self.next {
459                self.next = links.next;
460            }
461        } else if token == self.next {
462            self.next = EMPTY;
463        }
464    }
465
466    // Next tick containing a timeout
467    fn next_tick(&self) -> Option<Tick> {
468        if self.next != EMPTY {
469            let slot = self.slot_for(self.entries[self.next.into()].links.tick);
470
471            if self.wheel[slot].next_tick == self.tick {
472                // There is data ready right now
473                return Some(self.tick);
474            }
475        }
476
477        self.wheel.iter().map(|e| e.next_tick).min()
478    }
479
480    pub fn next_poll_date(&self) -> Option<Instant> {
481        self.next_tick()
482            .map(|tick| self.start + Duration::from_millis(self.tick_ms.saturating_mul(tick)))
483    }
484
485    fn slot_for(&self, tick: Tick) -> usize {
486        (self.mask & tick) as usize
487    }
488}
489
490impl<T> Default for Timer<T> {
491    fn default() -> Timer<T> {
492        Builder::default().build()
493    }
494}
495
496fn duration_to_tick(elapsed: Duration, tick_ms: u64) -> Tick {
497    // Calculate tick rounding up to the closest one
498    let elapsed_ms = convert::millis(elapsed);
499    elapsed_ms.saturating_add(tick_ms / 2) / tick_ms
500}
501
502fn current_tick(start: Instant, tick_ms: u64) -> Tick {
503    duration_to_tick(start.elapsed(), tick_ms)
504}
505
506impl<T> Entry<T> {
507    fn new(state: T, tick: u64, next: Token) -> Entry<T> {
508        Entry {
509            state,
510            links: EntryLinks {
511                tick,
512                prev: EMPTY,
513                next,
514            },
515        }
516    }
517}
518
519#[cfg(test)]
520mod test {
521    use std::time::{Duration, Instant};
522
523    use super::*;
524
525    #[test]
526    pub fn test_timeout_next_tick() {
527        let mut t = timer();
528
529        t.set_timeout_at(Duration::from_millis(100), "a");
530
531        let mut tick = ms_to_tick(&t, 50);
532        assert_eq!(None, t.poll_to(tick));
533
534        tick = ms_to_tick(&t, 100);
535        assert_eq!(Some("a"), t.poll_to(tick));
536        assert_eq!(None, t.poll_to(tick));
537
538        tick = ms_to_tick(&t, 150);
539        assert_eq!(None, t.poll_to(tick));
540
541        tick = ms_to_tick(&t, 200);
542        assert_eq!(None, t.poll_to(tick));
543
544        assert_eq!(count(&t), 0);
545    }
546
547    #[test]
548    pub fn test_clearing_timeout() {
549        let mut t = timer();
550
551        let to = t.set_timeout_at(Duration::from_millis(100), "a");
552        assert_eq!("a", t.cancel_timeout(&to).unwrap());
553
554        let mut tick = ms_to_tick(&t, 100);
555        assert_eq!(None, t.poll_to(tick));
556
557        tick = ms_to_tick(&t, 200);
558        assert_eq!(None, t.poll_to(tick));
559
560        assert_eq!(count(&t), 0);
561    }
562
563    #[test]
564    pub fn test_multiple_timeouts_same_tick() {
565        let mut t = timer();
566
567        t.set_timeout_at(Duration::from_millis(100), "a");
568        t.set_timeout_at(Duration::from_millis(100), "b");
569
570        let mut rcv = vec![];
571
572        let mut tick = ms_to_tick(&t, 100);
573        rcv.push(t.poll_to(tick).unwrap());
574        rcv.push(t.poll_to(tick).unwrap());
575
576        assert_eq!(None, t.poll_to(tick));
577
578        rcv.sort_unstable();
579        assert!(rcv == ["a", "b"], "actual={rcv:?}");
580
581        tick = ms_to_tick(&t, 200);
582        assert_eq!(None, t.poll_to(tick));
583
584        assert_eq!(count(&t), 0);
585    }
586
587    #[test]
588    pub fn test_multiple_timeouts_diff_tick() {
589        let mut t = timer();
590
591        t.set_timeout_at(Duration::from_millis(110), "a");
592        t.set_timeout_at(Duration::from_millis(220), "b");
593        t.set_timeout_at(Duration::from_millis(230), "c");
594        t.set_timeout_at(Duration::from_millis(440), "d");
595        t.set_timeout_at(Duration::from_millis(560), "e");
596
597        let mut tick = ms_to_tick(&t, 100);
598        assert_eq!(Some("a"), t.poll_to(tick));
599        assert_eq!(None, t.poll_to(tick));
600
601        tick = ms_to_tick(&t, 200);
602        assert_eq!(Some("c"), t.poll_to(tick));
603        assert_eq!(Some("b"), t.poll_to(tick));
604        assert_eq!(None, t.poll_to(tick));
605
606        tick = ms_to_tick(&t, 300);
607        assert_eq!(None, t.poll_to(tick));
608
609        tick = ms_to_tick(&t, 400);
610        assert_eq!(Some("d"), t.poll_to(tick));
611        assert_eq!(None, t.poll_to(tick));
612
613        tick = ms_to_tick(&t, 500);
614        assert_eq!(None, t.poll_to(tick));
615
616        tick = ms_to_tick(&t, 600);
617        assert_eq!(Some("e"), t.poll_to(tick));
618        assert_eq!(None, t.poll_to(tick));
619    }
620
621    #[test]
622    pub fn test_catching_up() {
623        let mut t = timer();
624
625        t.set_timeout_at(Duration::from_millis(110), "a");
626        t.set_timeout_at(Duration::from_millis(220), "b");
627        t.set_timeout_at(Duration::from_millis(230), "c");
628        t.set_timeout_at(Duration::from_millis(440), "d");
629
630        let tick = ms_to_tick(&t, 600);
631        assert_eq!(Some("a"), t.poll_to(tick));
632        assert_eq!(Some("c"), t.poll_to(tick));
633        assert_eq!(Some("b"), t.poll_to(tick));
634        assert_eq!(Some("d"), t.poll_to(tick));
635        assert_eq!(None, t.poll_to(tick));
636    }
637
638    #[test]
639    pub fn test_timeout_hash_collision() {
640        let mut t = timer();
641
642        t.set_timeout_at(Duration::from_millis(100), "a");
643        t.set_timeout_at(Duration::from_millis(100 + TICK * SLOTS as u64), "b");
644
645        let mut tick = ms_to_tick(&t, 100);
646        assert_eq!(Some("a"), t.poll_to(tick));
647        assert_eq!(1, count(&t));
648
649        tick = ms_to_tick(&t, 200);
650        assert_eq!(None, t.poll_to(tick));
651        assert_eq!(1, count(&t));
652
653        tick = ms_to_tick(&t, 100 + TICK * SLOTS as u64);
654        assert_eq!(Some("b"), t.poll_to(tick));
655        assert_eq!(0, count(&t));
656    }
657
658    #[test]
659    pub fn test_clearing_timeout_between_triggers() {
660        let mut t = timer();
661
662        let a = t.set_timeout_at(Duration::from_millis(100), "a");
663        let _ = t.set_timeout_at(Duration::from_millis(100), "b");
664        let _ = t.set_timeout_at(Duration::from_millis(200), "c");
665
666        let mut tick = ms_to_tick(&t, 100);
667        assert_eq!(Some("b"), t.poll_to(tick));
668        assert_eq!(2, count(&t));
669
670        t.cancel_timeout(&a);
671        assert_eq!(1, count(&t));
672
673        assert_eq!(None, t.poll_to(tick));
674
675        tick = ms_to_tick(&t, 200);
676        assert_eq!(Some("c"), t.poll_to(tick));
677        assert_eq!(0, count(&t));
678    }
679
680    const TICK: u64 = 100;
681    const SLOTS: usize = 16;
682    const CAPACITY: usize = 32;
683
684    fn count<T>(timer: &Timer<T>) -> usize {
685        timer.entries.len()
686    }
687
688    fn timer() -> Timer<&'static str> {
689        Timer::new(TICK, SLOTS, CAPACITY, Instant::now())
690    }
691
692    fn ms_to_tick<T>(timer: &Timer<T>, ms: u64) -> u64 {
693        ms / timer.tick_ms
694    }
695}