s2n_quic_core/time/
token_bucket.rs1use crate::time::{
5 timer::{self, Provider as _},
6 Timer, Timestamp,
7};
8use core::time::Duration;
9
10#[derive(Debug)]
11pub struct TokenBucket {
12 current: u64,
14 refill_amount: u64,
16 refill_interval: Duration,
18 refill_timer: Timer,
20 max: u64,
22}
23
24impl Default for TokenBucket {
25 #[inline]
26 fn default() -> Self {
27 Self::builder().build()
28 }
29}
30
31impl TokenBucket {
32 #[inline]
33 pub fn builder() -> Builder {
34 Builder::default()
35 }
36
37 #[inline]
38 pub fn take(&mut self, amount: u64, now: Timestamp) -> u64 {
39 if amount == 0 {
40 self.on_timeout(now);
41 return 0;
42 }
43
44 if self.current < amount {
46 self.on_timeout(now);
47 }
48
49 let credits = amount.min(self.current);
50 self.current -= credits;
51
52 self.on_timeout(now);
53
54 credits
55 }
56
57 #[inline]
58 pub fn set_refill_interval(&mut self, new_interval: Duration) {
59 if self.refill_interval == new_interval {
61 return;
62 }
63
64 let prev_interval = core::mem::replace(&mut self.refill_interval, new_interval);
66
67 if let Some(target) = self.refill_timer.next_expiration() {
69 if let Some(now) = target.checked_sub(prev_interval) {
70 self.refill_timer.set(now + new_interval);
71 }
72 }
73 }
74
75 #[inline]
76 pub fn on_timeout(&mut self, now: Timestamp) {
77 while self.current < self.max {
78 if let Some(target) = self.refill_timer.next_expiration() {
79 if !target.has_elapsed(now) {
81 break;
82 }
83
84 self.current = self
86 .max
87 .min(self.current.saturating_add(self.refill_amount));
88
89 if self.current == self.max {
91 self.refill_timer.cancel();
92 break;
93 }
94
95 self.refill_timer.set(target + self.refill_interval);
98 } else {
99 self.refill_timer.set(now + self.refill_interval);
101 break;
102 }
103 }
104
105 self.invariants();
106 }
107
108 #[inline]
109 pub fn cancel(&mut self) {
110 self.refill_timer.cancel();
111 }
112
113 #[inline]
114 fn invariants(&self) {
115 if cfg!(debug_assertions) {
116 assert!(self.current <= self.max);
117 assert_eq!(
118 self.refill_timer.is_armed(),
119 self.current < self.max,
120 "timer should be armed ({}) if current ({}) is less than max ({})",
121 self.refill_timer.is_armed(),
122 self.current,
123 self.max,
124 );
125 }
126 }
127}
128
129impl timer::Provider for TokenBucket {
130 #[inline]
131 fn timers<Q: timer::Query>(&self, query: &mut Q) -> timer::Result {
132 self.refill_timer.timers(query)?;
133 Ok(())
134 }
135}
136
137pub struct Builder {
138 max: u64,
139 refill_interval: Duration,
140 refill_amount: u64,
141}
142
143impl Default for Builder {
144 fn default() -> Self {
145 Self {
146 max: 100,
147 refill_amount: 5,
148 refill_interval: Duration::from_secs(1),
149 }
150 }
151}
152
153impl Builder {
154 #[inline]
155 pub fn with_max(mut self, max: u64) -> Self {
156 self.max = max;
157 self
158 }
159
160 #[inline]
161 pub fn with_refill_amount(mut self, amount: u64) -> Self {
162 self.refill_amount = amount;
163 self
164 }
165
166 #[inline]
167 pub fn with_refill_interval(mut self, interval: Duration) -> Self {
168 self.refill_interval = interval;
169 self
170 }
171
172 #[inline]
173 pub fn build(self) -> TokenBucket {
174 let Self {
175 max,
176 refill_interval,
177 refill_amount,
178 } = self;
179
180 TokenBucket {
181 current: max,
182 max,
183 refill_amount,
184 refill_interval,
185 refill_timer: Default::default(),
186 }
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use crate::time::{testing::Clock, Clock as _};
194
195 #[test]
196 fn example_test() {
197 let mut bucket = TokenBucket::default();
198
199 let mut clock = Clock::default();
200
201 assert_eq!(bucket.take(1, clock.get_time()), 1);
202 assert!(bucket.refill_timer.is_armed());
203
204 assert_eq!(bucket.take(100, clock.get_time()), 99);
205
206 assert_eq!(bucket.take(1, clock.get_time()), 0);
207
208 clock.inc_by(Duration::from_secs(1));
209
210 assert_eq!(bucket.take(100, clock.get_time()), 5);
211 assert!(bucket.refill_timer.is_armed());
212
213 clock.inc_by(Duration::from_secs(3));
214
215 assert_eq!(bucket.take(100, clock.get_time()), 15);
216 assert!(bucket.refill_timer.is_armed());
217 }
218}