Skip to main content

salvo_rate_limiter/
lib.rs

1//! Rate limiting middleware for Salvo.
2//!
3//! This middleware protects your server from abuse by limiting the number of
4//! requests a client can make within a specified time period. It's essential
5//! for preventing denial-of-service attacks and ensuring fair resource usage.
6//!
7//! # Key Components
8//!
9//! | Component | Purpose |
10//! |-----------|---------|
11//! | [`RateIssuer`] | Identifies clients (by IP, user ID, API key, etc.) |
12//! | [`QuotaGetter`] | Defines rate limits for each client |
13//! | [`RateGuard`] | Implements the limiting algorithm |
14//! | [`RateStore`] | Stores rate limit state |
15//!
16//! # Built-in Implementations
17//!
18//! ## Issuers
19//! - [`RemoteIpIssuer`]: Identifies clients by IP address
20//!
21//! ## Guards (Algorithms)
22//! - `FixedGuard`: Fixed window algorithm (requires `fixed-guard` feature)
23//! - `SlidingGuard`: Sliding window algorithm (requires `sliding-guard` feature)
24//!
25//! ## Stores
26//! - [`MokaStore`]: In-memory store backed by moka (requires `moka-store` feature)
27//!
28//! # Example
29//!
30//! Basic rate limiting by IP address:
31//!
32//! ```ignore
33//! use salvo_rate_limiter::{RateLimiter, RemoteIpIssuer, BasicQuota, FixedGuard, MokaStore};
34//! use salvo_core::prelude::*;
35//!
36//! let limiter = RateLimiter::new(
37//!     FixedGuard::default(),
38//!     MokaStore::default(),
39//!     RemoteIpIssuer,
40//!     BasicQuota::per_minute(100),  // 100 requests per minute
41//! );
42//!
43//! let router = Router::new()
44//!     .hoop(limiter)
45//!     .get(my_handler);
46//! ```
47//!
48//! # Custom Quotas Per User
49//!
50//! Different users can have different rate limits:
51//!
52//! ```ignore
53//! use salvo_rate_limiter::{QuotaGetter, BasicQuota};
54//!
55//! struct TieredQuota;
56//! impl QuotaGetter<String> for TieredQuota {
57//!     type Quota = BasicQuota;
58//!     type Error = salvo_core::Error;
59//!
60//!     async fn get<Q>(&self, user_id: &Q) -> Result<Self::Quota, Self::Error>
61//!     where
62//!         String: std::borrow::Borrow<Q>,
63//!         Q: std::hash::Hash + Eq + Sync,
64//!     {
65//!         // Premium users get higher limits
66//!         if is_premium_user(user_id) {
67//!             Ok(BasicQuota::per_minute(1000))
68//!         } else {
69//!             Ok(BasicQuota::per_minute(60))
70//!         }
71//!     }
72//! }
73//! ```
74//!
75//! # Response Headers
76//!
77//! Enable rate limit headers in responses with `.add_headers(true)`:
78//!
79//! - `X-RateLimit-Limit`: Maximum requests allowed
80//! - `X-RateLimit-Remaining`: Requests remaining in current window
81//! - `X-RateLimit-Reset`: Unix timestamp when the limit resets
82//!
83//! # HTTP Status
84//!
85//! When the limit is exceeded, returns `429 Too Many Requests`.
86//!
87//! Read more: <https://salvo.rs>
88#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
89#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
90#![cfg_attr(docsrs, feature(doc_cfg))]
91
92use std::borrow::Borrow;
93use std::error::Error as StdError;
94use std::fmt::{self, Debug, Formatter};
95use std::hash::Hash;
96use std::net::IpAddr;
97
98use salvo_core::handler::{Skipper, none_skipper};
99use salvo_core::http::{HeaderValue, Request, Response, StatusCode, StatusError};
100use salvo_core::{Depot, FlowCtrl, Handler, async_trait};
101
102mod quota;
103pub use quota::{BasicQuota, CelledQuota, QuotaGetter};
104#[macro_use]
105mod cfg;
106
107cfg_feature! {
108    #![feature = "moka-store"]
109
110    mod moka_store;
111    pub use moka_store::MokaStore;
112}
113
114cfg_feature! {
115    #![feature = "fixed-guard"]
116
117    mod fixed_guard;
118    pub use fixed_guard::FixedGuard;
119}
120
121cfg_feature! {
122    #![feature = "sliding-guard"]
123
124    mod sliding_guard;
125    pub use sliding_guard::SlidingGuard;
126}
127
128/// Issuer is used to identify every request.
129pub trait RateIssuer: Send + Sync + 'static {
130    /// The key is used to identify the rate limit.
131    type Key: Hash + Eq + Send + Sync + 'static;
132    /// Issue a new key for the request.
133    fn issue(
134        &self,
135        req: &mut Request,
136        depot: &Depot,
137    ) -> impl Future<Output = Option<Self::Key>> + Send;
138}
139impl<F, K> RateIssuer for F
140where
141    F: Fn(&mut Request, &Depot) -> Option<K> + Send + Sync + 'static,
142    K: Hash + Eq + Send + Sync + 'static,
143{
144    type Key = K;
145    async fn issue(&self, req: &mut Request, depot: &Depot) -> Option<Self::Key> {
146        (self)(req, depot)
147    }
148}
149
150/// Identify user by IP address.
151#[derive(Debug)]
152pub struct RemoteIpIssuer;
153impl RateIssuer for RemoteIpIssuer {
154    type Key = IpAddr;
155    async fn issue(&self, req: &mut Request, _depot: &Depot) -> Option<Self::Key> {
156        req.remote_addr().ip()
157    }
158}
159
160/// `RateGuard` is strategy to verify is the request exceeded quota
161pub trait RateGuard: Clone + Send + Sync + 'static {
162    /// The quota for the rate limit.
163    type Quota: Clone + Send + Sync + 'static;
164    /// Verify is current request exceed the quota.
165    fn verify(&mut self, quota: &Self::Quota) -> impl Future<Output = bool> + Send;
166
167    /// Returns the remaining quota.
168    fn remaining(&self, quota: &Self::Quota) -> impl Future<Output = usize> + Send;
169
170    /// Returns the reset time.
171    fn reset(&self, quota: &Self::Quota) -> impl Future<Output = i64> + Send;
172
173    /// Returns the limit.
174    fn limit(&self, quota: &Self::Quota) -> impl Future<Output = usize> + Send;
175}
176
177/// `RateStore` is used to store rate limit data.
178pub trait RateStore: Send + Sync + 'static {
179    /// Error type for RateStore.
180    type Error: StdError;
181    /// Key
182    type Key: Hash + Eq + Send + Clone + 'static;
183    /// Saved guard.
184    type Guard;
185    /// Get the guard from the store.
186    fn load_guard<Q>(
187        &self,
188        key: &Q,
189        refer: &Self::Guard,
190    ) -> impl Future<Output = Result<Self::Guard, Self::Error>> + Send
191    where
192        Self::Key: Borrow<Q>,
193        Q: Hash + Eq + Sync;
194    /// Save the guard from the store.
195    fn save_guard(
196        &self,
197        key: Self::Key,
198        guard: Self::Guard,
199    ) -> impl Future<Output = Result<(), Self::Error>> + Send;
200}
201
202/// `RateLimiter` is the main struct to used limit user request.
203pub struct RateLimiter<G, S, I, Q> {
204    guard: G,
205    store: S,
206    issuer: I,
207    quota_getter: Q,
208    add_headers: bool,
209    skipper: Box<dyn Skipper>,
210}
211impl<G, S, I, Q> Debug for RateLimiter<G, S, I, Q>
212where
213    G: Debug,
214    S: Debug,
215    I: Debug,
216    Q: Debug,
217{
218    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
219        f.debug_struct("RateLimiter")
220            .field("guard", &self.guard)
221            .field("store", &self.store)
222            .field("issuer", &self.issuer)
223            .field("quota_getter", &self.quota_getter)
224            .field("add_headers", &self.add_headers)
225            .finish()
226    }
227}
228
229impl<G: RateGuard, S: RateStore, I: RateIssuer, P: QuotaGetter<I::Key>> RateLimiter<G, S, I, P> {
230    /// Create a new `RateLimiter`
231    #[inline]
232    #[must_use]
233    pub fn new(guard: G, store: S, issuer: I, quota_getter: P) -> Self {
234        Self {
235            guard,
236            store,
237            issuer,
238            quota_getter,
239            add_headers: false,
240            skipper: Box::new(none_skipper),
241        }
242    }
243
244    /// Sets skipper and returns new `RateLimiter`.
245    #[inline]
246    #[must_use]
247    pub fn with_skipper(mut self, skipper: impl Skipper) -> Self {
248        self.skipper = Box::new(skipper);
249        self
250    }
251
252    /// Sets `add_headers` and returns new `RateLimiter`.
253    /// If `add_headers` is true, the rate limit headers will be added to the response.
254    #[inline]
255    #[must_use]
256    pub fn add_headers(mut self, add_headers: bool) -> Self {
257        self.add_headers = add_headers;
258        self
259    }
260}
261
262#[async_trait]
263impl<G, S, I, P> Handler for RateLimiter<G, S, I, P>
264where
265    G: RateGuard<Quota = P::Quota>,
266    S: RateStore<Key = I::Key, Guard = G>,
267    P: QuotaGetter<I::Key>,
268    I: RateIssuer,
269{
270    async fn handle(
271        &self,
272        req: &mut Request,
273        depot: &mut Depot,
274        res: &mut Response,
275        ctrl: &mut FlowCtrl,
276    ) {
277        if self.skipper.skipped(req, depot) {
278            return;
279        }
280        let Some(key) = self.issuer.issue(req, depot).await else {
281            res.render(StatusError::bad_request().brief("Invalid identifier."));
282            ctrl.skip_rest();
283            return;
284        };
285        let quota = match self.quota_getter.get(&key).await {
286            Ok(quota) => quota,
287            Err(e) => {
288                tracing::error!(error = ?e, "RateLimiter error: {}", e);
289                res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
290                ctrl.skip_rest();
291                return;
292            }
293        };
294        let mut guard = match self.store.load_guard(&key, &self.guard).await {
295            Ok(guard) => guard,
296            Err(e) => {
297                tracing::error!(error = ?e, "RateLimiter error: {}", e);
298                res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
299                ctrl.skip_rest();
300                return;
301            }
302        };
303        let verified = guard.verify(&quota).await;
304
305        if self.add_headers {
306            res.headers_mut().insert(
307                "X-RateLimit-Limit",
308                HeaderValue::from_str(&guard.limit(&quota).await.to_string())
309                    .expect("Invalid header value"),
310            );
311            res.headers_mut().insert(
312                "X-RateLimit-Remaining",
313                HeaderValue::from_str(&(guard.remaining(&quota).await).to_string())
314                    .expect("Invalid header value"),
315            );
316            res.headers_mut().insert(
317                "X-RateLimit-Reset",
318                HeaderValue::from_str(&guard.reset(&quota).await.to_string())
319                    .expect("Invalid header value"),
320            );
321        }
322        if !verified {
323            res.status_code(StatusCode::TOO_MANY_REQUESTS);
324            ctrl.skip_rest();
325        }
326        if let Err(e) = self.store.save_guard(key, guard).await {
327            tracing::error!(error = ?e, "RateLimiter save guard failed");
328        }
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use std::collections::HashMap;
335    use std::sync::LazyLock;
336
337    use salvo_core::Error;
338    use salvo_core::prelude::*;
339    use salvo_core::test::{ResponseExt, TestClient};
340
341    use super::*;
342
343    struct UserIssuer;
344    impl RateIssuer for UserIssuer {
345        type Key = String;
346        async fn issue(&self, req: &mut Request, _depot: &Depot) -> Option<Self::Key> {
347            req.query::<Self::Key>("user")
348        }
349    }
350
351    #[handler]
352    async fn limited() -> &'static str {
353        "Limited page"
354    }
355
356    #[tokio::test]
357    async fn test_fixed_dynamic_quota() {
358        static USER_QUOTAS: LazyLock<HashMap<String, BasicQuota>> = LazyLock::new(|| {
359            let mut map = HashMap::new();
360            map.insert("user1".into(), BasicQuota::per_second(1));
361            map.insert("user2".into(), BasicQuota::set_seconds(1, 5));
362            map
363        });
364
365        struct CustomQuotaGetter;
366        impl QuotaGetter<String> for CustomQuotaGetter {
367            type Quota = BasicQuota;
368            type Error = Error;
369
370            async fn get<Q>(&self, key: &Q) -> Result<Self::Quota, Self::Error>
371            where
372                String: Borrow<Q>,
373                Q: Hash + Eq + Sync,
374            {
375                USER_QUOTAS
376                    .get(key)
377                    .cloned()
378                    .ok_or_else(|| Error::other("user not found"))
379            }
380        }
381        let limiter = RateLimiter::new(
382            FixedGuard::default(),
383            MokaStore::default(),
384            UserIssuer,
385            CustomQuotaGetter,
386        );
387        let router = Router::new().push(Router::with_path("limited").hoop(limiter).get(limited));
388        let service = Service::new(router);
389
390        let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user1")
391            .send(&service)
392            .await;
393        assert_eq!(response.status_code, Some(StatusCode::OK));
394        assert_eq!(response.take_string().await.unwrap(), "Limited page");
395
396        let response = TestClient::get("http://127.0.0.1:8698/limited?user=user1")
397            .send(&service)
398            .await;
399        assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
400
401        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
402
403        let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user1")
404            .send(&service)
405            .await;
406        assert_eq!(response.status_code, Some(StatusCode::OK));
407        assert_eq!(response.take_string().await.unwrap(), "Limited page");
408
409        let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
410            .send(&service)
411            .await;
412        assert_eq!(response.status_code, Some(StatusCode::OK));
413        assert_eq!(response.take_string().await.unwrap(), "Limited page");
414
415        let response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
416            .send(&service)
417            .await;
418        assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
419
420        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
421
422        let response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
423            .send(&service)
424            .await;
425        assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
426
427        tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
428
429        let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
430            .send(&service)
431            .await;
432        assert_eq!(response.status_code, Some(StatusCode::OK));
433        assert_eq!(response.take_string().await.unwrap(), "Limited page");
434    }
435
436    #[tokio::test]
437    async fn test_sliding_dynamic_quota() {
438        static USER_QUOTAS: LazyLock<HashMap<String, CelledQuota>> = LazyLock::new(|| {
439            let mut map = HashMap::new();
440            map.insert("user1".into(), CelledQuota::per_second(1, 1));
441            map.insert("user2".into(), CelledQuota::set_seconds(1, 1, 5));
442            map
443        });
444
445        struct CustomQuotaGetter;
446        impl QuotaGetter<String> for CustomQuotaGetter {
447            type Quota = CelledQuota;
448            type Error = Error;
449
450            async fn get<Q>(&self, key: &Q) -> Result<Self::Quota, Self::Error>
451            where
452                String: Borrow<Q>,
453                Q: Hash + Eq + Sync,
454            {
455                USER_QUOTAS
456                    .get(key)
457                    .cloned()
458                    .ok_or_else(|| Error::other("user not found"))
459            }
460        }
461        let limiter = RateLimiter::new(
462            SlidingGuard::default(),
463            MokaStore::default(),
464            UserIssuer,
465            CustomQuotaGetter,
466        );
467        let router = Router::new().push(Router::with_path("limited").hoop(limiter).get(limited));
468        let service = Service::new(router);
469
470        let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user1")
471            .send(&service)
472            .await;
473        assert_eq!(response.status_code, Some(StatusCode::OK));
474        assert_eq!(response.take_string().await.unwrap(), "Limited page");
475
476        let response = TestClient::get("http://127.0.0.1:8698/limited?user=user1")
477            .send(&service)
478            .await;
479        assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
480
481        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
482
483        let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user1")
484            .send(&service)
485            .await;
486        assert_eq!(response.status_code, Some(StatusCode::OK));
487        assert_eq!(response.take_string().await.unwrap(), "Limited page");
488
489        let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
490            .send(&service)
491            .await;
492        assert_eq!(response.status_code, Some(StatusCode::OK));
493        assert_eq!(response.take_string().await.unwrap(), "Limited page");
494
495        let response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
496            .send(&service)
497            .await;
498        assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
499
500        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
501
502        let response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
503            .send(&service)
504            .await;
505        assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
506
507        tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
508
509        let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
510            .send(&service)
511            .await;
512        assert_eq!(response.status_code, Some(StatusCode::OK));
513        assert_eq!(response.take_string().await.unwrap(), "Limited page");
514    }
515}