xitca_web/middleware/
rate_limit.rs

1//! client ip address based rate limiting.
2
3use core::time::Duration;
4
5use http_rate::Quota;
6
7use crate::service::Service;
8
9/// builder for client ip address based rate limiting middleware.
10///
11/// # Examples
12/// ```rust
13/// # use xitca_web::{handler::handler_service, middleware::rate_limit::RateLimit, route::get, App, WebContext};
14/// App::new()
15///     .at("/", get(handler_service(|| async { "hello,world!" })))
16///     # .at("/infer", handler_service(|_: &WebContext<'_>| async{ "infer type" }))
17///     // rate limit to 60 rps for one ip address.
18///     .enclosed(RateLimit::per_minute(60));
19/// ```
20pub struct RateLimit(Quota);
21
22macro_rules! constructor {
23    ($method: tt) => {
24        #[doc = concat!("Construct a RateLimit for a number of cells ",stringify!($method)," period. The given number of cells is")]
25        /// also assumed to be the maximum burst size.
26        ///
27        /// # Panics
28        /// - When max_burst is zero.
29        pub fn $method(max_burst: u32) -> Self {
30            Self(Quota::$method(max_burst))
31        }
32    };
33}
34
35impl RateLimit {
36    constructor!(per_second);
37    constructor!(per_minute);
38    constructor!(per_hour);
39
40    /// Construct a RateLimit that replenishes one cell in a given
41    /// interval.
42    ///
43    /// # Panics
44    /// - When the Duration is zero.
45    pub fn with_period(replenish_1_per: Duration) -> Self {
46        Self(Quota::with_period(replenish_1_per).unwrap())
47    }
48}
49
50impl<S, E> Service<Result<S, E>> for RateLimit {
51    type Response = service::RateLimitService<S>;
52    type Error = E;
53
54    async fn call(&self, res: Result<S, E>) -> Result<Self::Response, Self::Error> {
55        res.map(|service| service::RateLimitService {
56            service,
57            rate_limit: http_rate::RateLimit::new(self.0),
58        })
59    }
60}
61
62mod service {
63    use core::convert::Infallible;
64
65    use crate::{
66        WebContext,
67        body::ResponseBody,
68        error::Error,
69        http::WebResponse,
70        service::{Service, ready::ReadyService},
71    };
72
73    pub struct RateLimitService<S> {
74        pub(super) service: S,
75        pub(super) rate_limit: http_rate::RateLimit,
76    }
77
78    impl<'r, C, B, S, ResB> Service<WebContext<'r, C, B>> for RateLimitService<S>
79    where
80        S: for<'r2> Service<WebContext<'r2, C, B>, Response = WebResponse<ResB>, Error = Error>,
81    {
82        type Response = WebResponse<ResB>;
83        type Error = Error;
84
85        async fn call(&self, ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
86            let headers = ctx.req().headers();
87            let addr = ctx.req().body().socket_addr();
88            let snap = self.rate_limit.rate_limit(headers, addr).map_err(Error::from_service)?;
89            self.service.call(ctx).await.map(|mut res| {
90                snap.extend_response(&mut res);
91                res
92            })
93        }
94    }
95
96    impl<'r, C, B> Service<WebContext<'r, C, B>> for http_rate::TooManyRequests {
97        type Response = WebResponse;
98        type Error = Infallible;
99
100        async fn call(&self, ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
101            let mut res = ctx.into_response(ResponseBody::empty());
102            self.extend_response(&mut res);
103            Ok(res)
104        }
105    }
106
107    impl<S> ReadyService for RateLimitService<S>
108    where
109        S: ReadyService,
110    {
111        type Ready = S::Ready;
112
113        #[inline]
114        async fn ready(&self) -> Self::Ready {
115            self.service.ready().await
116        }
117    }
118}