tide_governor/
lib.rs

1//! A [tide] middleware that implements rate-limiting using [governor].
2//! # Example
3//! ```rust
4//! use tide_governor::GovernorMiddleware;
5//! use std::env;
6//!
7//! #[async_std::main]
8//! async fn main() -> tide::Result<()> {
9//!     let mut app = tide::new();
10//!     app.at("/")
11//!         .with(GovernorMiddleware::per_minute(4)?)
12//!         .get(|_| async move { todo!() });
13//!     app.at("/foo/:bar")
14//!         .with(GovernorMiddleware::per_hour(360)?)
15//!         .put(|_| async move { todo!() });
16//!
17//!     app.listen(format!("http://localhost:{}", env::var("PORT")?))
18//!         .await?;
19//!     Ok(())
20//! }
21//! ```
22//! [tide]: https://github.com/http-rs/tide
23//! [governor]: https://github.com/antifuchs/governor
24
25// TODO: figure out how to add jitter support using `governor::Jitter`.
26// TODO: add usage examples (both in the docs and in an examples directory).
27// TODO: add unit tests.
28use governor::{
29    clock::{Clock, DefaultClock},
30    state::keyed::DefaultKeyedStateStore,
31    Quota, RateLimiter,
32};
33use lazy_static::lazy_static;
34use std::{
35    convert::TryInto,
36    error::Error,
37    net::{IpAddr, SocketAddr},
38    num::NonZeroU32,
39    sync::Arc,
40    time::Duration,
41};
42use tide::{
43    http::StatusCode,
44    log::{debug, trace},
45    utils::async_trait,
46    Middleware, Next, Request, Response, Result,
47};
48
49lazy_static! {
50    static ref CLOCK: DefaultClock = DefaultClock::default();
51}
52
53/// Once the rate limit has been reached, the middleware will respond with
54/// status code 429 (too many requests) and a `Retry-After` header with the amount
55/// of time that needs to pass before another request will be allowed.
56#[derive(Debug, Clone)]
57pub struct GovernorMiddleware {
58    limiter: Arc<RateLimiter<IpAddr, DefaultKeyedStateStore<IpAddr>, DefaultClock>>,
59}
60
61impl GovernorMiddleware {
62    /// Constructs a rate-limiting middleware from a [`Duration`] that allows one request in the given time interval.
63    ///
64    /// If the time interval is zero, returns `None`.
65    #[must_use]
66    pub fn with_period(duration: Duration) -> Option<Self> {
67        Some(Self {
68            limiter: Arc::new(RateLimiter::<IpAddr, _, _>::keyed(Quota::with_period(
69                duration,
70            )?)),
71        })
72    }
73
74    /// Constructs a rate-limiting middleware that allows a specified number of requests every second.
75    ///
76    /// Returns an error if `times` can't be converted into a [`NonZeroU32`].
77    pub fn per_second<T>(times: T) -> Result<Self>
78    where
79        T: TryInto<NonZeroU32>,
80        T::Error: Error + Send + Sync + 'static,
81    {
82        Ok(Self {
83            limiter: Arc::new(RateLimiter::<IpAddr, _, _>::keyed(Quota::per_second(
84                times.try_into()?,
85            ))),
86        })
87    }
88
89    /// Constructs a rate-limiting middleware that allows a specified number of requests every minute.
90    ///
91    /// Returns an error if `times` can't be converted into a [`NonZeroU32`].
92    pub fn per_minute<T>(times: T) -> Result<Self>
93    where
94        T: TryInto<NonZeroU32>,
95        T::Error: Error + Send + Sync + 'static,
96    {
97        Ok(Self {
98            limiter: Arc::new(RateLimiter::<IpAddr, _, _>::keyed(Quota::per_minute(
99                times.try_into()?,
100            ))),
101        })
102    }
103
104    /// Constructs a rate-limiting middleware that allows a specified number of requests every hour.
105    ///
106    /// Returns an error if `times` can't be converted into a [`NonZeroU32`].
107    pub fn per_hour<T>(times: T) -> Result<Self>
108    where
109        T: TryInto<NonZeroU32>,
110        T::Error: Error + Send + Sync + 'static,
111    {
112        Ok(Self {
113            limiter: Arc::new(RateLimiter::<IpAddr, _, _>::keyed(Quota::per_hour(
114                times.try_into()?,
115            ))),
116        })
117    }
118}
119
120#[async_trait]
121impl<State: Clone + Send + Sync + 'static> Middleware<State> for GovernorMiddleware {
122    async fn handle(&self, req: Request<State>, next: Next<'_, State>) -> tide::Result {
123        let remote = req.remote().ok_or_else(|| {
124            tide::Error::from_str(
125                StatusCode::InternalServerError,
126                "failed to get request remote address",
127            )
128        })?;
129        let remote: IpAddr = match remote.parse::<SocketAddr>() {
130            Ok(r) => r.ip(),
131            Err(_) => remote.parse()?,
132        };
133        trace!("remote: {}", remote);
134
135        match self.limiter.check_key(&remote) {
136            Ok(_) => {
137                debug!("allowing remote {}", remote);
138                Ok(next.run(req).await)
139            }
140            Err(negative) => {
141                let wait_time = negative.wait_time_from(CLOCK.now());
142                let res = Response::builder(StatusCode::TooManyRequests)
143                    .header(
144                        tide::http::headers::RETRY_AFTER,
145                        wait_time.as_secs().to_string(),
146                    )
147                    .build();
148                debug!(
149                    "blocking address {} for {} seconds",
150                    remote,
151                    wait_time.as_secs()
152                );
153                Ok(res)
154            }
155        }
156    }
157}