s2n_quic_core/time/
timer.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::time::timestamp::Timestamp;
5use core::task::Poll;
6
7/// A timer that does not trigger an update in a timer
8/// list. These are usually owned by individual components
9/// and needs to be explicitly polled.
10///
11/// Note: The timer doesn't implement Copy to ensure it isn't accidentally moved
12///       and have the expiration discarded.
13#[derive(Clone, Debug, Default, PartialEq, Eq)]
14pub struct Timer {
15    expiration: Option<Timestamp>,
16}
17
18impl Timer {
19    /// Sets the timer to expire at the given timestamp
20    #[inline]
21    pub fn set(&mut self, time: Timestamp) {
22        self.expiration = Some(time);
23    }
24
25    /// Cancels the timer.
26    /// After cancellation, a timer will no longer report as expired.
27    #[inline]
28    pub fn cancel(&mut self) {
29        self.expiration = None;
30    }
31
32    /// Returns true if the timer has expired
33    #[inline]
34    pub fn is_expired(&self, current_time: Timestamp) -> bool {
35        match self.expiration {
36            Some(timeout) => timeout.has_elapsed(current_time),
37            _ => false,
38        }
39    }
40
41    /// Returns true if the timer is armed
42    #[inline]
43    pub fn is_armed(&self) -> bool {
44        self.expiration.is_some()
45    }
46
47    /// Notifies the timer of the current time.
48    /// If the timer's expiration occurs before the current time, it will be cancelled.
49    /// The method returns whether the timer was expired and had been
50    /// cancelled.
51    #[inline]
52    pub fn poll_expiration(&mut self, current_time: Timestamp) -> Poll<()> {
53        if self.is_expired(current_time) {
54            self.cancel();
55            Poll::Ready(())
56        } else {
57            Poll::Pending
58        }
59    }
60}
61
62impl From<Option<Timestamp>> for Timer {
63    #[inline]
64    fn from(expiration: Option<Timestamp>) -> Self {
65        Self { expiration }
66    }
67}
68
69/// Returned when a `Query` wants to end a timer query
70#[derive(Clone, Copy, Debug, Default)]
71pub struct QueryBreak;
72
73/// The return type of a `timers` call
74pub type Result<T = (), E = QueryBreak> = core::result::Result<T, E>;
75
76/// A trait for a components that owns at least one timer
77pub trait Provider {
78    /// Notifies the query of any timers owned by the provider
79    ///
80    /// The provider should also delegate to subcomponents that own timers as well.
81    fn timers<Q: Query>(&self, query: &mut Q) -> Result;
82
83    /// Returns the next `Timestamp` at which the earliest timer is armed in the provider
84    #[inline]
85    fn next_expiration(&self) -> Option<Timestamp> {
86        let mut timeout: Option<Timestamp> = None;
87        let _ = self.timers(&mut timeout);
88        timeout
89    }
90
91    /// Returns `true` if there are any timers armed
92    #[inline]
93    fn is_armed(&self) -> bool {
94        let mut is_armed = IsArmed::default();
95        let _ = self.timers(&mut is_armed);
96        is_armed.0
97    }
98
99    /// Counts the number of armed timers in the provider
100    #[inline]
101    fn armed_timer_count(&self) -> usize {
102        let mut count = ArmedCount::default();
103        let _ = self.timers(&mut count);
104        count.0
105    }
106
107    /// Iterates over each timer in the provider and calls the provided function
108    #[inline]
109    fn for_each_timer<F: FnMut(&Timer) -> Result>(&self, f: F) {
110        let mut for_each = ForEach(f);
111        let _ = self.timers(&mut for_each);
112    }
113}
114
115impl Provider for Timer {
116    #[inline]
117    fn timers<Q: Query>(&self, query: &mut Q) -> Result {
118        query.on_timer(self)
119    }
120}
121
122impl<T: Provider> Provider for &T {
123    #[inline]
124    fn timers<Q: Query>(&self, query: &mut Q) -> Result {
125        (**self).timers(query)
126    }
127}
128
129impl<T: Provider> Provider for &mut T {
130    #[inline]
131    fn timers<Q: Query>(&self, query: &mut Q) -> Result {
132        (**self).timers(query)
133    }
134}
135
136/// Implement Provider for a 2-element tuple to make it easy to do joins
137impl<A: Provider, B: Provider> Provider for (A, B) {
138    #[inline]
139    fn timers<Q: Query>(&self, query: &mut Q) -> Result {
140        self.0.timers(query)?;
141        self.1.timers(query)?;
142        Ok(())
143    }
144}
145
146impl<T: Provider> Provider for Option<T> {
147    #[inline]
148    fn timers<Q: Query>(&self, query: &mut Q) -> Result {
149        if let Some(t) = self.as_ref() {
150            t.timers(query)?;
151        }
152        Ok(())
153    }
154}
155
156/// A query to be executed against a provider
157pub trait Query {
158    /// Called for each timer owned by the provider
159    fn on_timer(&mut self, timer: &Timer) -> Result;
160}
161
162/// Implement Query for `Option<Timestamp>` to make it easy to get the earliest armed timestamp
163impl Query for Option<Timestamp> {
164    #[inline]
165    fn on_timer(&mut self, timer: &Timer) -> Result {
166        match (self, timer.expiration) {
167            // Take the minimum of the two timers
168            (Some(a), Some(b)) => *a = (*a).min(b),
169            // We don't have a time yet so just assign the expiration of the other
170            (a @ None, b) => *a = b,
171            // do nothing for everything else
172            _ => {}
173        }
174        Ok(())
175    }
176}
177
178/// Counts all of the armed timers
179#[derive(Debug, Default)]
180pub struct ArmedCount(pub usize);
181
182impl Query for ArmedCount {
183    #[inline]
184    fn on_timer(&mut self, timer: &Timer) -> Result {
185        if timer.is_armed() {
186            self.0 += 1;
187        }
188        Ok(())
189    }
190}
191
192/// Returns `true` if any of the timers are armed
193#[derive(Debug, Default)]
194pub struct IsArmed(pub bool);
195
196impl Query for IsArmed {
197    #[inline]
198    fn on_timer(&mut self, timer: &Timer) -> Result {
199        if timer.is_armed() {
200            self.0 = true;
201            return Err(QueryBreak);
202        }
203        Ok(())
204    }
205}
206
207/// Iterates over each timer in the provider and calls a function
208#[derive(Debug, Default)]
209pub struct ForEach<F: FnMut(&Timer) -> Result>(F);
210
211impl<F: FnMut(&Timer) -> Result> Query for ForEach<F> {
212    #[inline]
213    fn on_timer(&mut self, timer: &Timer) -> Result {
214        (self.0)(timer)
215    }
216}
217
218#[cfg(feature = "tracing")]
219pub struct Debugger;
220
221#[cfg(feature = "tracing")]
222impl Query for Debugger {
223    #[inline]
224    #[track_caller]
225    fn on_timer(&mut self, timer: &Timer) -> Result {
226        tracing::trace!(location = %core::panic::Location::caller(), timer = ?timer);
227        Ok(())
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use crate::time::clock::{Clock, NoopClock};
235    use core::time::Duration;
236
237    #[test]
238    fn is_armed_test() {
239        let now = NoopClock.get_time();
240        let mut timer = Timer::default();
241
242        assert!(!timer.is_armed());
243
244        timer.set(now);
245        assert!(timer.is_armed());
246
247        timer.cancel();
248        assert!(!timer.is_armed());
249    }
250
251    #[test]
252    fn is_expired_test() {
253        let mut now = NoopClock.get_time();
254        let mut timer = Timer::default();
255
256        assert!(!timer.is_expired(now));
257
258        timer.set(now + Duration::from_millis(100));
259
260        now += Duration::from_millis(99);
261        assert!(!timer.is_expired(now));
262
263        assert!(
264            timer.is_expired(now + Duration::from_micros(1)),
265            "if a timer is less than 1ms in the future is should expire"
266        );
267
268        now += Duration::from_millis(1);
269        assert!(timer.is_expired(now));
270
271        timer.cancel();
272        assert!(!timer.is_expired(now));
273    }
274
275    #[test]
276    fn poll_expiration_test() {
277        let mut now = NoopClock.get_time();
278        let mut timer = Timer::default();
279
280        timer.set(now + Duration::from_millis(100));
281
282        assert!(!timer.poll_expiration(now).is_ready());
283        assert!(timer.is_armed());
284
285        now += Duration::from_millis(100);
286
287        assert!(timer.poll_expiration(now).is_ready());
288        assert!(!timer.is_armed());
289    }
290}