s2n_quic_core/time/
token_bucket.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::time::{
5    timer::{self, Provider as _},
6    Timer, Timestamp,
7};
8use core::time::Duration;
9
10#[derive(Debug)]
11pub struct TokenBucket {
12    /// The current number of tokens
13    current: u64,
14    /// The number of tokens to refill per interval
15    refill_amount: u64,
16    /// The rate to refill
17    refill_interval: Duration,
18    /// The current pending refill
19    refill_timer: Timer,
20    /// The maximum number of tokens for the bucket
21    max: u64,
22}
23
24impl Default for TokenBucket {
25    #[inline]
26    fn default() -> Self {
27        Self::builder().build()
28    }
29}
30
31impl TokenBucket {
32    #[inline]
33    pub fn builder() -> Builder {
34        Builder::default()
35    }
36
37    #[inline]
38    pub fn take(&mut self, amount: u64, now: Timestamp) -> u64 {
39        if amount == 0 {
40            self.on_timeout(now);
41            return 0;
42        }
43
44        // try to refill the bucket if we couldn't take the whole thing
45        if self.current < amount {
46            self.on_timeout(now);
47        }
48
49        let credits = amount.min(self.current);
50        self.current -= credits;
51
52        self.on_timeout(now);
53
54        credits
55    }
56
57    #[inline]
58    pub fn set_refill_interval(&mut self, new_interval: Duration) {
59        // if the value didn't change, then no need to update
60        if self.refill_interval == new_interval {
61            return;
62        }
63
64        // replace the previous with the new one
65        let prev_interval = core::mem::replace(&mut self.refill_interval, new_interval);
66
67        // recalibrate the refill timer with the new interval
68        if let Some(target) = self.refill_timer.next_expiration() {
69            if let Some(now) = target.checked_sub(prev_interval) {
70                self.refill_timer.set(now + new_interval);
71            }
72        }
73    }
74
75    #[inline]
76    pub fn on_timeout(&mut self, now: Timestamp) {
77        while self.current < self.max {
78            if let Some(target) = self.refill_timer.next_expiration() {
79                // the target hasn't expired yet
80                if !target.has_elapsed(now) {
81                    break;
82                }
83
84                // increase the allowed amount of credits
85                self.current = self
86                    .max
87                    .min(self.current.saturating_add(self.refill_amount));
88
89                // no need to keep looping if we're at the max
90                if self.current == self.max {
91                    self.refill_timer.cancel();
92                    break;
93                }
94
95                // reset the timer to the refill interval and loop back around to see if we can
96                // issue more, just in case we were late to query the timer
97                self.refill_timer.set(target + self.refill_interval);
98            } else {
99                // we haven't set a timer yet so set it now
100                self.refill_timer.set(now + self.refill_interval);
101                break;
102            }
103        }
104
105        self.invariants();
106    }
107
108    #[inline]
109    pub fn cancel(&mut self) {
110        self.refill_timer.cancel();
111    }
112
113    #[inline]
114    fn invariants(&self) {
115        if cfg!(debug_assertions) {
116            assert!(self.current <= self.max);
117            assert_eq!(
118                self.refill_timer.is_armed(),
119                self.current < self.max,
120                "timer should be armed ({}) if current ({}) is less than max ({})",
121                self.refill_timer.is_armed(),
122                self.current,
123                self.max,
124            );
125        }
126    }
127}
128
129impl timer::Provider for TokenBucket {
130    #[inline]
131    fn timers<Q: timer::Query>(&self, query: &mut Q) -> timer::Result {
132        self.refill_timer.timers(query)?;
133        Ok(())
134    }
135}
136
137pub struct Builder {
138    max: u64,
139    refill_interval: Duration,
140    refill_amount: u64,
141}
142
143impl Default for Builder {
144    fn default() -> Self {
145        Self {
146            max: 100,
147            refill_amount: 5,
148            refill_interval: Duration::from_secs(1),
149        }
150    }
151}
152
153impl Builder {
154    #[inline]
155    pub fn with_max(mut self, max: u64) -> Self {
156        self.max = max;
157        self
158    }
159
160    #[inline]
161    pub fn with_refill_amount(mut self, amount: u64) -> Self {
162        self.refill_amount = amount;
163        self
164    }
165
166    #[inline]
167    pub fn with_refill_interval(mut self, interval: Duration) -> Self {
168        self.refill_interval = interval;
169        self
170    }
171
172    #[inline]
173    pub fn build(self) -> TokenBucket {
174        let Self {
175            max,
176            refill_interval,
177            refill_amount,
178        } = self;
179
180        TokenBucket {
181            current: max,
182            max,
183            refill_amount,
184            refill_interval,
185            refill_timer: Default::default(),
186        }
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use crate::time::{testing::Clock, Clock as _};
194
195    #[test]
196    fn example_test() {
197        let mut bucket = TokenBucket::default();
198
199        let mut clock = Clock::default();
200
201        assert_eq!(bucket.take(1, clock.get_time()), 1);
202        assert!(bucket.refill_timer.is_armed());
203
204        assert_eq!(bucket.take(100, clock.get_time()), 99);
205
206        assert_eq!(bucket.take(1, clock.get_time()), 0);
207
208        clock.inc_by(Duration::from_secs(1));
209
210        assert_eq!(bucket.take(100, clock.get_time()), 5);
211        assert!(bucket.refill_timer.is_armed());
212
213        clock.inc_by(Duration::from_secs(3));
214
215        assert_eq!(bucket.take(100, clock.get_time()), 15);
216        assert!(bucket.refill_timer.is_armed());
217    }
218}