xitca_web/middleware/
rate_limit.rs1use core::time::Duration;
4
5use http_rate::Quota;
6
7use crate::service::Service;
8
9pub struct RateLimit(Quota);
21
22macro_rules! constructor {
23 ($method: tt) => {
24 #[doc = concat!("Construct a RateLimit for a number of cells ",stringify!($method)," period. The given number of cells is")]
25 pub fn $method(max_burst: u32) -> Self {
30 Self(Quota::$method(max_burst))
31 }
32 };
33}
34
35impl RateLimit {
36 constructor!(per_second);
37 constructor!(per_minute);
38 constructor!(per_hour);
39
40 pub fn with_period(replenish_1_per: Duration) -> Self {
46 Self(Quota::with_period(replenish_1_per).unwrap())
47 }
48}
49
50impl<S, E> Service<Result<S, E>> for RateLimit {
51 type Response = service::RateLimitService<S>;
52 type Error = E;
53
54 async fn call(&self, res: Result<S, E>) -> Result<Self::Response, Self::Error> {
55 res.map(|service| service::RateLimitService {
56 service,
57 rate_limit: http_rate::RateLimit::new(self.0),
58 })
59 }
60}
61
62mod service {
63 use core::convert::Infallible;
64
65 use crate::{
66 WebContext,
67 body::ResponseBody,
68 error::Error,
69 http::WebResponse,
70 service::{Service, ready::ReadyService},
71 };
72
73 pub struct RateLimitService<S> {
74 pub(super) service: S,
75 pub(super) rate_limit: http_rate::RateLimit,
76 }
77
78 impl<'r, C, B, S, ResB> Service<WebContext<'r, C, B>> for RateLimitService<S>
79 where
80 S: for<'r2> Service<WebContext<'r2, C, B>, Response = WebResponse<ResB>, Error = Error>,
81 {
82 type Response = WebResponse<ResB>;
83 type Error = Error;
84
85 async fn call(&self, ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
86 let headers = ctx.req().headers();
87 let addr = ctx.req().body().socket_addr();
88 let snap = self.rate_limit.rate_limit(headers, addr).map_err(Error::from_service)?;
89 self.service.call(ctx).await.map(|mut res| {
90 snap.extend_response(&mut res);
91 res
92 })
93 }
94 }
95
96 impl<'r, C, B> Service<WebContext<'r, C, B>> for http_rate::TooManyRequests {
97 type Response = WebResponse;
98 type Error = Infallible;
99
100 async fn call(&self, ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
101 let mut res = ctx.into_response(ResponseBody::empty());
102 self.extend_response(&mut res);
103 Ok(res)
104 }
105 }
106
107 impl<S> ReadyService for RateLimitService<S>
108 where
109 S: ReadyService,
110 {
111 type Ready = S::Ready;
112
113 #[inline]
114 async fn ready(&self) -> Self::Ready {
115 self.service.ready().await
116 }
117 }
118}