surf_governor/
lib.rs

1#![forbid(unsafe_code, future_incompatible)]
2#![deny(
3    missing_docs,
4    missing_debug_implementations,
5    missing_copy_implementations,
6    nonstandard_style,
7    unused_qualifications,
8    unused_import_braces,
9    unused_extern_crates,
10    trivial_casts,
11    trivial_numeric_casts
12)]
13#![cfg_attr(docsrs, feature(doc_cfg))]
14//! A [surf] middleware that implements rate-limiting using [governor].
15//! The majority of this has been copied from [tide-governor](https://github.com/ohmree/tide-governor)
16//! # Example
17//! ```no_run
18//! use surf_governor::GovernorMiddleware;
19//! use surf::{Client, Request, http::Method};
20//! use url::Url;
21//!
22//! #[async_std::main]
23//! async fn main() -> surf::Result<()> {
24//!     let req = Request::new(Method::Get, Url::parse("https://example.api")?);
25//!     // Construct Surf client with a governor
26//!     let client = Client::new().with(GovernorMiddleware::per_second(30)?);
27//!     let res = client.send(req).await?;
28//!     Ok(())
29//! }
30//! ```
31//! [surf]: https://github.com/http-rs/surf
32//! [governor]: https://github.com/antifuchs/governor
33
34// TODO: figure out how to add jitter support using `governor::Jitter`.
35// TODO: add more unit tests.
36use governor::{
37    clock::{Clock, DefaultClock},
38    state::keyed::DefaultKeyedStateStore,
39    Quota, RateLimiter,
40};
41use http_types::{headers, Response, StatusCode};
42use lazy_static::lazy_static;
43use std::{convert::TryInto, error::Error, num::NonZeroU32, sync::Arc, time::Duration};
44use surf::{middleware::Next, Client, Request, Result};
45
46lazy_static! {
47    static ref CLOCK: DefaultClock = DefaultClock::default();
48}
49
50/// Once the rate limit has been reached, the middleware will respond with
51/// status code 429 (too many requests) and a `Retry-After` header with the amount
52/// of time that needs to pass before another request will be allowed.
53#[derive(Debug, Clone)]
54pub struct GovernorMiddleware {
55    limiter: Arc<RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>>,
56}
57
58impl GovernorMiddleware {
59    /// Constructs a rate-limiting middleware from a [`Duration`] that allows one request in the given time interval.
60    ///
61    /// If the time interval is zero, returns `None`.
62    /// # Example
63    /// This constructs a client with a governor set to 1 requests every 5 nanoseconds
64    /// ```no_run
65    /// use surf_governor::GovernorMiddleware;
66    /// use surf::{Client, Request, http::Method};
67    /// use url::Url;
68    ///
69    /// use std::time::Duration;
70    ///
71    /// #[async_std::main]
72    /// async fn main() -> surf::Result<()> {
73    ///     let req = Request::new(Method::Get, Url::parse("https://example.api")?);
74    ///     // Construct Surf client with a governor
75    ///     let client = Client::new().with(GovernorMiddleware::with_period(Duration::new(0, 5)).unwrap());
76    ///     let res = client.send(req).await?;
77    ///     Ok(())
78    /// }
79    /// ```
80    #[must_use]
81    pub fn with_period(duration: Duration) -> Option<Self> {
82        Some(Self {
83            limiter: Arc::new(RateLimiter::<String, _, _>::keyed(Quota::with_period(
84                duration,
85            )?)),
86        })
87    }
88
89    /// Constructs a rate-limiting middleware that allows a specified number of requests every second.
90    ///
91    /// Returns an error if `times` can't be converted into a [`NonZeroU32`].
92    ///
93    /// # Example
94    /// This constructs a client with a governor set to 30 requests per second limit
95    /// ```no_run
96    /// use surf_governor::GovernorMiddleware;
97    /// use surf::{Client, Request, http::Method};
98    /// use url::Url;
99    ///
100    /// #[async_std::main]
101    /// async fn main() -> surf::Result<()> {
102    ///     let req = Request::new(Method::Get, Url::parse("https://example.api")?);
103    ///     // Construct Surf client with a governor
104    ///     let client = Client::new().with(GovernorMiddleware::per_second(30)?);
105    ///     let res = client.send(req).await?;
106    ///     Ok(())
107    /// }
108    /// ```
109    pub fn per_second<T>(times: T) -> Result<Self>
110    where
111        T: TryInto<NonZeroU32>,
112        T::Error: Error + Send + Sync + 'static,
113    {
114        Ok(Self {
115            limiter: Arc::new(RateLimiter::<String, _, _>::keyed(Quota::per_second(
116                times.try_into()?,
117            ))),
118        })
119    }
120
121    /// Constructs a rate-limiting middleware that allows a specified number of requests every minute.
122    ///
123    /// Returns an error if `times` can't be converted into a [`NonZeroU32`].
124    ///
125    /// # Example
126    /// This constructs a client with a governor set to 300 requests per minute limit
127    /// ```no_run
128    /// use surf_governor::GovernorMiddleware;
129    /// use surf::{Client, Request, http::Method};
130    /// use url::Url;
131    ///
132    /// #[async_std::main]
133    /// async fn main() -> surf::Result<()> {
134    ///     let req = Request::new(Method::Get, Url::parse("https://example.api")?);
135    ///     // Construct Surf client with a governor
136    ///     let client = Client::new().with(GovernorMiddleware::per_minute(300)?);
137    ///     let res = client.send(req).await?;
138    ///     Ok(())
139    /// }
140    /// ```
141    pub fn per_minute<T>(times: T) -> Result<Self>
142    where
143        T: TryInto<NonZeroU32>,
144        T::Error: Error + Send + Sync + 'static,
145    {
146        Ok(Self {
147            limiter: Arc::new(RateLimiter::<String, _, _>::keyed(Quota::per_minute(
148                times.try_into()?,
149            ))),
150        })
151    }
152
153    /// Constructs a rate-limiting middleware that allows a specified number of requests every hour.
154    ///
155    /// Returns an error if `times` can't be converted into a [`NonZeroU32`].
156    ///
157    /// # Example
158    /// This constructs a client with a governor set to 3000 requests per hour limit
159    /// ```no_run
160    /// use surf_governor::GovernorMiddleware;
161    /// use surf::{Client, Request, http::Method};
162    /// use url::Url;
163    ///
164    /// #[async_std::main]
165    /// async fn main() -> surf::Result<()> {
166    ///     let req = Request::new(Method::Get, Url::parse("https://example.api")?);
167    ///     // Construct Surf client with a governor
168    ///     let client = Client::new().with(GovernorMiddleware::per_hour(3000)?);
169    ///     let res = client.send(req).await?;
170    ///     Ok(())
171    /// }
172    /// ```
173    pub fn per_hour<T>(times: T) -> Result<Self>
174    where
175        T: TryInto<NonZeroU32>,
176        T::Error: Error + Send + Sync + 'static,
177    {
178        Ok(Self {
179            limiter: Arc::new(RateLimiter::<String, _, _>::keyed(Quota::per_hour(
180                times.try_into()?,
181            ))),
182        })
183    }
184}
185
186#[surf::utils::async_trait]
187impl surf::middleware::Middleware for GovernorMiddleware {
188    async fn handle(
189        &self,
190        req: Request,
191        client: Client,
192        next: Next<'_>,
193    ) -> std::result::Result<surf::Response, http_types::Error> {
194        match self
195            .limiter
196            .check_key(&req.url().host_str().unwrap().to_string())
197        {
198            Ok(_) => Ok(next.run(req, client).await?),
199            Err(negative) => {
200                let wait_time = negative.wait_time_from(CLOCK.now());
201                let mut res = Response::new(StatusCode::TooManyRequests);
202                res.insert_header(headers::RETRY_AFTER, wait_time.as_secs().to_string());
203                Ok(res.try_into()?)
204            }
205        }
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use crate::GovernorMiddleware;
212    use surf::{http::Method, Client, Request};
213    use url::Url;
214    use wiremock::{matchers::method, Mock, MockServer, ResponseTemplate};
215    #[async_std::test]
216    async fn limits_requests() -> surf::Result<()> {
217        let mock_server = MockServer::start().await;
218        let m = Mock::given(method("GET"))
219            .respond_with(ResponseTemplate::new(200).set_body_string("Hello!".to_string()))
220            .expect(1);
221        let _mock_guard = mock_server.register_as_scoped(m).await;
222        let url = format!("{}/", &mock_server.uri());
223        let req = Request::new(Method::Get, Url::parse(&url).unwrap());
224        let client = Client::new().with(GovernorMiddleware::per_second(1)?);
225        let good_res = client.send(req.clone()).await?;
226        assert_eq!(good_res.status(), 200);
227        let wait_res = client.send(req).await?;
228        assert_eq!(wait_res.status(), 429);
229        Ok(())
230    }
231}