use core::time::Duration;
use http_rate::Quota;
use crate::service::Service;
pub struct RateLimit(Quota);
macro_rules! constructor {
($method: tt) => {
#[doc = concat!("Construct a RateLimit for a number of cells ",stringify!($method)," period. The given number of cells is")]
pub fn $method(max_burst: u32) -> Self {
Self(Quota::$method(max_burst))
}
};
}
impl RateLimit {
constructor!(per_second);
constructor!(per_minute);
constructor!(per_hour);
pub fn with_period(replenish_1_per: Duration) -> Self {
Self(Quota::with_period(replenish_1_per).unwrap())
}
}
impl<S, E> Service<Result<S, E>> for RateLimit {
type Response = service::RateLimitService<S>;
type Error = E;
async fn call(&self, res: Result<S, E>) -> Result<Self::Response, Self::Error> {
res.map(|service| service::RateLimitService {
service,
rate_limit: http_rate::RateLimit::new(self.0),
})
}
}
mod service {
use core::convert::Infallible;
use crate::{
body::ResponseBody,
error::Error,
http::WebResponse,
service::{ready::ReadyService, Service},
WebContext,
};
pub struct RateLimitService<S> {
pub(super) service: S,
pub(super) rate_limit: http_rate::RateLimit,
}
impl<'r, C, B, S, ResB> Service<WebContext<'r, C, B>> for RateLimitService<S>
where
S: for<'r2> Service<WebContext<'r2, C, B>, Response = WebResponse<ResB>, Error = Error<C>>,
{
type Response = WebResponse<ResB>;
type Error = Error<C>;
async fn call(&self, ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
let headers = ctx.req().headers();
let addr = ctx.req().body().socket_addr();
let snap = self.rate_limit.rate_limit(headers, addr).map_err(Error::from_service)?;
self.service.call(ctx).await.map(|mut res| {
snap.extend_response(&mut res);
res
})
}
}
impl<'r, C, B> Service<WebContext<'r, C, B>> for http_rate::TooManyRequests {
type Response = WebResponse;
type Error = Infallible;
async fn call(&self, ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
let mut res = ctx.into_response(ResponseBody::empty());
self.extend_response(&mut res);
Ok(res)
}
}
impl<S> ReadyService for RateLimitService<S>
where
S: ReadyService,
{
type Ready = S::Ready;
#[inline]
async fn ready(&self) -> Self::Ready {
self.service.ready().await
}
}
}