surf_governor/lib.rs
1#![forbid(unsafe_code, future_incompatible)]
2#![deny(
3 missing_docs,
4 missing_debug_implementations,
5 missing_copy_implementations,
6 nonstandard_style,
7 unused_qualifications,
8 unused_import_braces,
9 unused_extern_crates,
10 trivial_casts,
11 trivial_numeric_casts
12)]
13#![cfg_attr(docsrs, feature(doc_cfg))]
14//! A [surf] middleware that implements rate-limiting using [governor].
15//! The majority of this has been copied from [tide-governor](https://github.com/ohmree/tide-governor)
16//! # Example
17//! ```no_run
18//! use surf_governor::GovernorMiddleware;
19//! use surf::{Client, Request, http::Method};
20//! use url::Url;
21//!
22//! #[async_std::main]
23//! async fn main() -> surf::Result<()> {
24//! let req = Request::new(Method::Get, Url::parse("https://example.api")?);
25//! // Construct Surf client with a governor
26//! let client = Client::new().with(GovernorMiddleware::per_second(30)?);
27//! let res = client.send(req).await?;
28//! Ok(())
29//! }
30//! ```
31//! [surf]: https://github.com/http-rs/surf
32//! [governor]: https://github.com/antifuchs/governor
33
34// TODO: figure out how to add jitter support using `governor::Jitter`.
35// TODO: add more unit tests.
36use governor::{
37 clock::{Clock, DefaultClock},
38 state::keyed::DefaultKeyedStateStore,
39 Quota, RateLimiter,
40};
41use http_types::{headers, Response, StatusCode};
42use lazy_static::lazy_static;
43use std::{convert::TryInto, error::Error, num::NonZeroU32, sync::Arc, time::Duration};
44use surf::{middleware::Next, Client, Request, Result};
45
46lazy_static! {
47 static ref CLOCK: DefaultClock = DefaultClock::default();
48}
49
50/// Once the rate limit has been reached, the middleware will respond with
51/// status code 429 (too many requests) and a `Retry-After` header with the amount
52/// of time that needs to pass before another request will be allowed.
53#[derive(Debug, Clone)]
54pub struct GovernorMiddleware {
55 limiter: Arc<RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>>,
56}
57
58impl GovernorMiddleware {
59 /// Constructs a rate-limiting middleware from a [`Duration`] that allows one request in the given time interval.
60 ///
61 /// If the time interval is zero, returns `None`.
62 /// # Example
63 /// This constructs a client with a governor set to 1 requests every 5 nanoseconds
64 /// ```no_run
65 /// use surf_governor::GovernorMiddleware;
66 /// use surf::{Client, Request, http::Method};
67 /// use url::Url;
68 ///
69 /// use std::time::Duration;
70 ///
71 /// #[async_std::main]
72 /// async fn main() -> surf::Result<()> {
73 /// let req = Request::new(Method::Get, Url::parse("https://example.api")?);
74 /// // Construct Surf client with a governor
75 /// let client = Client::new().with(GovernorMiddleware::with_period(Duration::new(0, 5)).unwrap());
76 /// let res = client.send(req).await?;
77 /// Ok(())
78 /// }
79 /// ```
80 #[must_use]
81 pub fn with_period(duration: Duration) -> Option<Self> {
82 Some(Self {
83 limiter: Arc::new(RateLimiter::<String, _, _>::keyed(Quota::with_period(
84 duration,
85 )?)),
86 })
87 }
88
89 /// Constructs a rate-limiting middleware that allows a specified number of requests every second.
90 ///
91 /// Returns an error if `times` can't be converted into a [`NonZeroU32`].
92 ///
93 /// # Example
94 /// This constructs a client with a governor set to 30 requests per second limit
95 /// ```no_run
96 /// use surf_governor::GovernorMiddleware;
97 /// use surf::{Client, Request, http::Method};
98 /// use url::Url;
99 ///
100 /// #[async_std::main]
101 /// async fn main() -> surf::Result<()> {
102 /// let req = Request::new(Method::Get, Url::parse("https://example.api")?);
103 /// // Construct Surf client with a governor
104 /// let client = Client::new().with(GovernorMiddleware::per_second(30)?);
105 /// let res = client.send(req).await?;
106 /// Ok(())
107 /// }
108 /// ```
109 pub fn per_second<T>(times: T) -> Result<Self>
110 where
111 T: TryInto<NonZeroU32>,
112 T::Error: Error + Send + Sync + 'static,
113 {
114 Ok(Self {
115 limiter: Arc::new(RateLimiter::<String, _, _>::keyed(Quota::per_second(
116 times.try_into()?,
117 ))),
118 })
119 }
120
121 /// Constructs a rate-limiting middleware that allows a specified number of requests every minute.
122 ///
123 /// Returns an error if `times` can't be converted into a [`NonZeroU32`].
124 ///
125 /// # Example
126 /// This constructs a client with a governor set to 300 requests per minute limit
127 /// ```no_run
128 /// use surf_governor::GovernorMiddleware;
129 /// use surf::{Client, Request, http::Method};
130 /// use url::Url;
131 ///
132 /// #[async_std::main]
133 /// async fn main() -> surf::Result<()> {
134 /// let req = Request::new(Method::Get, Url::parse("https://example.api")?);
135 /// // Construct Surf client with a governor
136 /// let client = Client::new().with(GovernorMiddleware::per_minute(300)?);
137 /// let res = client.send(req).await?;
138 /// Ok(())
139 /// }
140 /// ```
141 pub fn per_minute<T>(times: T) -> Result<Self>
142 where
143 T: TryInto<NonZeroU32>,
144 T::Error: Error + Send + Sync + 'static,
145 {
146 Ok(Self {
147 limiter: Arc::new(RateLimiter::<String, _, _>::keyed(Quota::per_minute(
148 times.try_into()?,
149 ))),
150 })
151 }
152
153 /// Constructs a rate-limiting middleware that allows a specified number of requests every hour.
154 ///
155 /// Returns an error if `times` can't be converted into a [`NonZeroU32`].
156 ///
157 /// # Example
158 /// This constructs a client with a governor set to 3000 requests per hour limit
159 /// ```no_run
160 /// use surf_governor::GovernorMiddleware;
161 /// use surf::{Client, Request, http::Method};
162 /// use url::Url;
163 ///
164 /// #[async_std::main]
165 /// async fn main() -> surf::Result<()> {
166 /// let req = Request::new(Method::Get, Url::parse("https://example.api")?);
167 /// // Construct Surf client with a governor
168 /// let client = Client::new().with(GovernorMiddleware::per_hour(3000)?);
169 /// let res = client.send(req).await?;
170 /// Ok(())
171 /// }
172 /// ```
173 pub fn per_hour<T>(times: T) -> Result<Self>
174 where
175 T: TryInto<NonZeroU32>,
176 T::Error: Error + Send + Sync + 'static,
177 {
178 Ok(Self {
179 limiter: Arc::new(RateLimiter::<String, _, _>::keyed(Quota::per_hour(
180 times.try_into()?,
181 ))),
182 })
183 }
184}
185
186#[surf::utils::async_trait]
187impl surf::middleware::Middleware for GovernorMiddleware {
188 async fn handle(
189 &self,
190 req: Request,
191 client: Client,
192 next: Next<'_>,
193 ) -> std::result::Result<surf::Response, http_types::Error> {
194 match self
195 .limiter
196 .check_key(&req.url().host_str().unwrap().to_string())
197 {
198 Ok(_) => Ok(next.run(req, client).await?),
199 Err(negative) => {
200 let wait_time = negative.wait_time_from(CLOCK.now());
201 let mut res = Response::new(StatusCode::TooManyRequests);
202 res.insert_header(headers::RETRY_AFTER, wait_time.as_secs().to_string());
203 Ok(res.try_into()?)
204 }
205 }
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use crate::GovernorMiddleware;
212 use surf::{http::Method, Client, Request};
213 use url::Url;
214 use wiremock::{matchers::method, Mock, MockServer, ResponseTemplate};
215 #[async_std::test]
216 async fn limits_requests() -> surf::Result<()> {
217 let mock_server = MockServer::start().await;
218 let m = Mock::given(method("GET"))
219 .respond_with(ResponseTemplate::new(200).set_body_string("Hello!".to_string()))
220 .expect(1);
221 let _mock_guard = mock_server.register_as_scoped(m).await;
222 let url = format!("{}/", &mock_server.uri());
223 let req = Request::new(Method::Get, Url::parse(&url).unwrap());
224 let client = Client::new().with(GovernorMiddleware::per_second(1)?);
225 let good_res = client.send(req.clone()).await?;
226 assert_eq!(good_res.status(), 200);
227 let wait_res = client.send(req).await?;
228 assert_eq!(wait_res.status(), 429);
229 Ok(())
230 }
231}