use crate::time::{
timer::{self, Provider as _},
Timer, Timestamp,
};
use core::time::Duration;
#[derive(Debug)]
pub struct TokenBucket {
current: u64,
refill_amount: u64,
refill_interval: Duration,
refill_timer: Timer,
max: u64,
}
impl Default for TokenBucket {
#[inline]
fn default() -> Self {
Self::builder().build()
}
}
impl TokenBucket {
#[inline]
pub fn builder() -> Builder {
Builder::default()
}
#[inline]
pub fn take(&mut self, amount: u64, now: Timestamp) -> u64 {
if amount == 0 {
self.on_timeout(now);
return 0;
}
if self.current < amount {
self.on_timeout(now);
}
let credits = amount.min(self.current);
self.current -= credits;
self.on_timeout(now);
credits
}
#[inline]
pub fn set_refill_interval(&mut self, new_interval: Duration) {
if self.refill_interval == new_interval {
return;
}
let prev_interval = core::mem::replace(&mut self.refill_interval, new_interval);
if let Some(target) = self.refill_timer.next_expiration() {
if let Some(now) = target.checked_sub(prev_interval) {
self.refill_timer.set(now + new_interval);
}
}
}
#[inline]
pub fn on_timeout(&mut self, now: Timestamp) {
while self.current < self.max {
if let Some(target) = self.refill_timer.next_expiration() {
if !target.has_elapsed(now) {
break;
}
self.current = self
.max
.min(self.current.saturating_add(self.refill_amount));
if self.current == self.max {
self.refill_timer.cancel();
break;
}
self.refill_timer.set(target + self.refill_interval);
} else {
self.refill_timer.set(now + self.refill_interval);
break;
}
}
self.invariants();
}
#[inline]
pub fn cancel(&mut self) {
self.refill_timer.cancel();
}
#[inline]
fn invariants(&self) {
if cfg!(debug_assertions) {
assert!(self.current <= self.max);
assert_eq!(
self.refill_timer.is_armed(),
self.current < self.max,
"timer should be armed ({}) if current ({}) is less than max ({})",
self.refill_timer.is_armed(),
self.current,
self.max,
);
}
}
}
impl timer::Provider for TokenBucket {
#[inline]
fn timers<Q: timer::Query>(&self, query: &mut Q) -> timer::Result {
self.refill_timer.timers(query)?;
Ok(())
}
}
pub struct Builder {
max: u64,
refill_interval: Duration,
refill_amount: u64,
}
impl Default for Builder {
fn default() -> Self {
Self {
max: 100,
refill_amount: 5,
refill_interval: Duration::from_secs(1),
}
}
}
impl Builder {
#[inline]
pub fn with_max(mut self, max: u64) -> Self {
self.max = max;
self
}
#[inline]
pub fn with_refill_amount(mut self, amount: u64) -> Self {
self.refill_amount = amount;
self
}
#[inline]
pub fn with_refill_interval(mut self, interval: Duration) -> Self {
self.refill_interval = interval;
self
}
#[inline]
pub fn build(self) -> TokenBucket {
let Self {
max,
refill_interval,
refill_amount,
} = self;
TokenBucket {
current: max,
max,
refill_amount,
refill_interval,
refill_timer: Default::default(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::time::{testing::Clock, Clock as _};
#[test]
fn example_test() {
let mut bucket = TokenBucket::default();
let mut clock = Clock::default();
assert_eq!(bucket.take(1, clock.get_time()), 1);
assert!(bucket.refill_timer.is_armed());
assert_eq!(bucket.take(100, clock.get_time()), 99);
assert_eq!(bucket.take(1, clock.get_time()), 0);
clock.inc_by(Duration::from_secs(1));
assert_eq!(bucket.take(100, clock.get_time()), 5);
assert!(bucket.refill_timer.is_armed());
clock.inc_by(Duration::from_secs(3));
assert_eq!(bucket.take(100, clock.get_time()), 15);
assert!(bucket.refill_timer.is_armed());
}
}