use std::{
future::{poll_fn, Future},
pin::Pin,
task::{Context, Poll},
};
use tokio::time::{sleep, Duration, Instant, Sleep};
const COMMANDS_PER_PERIOD: u8 = 120;
const PERIOD: Duration = Duration::from_secs(60);
#[derive(Debug)]
pub struct CommandRatelimiter {
delay: Pin<Box<Sleep>>,
instants: Vec<Instant>,
}
impl CommandRatelimiter {
pub(crate) async fn new(heartbeat_interval: Duration) -> Self {
let allotted = nonreserved_commands_per_reset(heartbeat_interval);
let mut delay = Box::pin(sleep(Duration::ZERO));
(&mut delay).await;
Self {
delay,
instants: Vec::with_capacity(allotted.into()),
}
}
#[allow(clippy::cast_possible_truncation)]
pub fn available(&self) -> u8 {
let now = Instant::now();
let elapsed_permits = self.instants.partition_point(|&elapsed| elapsed <= now);
let used_permits = self.instants.len() - elapsed_permits;
self.max() - used_permits as u8
}
#[allow(clippy::cast_possible_truncation)]
pub fn max(&self) -> u8 {
self.instants.capacity() as u8
}
pub fn next_available(&self) -> Duration {
self.instants.first().map_or(Duration::ZERO, |elapsed| {
elapsed.saturating_duration_since(Instant::now())
})
}
pub(crate) async fn acquire(&mut self) {
poll_fn(|cx| self.poll_available(cx)).await;
self.instants.push(Instant::now() + PERIOD);
}
pub(crate) fn poll_available(&mut self, cx: &mut Context<'_>) -> Poll<()> {
if self.instants.len() != self.instants.capacity() {
return Poll::Ready(());
}
if !self.delay.is_elapsed() {
return Poll::Pending;
}
let new_deadline = self.instants[0];
if new_deadline > Instant::now() {
tracing::trace!(?new_deadline, old_deadline = ?self.delay.deadline());
self.delay.as_mut().reset(new_deadline);
_ = self.delay.as_mut().poll(cx);
Poll::Pending
} else {
let used_permits = (self.max() - self.available()).into();
self.instants.rotate_right(used_permits);
self.instants.truncate(used_permits);
Poll::Ready(())
}
}
}
fn nonreserved_commands_per_reset(heartbeat_interval: Duration) -> u8 {
const MAX_NONRESERVED_COMMANDS_PER_PERIOD: u8 = COMMANDS_PER_PERIOD - 10;
let heartbeats_per_reset = PERIOD.as_secs_f32() / heartbeat_interval.as_secs_f32();
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
let heartbeats_per_reset = heartbeats_per_reset.ceil() as u8;
let heartbeats_per_reset = heartbeats_per_reset.saturating_add(1);
let nonreserved_commands_per_reset = COMMANDS_PER_PERIOD.saturating_sub(heartbeats_per_reset);
nonreserved_commands_per_reset.max(MAX_NONRESERVED_COMMANDS_PER_PERIOD)
}
#[cfg(test)]
mod tests {
use super::{nonreserved_commands_per_reset, CommandRatelimiter, PERIOD};
use static_assertions::assert_impl_all;
use std::{fmt::Debug, time::Duration};
use tokio::time;
assert_impl_all!(CommandRatelimiter: Debug, Send, Sync);
#[test]
fn nonreserved_commands() {
assert_eq!(
118,
nonreserved_commands_per_reset(Duration::from_secs(u64::MAX))
);
assert_eq!(118, nonreserved_commands_per_reset(Duration::from_secs(60)));
assert_eq!(
117,
nonreserved_commands_per_reset(Duration::from_millis(42_500))
);
assert_eq!(117, nonreserved_commands_per_reset(Duration::from_secs(30)));
assert_eq!(
116,
nonreserved_commands_per_reset(Duration::from_millis(29_999))
);
assert_eq!(110, nonreserved_commands_per_reset(Duration::ZERO));
}
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(60);
#[tokio::test(start_paused = true)]
async fn full_reset() {
let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL).await;
assert_eq!(ratelimiter.available(), ratelimiter.max());
for _ in 0..ratelimiter.max() {
ratelimiter.acquire().await;
}
assert_eq!(ratelimiter.available(), 0);
time::advance(PERIOD - Duration::from_millis(100)).await;
assert_eq!(ratelimiter.available(), 0);
time::advance(Duration::from_millis(100)).await;
assert_eq!(ratelimiter.available(), ratelimiter.max());
}
#[tokio::test(start_paused = true)]
async fn half_reset() {
let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL).await;
assert_eq!(ratelimiter.available(), ratelimiter.max());
for _ in 0..ratelimiter.max() / 2 {
ratelimiter.acquire().await;
}
assert_eq!(ratelimiter.available(), ratelimiter.max() / 2);
time::advance(PERIOD / 2).await;
assert_eq!(ratelimiter.available(), ratelimiter.max() / 2);
for _ in 0..ratelimiter.max() / 2 {
ratelimiter.acquire().await;
}
assert_eq!(ratelimiter.available(), 0);
time::advance(PERIOD / 2).await;
assert_eq!(ratelimiter.available(), ratelimiter.max() / 2);
time::advance(PERIOD / 2).await;
assert_eq!(ratelimiter.available(), ratelimiter.max());
}
#[tokio::test(start_paused = true)]
async fn constant_capacity() {
let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL).await;
let max = ratelimiter.max();
for _ in 0..max {
ratelimiter.acquire().await;
}
assert_eq!(ratelimiter.available(), 0);
ratelimiter.acquire().await;
assert_eq!(max, ratelimiter.max());
}
}