s2n_quic_core/task/
cooldown.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use core::{
5    future::Future,
6    pin::Pin,
7    task::{Context, Poll},
8};
9use pin_project_lite::pin_project;
10
11#[derive(Clone, Debug, Default)]
12pub struct Cooldown {
13    credits: u16,
14    limit: u16,
15}
16
17#[derive(Clone, Copy, Debug, PartialEq, Eq)]
18pub enum Outcome {
19    /// The task should loop
20    Loop,
21    /// The task should return Pending and wait for an actual wake notification
22    Sleep,
23}
24
25impl Outcome {
26    #[inline]
27    pub fn is_loop(&self) -> bool {
28        matches!(self, Self::Loop)
29    }
30
31    #[inline]
32    pub fn is_sleep(&self) -> bool {
33        matches!(self, Self::Sleep)
34    }
35}
36
37impl Cooldown {
38    #[inline]
39    pub fn new(limit: u16) -> Self {
40        Self {
41            limit,
42            credits: limit,
43        }
44    }
45
46    #[inline]
47    pub fn state(&self) -> Outcome {
48        if self.credits > 0 {
49            Outcome::Loop
50        } else {
51            Outcome::Sleep
52        }
53    }
54
55    /// Notifies the cooldown that the poll operation was ready
56    ///
57    /// This resets the cooldown period until another `Pending` result.
58    #[inline]
59    pub fn on_ready(&mut self) {
60        // reset the pending count
61        self.credits = self.limit;
62    }
63
64    /// Notifies the cooldown that the poll operation was pending
65    ///
66    /// This consumes a cooldown credit until they are exhausted at which point the task should
67    /// sleep.
68    #[inline]
69    pub fn on_pending(&mut self) -> Outcome {
70        if self.credits > 0 {
71            self.credits -= 1;
72            return Outcome::Loop;
73        }
74
75        Outcome::Sleep
76    }
77
78    #[inline]
79    pub fn on_pending_task(&mut self, cx: &mut core::task::Context) -> Outcome {
80        let outcome = self.on_pending();
81
82        if outcome.is_loop() {
83            cx.waker().wake_by_ref();
84        }
85
86        outcome
87    }
88
89    #[inline]
90    pub async fn wrap<F>(&mut self, fut: F) -> F::Output
91    where
92        F: Future + Unpin,
93    {
94        Wrapped {
95            fut,
96            cooldown: self,
97        }
98        .await
99    }
100}
101
102pin_project!(
103    struct Wrapped<'a, F>
104    where
105        F: core::future::Future,
106    {
107        #[pin]
108        fut: F,
109        cooldown: &'a mut Cooldown,
110    }
111);
112
113impl<F> Future for Wrapped<'_, F>
114where
115    F: Future,
116{
117    type Output = F::Output;
118
119    #[inline]
120    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
121        let this = self.project();
122        match this.fut.poll(cx) {
123            Poll::Ready(v) => {
124                this.cooldown.on_ready();
125                Poll::Ready(v)
126            }
127            Poll::Pending => {
128                this.cooldown.on_pending_task(cx);
129                Poll::Pending
130            }
131        }
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138
139    #[test]
140    fn cooldown_test() {
141        let mut cooldown = Cooldown::new(2);
142
143        assert_eq!(cooldown.on_pending(), Outcome::Loop);
144        assert_eq!(cooldown.on_pending(), Outcome::Loop);
145        assert_eq!(cooldown.on_pending(), Outcome::Sleep);
146        assert_eq!(cooldown.on_pending(), Outcome::Sleep);
147
148        // call on ready to restore credits
149        cooldown.on_ready();
150
151        assert_eq!(cooldown.on_pending(), Outcome::Loop);
152        assert_eq!(cooldown.on_pending(), Outcome::Loop);
153        assert_eq!(cooldown.on_pending(), Outcome::Sleep);
154        assert_eq!(cooldown.on_pending(), Outcome::Sleep);
155
156        cooldown.on_ready();
157
158        // call on ready while we're still looping
159        assert_eq!(cooldown.on_pending(), Outcome::Loop);
160        cooldown.on_ready();
161
162        assert_eq!(cooldown.on_pending(), Outcome::Loop);
163        assert_eq!(cooldown.on_pending(), Outcome::Loop);
164        assert_eq!(cooldown.on_pending(), Outcome::Sleep);
165        assert_eq!(cooldown.on_pending(), Outcome::Sleep);
166    }
167
168    #[test]
169    fn disabled_test() {
170        let mut cooldown = Cooldown::new(0);
171
172        // with cooldown disabled, it should always return sleep
173        assert_eq!(cooldown.on_pending(), Outcome::Sleep);
174
175        cooldown.on_ready();
176        assert_eq!(cooldown.on_pending(), Outcome::Sleep);
177    }
178}