salvo_rate_limiter/
lib.rs

1//! Rate limiter middleware for Salvo.
2//!
3//! Rate Limiter middleware is used to limiting the amount of requests to the server
4//! from a particular IP or id within a time period.
5//!
6//! [`RateIssuer`] is used to issue a key to request, your can define your custom `RateIssuer`.
7//! If you want just identify user by IP address, you can use [`RemoteIpIssuer`].
8//!
9//! [`QuotaGetter`] is used to get quota for every key.
10//!
11//! [`RateGuard`] is strategy to verify is the request exceeded quota.
12//!
13//! Read more: <https://salvo.rs>
14#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
15#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
16#![cfg_attr(docsrs, feature(doc_cfg))]
17
18use std::borrow::Borrow;
19use std::error::Error as StdError;
20use std::hash::Hash;
21
22use salvo_core::conn::SocketAddr;
23use salvo_core::handler::{Skipper, none_skipper};
24use salvo_core::http::{HeaderValue, Request, Response, StatusCode, StatusError};
25use salvo_core::{Depot, FlowCtrl, Handler, async_trait};
26
27mod quota;
28pub use quota::{BasicQuota, CelledQuota, QuotaGetter};
29#[macro_use]
30mod cfg;
31
32cfg_feature! {
33    #![feature = "moka-store"]
34
35    mod moka_store;
36    pub use moka_store::MokaStore;
37}
38
39cfg_feature! {
40    #![feature = "fixed-guard"]
41
42    mod fixed_guard;
43    pub use fixed_guard::FixedGuard;
44}
45
46cfg_feature! {
47    #![feature = "sliding-guard"]
48
49    mod sliding_guard;
50    pub use sliding_guard::SlidingGuard;
51}
52
53/// Issuer is used to identify every request.
54pub trait RateIssuer: Send + Sync + 'static {
55    /// The key is used to identify the rate limit.
56    type Key: Hash + Eq + Send + Sync + 'static;
57    /// Issue a new key for the request.
58    fn issue(
59        &self,
60        req: &mut Request,
61        depot: &Depot,
62    ) -> impl Future<Output = Option<Self::Key>> + Send;
63}
64impl<F, K> RateIssuer for F
65where
66    F: Fn(&mut Request, &Depot) -> Option<K> + Send + Sync + 'static,
67    K: Hash + Eq + Send + Sync + 'static,
68{
69    type Key = K;
70    async fn issue(&self, req: &mut Request, depot: &Depot) -> Option<Self::Key> {
71        (self)(req, depot)
72    }
73}
74
75/// Identify user by IP address.
76pub struct RemoteIpIssuer;
77impl RateIssuer for RemoteIpIssuer {
78    type Key = String;
79    async fn issue(&self, req: &mut Request, _depot: &Depot) -> Option<Self::Key> {
80        match req.remote_addr() {
81            SocketAddr::IPv4(addr) => Some(addr.ip().to_string()),
82            SocketAddr::IPv6(addr) => Some(addr.ip().to_string()),
83            _ => None,
84        }
85    }
86}
87
88/// `RateGuard` is strategy to verify is the request exceeded quota
89pub trait RateGuard: Clone + Send + Sync + 'static {
90    /// The quota for the rate limit.
91    type Quota: Clone + Send + Sync + 'static;
92    /// Verify is current request exceed the quota.
93    fn verify(&mut self, quota: &Self::Quota) -> impl Future<Output = bool> + Send;
94
95    /// Returns the remaining quota.
96    fn remaining(&self, quota: &Self::Quota) -> impl Future<Output = usize> + Send;
97
98    /// Returns the reset time.
99    fn reset(&self, quota: &Self::Quota) -> impl Future<Output = i64> + Send;
100
101    /// Returns the limit.
102    fn limit(&self, quota: &Self::Quota) -> impl Future<Output = usize> + Send;
103}
104
105/// `RateStore` is used to store rate limit data.
106pub trait RateStore: Send + Sync + 'static {
107    /// Error type for RateStore.
108    type Error: StdError;
109    /// Key
110    type Key: Hash + Eq + Send + Clone + 'static;
111    /// Saved guard.
112    type Guard;
113    /// Get the guard from the store.
114    fn load_guard<Q>(
115        &self,
116        key: &Q,
117        refer: &Self::Guard,
118    ) -> impl Future<Output = Result<Self::Guard, Self::Error>> + Send
119    where
120        Self::Key: Borrow<Q>,
121        Q: Hash + Eq + Sync;
122    /// Save the guard from the store.
123    fn save_guard(
124        &self,
125        key: Self::Key,
126        guard: Self::Guard,
127    ) -> impl Future<Output = Result<(), Self::Error>> + Send;
128}
129
130/// `RateLimiter` is the main struct to used limit user request.
131pub struct RateLimiter<G, S, I, Q> {
132    guard: G,
133    store: S,
134    issuer: I,
135    quota_getter: Q,
136    add_headers: bool,
137    skipper: Box<dyn Skipper>,
138}
139
140impl<G: RateGuard, S: RateStore, I: RateIssuer, P: QuotaGetter<I::Key>> RateLimiter<G, S, I, P> {
141    /// Create a new `RateLimiter`
142    #[inline]
143    pub fn new(guard: G, store: S, issuer: I, quota_getter: P) -> Self {
144        Self {
145            guard,
146            store,
147            issuer,
148            quota_getter,
149            add_headers: false,
150            skipper: Box::new(none_skipper),
151        }
152    }
153
154    /// Sets skipper and returns new `RateLimiter`.
155    #[inline]
156    pub fn with_skipper(mut self, skipper: impl Skipper) -> Self {
157        self.skipper = Box::new(skipper);
158        self
159    }
160
161    /// Sets `add_headers` and returns new `RateLimiter`.
162    /// If `add_headers` is true, the rate limit headers will be added to the response.
163    #[inline]
164    pub fn add_headers(mut self, add_headers: bool) -> Self {
165        self.add_headers = add_headers;
166        self
167    }
168}
169
170#[async_trait]
171impl<G, S, I, P> Handler for RateLimiter<G, S, I, P>
172where
173    G: RateGuard<Quota = P::Quota>,
174    S: RateStore<Key = I::Key, Guard = G>,
175    P: QuotaGetter<I::Key>,
176    I: RateIssuer,
177{
178    async fn handle(
179        &self,
180        req: &mut Request,
181        depot: &mut Depot,
182        res: &mut Response,
183        ctrl: &mut FlowCtrl,
184    ) {
185        if self.skipper.skipped(req, depot) {
186            return;
187        }
188        let key = match self.issuer.issue(req, depot).await {
189            Some(key) => key,
190            None => {
191                res.render(StatusError::bad_request().brief("Invalid identifier."));
192                ctrl.skip_rest();
193                return;
194            }
195        };
196        let quota = match self.quota_getter.get(&key).await {
197            Ok(quota) => quota,
198            Err(e) => {
199                tracing::error!(error = ?e, "RateLimiter error: {}", e);
200                res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
201                ctrl.skip_rest();
202                return;
203            }
204        };
205        let mut guard = match self.store.load_guard(&key, &self.guard).await {
206            Ok(guard) => guard,
207            Err(e) => {
208                tracing::error!(error = ?e, "RateLimiter error: {}", e);
209                res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
210                ctrl.skip_rest();
211                return;
212            }
213        };
214        let verified = guard.verify(&quota).await;
215
216        if self.add_headers {
217            res.headers_mut().insert(
218                "X-RateLimit-Limit",
219                HeaderValue::from_str(&guard.limit(&quota).await.to_string())
220                    .expect("Invalid header value"),
221            );
222            res.headers_mut().insert(
223                "X-RateLimit-Remaining",
224                HeaderValue::from_str(&(guard.remaining(&quota).await).to_string())
225                    .expect("Invalid header value"),
226            );
227            res.headers_mut().insert(
228                "X-RateLimit-Reset",
229                HeaderValue::from_str(&guard.reset(&quota).await.to_string())
230                    .expect("Invalid header value"),
231            );
232        }
233        if !verified {
234            res.status_code(StatusCode::TOO_MANY_REQUESTS);
235            ctrl.skip_rest();
236        }
237        if let Err(e) = self.store.save_guard(key, guard).await {
238            tracing::error!(error = ?e, "RateLimiter save guard failed");
239        }
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use std::collections::HashMap;
246    use std::sync::LazyLock;
247
248    use salvo_core::Error;
249    use salvo_core::prelude::*;
250    use salvo_core::test::{ResponseExt, TestClient};
251
252    use super::*;
253
254    struct UserIssuer;
255    impl RateIssuer for UserIssuer {
256        type Key = String;
257        async fn issue(&self, req: &mut Request, _depot: &Depot) -> Option<Self::Key> {
258            req.query::<Self::Key>("user")
259        }
260    }
261
262    #[handler]
263    async fn limited() -> &'static str {
264        "Limited page"
265    }
266
267    #[tokio::test]
268    async fn test_fixed_dynamic_quota() {
269        static USER_QUOTAS: LazyLock<HashMap<String, BasicQuota>> = LazyLock::new(|| {
270            let mut map = HashMap::new();
271            map.insert("user1".into(), BasicQuota::per_second(1));
272            map.insert("user2".into(), BasicQuota::set_seconds(1, 5));
273            map
274        });
275
276        struct CustomQuotaGetter;
277        impl QuotaGetter<String> for CustomQuotaGetter {
278            type Quota = BasicQuota;
279            type Error = Error;
280
281            async fn get<Q>(&self, key: &Q) -> Result<Self::Quota, Self::Error>
282            where
283                String: Borrow<Q>,
284                Q: Hash + Eq + Sync,
285            {
286                USER_QUOTAS
287                    .get(key)
288                    .cloned()
289                    .ok_or_else(|| Error::other("user not found"))
290            }
291        }
292        let limiter = RateLimiter::new(
293            FixedGuard::default(),
294            MokaStore::default(),
295            UserIssuer,
296            CustomQuotaGetter,
297        );
298        let router = Router::new().push(Router::with_path("limited").hoop(limiter).get(limited));
299        let service = Service::new(router);
300
301        let mut response = TestClient::get("http://127.0.0.1:5800/limited?user=user1")
302            .send(&service)
303            .await;
304        assert_eq!(response.status_code, Some(StatusCode::OK));
305        assert_eq!(response.take_string().await.unwrap(), "Limited page");
306
307        let response = TestClient::get("http://127.0.0.1:5800/limited?user=user1")
308            .send(&service)
309            .await;
310        assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
311
312        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
313
314        let mut response = TestClient::get("http://127.0.0.1:5800/limited?user=user1")
315            .send(&service)
316            .await;
317        assert_eq!(response.status_code, Some(StatusCode::OK));
318        assert_eq!(response.take_string().await.unwrap(), "Limited page");
319
320        let mut response = TestClient::get("http://127.0.0.1:5800/limited?user=user2")
321            .send(&service)
322            .await;
323        assert_eq!(response.status_code, Some(StatusCode::OK));
324        assert_eq!(response.take_string().await.unwrap(), "Limited page");
325
326        let response = TestClient::get("http://127.0.0.1:5800/limited?user=user2")
327            .send(&service)
328            .await;
329        assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
330
331        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
332
333        let response = TestClient::get("http://127.0.0.1:5800/limited?user=user2")
334            .send(&service)
335            .await;
336        assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
337
338        tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
339
340        let mut response = TestClient::get("http://127.0.0.1:5800/limited?user=user2")
341            .send(&service)
342            .await;
343        assert_eq!(response.status_code, Some(StatusCode::OK));
344        assert_eq!(response.take_string().await.unwrap(), "Limited page");
345    }
346
347    #[tokio::test]
348    async fn test_sliding_dynamic_quota() {
349        static USER_QUOTAS: LazyLock<HashMap<String, CelledQuota>> = LazyLock::new(|| {
350            let mut map = HashMap::new();
351            map.insert("user1".into(), CelledQuota::per_second(1, 1));
352            map.insert("user2".into(), CelledQuota::set_seconds(1, 1, 5));
353            map
354        });
355
356        struct CustomQuotaGetter;
357        impl QuotaGetter<String> for CustomQuotaGetter {
358            type Quota = CelledQuota;
359            type Error = Error;
360
361            async fn get<Q>(&self, key: &Q) -> Result<Self::Quota, Self::Error>
362            where
363                String: Borrow<Q>,
364                Q: Hash + Eq + Sync,
365            {
366                USER_QUOTAS
367                    .get(key)
368                    .cloned()
369                    .ok_or_else(|| Error::other("user not found"))
370            }
371        }
372        let limiter = RateLimiter::new(
373            SlidingGuard::default(),
374            MokaStore::default(),
375            UserIssuer,
376            CustomQuotaGetter,
377        );
378        let router = Router::new().push(Router::with_path("limited").hoop(limiter).get(limited));
379        let service = Service::new(router);
380
381        let mut response = TestClient::get("http://127.0.0.1:5800/limited?user=user1")
382            .send(&service)
383            .await;
384        assert_eq!(response.status_code, Some(StatusCode::OK));
385        assert_eq!(response.take_string().await.unwrap(), "Limited page");
386
387        let response = TestClient::get("http://127.0.0.1:5800/limited?user=user1")
388            .send(&service)
389            .await;
390        assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
391
392        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
393
394        let mut response = TestClient::get("http://127.0.0.1:5800/limited?user=user1")
395            .send(&service)
396            .await;
397        assert_eq!(response.status_code, Some(StatusCode::OK));
398        assert_eq!(response.take_string().await.unwrap(), "Limited page");
399
400        let mut response = TestClient::get("http://127.0.0.1:5800/limited?user=user2")
401            .send(&service)
402            .await;
403        assert_eq!(response.status_code, Some(StatusCode::OK));
404        assert_eq!(response.take_string().await.unwrap(), "Limited page");
405
406        let response = TestClient::get("http://127.0.0.1:5800/limited?user=user2")
407            .send(&service)
408            .await;
409        assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
410
411        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
412
413        let response = TestClient::get("http://127.0.0.1:5800/limited?user=user2")
414            .send(&service)
415            .await;
416        assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
417
418        tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
419
420        let mut response = TestClient::get("http://127.0.0.1:5800/limited?user=user2")
421            .send(&service)
422            .await;
423        assert_eq!(response.status_code, Some(StatusCode::OK));
424        assert_eq!(response.take_string().await.unwrap(), "Limited page");
425    }
426}