1#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
89#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
90#![cfg_attr(docsrs, feature(doc_cfg))]
91
92use std::borrow::Borrow;
93use std::error::Error as StdError;
94use std::fmt::{self, Debug, Formatter};
95use std::hash::Hash;
96use std::net::IpAddr;
97
98use salvo_core::handler::{Skipper, none_skipper};
99use salvo_core::http::{HeaderValue, Request, Response, StatusCode, StatusError};
100use salvo_core::{Depot, FlowCtrl, Handler, async_trait};
101
102mod quota;
103pub use quota::{BasicQuota, CelledQuota, QuotaGetter};
104#[macro_use]
105mod cfg;
106
107cfg_feature! {
108 #![feature = "moka-store"]
109
110 mod moka_store;
111 pub use moka_store::MokaStore;
112}
113
114cfg_feature! {
115 #![feature = "fixed-guard"]
116
117 mod fixed_guard;
118 pub use fixed_guard::FixedGuard;
119}
120
121cfg_feature! {
122 #![feature = "sliding-guard"]
123
124 mod sliding_guard;
125 pub use sliding_guard::SlidingGuard;
126}
127
128pub trait RateIssuer: Send + Sync + 'static {
130 type Key: Hash + Eq + Send + Sync + 'static;
132 fn issue(
134 &self,
135 req: &mut Request,
136 depot: &Depot,
137 ) -> impl Future<Output = Option<Self::Key>> + Send;
138}
139impl<F, K> RateIssuer for F
140where
141 F: Fn(&mut Request, &Depot) -> Option<K> + Send + Sync + 'static,
142 K: Hash + Eq + Send + Sync + 'static,
143{
144 type Key = K;
145 async fn issue(&self, req: &mut Request, depot: &Depot) -> Option<Self::Key> {
146 (self)(req, depot)
147 }
148}
149
150#[derive(Debug)]
152pub struct RemoteIpIssuer;
153impl RateIssuer for RemoteIpIssuer {
154 type Key = IpAddr;
155 async fn issue(&self, req: &mut Request, _depot: &Depot) -> Option<Self::Key> {
156 req.remote_addr().ip()
157 }
158}
159
160pub trait RateGuard: Clone + Send + Sync + 'static {
162 type Quota: Clone + Send + Sync + 'static;
164 fn verify(&mut self, quota: &Self::Quota) -> impl Future<Output = bool> + Send;
166
167 fn remaining(&self, quota: &Self::Quota) -> impl Future<Output = usize> + Send;
169
170 fn reset(&self, quota: &Self::Quota) -> impl Future<Output = i64> + Send;
172
173 fn limit(&self, quota: &Self::Quota) -> impl Future<Output = usize> + Send;
175}
176
177pub trait RateStore: Send + Sync + 'static {
179 type Error: StdError;
181 type Key: Hash + Eq + Send + Clone + 'static;
183 type Guard;
185 fn load_guard<Q>(
187 &self,
188 key: &Q,
189 refer: &Self::Guard,
190 ) -> impl Future<Output = Result<Self::Guard, Self::Error>> + Send
191 where
192 Self::Key: Borrow<Q>,
193 Q: Hash + Eq + Sync;
194 fn save_guard(
196 &self,
197 key: Self::Key,
198 guard: Self::Guard,
199 ) -> impl Future<Output = Result<(), Self::Error>> + Send;
200}
201
202pub struct RateLimiter<G, S, I, Q> {
204 guard: G,
205 store: S,
206 issuer: I,
207 quota_getter: Q,
208 add_headers: bool,
209 skipper: Box<dyn Skipper>,
210}
211impl<G, S, I, Q> Debug for RateLimiter<G, S, I, Q>
212where
213 G: Debug,
214 S: Debug,
215 I: Debug,
216 Q: Debug,
217{
218 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
219 f.debug_struct("RateLimiter")
220 .field("guard", &self.guard)
221 .field("store", &self.store)
222 .field("issuer", &self.issuer)
223 .field("quota_getter", &self.quota_getter)
224 .field("add_headers", &self.add_headers)
225 .finish()
226 }
227}
228
229impl<G: RateGuard, S: RateStore, I: RateIssuer, P: QuotaGetter<I::Key>> RateLimiter<G, S, I, P> {
230 #[inline]
232 #[must_use]
233 pub fn new(guard: G, store: S, issuer: I, quota_getter: P) -> Self {
234 Self {
235 guard,
236 store,
237 issuer,
238 quota_getter,
239 add_headers: false,
240 skipper: Box::new(none_skipper),
241 }
242 }
243
244 #[inline]
246 #[must_use]
247 pub fn with_skipper(mut self, skipper: impl Skipper) -> Self {
248 self.skipper = Box::new(skipper);
249 self
250 }
251
252 #[inline]
255 #[must_use]
256 pub fn add_headers(mut self, add_headers: bool) -> Self {
257 self.add_headers = add_headers;
258 self
259 }
260}
261
262#[async_trait]
263impl<G, S, I, P> Handler for RateLimiter<G, S, I, P>
264where
265 G: RateGuard<Quota = P::Quota>,
266 S: RateStore<Key = I::Key, Guard = G>,
267 P: QuotaGetter<I::Key>,
268 I: RateIssuer,
269{
270 async fn handle(
271 &self,
272 req: &mut Request,
273 depot: &mut Depot,
274 res: &mut Response,
275 ctrl: &mut FlowCtrl,
276 ) {
277 if self.skipper.skipped(req, depot) {
278 return;
279 }
280 let Some(key) = self.issuer.issue(req, depot).await else {
281 res.render(StatusError::bad_request().brief("Invalid identifier."));
282 ctrl.skip_rest();
283 return;
284 };
285 let quota = match self.quota_getter.get(&key).await {
286 Ok(quota) => quota,
287 Err(e) => {
288 tracing::error!(error = ?e, "RateLimiter error: {}", e);
289 res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
290 ctrl.skip_rest();
291 return;
292 }
293 };
294 let mut guard = match self.store.load_guard(&key, &self.guard).await {
295 Ok(guard) => guard,
296 Err(e) => {
297 tracing::error!(error = ?e, "RateLimiter error: {}", e);
298 res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
299 ctrl.skip_rest();
300 return;
301 }
302 };
303 let verified = guard.verify("a).await;
304
305 if self.add_headers {
306 res.headers_mut().insert(
307 "X-RateLimit-Limit",
308 HeaderValue::from_str(&guard.limit("a).await.to_string())
309 .expect("Invalid header value"),
310 );
311 res.headers_mut().insert(
312 "X-RateLimit-Remaining",
313 HeaderValue::from_str(&(guard.remaining("a).await).to_string())
314 .expect("Invalid header value"),
315 );
316 res.headers_mut().insert(
317 "X-RateLimit-Reset",
318 HeaderValue::from_str(&guard.reset("a).await.to_string())
319 .expect("Invalid header value"),
320 );
321 }
322 if !verified {
323 res.status_code(StatusCode::TOO_MANY_REQUESTS);
324 ctrl.skip_rest();
325 }
326 if let Err(e) = self.store.save_guard(key, guard).await {
327 tracing::error!(error = ?e, "RateLimiter save guard failed");
328 }
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use std::collections::HashMap;
335 use std::sync::LazyLock;
336
337 use salvo_core::Error;
338 use salvo_core::prelude::*;
339 use salvo_core::test::{ResponseExt, TestClient};
340
341 use super::*;
342
343 struct UserIssuer;
344 impl RateIssuer for UserIssuer {
345 type Key = String;
346 async fn issue(&self, req: &mut Request, _depot: &Depot) -> Option<Self::Key> {
347 req.query::<Self::Key>("user")
348 }
349 }
350
351 #[handler]
352 async fn limited() -> &'static str {
353 "Limited page"
354 }
355
356 #[tokio::test]
357 async fn test_fixed_dynamic_quota() {
358 static USER_QUOTAS: LazyLock<HashMap<String, BasicQuota>> = LazyLock::new(|| {
359 let mut map = HashMap::new();
360 map.insert("user1".into(), BasicQuota::per_second(1));
361 map.insert("user2".into(), BasicQuota::set_seconds(1, 5));
362 map
363 });
364
365 struct CustomQuotaGetter;
366 impl QuotaGetter<String> for CustomQuotaGetter {
367 type Quota = BasicQuota;
368 type Error = Error;
369
370 async fn get<Q>(&self, key: &Q) -> Result<Self::Quota, Self::Error>
371 where
372 String: Borrow<Q>,
373 Q: Hash + Eq + Sync,
374 {
375 USER_QUOTAS
376 .get(key)
377 .cloned()
378 .ok_or_else(|| Error::other("user not found"))
379 }
380 }
381 let limiter = RateLimiter::new(
382 FixedGuard::default(),
383 MokaStore::default(),
384 UserIssuer,
385 CustomQuotaGetter,
386 );
387 let router = Router::new().push(Router::with_path("limited").hoop(limiter).get(limited));
388 let service = Service::new(router);
389
390 let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user1")
391 .send(&service)
392 .await;
393 assert_eq!(response.status_code, Some(StatusCode::OK));
394 assert_eq!(response.take_string().await.unwrap(), "Limited page");
395
396 let response = TestClient::get("http://127.0.0.1:8698/limited?user=user1")
397 .send(&service)
398 .await;
399 assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
400
401 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
402
403 let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user1")
404 .send(&service)
405 .await;
406 assert_eq!(response.status_code, Some(StatusCode::OK));
407 assert_eq!(response.take_string().await.unwrap(), "Limited page");
408
409 let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
410 .send(&service)
411 .await;
412 assert_eq!(response.status_code, Some(StatusCode::OK));
413 assert_eq!(response.take_string().await.unwrap(), "Limited page");
414
415 let response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
416 .send(&service)
417 .await;
418 assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
419
420 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
421
422 let response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
423 .send(&service)
424 .await;
425 assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
426
427 tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
428
429 let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
430 .send(&service)
431 .await;
432 assert_eq!(response.status_code, Some(StatusCode::OK));
433 assert_eq!(response.take_string().await.unwrap(), "Limited page");
434 }
435
436 #[tokio::test]
437 async fn test_sliding_dynamic_quota() {
438 static USER_QUOTAS: LazyLock<HashMap<String, CelledQuota>> = LazyLock::new(|| {
439 let mut map = HashMap::new();
440 map.insert("user1".into(), CelledQuota::per_second(1, 1));
441 map.insert("user2".into(), CelledQuota::set_seconds(1, 1, 5));
442 map
443 });
444
445 struct CustomQuotaGetter;
446 impl QuotaGetter<String> for CustomQuotaGetter {
447 type Quota = CelledQuota;
448 type Error = Error;
449
450 async fn get<Q>(&self, key: &Q) -> Result<Self::Quota, Self::Error>
451 where
452 String: Borrow<Q>,
453 Q: Hash + Eq + Sync,
454 {
455 USER_QUOTAS
456 .get(key)
457 .cloned()
458 .ok_or_else(|| Error::other("user not found"))
459 }
460 }
461 let limiter = RateLimiter::new(
462 SlidingGuard::default(),
463 MokaStore::default(),
464 UserIssuer,
465 CustomQuotaGetter,
466 );
467 let router = Router::new().push(Router::with_path("limited").hoop(limiter).get(limited));
468 let service = Service::new(router);
469
470 let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user1")
471 .send(&service)
472 .await;
473 assert_eq!(response.status_code, Some(StatusCode::OK));
474 assert_eq!(response.take_string().await.unwrap(), "Limited page");
475
476 let response = TestClient::get("http://127.0.0.1:8698/limited?user=user1")
477 .send(&service)
478 .await;
479 assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
480
481 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
482
483 let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user1")
484 .send(&service)
485 .await;
486 assert_eq!(response.status_code, Some(StatusCode::OK));
487 assert_eq!(response.take_string().await.unwrap(), "Limited page");
488
489 let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
490 .send(&service)
491 .await;
492 assert_eq!(response.status_code, Some(StatusCode::OK));
493 assert_eq!(response.take_string().await.unwrap(), "Limited page");
494
495 let response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
496 .send(&service)
497 .await;
498 assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
499
500 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
501
502 let response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
503 .send(&service)
504 .await;
505 assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
506
507 tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
508
509 let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
510 .send(&service)
511 .await;
512 assert_eq!(response.status_code, Some(StatusCode::OK));
513 assert_eq!(response.take_string().await.unwrap(), "Limited page");
514 }
515}