1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
//! A [tide] middleware that implements rate-limiting using [governor].
//! # Example
//! ```rust
//! use tide_governor::GovernorMiddleware;
//! use std::env;
//!
//! #[async_std::main]
//! async fn main() -> tide::Result<()> {
//!     let mut app = tide::new();
//!     app.at("/")
//!         .with(GovernorMiddleware::per_minute(4)?)
//!         .get(|_| async move { todo!() });
//!     app.at("/foo/:bar")
//!         .with(GovernorMiddleware::per_hour(360)?)
//!         .put(|_| async move { todo!() });
//!
//!     app.listen(format!("http://localhost:{}", env::var("PORT")?))
//!         .await?;
//!     Ok(())
//! }
//! ```
//! [tide]: https://github.com/http-rs/tide
//! [governor]: https://github.com/antifuchs/governor

// TODO: figure out how to add jitter support using `governor::Jitter`.
// TODO: add usage examples (both in the docs and in an examples directory).
// TODO: add unit tests.
use governor::{
    clock::{Clock, DefaultClock},
    state::keyed::DefaultKeyedStateStore,
    Quota, RateLimiter,
};
use lazy_static::lazy_static;
use std::{
    convert::TryInto,
    error::Error,
    net::{IpAddr, SocketAddr},
    num::NonZeroU32,
    sync::Arc,
    time::Duration,
};
use tide::{
    http::StatusCode,
    log::{debug, trace},
    utils::async_trait,
    Middleware, Next, Request, Response, Result,
};

lazy_static! {
    static ref CLOCK: DefaultClock = DefaultClock::default();
}

/// Once the rate limit has been reached, the middleware will respond with
/// status code 429 (too many requests) and a `Retry-After` header with the amount
/// of time that needs to pass before another request will be allowed.
#[derive(Debug, Clone)]
pub struct GovernorMiddleware {
    limiter: Arc<RateLimiter<IpAddr, DefaultKeyedStateStore<IpAddr>, DefaultClock>>,
}

impl GovernorMiddleware {
    /// Constructs a rate-limiting middleware from a [`Duration`] that allows one request in the given time interval.
    ///
    /// If the time interval is zero, returns `None`.
    #[must_use]
    pub fn with_period(duration: Duration) -> Option<Self> {
        Some(Self {
            limiter: Arc::new(RateLimiter::<IpAddr, _, _>::keyed(Quota::with_period(
                duration,
            )?)),
        })
    }

    /// Constructs a rate-limiting middleware that allows a specified number of requests every second.
    ///
    /// Returns an error if `times` can't be converted into a [`NonZeroU32`].
    pub fn per_second<T>(times: T) -> Result<Self>
    where
        T: TryInto<NonZeroU32>,
        T::Error: Error + Send + Sync + 'static,
    {
        Ok(Self {
            limiter: Arc::new(RateLimiter::<IpAddr, _, _>::keyed(Quota::per_second(
                times.try_into()?,
            ))),
        })
    }

    /// Constructs a rate-limiting middleware that allows a specified number of requests every minute.
    ///
    /// Returns an error if `times` can't be converted into a [`NonZeroU32`].
    pub fn per_minute<T>(times: T) -> Result<Self>
    where
        T: TryInto<NonZeroU32>,
        T::Error: Error + Send + Sync + 'static,
    {
        Ok(Self {
            limiter: Arc::new(RateLimiter::<IpAddr, _, _>::keyed(Quota::per_minute(
                times.try_into()?,
            ))),
        })
    }

    /// Constructs a rate-limiting middleware that allows a specified number of requests every hour.
    ///
    /// Returns an error if `times` can't be converted into a [`NonZeroU32`].
    pub fn per_hour<T>(times: T) -> Result<Self>
    where
        T: TryInto<NonZeroU32>,
        T::Error: Error + Send + Sync + 'static,
    {
        Ok(Self {
            limiter: Arc::new(RateLimiter::<IpAddr, _, _>::keyed(Quota::per_hour(
                times.try_into()?,
            ))),
        })
    }
}

#[async_trait]
impl<State: Clone + Send + Sync + 'static> Middleware<State> for GovernorMiddleware {
    async fn handle(&self, req: Request<State>, next: Next<'_, State>) -> tide::Result {
        let remote: SocketAddr = req
            .remote()
            .ok_or_else(|| {
                tide::Error::from_str(
                    StatusCode::InternalServerError,
                    "failed to get request remote address",
                )
            })?
            .parse()?;
        trace!("remote: {}", remote);
        match self.limiter.check_key(&remote.ip()) {
            Ok(_) => {
                debug!("allowing remote {}", remote);
                Ok(next.run(req).await)
            }
            Err(negative) => {
                let wait_time = negative.wait_time_from(CLOCK.now());
                let res = Response::builder(StatusCode::TooManyRequests)
                    .header(
                        tide::http::headers::RETRY_AFTER,
                        wait_time.as_secs().to_string(),
                    )
                    .build();
                debug!(
                    "blocking address {} for {} seconds",
                    remote.ip(),
                    wait_time.as_secs()
                );
                Ok(res)
            }
        }
    }
}