Skip to main content

tor_basic_utils/
token_bucket.rs

1//! A token bucket implementation.
2
3use std::fmt::Debug;
4use web_time_compat::{Duration, Instant};
5
6/// A token bucket.
7///
8/// Calculations are performed at microsecond resolution.
9/// You likely want to call [`refill()`](Self::refill) each time you want to access or perform an
10/// operation on the token bucket.
11///
12/// This is partially inspired by tor's `token_bucket_ctr_t`,
13/// but the implementation is quite a bit different.
14/// We use larger values here (for example `u64`),
15/// and we aim to avoid drift when refills occur at times that aren't exactly in period with the
16/// refill rate.
17///
18/// It's possible that we could relax these requirements to reduce memory usage and computation
19/// complexity, but that optimization should probably only be made if/when needed since it would
20/// make the code more difficult to reason about, and possibly more complex.
21#[derive(Debug)]
22pub struct TokenBucket<I> {
23    /// The refill rate in tokens/second.
24    rate: u64,
25    /// The max amount of tokens in the bucket.
26    /// Commonly referred to as the "burst".
27    bucket_max: u64,
28    /// Current amount of tokens in the bucket.
29    // It's possible that in the future we may want a token bucket to allow negative values. For
30    // example we might want to send a few extra bytes over the allowed limit if it would mean that
31    // we send a complete TLS record.
32    bucket: u64,
33    /// Time that the most recent token was added to the bucket.
34    ///
35    /// While this can be thought of as the last time the bucket was partially refilled, it more
36    /// specifically is the time that the most recent token was added. For example if the bucket
37    /// refills one token every 100 ms, and the bucket is refilled at time 510 ms, the bucket would
38    /// gain 5 tokens and the stored time would be 500 ms.
39    added_tokens_at: I,
40}
41
42impl<I: TokenBucketInstant> TokenBucket<I> {
43    /// A new [`TokenBucket`] with a given `rate` in tokens/second and a `max` token limit.
44    ///
45    /// The bucket will initially be full.
46    /// The value `max` is commonly referred to as the "burst".
47    pub fn new(config: &TokenBucketConfig, now: I) -> Self {
48        Self {
49            rate: config.rate,
50            bucket_max: config.bucket_max,
51            bucket: config.bucket_max,
52            added_tokens_at: now,
53        }
54    }
55
56    /// Are there no tokens in the bucket?
57    pub fn is_empty(&self) -> bool {
58        self.bucket == 0
59    }
60
61    /// The maximum number of tokens that this bucket can hold.
62    pub fn max(&self) -> u64 {
63        self.bucket_max
64    }
65
66    /// Remove `count` tokens from the bucket.
67    pub fn drain(&mut self, count: u64) -> Result<BecameEmpty, InsufficientTokensError> {
68        Ok(self.claim(count)?.commit())
69    }
70
71    /// Claim a number of tokens.
72    ///
73    /// The claim will be held by the returned [`ClaimedTokens`], and committed when dropped.
74    ///
75    /// **Note:** You probably want to call [`refill()`](Self::refill) before this.
76    // Since the `ClaimedTokens` holds a `&mut` to this `TokenBucket`, we don't need to worry about
77    // other calls accessing the `TokenBucket` before the `ClaimedTokens` are committed.
78    pub fn claim(&mut self, count: u64) -> Result<ClaimedTokens<I>, InsufficientTokensError> {
79        if count > self.bucket {
80            return Err(InsufficientTokensError {
81                available: self.bucket,
82            });
83        }
84
85        Ok(ClaimedTokens::new(self, count))
86    }
87
88    /// Adjust the refill rate and max tokens of the bucket.
89    ///
90    /// The token bucket is refilled up to `now` before changing the rate.
91    ///
92    /// If the new max is smaller than the existing number of tokens,
93    /// the number of tokens will be reduced to the new max.
94    ///
95    /// A rate and/or max of 0 is allowed.
96    pub fn adjust(&mut self, now: I, config: &TokenBucketConfig) {
97        // make sure that the bucket gets the tokens it is owed before we change the rate
98        self.refill(now);
99
100        // If the old rate was small (or 0), the `refill()` might not have updated
101        // `added_tokens_at`.
102        //
103        // For example if the bucket has a rate of 0 and was last refilled 10 seconds ago, it will
104        // not have gained any tokens in the last 10 seconds. If we were to only update the rate to
105        // 100 tokens/second now, the bucket would immediately become eligible to refill 1000
106        // tokens. We only want the rate change to become effective now, not in the past, so we
107        // ensure this by resetting `added_tokens_at`.
108        self.added_tokens_at = std::cmp::max(self.added_tokens_at, now);
109
110        self.rate = config.rate;
111        self.bucket_max = config.bucket_max;
112        self.bucket = std::cmp::min(self.bucket, self.bucket_max);
113    }
114
115    /// An estimated time at which the bucket will have `tokens` available.
116    ///
117    /// It is not guaranteed that `tokens` will be available at the returned time.
118    ///
119    /// If there are already enough tokens available, a time in the past may be returned.
120    ///
121    /// A value of `None` implies "never",
122    /// for example if the refill rate is 0,
123    /// the bucket max is too small,
124    /// or the time is too large to be represented as an `I`.
125    pub fn tokens_available_at(&self, tokens: u64) -> Result<I, NeverEnoughTokensError> {
126        let tokens_needed = tokens.saturating_sub(self.bucket);
127
128        // check if we currently have enough tokens before considering refilling
129        if tokens_needed == 0 {
130            return Ok(self.added_tokens_at);
131        }
132
133        // if the rate is 0, we'll never get more tokens
134        if self.rate == 0 {
135            return Err(NeverEnoughTokensError::ZeroRate);
136        }
137
138        // if more tokens are wanted than the capacity of the bucket, we'll never get enough
139        if tokens > self.bucket_max {
140            return Err(NeverEnoughTokensError::ExceedsMaxTokens);
141        }
142
143        // this may underestimate the time if either argument is very large
144        let time_needed = Self::tokens_to_duration(tokens_needed, self.rate)
145            .ok_or(NeverEnoughTokensError::ZeroRate)?;
146
147        // Always return at least 1 microsecond since:
148        // 1. We don't want to return `Duration::ZERO` if the tokens aren't ready,
149        //    which may occur if the rate is very large (<1 ns/token).
150        // 2. Clocks generally don't operate at <1 us resolution.
151        let time_needed = std::cmp::max(time_needed, Duration::from_micros(1));
152
153        self.added_tokens_at
154            .checked_add(time_needed)
155            .ok_or(NeverEnoughTokensError::InstantNotRepresentable)
156    }
157
158    /// Refill the bucket.
159    pub fn refill(&mut self, now: I) -> BecameNonEmpty {
160        // time since we last added tokens
161        let elapsed = now.saturating_duration_since(self.added_tokens_at);
162
163        // If we exceeded the threshold, update the timestamp and return.
164        // This is taken from tor, which has the comment below:
165        //
166        // > Skip over updates that include an overflow or a very large jump. This can happen for
167        // > platform specific reasons, such as the old ~48 day windows timer.
168        //
169        // It's unclear if this type of OS bug is still common enough that this check is useful,
170        // but it shouldn't hurt.
171        if elapsed > I::IGNORE_THRESHOLD {
172            tracing::debug!(
173                "Time jump of {elapsed:?} is larger than {:?}; not refilling token bucket",
174                I::IGNORE_THRESHOLD,
175            );
176            self.added_tokens_at = now;
177            return BecameNonEmpty::No;
178        }
179
180        let old_bucket = self.bucket;
181
182        // Compute how much we should increment the bucket by.
183        // This may be underestimated in some cases.
184        let bucket_inc = Self::duration_to_tokens(elapsed, self.rate);
185
186        self.bucket = std::cmp::min(self.bucket_max, self.bucket.saturating_add(bucket_inc));
187
188        // Compute how much we should increment the `last_added_tokens` time by. This avoids
189        // drifting if the `bucket_inc` was underestimated, and avoids rounding errors which could
190        // cause the token bucket to effectively use a lower rate. For example if the rate was
191        // "1 token / sec" and the elapsed time was "1.2 sec", we only want to refill 1 token and
192        // increment the time by 1 second.
193        //
194        // While the docs for `tokens_to_duration` say that a smaller than expected duration may be
195        // returned, we have a test `test_duration_token_round_trip` which ensures that
196        // `tokens_to_duration` returns the expected value when used with the result from
197        // `duration_to_tokens`.
198        let added_tokens_at_inc =
199            Self::tokens_to_duration(bucket_inc, self.rate).unwrap_or(Duration::ZERO);
200
201        self.added_tokens_at = self
202            .added_tokens_at
203            .checked_add(added_tokens_at_inc)
204            .expect("overflowed time");
205        debug_assert!(self.added_tokens_at <= now);
206
207        if old_bucket == 0 && self.bucket != 0 {
208            BecameNonEmpty::Yes
209        } else {
210            BecameNonEmpty::No
211        }
212    }
213
214    /// How long would it take to refill `tokens` at `rate`?
215    ///
216    /// The result is rounded up to the nearest microsecond.
217    /// If the number of `tokens` is large,
218    /// the result may be much lower than the expected duration due to saturating 64-bit arithmetic.
219    ///
220    /// `None` will be returned if the `rate` is 0.
221    fn tokens_to_duration(tokens: u64, rate: u64) -> Option<Duration> {
222        // Perform the calculation in microseconds rather than nanoseconds since timers typically
223        // have microsecond granularity, and it lowers the chance that the calculation overflows the
224        // `u64::MAX` limit compared to nanoseconds. In the case that the calculation saturates, the
225        // returned duration will be shorter than the real value.
226        //
227        // For example with `tokens = u64::MAX` and `rate = u64::MAX` we'd expect a result of 1
228        // second, but:
229        // u64::MAX.saturating_mul(1000 * 1000).div_ceil(u64::MAX) = 1 microsecond
230        //
231        // The `div_ceil` ensures we always round up to the nearest microsecond.
232        //
233        // dimensional analysis:
234        // (tokens) * (microseconds / second) / (tokens / second) = microseconds
235        if rate == 0 {
236            return None;
237        }
238        let micros = tokens.saturating_mul(1000 * 1000).div_ceil(rate);
239        Some(Duration::from_micros(micros))
240    }
241
242    /// How many tokens would be refilled within `time` at `rate`?
243    ///
244    /// The `time` is truncated to microsecond granularity.
245    /// If the `time` or `rate` is large,
246    /// the result may be much lower than the expected number of tokens due to saturating 64-bit
247    /// arithmetic.
248    fn duration_to_tokens(time: Duration, rate: u64) -> u64 {
249        let micros = u64::try_from(time.as_micros()).unwrap_or(u64::MAX);
250        // dimensional analysis:
251        // (tokens / second) * (microseconds) / (microseconds / second) = tokens
252        rate.saturating_mul(micros) / (1000 * 1000)
253    }
254}
255
256/// The refill rate and token max for a [`TokenBucket`].
257#[derive(Clone, Debug)]
258#[allow(clippy::exhaustive_structs)] // constructed directly by callers configuring the bucket
259pub struct TokenBucketConfig {
260    /// The refill rate in tokens/second.
261    pub rate: u64,
262    /// The max amount of tokens in the bucket.
263    /// Commonly referred to as the "burst".
264    pub bucket_max: u64,
265}
266
267/// A handle to a number of claimed tokens.
268///
269/// Dropping this handle will commit the claim.
270#[derive(Debug)]
271pub struct ClaimedTokens<'a, I> {
272    /// The bucket that the claim is for.
273    bucket: &'a mut TokenBucket<I>,
274    /// How many tokens to remove from the bucket.
275    count: u64,
276}
277
278impl<'a, I> ClaimedTokens<'a, I> {
279    /// Create a new [`ClaimedTokens`] that will remove `count` tokens from the token `bucket` when
280    /// dropped.
281    fn new(bucket: &'a mut TokenBucket<I>, count: u64) -> Self {
282        Self { bucket, count }
283    }
284
285    /// Commit the claimed tokens.
286    ///
287    /// This is equivalent to just dropping the [`ClaimedTokens`], but also returns whether the
288    /// token bucket became empty or not.
289    pub fn commit(mut self) -> BecameEmpty {
290        self.commit_impl()
291    }
292
293    /// Reduce the claim to a fewer number of tokens than the original claim.
294    ///
295    /// If `count` is larger than the original claim, an error will be returned containing the
296    /// current number of claimed tokens.
297    pub fn reduce(&mut self, count: u64) -> Result<(), InsufficientTokensError> {
298        if count > self.count {
299            return Err(InsufficientTokensError {
300                available: self.count,
301            });
302        }
303
304        self.count = count;
305        Ok(())
306    }
307
308    /// Discard the claim.
309    ///
310    /// This does not remove any tokens from the token bucket.
311    pub fn discard(mut self) {
312        self.count = 0;
313    }
314
315    /// The commit implementation.
316    ///
317    /// After calling [`commit_impl()`](Self::commit_impl),
318    /// the [`ClaimedTokens`] should no longer be used and should be dropped immediately.
319    fn commit_impl(&mut self) -> BecameEmpty {
320        // when the `ClaimedTokens` was created by the `TokenBucket`, it should have ensured that
321        // there were enough tokens
322        self.bucket.bucket = self
323            .bucket
324            .bucket
325            .checked_sub(self.count)
326            .unwrap_or_else(|| {
327                panic!(
328                    "claim commit failed: {}, {}",
329                    self.count, self.bucket.bucket,
330                )
331            });
332
333        // when `self` is dropped some time after this function ends,
334        // we don't want to subtract again
335        self.count = 0;
336
337        if self.bucket.bucket > 0 {
338            BecameEmpty::No
339        } else {
340            BecameEmpty::Yes
341        }
342    }
343}
344
345impl<'a, I> std::ops::Drop for ClaimedTokens<'a, I> {
346    fn drop(&mut self) {
347        self.commit_impl();
348    }
349}
350
351/// An operation was attempted to reduce the number of tokens,
352/// but the token bucket did not have enough tokens.
353#[derive(Copy, Clone, Debug, PartialEq, Eq, thiserror::Error)]
354#[error("insufficient tokens for operation")]
355pub struct InsufficientTokensError {
356    /// The number of tokens that are available to drain/commit.
357    available: u64,
358}
359
360impl InsufficientTokensError {
361    /// Get the number of tokens that are available to drain/commit.
362    pub fn available_tokens(&self) -> u64 {
363        self.available
364    }
365}
366
367/// The token bucket will never have the requested number of tokens.
368#[derive(Copy, Clone, Debug, PartialEq, Eq, thiserror::Error)]
369#[allow(clippy::exhaustive_enums)] // callers exhaustively match on these variants
370#[error("there will never be enough tokens for this operation")]
371pub enum NeverEnoughTokensError {
372    /// The request exceeds the bucket's maximum number of tokens.
373    ExceedsMaxTokens,
374    /// The refill rate is 0.
375    ZeroRate,
376    /// The time is not representable.
377    ///
378    /// For example the if the rate is low and a large number of tokens were requested, it may be
379    /// too far in the future that it cannot be represented as a time value.
380    InstantNotRepresentable,
381}
382
383/// The token bucket transitioned from "empty" to "non-empty".
384#[derive(Copy, Clone, Debug, PartialEq, Eq)]
385#[allow(clippy::exhaustive_enums)] // a simple yes/no status that callers match on
386pub enum BecameNonEmpty {
387    /// Token bucket became non-empty.
388    Yes,
389    /// Token bucket remains empty.
390    No,
391}
392
393/// The token bucket transitioned from "non-empty" to "empty".
394#[derive(Copy, Clone, Debug, PartialEq, Eq)]
395#[allow(clippy::exhaustive_enums)] // a simple yes/no status that callers match on
396pub enum BecameEmpty {
397    /// Token bucket became empty.
398    Yes,
399    /// Token bucket remains non-empty.
400    No,
401}
402
403/// Any type implementing this must be represented as a measurement of a monotonically nondecreasing
404/// clock.
405pub trait TokenBucketInstant: Copy + Clone + Debug + PartialEq + Eq + PartialOrd + Ord {
406    /// An unrealistically large time jump.
407    ///
408    /// We assume that any time change larger than this indicates a broken monotonic clock,
409    /// and the bucket will not be refilled.
410    const IGNORE_THRESHOLD: Duration;
411
412    /// See [`Instant::checked_add`].
413    fn checked_add(&self, duration: Duration) -> Option<Self>;
414
415    /// See [`Instant::checked_duration_since`].
416    fn checked_duration_since(&self, earlier: Self) -> Option<Duration>;
417
418    /// See [`Instant::saturating_duration_since`].
419    fn saturating_duration_since(&self, earlier: Self) -> Duration {
420        self.checked_duration_since(earlier).unwrap_or_default()
421    }
422}
423
424impl TokenBucketInstant for Instant {
425    // This value is taken from tor (see `elapsed_ticks <= UINT32_MAX/4` in
426    // `src/lib/evloop/token_bucket.c`).
427    const IGNORE_THRESHOLD: Duration = Duration::from_secs((u32::MAX / 4) as u64);
428
429    #[inline]
430    fn checked_add(&self, duration: Duration) -> Option<Self> {
431        self.checked_add(duration)
432    }
433
434    #[inline]
435    fn checked_duration_since(&self, earlier: Self) -> Option<Duration> {
436        self.checked_duration_since(earlier)
437    }
438
439    #[inline]
440    fn saturating_duration_since(&self, earlier: Self) -> Duration {
441        self.saturating_duration_since(earlier)
442    }
443}
444
445#[cfg(test)]
446mod test {
447    #![allow(clippy::unwrap_used)]
448
449    use super::*;
450
451    use rand::RngExt;
452
453    #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
454    struct MillisTimestamp(u64);
455
456    impl TokenBucketInstant for MillisTimestamp {
457        const IGNORE_THRESHOLD: Duration = Duration::from_millis(1_000_000_000);
458
459        fn checked_add(&self, duration: Duration) -> Option<Self> {
460            let duration = u64::try_from(duration.as_millis()).ok()?;
461            self.0.checked_add(duration).map(Self)
462        }
463
464        fn checked_duration_since(&self, earlier: Self) -> Option<Duration> {
465            Some(Duration::from_millis(self.0.checked_sub(earlier.0)?))
466        }
467    }
468
469    #[test]
470    fn adjust_now() {
471        let time = MillisTimestamp(100);
472
473        let config = TokenBucketConfig {
474            rate: 10,
475            bucket_max: 100,
476        };
477        let mut tb = TokenBucket::new(&config, time);
478        assert_eq!(tb.bucket, 100);
479        assert_eq!(tb.bucket_max, 100);
480        assert_eq!(tb.rate, 10);
481
482        tb.adjust(
483            time,
484            &TokenBucketConfig {
485                rate: 20,
486                bucket_max: 100,
487            },
488        );
489        assert_eq!(tb.bucket, 100);
490        assert_eq!(tb.bucket_max, 100);
491
492        tb.adjust(
493            time,
494            &TokenBucketConfig {
495                rate: 20,
496                bucket_max: 40,
497            },
498        );
499        assert_eq!(tb.bucket, 40);
500        assert_eq!(tb.bucket_max, 40);
501
502        tb.adjust(
503            time,
504            &TokenBucketConfig {
505                rate: 20,
506                bucket_max: 100,
507            },
508        );
509        assert_eq!(tb.bucket, 40);
510        assert_eq!(tb.bucket_max, 100);
511
512        tb.adjust(
513            time,
514            &TokenBucketConfig {
515                rate: 200,
516                bucket_max: 100,
517            },
518        );
519        assert_eq!(tb.bucket, 40);
520        assert_eq!(tb.bucket_max, 100);
521        assert_eq!(tb.rate, 200);
522    }
523
524    #[test]
525    fn adjust_future() {
526        let config = TokenBucketConfig {
527            rate: 10,
528            bucket_max: 100,
529        };
530        let mut tb = TokenBucket::new(&config, MillisTimestamp(100));
531        assert_eq!(tb.bucket, 100);
532        assert_eq!(tb.bucket_max, 100);
533        assert_eq!(tb.rate, 10);
534
535        // at 300 ms: increase rate and max; bucket was already full, so doesn't gain any tokens
536        tb.adjust(
537            MillisTimestamp(300),
538            &TokenBucketConfig {
539                rate: 20,
540                bucket_max: 200,
541            },
542        );
543        assert_eq!(tb.bucket, 100);
544        assert_eq!(tb.bucket_max, 200);
545
546        // at 500 ms: no changes; bucket is refilled during `adjust()`, so gains 4 tokens
547        tb.adjust(
548            MillisTimestamp(500),
549            &TokenBucketConfig {
550                rate: 20,
551                bucket_max: 200,
552            },
553        );
554        assert_eq!(tb.bucket, 104);
555        assert_eq!(tb.bucket_max, 200);
556
557        // at 700 ms: lower rate and max; bucket is lowered to new max, so loses 4 tokens
558        tb.adjust(
559            MillisTimestamp(700),
560            &TokenBucketConfig {
561                rate: 0,
562                bucket_max: 100,
563            },
564        );
565        assert_eq!(tb.bucket, 100);
566        assert_eq!(tb.bucket_max, 100);
567
568        // at 900 ms: raise rate and max; rate was previously 0 so doesn't gain any tokens
569        tb.adjust(
570            MillisTimestamp(900),
571            &TokenBucketConfig {
572                rate: 100,
573                bucket_max: 200,
574            },
575        );
576        assert_eq!(tb.bucket, 100);
577        assert_eq!(tb.bucket_max, 200);
578    }
579
580    #[test]
581    fn adjust_zero() {
582        let time = MillisTimestamp(100);
583
584        let config = TokenBucketConfig {
585            rate: 10,
586            bucket_max: 100,
587        };
588
589        let mut tb = TokenBucket::new(&config, time);
590        tb.adjust(
591            time,
592            &TokenBucketConfig {
593                rate: 0,
594                bucket_max: 200,
595            },
596        );
597        assert_eq!(tb.bucket, 100);
598        assert_eq!(tb.bucket_max, 200);
599        assert_eq!(tb.rate, 0);
600        // bucket should not increase
601        tb.refill(MillisTimestamp(10_000_000));
602        assert_eq!(tb.bucket, 100);
603
604        let mut tb = TokenBucket::new(&config, time);
605        tb.adjust(
606            time,
607            &TokenBucketConfig {
608                rate: 10,
609                bucket_max: 0,
610            },
611        );
612        assert_eq!(tb.bucket, 0);
613        assert_eq!(tb.bucket_max, 0);
614        assert_eq!(tb.rate, 10);
615        // bucket should stay empty
616        tb.refill(MillisTimestamp(10_000_000));
617        assert_eq!(tb.bucket, 0);
618
619        let mut tb = TokenBucket::new(&config, time);
620        tb.adjust(
621            time,
622            &TokenBucketConfig {
623                rate: 0,
624                bucket_max: 0,
625            },
626        );
627        assert_eq!(tb.bucket, 0);
628        assert_eq!(tb.bucket_max, 0);
629        assert_eq!(tb.rate, 0);
630        // bucket should stay empty
631        tb.refill(MillisTimestamp(10_000_000));
632        assert_eq!(tb.bucket, 0);
633    }
634
635    #[test]
636    fn is_empty() {
637        // increases 10 tokens/second (one every 100 ms)
638        let config = TokenBucketConfig {
639            rate: 10,
640            bucket_max: 100,
641        };
642        let mut tb = TokenBucket::new(&config, MillisTimestamp(100));
643        assert!(!tb.is_empty());
644
645        tb.drain(99).unwrap();
646        assert!(!tb.is_empty());
647
648        tb.drain(1).unwrap();
649        assert!(tb.is_empty());
650
651        tb.refill(MillisTimestamp(199));
652        assert!(tb.is_empty());
653
654        tb.refill(MillisTimestamp(200));
655        assert!(!tb.is_empty());
656    }
657
658    #[test]
659    fn correctness() {
660        // increases 10 tokens/second (one every 100 ms)
661        let config = TokenBucketConfig {
662            rate: 10,
663            bucket_max: 100,
664        };
665        let mut tb = TokenBucket::new(&config, MillisTimestamp(100));
666
667        tb.drain(50).unwrap();
668        assert_eq!(tb.bucket, 50);
669
670        tb.refill(MillisTimestamp(1100));
671        assert_eq!(tb.bucket, 60);
672
673        tb.drain(50).unwrap();
674        assert_eq!(tb.bucket, 10);
675
676        tb.refill(MillisTimestamp(2100));
677        assert_eq!(tb.bucket, 20);
678
679        tb.refill(MillisTimestamp(2101));
680        assert_eq!(tb.bucket, 20);
681        tb.refill(MillisTimestamp(2199));
682        assert_eq!(tb.bucket, 20);
683        tb.refill(MillisTimestamp(2200));
684        assert_eq!(tb.bucket, 21);
685    }
686
687    #[test]
688    fn rounding() {
689        // increases 10 tokens/second (one every 100 ms)
690        let config = TokenBucketConfig {
691            rate: 10,
692            bucket_max: 100,
693        };
694        let mut tb = TokenBucket::new(&config, MillisTimestamp(0));
695        tb.drain(100).unwrap();
696
697        // ensure that refilling at 150 ms does not change the `added_tokens_at` time to 150 ms,
698        // otherwise the next refill wouldn't occur until 250 ms instead of 200 ms
699        tb.refill(MillisTimestamp(99));
700        assert_eq!(tb.bucket, 0);
701        tb.refill(MillisTimestamp(150));
702        assert_eq!(tb.bucket, 1);
703        tb.refill(MillisTimestamp(199));
704        assert_eq!(tb.bucket, 1);
705        tb.refill(MillisTimestamp(200));
706        assert_eq!(tb.bucket, 2);
707    }
708
709    #[test]
710    fn tokens_available_at() {
711        // increases 10 tokens/second (one every 100 ms)
712        let config = TokenBucketConfig {
713            rate: 10,
714            bucket_max: 100,
715        };
716        let mut tb = TokenBucket::new(&config, MillisTimestamp(0));
717
718        // bucket is empty at 0 ms, next token at 100 ms
719        tb.drain(100).unwrap();
720
721        assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(0)));
722        assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(100)));
723        assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(200)));
724
725        // bucket is still empty at 40 ms, next token at 100 ms
726        tb.refill(MillisTimestamp(40));
727
728        assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(0)));
729        assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(100)));
730        assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(200)));
731
732        // bucket has 1 token at 100 ms, next token at 200 ms
733        tb.refill(MillisTimestamp(100));
734
735        assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(100)));
736        assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(100)));
737        assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(200)));
738
739        // bucket is empty at 100 ms, next token at 200 ms
740        tb.drain(1).unwrap();
741
742        assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(100)));
743        assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(200)));
744        assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(300)));
745
746        // bucket is empty at 140 ms, next token at 200 ms
747        tb.refill(MillisTimestamp(140));
748
749        assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(100)));
750        assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(200)));
751        assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(300)));
752
753        // bucket has 1 token at 210 ms, next token at 300 ms
754        tb.refill(MillisTimestamp(210));
755
756        assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(200)));
757        assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(200)));
758        assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(300)));
759
760        use NeverEnoughTokensError as NETE;
761
762        assert_eq!(tb.tokens_available_at(100), Ok(MillisTimestamp(10_100)));
763        assert_eq!(tb.tokens_available_at(101), Err(NETE::ExceedsMaxTokens));
764        assert_eq!(
765            tb.tokens_available_at(u64::MAX),
766            Err(NETE::ExceedsMaxTokens),
767        );
768
769        // set the refill rate to 0; note that adjusting the rate also resets `added_tokens_at`
770        tb.adjust(
771            MillisTimestamp(210),
772            &TokenBucketConfig {
773                rate: 0,
774                bucket_max: 100,
775            },
776        );
777
778        assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(210)));
779        assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(210)));
780        assert_eq!(tb.tokens_available_at(2), Err(NETE::ZeroRate));
781    }
782
783    #[test]
784    fn test_duration_token_round_trip() {
785        let tokens_to_duration = TokenBucket::<Instant>::tokens_to_duration;
786        let duration_to_tokens = TokenBucket::<Instant>::duration_to_tokens;
787
788        // start with some hand-picked cases
789        let mut duration_rate_pairs = vec![
790            (Duration::from_nanos(0), 1),
791            (Duration::from_nanos(1), 1),
792            (Duration::from_micros(2), 1),
793            (Duration::MAX, 1),
794            (Duration::from_nanos(0), 3),
795            (Duration::from_nanos(1), 3),
796            (Duration::from_micros(2), 3),
797            (Duration::MAX, 3),
798            (Duration::from_nanos(0), 1000),
799            (Duration::from_nanos(1), 1000),
800            (Duration::from_micros(2), 1000),
801            (Duration::MAX, 1000),
802            (Duration::from_nanos(0), u64::MAX),
803            (Duration::from_nanos(1), u64::MAX),
804            (Duration::from_micros(2), u64::MAX),
805            (Duration::MAX, u64::MAX),
806        ];
807
808        let mut rng = rand::rng();
809
810        // add some fuzzing
811        for _ in 0..10_000 {
812            let secs = rng.random();
813            let nanos = rng.random();
814            // Duration::new() may panic, so just skip if there's a panic rather than trying to
815            // write our own logic to avoid the panic in the first place
816            let Ok(random_duration) = std::panic::catch_unwind(|| Duration::new(secs, nanos))
817            else {
818                continue;
819            };
820            let random_rate = rng.random();
821            duration_rate_pairs.push((random_duration, random_rate));
822        }
823
824        // for various combinations of durations and rates, we ensure that after an initial
825        // `duration_to_tokens` calculation which may truncate, a round-trip between
826        // `tokens_to_duration` and `duration_to_tokens` isn't lossy
827        for (original_duration, rate) in duration_rate_pairs {
828            // this may give a smaller number of tokens than expected (see docs on
829            // `TokenBucket::duration_to_tokens`)
830            let tokens = duration_to_tokens(original_duration, rate);
831
832            // we want to ensure that converting these `tokens` to a duration and then back to
833            // tokens is not lossy, which implies that `tokens_to_duration` is returning the
834            // expected value and not a truncated value due to saturating arithmetic
835            let duration = tokens_to_duration(tokens, rate).unwrap();
836            assert_eq!(tokens, duration_to_tokens(duration, rate));
837        }
838    }
839}