tor_basic_utils/token_bucket.rs
1//! A token bucket implementation.
2
3use std::fmt::Debug;
4use web_time_compat::{Duration, Instant};
5
6/// A token bucket.
7///
8/// Calculations are performed at microsecond resolution.
9/// You likely want to call [`refill()`](Self::refill) each time you want to access or perform an
10/// operation on the token bucket.
11///
12/// This is partially inspired by tor's `token_bucket_ctr_t`,
13/// but the implementation is quite a bit different.
14/// We use larger values here (for example `u64`),
15/// and we aim to avoid drift when refills occur at times that aren't exactly in period with the
16/// refill rate.
17///
18/// It's possible that we could relax these requirements to reduce memory usage and computation
19/// complexity, but that optimization should probably only be made if/when needed since it would
20/// make the code more difficult to reason about, and possibly more complex.
21#[derive(Debug)]
22pub struct TokenBucket<I> {
23 /// The refill rate in tokens/second.
24 rate: u64,
25 /// The max amount of tokens in the bucket.
26 /// Commonly referred to as the "burst".
27 bucket_max: u64,
28 /// Current amount of tokens in the bucket.
29 // It's possible that in the future we may want a token bucket to allow negative values. For
30 // example we might want to send a few extra bytes over the allowed limit if it would mean that
31 // we send a complete TLS record.
32 bucket: u64,
33 /// Time that the most recent token was added to the bucket.
34 ///
35 /// While this can be thought of as the last time the bucket was partially refilled, it more
36 /// specifically is the time that the most recent token was added. For example if the bucket
37 /// refills one token every 100 ms, and the bucket is refilled at time 510 ms, the bucket would
38 /// gain 5 tokens and the stored time would be 500 ms.
39 added_tokens_at: I,
40}
41
42impl<I: TokenBucketInstant> TokenBucket<I> {
43 /// A new [`TokenBucket`] with a given `rate` in tokens/second and a `max` token limit.
44 ///
45 /// The bucket will initially be full.
46 /// The value `max` is commonly referred to as the "burst".
47 pub fn new(config: &TokenBucketConfig, now: I) -> Self {
48 Self {
49 rate: config.rate,
50 bucket_max: config.bucket_max,
51 bucket: config.bucket_max,
52 added_tokens_at: now,
53 }
54 }
55
56 /// Are there no tokens in the bucket?
57 pub fn is_empty(&self) -> bool {
58 self.bucket == 0
59 }
60
61 /// The maximum number of tokens that this bucket can hold.
62 pub fn max(&self) -> u64 {
63 self.bucket_max
64 }
65
66 /// Remove `count` tokens from the bucket.
67 pub fn drain(&mut self, count: u64) -> Result<BecameEmpty, InsufficientTokensError> {
68 Ok(self.claim(count)?.commit())
69 }
70
71 /// Claim a number of tokens.
72 ///
73 /// The claim will be held by the returned [`ClaimedTokens`], and committed when dropped.
74 ///
75 /// **Note:** You probably want to call [`refill()`](Self::refill) before this.
76 // Since the `ClaimedTokens` holds a `&mut` to this `TokenBucket`, we don't need to worry about
77 // other calls accessing the `TokenBucket` before the `ClaimedTokens` are committed.
78 pub fn claim(&mut self, count: u64) -> Result<ClaimedTokens<I>, InsufficientTokensError> {
79 if count > self.bucket {
80 return Err(InsufficientTokensError {
81 available: self.bucket,
82 });
83 }
84
85 Ok(ClaimedTokens::new(self, count))
86 }
87
88 /// Adjust the refill rate and max tokens of the bucket.
89 ///
90 /// The token bucket is refilled up to `now` before changing the rate.
91 ///
92 /// If the new max is smaller than the existing number of tokens,
93 /// the number of tokens will be reduced to the new max.
94 ///
95 /// A rate and/or max of 0 is allowed.
96 pub fn adjust(&mut self, now: I, config: &TokenBucketConfig) {
97 // make sure that the bucket gets the tokens it is owed before we change the rate
98 self.refill(now);
99
100 // If the old rate was small (or 0), the `refill()` might not have updated
101 // `added_tokens_at`.
102 //
103 // For example if the bucket has a rate of 0 and was last refilled 10 seconds ago, it will
104 // not have gained any tokens in the last 10 seconds. If we were to only update the rate to
105 // 100 tokens/second now, the bucket would immediately become eligible to refill 1000
106 // tokens. We only want the rate change to become effective now, not in the past, so we
107 // ensure this by resetting `added_tokens_at`.
108 self.added_tokens_at = std::cmp::max(self.added_tokens_at, now);
109
110 self.rate = config.rate;
111 self.bucket_max = config.bucket_max;
112 self.bucket = std::cmp::min(self.bucket, self.bucket_max);
113 }
114
115 /// An estimated time at which the bucket will have `tokens` available.
116 ///
117 /// It is not guaranteed that `tokens` will be available at the returned time.
118 ///
119 /// If there are already enough tokens available, a time in the past may be returned.
120 ///
121 /// A value of `None` implies "never",
122 /// for example if the refill rate is 0,
123 /// the bucket max is too small,
124 /// or the time is too large to be represented as an `I`.
125 pub fn tokens_available_at(&self, tokens: u64) -> Result<I, NeverEnoughTokensError> {
126 let tokens_needed = tokens.saturating_sub(self.bucket);
127
128 // check if we currently have enough tokens before considering refilling
129 if tokens_needed == 0 {
130 return Ok(self.added_tokens_at);
131 }
132
133 // if the rate is 0, we'll never get more tokens
134 if self.rate == 0 {
135 return Err(NeverEnoughTokensError::ZeroRate);
136 }
137
138 // if more tokens are wanted than the capacity of the bucket, we'll never get enough
139 if tokens > self.bucket_max {
140 return Err(NeverEnoughTokensError::ExceedsMaxTokens);
141 }
142
143 // this may underestimate the time if either argument is very large
144 let time_needed = Self::tokens_to_duration(tokens_needed, self.rate)
145 .ok_or(NeverEnoughTokensError::ZeroRate)?;
146
147 // Always return at least 1 microsecond since:
148 // 1. We don't want to return `Duration::ZERO` if the tokens aren't ready,
149 // which may occur if the rate is very large (<1 ns/token).
150 // 2. Clocks generally don't operate at <1 us resolution.
151 let time_needed = std::cmp::max(time_needed, Duration::from_micros(1));
152
153 self.added_tokens_at
154 .checked_add(time_needed)
155 .ok_or(NeverEnoughTokensError::InstantNotRepresentable)
156 }
157
158 /// Refill the bucket.
159 pub fn refill(&mut self, now: I) -> BecameNonEmpty {
160 // time since we last added tokens
161 let elapsed = now.saturating_duration_since(self.added_tokens_at);
162
163 // If we exceeded the threshold, update the timestamp and return.
164 // This is taken from tor, which has the comment below:
165 //
166 // > Skip over updates that include an overflow or a very large jump. This can happen for
167 // > platform specific reasons, such as the old ~48 day windows timer.
168 //
169 // It's unclear if this type of OS bug is still common enough that this check is useful,
170 // but it shouldn't hurt.
171 if elapsed > I::IGNORE_THRESHOLD {
172 tracing::debug!(
173 "Time jump of {elapsed:?} is larger than {:?}; not refilling token bucket",
174 I::IGNORE_THRESHOLD,
175 );
176 self.added_tokens_at = now;
177 return BecameNonEmpty::No;
178 }
179
180 let old_bucket = self.bucket;
181
182 // Compute how much we should increment the bucket by.
183 // This may be underestimated in some cases.
184 let bucket_inc = Self::duration_to_tokens(elapsed, self.rate);
185
186 self.bucket = std::cmp::min(self.bucket_max, self.bucket.saturating_add(bucket_inc));
187
188 // Compute how much we should increment the `last_added_tokens` time by. This avoids
189 // drifting if the `bucket_inc` was underestimated, and avoids rounding errors which could
190 // cause the token bucket to effectively use a lower rate. For example if the rate was
191 // "1 token / sec" and the elapsed time was "1.2 sec", we only want to refill 1 token and
192 // increment the time by 1 second.
193 //
194 // While the docs for `tokens_to_duration` say that a smaller than expected duration may be
195 // returned, we have a test `test_duration_token_round_trip` which ensures that
196 // `tokens_to_duration` returns the expected value when used with the result from
197 // `duration_to_tokens`.
198 let added_tokens_at_inc =
199 Self::tokens_to_duration(bucket_inc, self.rate).unwrap_or(Duration::ZERO);
200
201 self.added_tokens_at = self
202 .added_tokens_at
203 .checked_add(added_tokens_at_inc)
204 .expect("overflowed time");
205 debug_assert!(self.added_tokens_at <= now);
206
207 if old_bucket == 0 && self.bucket != 0 {
208 BecameNonEmpty::Yes
209 } else {
210 BecameNonEmpty::No
211 }
212 }
213
214 /// How long would it take to refill `tokens` at `rate`?
215 ///
216 /// The result is rounded up to the nearest microsecond.
217 /// If the number of `tokens` is large,
218 /// the result may be much lower than the expected duration due to saturating 64-bit arithmetic.
219 ///
220 /// `None` will be returned if the `rate` is 0.
221 fn tokens_to_duration(tokens: u64, rate: u64) -> Option<Duration> {
222 // Perform the calculation in microseconds rather than nanoseconds since timers typically
223 // have microsecond granularity, and it lowers the chance that the calculation overflows the
224 // `u64::MAX` limit compared to nanoseconds. In the case that the calculation saturates, the
225 // returned duration will be shorter than the real value.
226 //
227 // For example with `tokens = u64::MAX` and `rate = u64::MAX` we'd expect a result of 1
228 // second, but:
229 // u64::MAX.saturating_mul(1000 * 1000).div_ceil(u64::MAX) = 1 microsecond
230 //
231 // The `div_ceil` ensures we always round up to the nearest microsecond.
232 //
233 // dimensional analysis:
234 // (tokens) * (microseconds / second) / (tokens / second) = microseconds
235 if rate == 0 {
236 return None;
237 }
238 let micros = tokens.saturating_mul(1000 * 1000).div_ceil(rate);
239 Some(Duration::from_micros(micros))
240 }
241
242 /// How many tokens would be refilled within `time` at `rate`?
243 ///
244 /// The `time` is truncated to microsecond granularity.
245 /// If the `time` or `rate` is large,
246 /// the result may be much lower than the expected number of tokens due to saturating 64-bit
247 /// arithmetic.
248 fn duration_to_tokens(time: Duration, rate: u64) -> u64 {
249 let micros = u64::try_from(time.as_micros()).unwrap_or(u64::MAX);
250 // dimensional analysis:
251 // (tokens / second) * (microseconds) / (microseconds / second) = tokens
252 rate.saturating_mul(micros) / (1000 * 1000)
253 }
254}
255
256/// The refill rate and token max for a [`TokenBucket`].
257#[derive(Clone, Debug)]
258#[allow(clippy::exhaustive_structs)] // constructed directly by callers configuring the bucket
259pub struct TokenBucketConfig {
260 /// The refill rate in tokens/second.
261 pub rate: u64,
262 /// The max amount of tokens in the bucket.
263 /// Commonly referred to as the "burst".
264 pub bucket_max: u64,
265}
266
267/// A handle to a number of claimed tokens.
268///
269/// Dropping this handle will commit the claim.
270#[derive(Debug)]
271pub struct ClaimedTokens<'a, I> {
272 /// The bucket that the claim is for.
273 bucket: &'a mut TokenBucket<I>,
274 /// How many tokens to remove from the bucket.
275 count: u64,
276}
277
278impl<'a, I> ClaimedTokens<'a, I> {
279 /// Create a new [`ClaimedTokens`] that will remove `count` tokens from the token `bucket` when
280 /// dropped.
281 fn new(bucket: &'a mut TokenBucket<I>, count: u64) -> Self {
282 Self { bucket, count }
283 }
284
285 /// Commit the claimed tokens.
286 ///
287 /// This is equivalent to just dropping the [`ClaimedTokens`], but also returns whether the
288 /// token bucket became empty or not.
289 pub fn commit(mut self) -> BecameEmpty {
290 self.commit_impl()
291 }
292
293 /// Reduce the claim to a fewer number of tokens than the original claim.
294 ///
295 /// If `count` is larger than the original claim, an error will be returned containing the
296 /// current number of claimed tokens.
297 pub fn reduce(&mut self, count: u64) -> Result<(), InsufficientTokensError> {
298 if count > self.count {
299 return Err(InsufficientTokensError {
300 available: self.count,
301 });
302 }
303
304 self.count = count;
305 Ok(())
306 }
307
308 /// Discard the claim.
309 ///
310 /// This does not remove any tokens from the token bucket.
311 pub fn discard(mut self) {
312 self.count = 0;
313 }
314
315 /// The commit implementation.
316 ///
317 /// After calling [`commit_impl()`](Self::commit_impl),
318 /// the [`ClaimedTokens`] should no longer be used and should be dropped immediately.
319 fn commit_impl(&mut self) -> BecameEmpty {
320 // when the `ClaimedTokens` was created by the `TokenBucket`, it should have ensured that
321 // there were enough tokens
322 self.bucket.bucket = self
323 .bucket
324 .bucket
325 .checked_sub(self.count)
326 .unwrap_or_else(|| {
327 panic!(
328 "claim commit failed: {}, {}",
329 self.count, self.bucket.bucket,
330 )
331 });
332
333 // when `self` is dropped some time after this function ends,
334 // we don't want to subtract again
335 self.count = 0;
336
337 if self.bucket.bucket > 0 {
338 BecameEmpty::No
339 } else {
340 BecameEmpty::Yes
341 }
342 }
343}
344
345impl<'a, I> std::ops::Drop for ClaimedTokens<'a, I> {
346 fn drop(&mut self) {
347 self.commit_impl();
348 }
349}
350
351/// An operation was attempted to reduce the number of tokens,
352/// but the token bucket did not have enough tokens.
353#[derive(Copy, Clone, Debug, PartialEq, Eq, thiserror::Error)]
354#[error("insufficient tokens for operation")]
355pub struct InsufficientTokensError {
356 /// The number of tokens that are available to drain/commit.
357 available: u64,
358}
359
360impl InsufficientTokensError {
361 /// Get the number of tokens that are available to drain/commit.
362 pub fn available_tokens(&self) -> u64 {
363 self.available
364 }
365}
366
367/// The token bucket will never have the requested number of tokens.
368#[derive(Copy, Clone, Debug, PartialEq, Eq, thiserror::Error)]
369#[allow(clippy::exhaustive_enums)] // callers exhaustively match on these variants
370#[error("there will never be enough tokens for this operation")]
371pub enum NeverEnoughTokensError {
372 /// The request exceeds the bucket's maximum number of tokens.
373 ExceedsMaxTokens,
374 /// The refill rate is 0.
375 ZeroRate,
376 /// The time is not representable.
377 ///
378 /// For example the if the rate is low and a large number of tokens were requested, it may be
379 /// too far in the future that it cannot be represented as a time value.
380 InstantNotRepresentable,
381}
382
383/// The token bucket transitioned from "empty" to "non-empty".
384#[derive(Copy, Clone, Debug, PartialEq, Eq)]
385#[allow(clippy::exhaustive_enums)] // a simple yes/no status that callers match on
386pub enum BecameNonEmpty {
387 /// Token bucket became non-empty.
388 Yes,
389 /// Token bucket remains empty.
390 No,
391}
392
393/// The token bucket transitioned from "non-empty" to "empty".
394#[derive(Copy, Clone, Debug, PartialEq, Eq)]
395#[allow(clippy::exhaustive_enums)] // a simple yes/no status that callers match on
396pub enum BecameEmpty {
397 /// Token bucket became empty.
398 Yes,
399 /// Token bucket remains non-empty.
400 No,
401}
402
403/// Any type implementing this must be represented as a measurement of a monotonically nondecreasing
404/// clock.
405pub trait TokenBucketInstant: Copy + Clone + Debug + PartialEq + Eq + PartialOrd + Ord {
406 /// An unrealistically large time jump.
407 ///
408 /// We assume that any time change larger than this indicates a broken monotonic clock,
409 /// and the bucket will not be refilled.
410 const IGNORE_THRESHOLD: Duration;
411
412 /// See [`Instant::checked_add`].
413 fn checked_add(&self, duration: Duration) -> Option<Self>;
414
415 /// See [`Instant::checked_duration_since`].
416 fn checked_duration_since(&self, earlier: Self) -> Option<Duration>;
417
418 /// See [`Instant::saturating_duration_since`].
419 fn saturating_duration_since(&self, earlier: Self) -> Duration {
420 self.checked_duration_since(earlier).unwrap_or_default()
421 }
422}
423
424impl TokenBucketInstant for Instant {
425 // This value is taken from tor (see `elapsed_ticks <= UINT32_MAX/4` in
426 // `src/lib/evloop/token_bucket.c`).
427 const IGNORE_THRESHOLD: Duration = Duration::from_secs((u32::MAX / 4) as u64);
428
429 #[inline]
430 fn checked_add(&self, duration: Duration) -> Option<Self> {
431 self.checked_add(duration)
432 }
433
434 #[inline]
435 fn checked_duration_since(&self, earlier: Self) -> Option<Duration> {
436 self.checked_duration_since(earlier)
437 }
438
439 #[inline]
440 fn saturating_duration_since(&self, earlier: Self) -> Duration {
441 self.saturating_duration_since(earlier)
442 }
443}
444
445#[cfg(test)]
446mod test {
447 #![allow(clippy::unwrap_used)]
448
449 use super::*;
450
451 use rand::RngExt;
452
453 #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
454 struct MillisTimestamp(u64);
455
456 impl TokenBucketInstant for MillisTimestamp {
457 const IGNORE_THRESHOLD: Duration = Duration::from_millis(1_000_000_000);
458
459 fn checked_add(&self, duration: Duration) -> Option<Self> {
460 let duration = u64::try_from(duration.as_millis()).ok()?;
461 self.0.checked_add(duration).map(Self)
462 }
463
464 fn checked_duration_since(&self, earlier: Self) -> Option<Duration> {
465 Some(Duration::from_millis(self.0.checked_sub(earlier.0)?))
466 }
467 }
468
469 #[test]
470 fn adjust_now() {
471 let time = MillisTimestamp(100);
472
473 let config = TokenBucketConfig {
474 rate: 10,
475 bucket_max: 100,
476 };
477 let mut tb = TokenBucket::new(&config, time);
478 assert_eq!(tb.bucket, 100);
479 assert_eq!(tb.bucket_max, 100);
480 assert_eq!(tb.rate, 10);
481
482 tb.adjust(
483 time,
484 &TokenBucketConfig {
485 rate: 20,
486 bucket_max: 100,
487 },
488 );
489 assert_eq!(tb.bucket, 100);
490 assert_eq!(tb.bucket_max, 100);
491
492 tb.adjust(
493 time,
494 &TokenBucketConfig {
495 rate: 20,
496 bucket_max: 40,
497 },
498 );
499 assert_eq!(tb.bucket, 40);
500 assert_eq!(tb.bucket_max, 40);
501
502 tb.adjust(
503 time,
504 &TokenBucketConfig {
505 rate: 20,
506 bucket_max: 100,
507 },
508 );
509 assert_eq!(tb.bucket, 40);
510 assert_eq!(tb.bucket_max, 100);
511
512 tb.adjust(
513 time,
514 &TokenBucketConfig {
515 rate: 200,
516 bucket_max: 100,
517 },
518 );
519 assert_eq!(tb.bucket, 40);
520 assert_eq!(tb.bucket_max, 100);
521 assert_eq!(tb.rate, 200);
522 }
523
524 #[test]
525 fn adjust_future() {
526 let config = TokenBucketConfig {
527 rate: 10,
528 bucket_max: 100,
529 };
530 let mut tb = TokenBucket::new(&config, MillisTimestamp(100));
531 assert_eq!(tb.bucket, 100);
532 assert_eq!(tb.bucket_max, 100);
533 assert_eq!(tb.rate, 10);
534
535 // at 300 ms: increase rate and max; bucket was already full, so doesn't gain any tokens
536 tb.adjust(
537 MillisTimestamp(300),
538 &TokenBucketConfig {
539 rate: 20,
540 bucket_max: 200,
541 },
542 );
543 assert_eq!(tb.bucket, 100);
544 assert_eq!(tb.bucket_max, 200);
545
546 // at 500 ms: no changes; bucket is refilled during `adjust()`, so gains 4 tokens
547 tb.adjust(
548 MillisTimestamp(500),
549 &TokenBucketConfig {
550 rate: 20,
551 bucket_max: 200,
552 },
553 );
554 assert_eq!(tb.bucket, 104);
555 assert_eq!(tb.bucket_max, 200);
556
557 // at 700 ms: lower rate and max; bucket is lowered to new max, so loses 4 tokens
558 tb.adjust(
559 MillisTimestamp(700),
560 &TokenBucketConfig {
561 rate: 0,
562 bucket_max: 100,
563 },
564 );
565 assert_eq!(tb.bucket, 100);
566 assert_eq!(tb.bucket_max, 100);
567
568 // at 900 ms: raise rate and max; rate was previously 0 so doesn't gain any tokens
569 tb.adjust(
570 MillisTimestamp(900),
571 &TokenBucketConfig {
572 rate: 100,
573 bucket_max: 200,
574 },
575 );
576 assert_eq!(tb.bucket, 100);
577 assert_eq!(tb.bucket_max, 200);
578 }
579
580 #[test]
581 fn adjust_zero() {
582 let time = MillisTimestamp(100);
583
584 let config = TokenBucketConfig {
585 rate: 10,
586 bucket_max: 100,
587 };
588
589 let mut tb = TokenBucket::new(&config, time);
590 tb.adjust(
591 time,
592 &TokenBucketConfig {
593 rate: 0,
594 bucket_max: 200,
595 },
596 );
597 assert_eq!(tb.bucket, 100);
598 assert_eq!(tb.bucket_max, 200);
599 assert_eq!(tb.rate, 0);
600 // bucket should not increase
601 tb.refill(MillisTimestamp(10_000_000));
602 assert_eq!(tb.bucket, 100);
603
604 let mut tb = TokenBucket::new(&config, time);
605 tb.adjust(
606 time,
607 &TokenBucketConfig {
608 rate: 10,
609 bucket_max: 0,
610 },
611 );
612 assert_eq!(tb.bucket, 0);
613 assert_eq!(tb.bucket_max, 0);
614 assert_eq!(tb.rate, 10);
615 // bucket should stay empty
616 tb.refill(MillisTimestamp(10_000_000));
617 assert_eq!(tb.bucket, 0);
618
619 let mut tb = TokenBucket::new(&config, time);
620 tb.adjust(
621 time,
622 &TokenBucketConfig {
623 rate: 0,
624 bucket_max: 0,
625 },
626 );
627 assert_eq!(tb.bucket, 0);
628 assert_eq!(tb.bucket_max, 0);
629 assert_eq!(tb.rate, 0);
630 // bucket should stay empty
631 tb.refill(MillisTimestamp(10_000_000));
632 assert_eq!(tb.bucket, 0);
633 }
634
635 #[test]
636 fn is_empty() {
637 // increases 10 tokens/second (one every 100 ms)
638 let config = TokenBucketConfig {
639 rate: 10,
640 bucket_max: 100,
641 };
642 let mut tb = TokenBucket::new(&config, MillisTimestamp(100));
643 assert!(!tb.is_empty());
644
645 tb.drain(99).unwrap();
646 assert!(!tb.is_empty());
647
648 tb.drain(1).unwrap();
649 assert!(tb.is_empty());
650
651 tb.refill(MillisTimestamp(199));
652 assert!(tb.is_empty());
653
654 tb.refill(MillisTimestamp(200));
655 assert!(!tb.is_empty());
656 }
657
658 #[test]
659 fn correctness() {
660 // increases 10 tokens/second (one every 100 ms)
661 let config = TokenBucketConfig {
662 rate: 10,
663 bucket_max: 100,
664 };
665 let mut tb = TokenBucket::new(&config, MillisTimestamp(100));
666
667 tb.drain(50).unwrap();
668 assert_eq!(tb.bucket, 50);
669
670 tb.refill(MillisTimestamp(1100));
671 assert_eq!(tb.bucket, 60);
672
673 tb.drain(50).unwrap();
674 assert_eq!(tb.bucket, 10);
675
676 tb.refill(MillisTimestamp(2100));
677 assert_eq!(tb.bucket, 20);
678
679 tb.refill(MillisTimestamp(2101));
680 assert_eq!(tb.bucket, 20);
681 tb.refill(MillisTimestamp(2199));
682 assert_eq!(tb.bucket, 20);
683 tb.refill(MillisTimestamp(2200));
684 assert_eq!(tb.bucket, 21);
685 }
686
687 #[test]
688 fn rounding() {
689 // increases 10 tokens/second (one every 100 ms)
690 let config = TokenBucketConfig {
691 rate: 10,
692 bucket_max: 100,
693 };
694 let mut tb = TokenBucket::new(&config, MillisTimestamp(0));
695 tb.drain(100).unwrap();
696
697 // ensure that refilling at 150 ms does not change the `added_tokens_at` time to 150 ms,
698 // otherwise the next refill wouldn't occur until 250 ms instead of 200 ms
699 tb.refill(MillisTimestamp(99));
700 assert_eq!(tb.bucket, 0);
701 tb.refill(MillisTimestamp(150));
702 assert_eq!(tb.bucket, 1);
703 tb.refill(MillisTimestamp(199));
704 assert_eq!(tb.bucket, 1);
705 tb.refill(MillisTimestamp(200));
706 assert_eq!(tb.bucket, 2);
707 }
708
709 #[test]
710 fn tokens_available_at() {
711 // increases 10 tokens/second (one every 100 ms)
712 let config = TokenBucketConfig {
713 rate: 10,
714 bucket_max: 100,
715 };
716 let mut tb = TokenBucket::new(&config, MillisTimestamp(0));
717
718 // bucket is empty at 0 ms, next token at 100 ms
719 tb.drain(100).unwrap();
720
721 assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(0)));
722 assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(100)));
723 assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(200)));
724
725 // bucket is still empty at 40 ms, next token at 100 ms
726 tb.refill(MillisTimestamp(40));
727
728 assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(0)));
729 assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(100)));
730 assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(200)));
731
732 // bucket has 1 token at 100 ms, next token at 200 ms
733 tb.refill(MillisTimestamp(100));
734
735 assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(100)));
736 assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(100)));
737 assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(200)));
738
739 // bucket is empty at 100 ms, next token at 200 ms
740 tb.drain(1).unwrap();
741
742 assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(100)));
743 assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(200)));
744 assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(300)));
745
746 // bucket is empty at 140 ms, next token at 200 ms
747 tb.refill(MillisTimestamp(140));
748
749 assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(100)));
750 assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(200)));
751 assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(300)));
752
753 // bucket has 1 token at 210 ms, next token at 300 ms
754 tb.refill(MillisTimestamp(210));
755
756 assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(200)));
757 assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(200)));
758 assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(300)));
759
760 use NeverEnoughTokensError as NETE;
761
762 assert_eq!(tb.tokens_available_at(100), Ok(MillisTimestamp(10_100)));
763 assert_eq!(tb.tokens_available_at(101), Err(NETE::ExceedsMaxTokens));
764 assert_eq!(
765 tb.tokens_available_at(u64::MAX),
766 Err(NETE::ExceedsMaxTokens),
767 );
768
769 // set the refill rate to 0; note that adjusting the rate also resets `added_tokens_at`
770 tb.adjust(
771 MillisTimestamp(210),
772 &TokenBucketConfig {
773 rate: 0,
774 bucket_max: 100,
775 },
776 );
777
778 assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(210)));
779 assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(210)));
780 assert_eq!(tb.tokens_available_at(2), Err(NETE::ZeroRate));
781 }
782
783 #[test]
784 fn test_duration_token_round_trip() {
785 let tokens_to_duration = TokenBucket::<Instant>::tokens_to_duration;
786 let duration_to_tokens = TokenBucket::<Instant>::duration_to_tokens;
787
788 // start with some hand-picked cases
789 let mut duration_rate_pairs = vec![
790 (Duration::from_nanos(0), 1),
791 (Duration::from_nanos(1), 1),
792 (Duration::from_micros(2), 1),
793 (Duration::MAX, 1),
794 (Duration::from_nanos(0), 3),
795 (Duration::from_nanos(1), 3),
796 (Duration::from_micros(2), 3),
797 (Duration::MAX, 3),
798 (Duration::from_nanos(0), 1000),
799 (Duration::from_nanos(1), 1000),
800 (Duration::from_micros(2), 1000),
801 (Duration::MAX, 1000),
802 (Duration::from_nanos(0), u64::MAX),
803 (Duration::from_nanos(1), u64::MAX),
804 (Duration::from_micros(2), u64::MAX),
805 (Duration::MAX, u64::MAX),
806 ];
807
808 let mut rng = rand::rng();
809
810 // add some fuzzing
811 for _ in 0..10_000 {
812 let secs = rng.random();
813 let nanos = rng.random();
814 // Duration::new() may panic, so just skip if there's a panic rather than trying to
815 // write our own logic to avoid the panic in the first place
816 let Ok(random_duration) = std::panic::catch_unwind(|| Duration::new(secs, nanos))
817 else {
818 continue;
819 };
820 let random_rate = rng.random();
821 duration_rate_pairs.push((random_duration, random_rate));
822 }
823
824 // for various combinations of durations and rates, we ensure that after an initial
825 // `duration_to_tokens` calculation which may truncate, a round-trip between
826 // `tokens_to_duration` and `duration_to_tokens` isn't lossy
827 for (original_duration, rate) in duration_rate_pairs {
828 // this may give a smaller number of tokens than expected (see docs on
829 // `TokenBucket::duration_to_tokens`)
830 let tokens = duration_to_tokens(original_duration, rate);
831
832 // we want to ensure that converting these `tokens` to a duration and then back to
833 // tokens is not lossy, which implies that `tokens_to_duration` is returning the
834 // expected value and not a truncated value due to saturating arithmetic
835 let duration = tokens_to_duration(tokens, rate).unwrap();
836 assert_eq!(tokens, duration_to_tokens(duration, rate));
837 }
838 }
839}