salvo_rate_limiter/
lib.rs

1//! Rate limiter middleware for Salvo.
2//!
3//! Rate Limiter middleware is used to limiting the amount of requests to the server
4//! from a particular IP or id within a time period.
5//!
6//! [`RateIssuer`] is used to issue a key to request, your can define your custom `RateIssuer`.
7//! If you want just identify user by IP address, you can use [`RemoteIpIssuer`].
8//!
9//! [`QuotaGetter`] is used to get quota for every key.
10//!
11//! [`RateGuard`] is strategy to verify is the request exceeded quota.
12//!
13//! Read more: <https://salvo.rs>
14#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
15#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
16#![cfg_attr(docsrs, feature(doc_cfg))]
17
18use std::borrow::Borrow;
19use std::error::Error as StdError;
20use std::fmt::{self, Debug, Formatter};
21use std::hash::Hash;
22
23use salvo_core::conn::SocketAddr;
24use salvo_core::handler::{Skipper, none_skipper};
25use salvo_core::http::{HeaderValue, Request, Response, StatusCode, StatusError};
26use salvo_core::{Depot, FlowCtrl, Handler, async_trait};
27
28mod quota;
29pub use quota::{BasicQuota, CelledQuota, QuotaGetter};
30#[macro_use]
31mod cfg;
32
33cfg_feature! {
34    #![feature = "moka-store"]
35
36    mod moka_store;
37    pub use moka_store::MokaStore;
38}
39
40cfg_feature! {
41    #![feature = "fixed-guard"]
42
43    mod fixed_guard;
44    pub use fixed_guard::FixedGuard;
45}
46
47cfg_feature! {
48    #![feature = "sliding-guard"]
49
50    mod sliding_guard;
51    pub use sliding_guard::SlidingGuard;
52}
53
54/// Issuer is used to identify every request.
55pub trait RateIssuer: Send + Sync + 'static {
56    /// The key is used to identify the rate limit.
57    type Key: Hash + Eq + Send + Sync + 'static;
58    /// Issue a new key for the request.
59    fn issue(
60        &self,
61        req: &mut Request,
62        depot: &Depot,
63    ) -> impl Future<Output = Option<Self::Key>> + Send;
64}
65impl<F, K> RateIssuer for F
66where
67    F: Fn(&mut Request, &Depot) -> Option<K> + Send + Sync + 'static,
68    K: Hash + Eq + Send + Sync + 'static,
69{
70    type Key = K;
71    async fn issue(&self, req: &mut Request, depot: &Depot) -> Option<Self::Key> {
72        (self)(req, depot)
73    }
74}
75
76/// Identify user by IP address.
77#[derive(Debug)]
78pub struct RemoteIpIssuer;
79impl RateIssuer for RemoteIpIssuer {
80    type Key = String;
81    async fn issue(&self, req: &mut Request, _depot: &Depot) -> Option<Self::Key> {
82        match req.remote_addr() {
83            SocketAddr::IPv4(addr) => Some(addr.ip().to_string()),
84            SocketAddr::IPv6(addr) => Some(addr.ip().to_string()),
85            _ => None,
86        }
87    }
88}
89
90/// `RateGuard` is strategy to verify is the request exceeded quota
91pub trait RateGuard: Clone + Send + Sync + 'static {
92    /// The quota for the rate limit.
93    type Quota: Clone + Send + Sync + 'static;
94    /// Verify is current request exceed the quota.
95    fn verify(&mut self, quota: &Self::Quota) -> impl Future<Output = bool> + Send;
96
97    /// Returns the remaining quota.
98    fn remaining(&self, quota: &Self::Quota) -> impl Future<Output = usize> + Send;
99
100    /// Returns the reset time.
101    fn reset(&self, quota: &Self::Quota) -> impl Future<Output = i64> + Send;
102
103    /// Returns the limit.
104    fn limit(&self, quota: &Self::Quota) -> impl Future<Output = usize> + Send;
105}
106
107/// `RateStore` is used to store rate limit data.
108pub trait RateStore: Send + Sync + 'static {
109    /// Error type for RateStore.
110    type Error: StdError;
111    /// Key
112    type Key: Hash + Eq + Send + Clone + 'static;
113    /// Saved guard.
114    type Guard;
115    /// Get the guard from the store.
116    fn load_guard<Q>(
117        &self,
118        key: &Q,
119        refer: &Self::Guard,
120    ) -> impl Future<Output = Result<Self::Guard, Self::Error>> + Send
121    where
122        Self::Key: Borrow<Q>,
123        Q: Hash + Eq + Sync;
124    /// Save the guard from the store.
125    fn save_guard(
126        &self,
127        key: Self::Key,
128        guard: Self::Guard,
129    ) -> impl Future<Output = Result<(), Self::Error>> + Send;
130}
131
132/// `RateLimiter` is the main struct to used limit user request.
133pub struct RateLimiter<G, S, I, Q> {
134    guard: G,
135    store: S,
136    issuer: I,
137    quota_getter: Q,
138    add_headers: bool,
139    skipper: Box<dyn Skipper>,
140}
141impl<G, S, I, Q> Debug for RateLimiter<G, S, I, Q>
142where
143    G: Debug,
144    S: Debug,
145    I: Debug,
146    Q: Debug,
147{
148    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
149        f.debug_struct("RateLimiter")
150            .field("guard", &self.guard)
151            .field("store", &self.store)
152            .field("issuer", &self.issuer)
153            .field("quota_getter", &self.quota_getter)
154            .field("add_headers", &self.add_headers)
155            .finish()
156    }
157}
158
159impl<G: RateGuard, S: RateStore, I: RateIssuer, P: QuotaGetter<I::Key>> RateLimiter<G, S, I, P> {
160    /// Create a new `RateLimiter`
161    #[inline]
162    #[must_use]
163    pub fn new(guard: G, store: S, issuer: I, quota_getter: P) -> Self {
164        Self {
165            guard,
166            store,
167            issuer,
168            quota_getter,
169            add_headers: false,
170            skipper: Box::new(none_skipper),
171        }
172    }
173
174    /// Sets skipper and returns new `RateLimiter`.
175    #[inline]
176    #[must_use]
177    pub fn with_skipper(mut self, skipper: impl Skipper) -> Self {
178        self.skipper = Box::new(skipper);
179        self
180    }
181
182    /// Sets `add_headers` and returns new `RateLimiter`.
183    /// If `add_headers` is true, the rate limit headers will be added to the response.
184    #[inline]
185    #[must_use]
186    pub fn add_headers(mut self, add_headers: bool) -> Self {
187        self.add_headers = add_headers;
188        self
189    }
190}
191
192#[async_trait]
193impl<G, S, I, P> Handler for RateLimiter<G, S, I, P>
194where
195    G: RateGuard<Quota = P::Quota>,
196    S: RateStore<Key = I::Key, Guard = G>,
197    P: QuotaGetter<I::Key>,
198    I: RateIssuer,
199{
200    async fn handle(
201        &self,
202        req: &mut Request,
203        depot: &mut Depot,
204        res: &mut Response,
205        ctrl: &mut FlowCtrl,
206    ) {
207        if self.skipper.skipped(req, depot) {
208            return;
209        }
210        let Some(key) = self.issuer.issue(req, depot).await else {
211            res.render(StatusError::bad_request().brief("Invalid identifier."));
212            ctrl.skip_rest();
213            return;
214        };
215        let quota = match self.quota_getter.get(&key).await {
216            Ok(quota) => quota,
217            Err(e) => {
218                tracing::error!(error = ?e, "RateLimiter error: {}", e);
219                res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
220                ctrl.skip_rest();
221                return;
222            }
223        };
224        let mut guard = match self.store.load_guard(&key, &self.guard).await {
225            Ok(guard) => guard,
226            Err(e) => {
227                tracing::error!(error = ?e, "RateLimiter error: {}", e);
228                res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
229                ctrl.skip_rest();
230                return;
231            }
232        };
233        let verified = guard.verify(&quota).await;
234
235        if self.add_headers {
236            res.headers_mut().insert(
237                "X-RateLimit-Limit",
238                HeaderValue::from_str(&guard.limit(&quota).await.to_string())
239                    .expect("Invalid header value"),
240            );
241            res.headers_mut().insert(
242                "X-RateLimit-Remaining",
243                HeaderValue::from_str(&(guard.remaining(&quota).await).to_string())
244                    .expect("Invalid header value"),
245            );
246            res.headers_mut().insert(
247                "X-RateLimit-Reset",
248                HeaderValue::from_str(&guard.reset(&quota).await.to_string())
249                    .expect("Invalid header value"),
250            );
251        }
252        if !verified {
253            res.status_code(StatusCode::TOO_MANY_REQUESTS);
254            ctrl.skip_rest();
255        }
256        if let Err(e) = self.store.save_guard(key, guard).await {
257            tracing::error!(error = ?e, "RateLimiter save guard failed");
258        }
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use std::collections::HashMap;
265    use std::sync::LazyLock;
266
267    use salvo_core::Error;
268    use salvo_core::prelude::*;
269    use salvo_core::test::{ResponseExt, TestClient};
270
271    use super::*;
272
273    struct UserIssuer;
274    impl RateIssuer for UserIssuer {
275        type Key = String;
276        async fn issue(&self, req: &mut Request, _depot: &Depot) -> Option<Self::Key> {
277            req.query::<Self::Key>("user")
278        }
279    }
280
281    #[handler]
282    async fn limited() -> &'static str {
283        "Limited page"
284    }
285
286    #[tokio::test]
287    async fn test_fixed_dynamic_quota() {
288        static USER_QUOTAS: LazyLock<HashMap<String, BasicQuota>> = LazyLock::new(|| {
289            let mut map = HashMap::new();
290            map.insert("user1".into(), BasicQuota::per_second(1));
291            map.insert("user2".into(), BasicQuota::set_seconds(1, 5));
292            map
293        });
294
295        struct CustomQuotaGetter;
296        impl QuotaGetter<String> for CustomQuotaGetter {
297            type Quota = BasicQuota;
298            type Error = Error;
299
300            async fn get<Q>(&self, key: &Q) -> Result<Self::Quota, Self::Error>
301            where
302                String: Borrow<Q>,
303                Q: Hash + Eq + Sync,
304            {
305                USER_QUOTAS
306                    .get(key)
307                    .cloned()
308                    .ok_or_else(|| Error::other("user not found"))
309            }
310        }
311        let limiter = RateLimiter::new(
312            FixedGuard::default(),
313            MokaStore::default(),
314            UserIssuer,
315            CustomQuotaGetter,
316        );
317        let router = Router::new().push(Router::with_path("limited").hoop(limiter).get(limited));
318        let service = Service::new(router);
319
320        let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user1")
321            .send(&service)
322            .await;
323        assert_eq!(response.status_code, Some(StatusCode::OK));
324        assert_eq!(response.take_string().await.unwrap(), "Limited page");
325
326        let response = TestClient::get("http://127.0.0.1:8698/limited?user=user1")
327            .send(&service)
328            .await;
329        assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
330
331        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
332
333        let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user1")
334            .send(&service)
335            .await;
336        assert_eq!(response.status_code, Some(StatusCode::OK));
337        assert_eq!(response.take_string().await.unwrap(), "Limited page");
338
339        let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
340            .send(&service)
341            .await;
342        assert_eq!(response.status_code, Some(StatusCode::OK));
343        assert_eq!(response.take_string().await.unwrap(), "Limited page");
344
345        let response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
346            .send(&service)
347            .await;
348        assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
349
350        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
351
352        let response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
353            .send(&service)
354            .await;
355        assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
356
357        tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
358
359        let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
360            .send(&service)
361            .await;
362        assert_eq!(response.status_code, Some(StatusCode::OK));
363        assert_eq!(response.take_string().await.unwrap(), "Limited page");
364    }
365
366    #[tokio::test]
367    async fn test_sliding_dynamic_quota() {
368        static USER_QUOTAS: LazyLock<HashMap<String, CelledQuota>> = LazyLock::new(|| {
369            let mut map = HashMap::new();
370            map.insert("user1".into(), CelledQuota::per_second(1, 1));
371            map.insert("user2".into(), CelledQuota::set_seconds(1, 1, 5));
372            map
373        });
374
375        struct CustomQuotaGetter;
376        impl QuotaGetter<String> for CustomQuotaGetter {
377            type Quota = CelledQuota;
378            type Error = Error;
379
380            async fn get<Q>(&self, key: &Q) -> Result<Self::Quota, Self::Error>
381            where
382                String: Borrow<Q>,
383                Q: Hash + Eq + Sync,
384            {
385                USER_QUOTAS
386                    .get(key)
387                    .cloned()
388                    .ok_or_else(|| Error::other("user not found"))
389            }
390        }
391        let limiter = RateLimiter::new(
392            SlidingGuard::default(),
393            MokaStore::default(),
394            UserIssuer,
395            CustomQuotaGetter,
396        );
397        let router = Router::new().push(Router::with_path("limited").hoop(limiter).get(limited));
398        let service = Service::new(router);
399
400        let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user1")
401            .send(&service)
402            .await;
403        assert_eq!(response.status_code, Some(StatusCode::OK));
404        assert_eq!(response.take_string().await.unwrap(), "Limited page");
405
406        let response = TestClient::get("http://127.0.0.1:8698/limited?user=user1")
407            .send(&service)
408            .await;
409        assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
410
411        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
412
413        let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user1")
414            .send(&service)
415            .await;
416        assert_eq!(response.status_code, Some(StatusCode::OK));
417        assert_eq!(response.take_string().await.unwrap(), "Limited page");
418
419        let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
420            .send(&service)
421            .await;
422        assert_eq!(response.status_code, Some(StatusCode::OK));
423        assert_eq!(response.take_string().await.unwrap(), "Limited page");
424
425        let response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
426            .send(&service)
427            .await;
428        assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
429
430        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
431
432        let response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
433            .send(&service)
434            .await;
435        assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
436
437        tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
438
439        let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
440            .send(&service)
441            .await;
442        assert_eq!(response.status_code, Some(StatusCode::OK));
443        assert_eq!(response.take_string().await.unwrap(), "Limited page");
444    }
445}