tide_governor/lib.rs
1//! A [tide] middleware that implements rate-limiting using [governor].
2//! # Example
3//! ```rust
4//! use tide_governor::GovernorMiddleware;
5//! use std::env;
6//!
7//! #[async_std::main]
8//! async fn main() -> tide::Result<()> {
9//! let mut app = tide::new();
10//! app.at("/")
11//! .with(GovernorMiddleware::per_minute(4)?)
12//! .get(|_| async move { todo!() });
13//! app.at("/foo/:bar")
14//! .with(GovernorMiddleware::per_hour(360)?)
15//! .put(|_| async move { todo!() });
16//!
17//! app.listen(format!("http://localhost:{}", env::var("PORT")?))
18//! .await?;
19//! Ok(())
20//! }
21//! ```
22//! [tide]: https://github.com/http-rs/tide
23//! [governor]: https://github.com/antifuchs/governor
24
25// TODO: figure out how to add jitter support using `governor::Jitter`.
26// TODO: add usage examples (both in the docs and in an examples directory).
27// TODO: add unit tests.
28use governor::{
29 clock::{Clock, DefaultClock},
30 state::keyed::DefaultKeyedStateStore,
31 Quota, RateLimiter,
32};
33use lazy_static::lazy_static;
34use std::{
35 convert::TryInto,
36 error::Error,
37 net::{IpAddr, SocketAddr},
38 num::NonZeroU32,
39 sync::Arc,
40 time::Duration,
41};
42use tide::{
43 http::StatusCode,
44 log::{debug, trace},
45 utils::async_trait,
46 Middleware, Next, Request, Response, Result,
47};
48
49lazy_static! {
50 static ref CLOCK: DefaultClock = DefaultClock::default();
51}
52
53/// Once the rate limit has been reached, the middleware will respond with
54/// status code 429 (too many requests) and a `Retry-After` header with the amount
55/// of time that needs to pass before another request will be allowed.
56#[derive(Debug, Clone)]
57pub struct GovernorMiddleware {
58 limiter: Arc<RateLimiter<IpAddr, DefaultKeyedStateStore<IpAddr>, DefaultClock>>,
59}
60
61impl GovernorMiddleware {
62 /// Constructs a rate-limiting middleware from a [`Duration`] that allows one request in the given time interval.
63 ///
64 /// If the time interval is zero, returns `None`.
65 #[must_use]
66 pub fn with_period(duration: Duration) -> Option<Self> {
67 Some(Self {
68 limiter: Arc::new(RateLimiter::<IpAddr, _, _>::keyed(Quota::with_period(
69 duration,
70 )?)),
71 })
72 }
73
74 /// Constructs a rate-limiting middleware that allows a specified number of requests every second.
75 ///
76 /// Returns an error if `times` can't be converted into a [`NonZeroU32`].
77 pub fn per_second<T>(times: T) -> Result<Self>
78 where
79 T: TryInto<NonZeroU32>,
80 T::Error: Error + Send + Sync + 'static,
81 {
82 Ok(Self {
83 limiter: Arc::new(RateLimiter::<IpAddr, _, _>::keyed(Quota::per_second(
84 times.try_into()?,
85 ))),
86 })
87 }
88
89 /// Constructs a rate-limiting middleware that allows a specified number of requests every minute.
90 ///
91 /// Returns an error if `times` can't be converted into a [`NonZeroU32`].
92 pub fn per_minute<T>(times: T) -> Result<Self>
93 where
94 T: TryInto<NonZeroU32>,
95 T::Error: Error + Send + Sync + 'static,
96 {
97 Ok(Self {
98 limiter: Arc::new(RateLimiter::<IpAddr, _, _>::keyed(Quota::per_minute(
99 times.try_into()?,
100 ))),
101 })
102 }
103
104 /// Constructs a rate-limiting middleware that allows a specified number of requests every hour.
105 ///
106 /// Returns an error if `times` can't be converted into a [`NonZeroU32`].
107 pub fn per_hour<T>(times: T) -> Result<Self>
108 where
109 T: TryInto<NonZeroU32>,
110 T::Error: Error + Send + Sync + 'static,
111 {
112 Ok(Self {
113 limiter: Arc::new(RateLimiter::<IpAddr, _, _>::keyed(Quota::per_hour(
114 times.try_into()?,
115 ))),
116 })
117 }
118}
119
120#[async_trait]
121impl<State: Clone + Send + Sync + 'static> Middleware<State> for GovernorMiddleware {
122 async fn handle(&self, req: Request<State>, next: Next<'_, State>) -> tide::Result {
123 let remote = req.remote().ok_or_else(|| {
124 tide::Error::from_str(
125 StatusCode::InternalServerError,
126 "failed to get request remote address",
127 )
128 })?;
129 let remote: IpAddr = match remote.parse::<SocketAddr>() {
130 Ok(r) => r.ip(),
131 Err(_) => remote.parse()?,
132 };
133 trace!("remote: {}", remote);
134
135 match self.limiter.check_key(&remote) {
136 Ok(_) => {
137 debug!("allowing remote {}", remote);
138 Ok(next.run(req).await)
139 }
140 Err(negative) => {
141 let wait_time = negative.wait_time_from(CLOCK.now());
142 let res = Response::builder(StatusCode::TooManyRequests)
143 .header(
144 tide::http::headers::RETRY_AFTER,
145 wait_time.as_secs().to_string(),
146 )
147 .build();
148 debug!(
149 "blocking address {} for {} seconds",
150 remote,
151 wait_time.as_secs()
152 );
153 Ok(res)
154 }
155 }
156 }
157}