salvo_extra/concurrency_limiter.rs
1//! Middleware for limiting concurrency.
2//!
3//! This middleware limits the maximum number of requests being processed concurrently,
4//! which helps prevent server overload during traffic spikes.
5//!
6//! # Example
7//!
8//! ```no_run
9//! use std::fs::create_dir_all;
10//! use std::path::Path;
11//!
12//! use salvo_core::prelude::*;
13//! use salvo_extra::concurrency_limiter::*;
14//!
15//! #[handler]
16//! async fn index(res: &mut Response) {
17//! res.render(Text::Html(INDEX_HTML));
18//! }
19//! #[handler]
20//! async fn upload(req: &mut Request, res: &mut Response) {
21//! let file = req.file("file").await;
22//! tokio::time::sleep(tokio::time::Duration::from_secs(10)).await;
23//! if let Some(file) = file {
24//! let dest = format!("temp/{}", file.name().unwrap_or("file"));
25//! tracing::debug!(dest = %dest, "upload file");
26//! if let Err(e) = std::fs::copy(file.path(), Path::new(&dest)) {
27//! res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
28//! res.render(Text::Plain(format!("file not found in request: {e}")));
29//! } else {
30//! res.render(Text::Plain(format!("File uploaded to {dest}")));
31//! }
32//! } else {
33//! res.status_code(StatusCode::BAD_REQUEST);
34//! res.render(Text::Plain("file not found in request"));
35//! }
36//! }
37//!
38//! #[tokio::main]
39//! async fn main() {
40//! create_dir_all("temp").unwrap();
41//! let router = Router::new()
42//! .get(index)
43//! .push(Router::new().hoop(max_concurrency(1)).path("limited").post(upload))
44//! .push(Router::with_path("unlimit").post(upload));
45//!
46//! let acceptor = TcpListener::new("0.0.0.0:8698").bind().await;
47//! Server::new(acceptor).serve(router).await;
48//! }
49//!
50//! static INDEX_HTML: &str = r#"<!DOCTYPE html>
51//! <html>
52//! <head>
53//! <title>Upload file</title>
54//! </head>
55//! <body>
56//! <h1>Upload file</h1>
57//! <form action="/unlimit" method="post" enctype="multipart/form-data">
58//! <h3>Unlimit</h3>
59//! <input type="file" name="file" />
60//! <input type="submit" value="upload" />
61//! </form>
62//! <form action="/limited" method="post" enctype="multipart/form-data">
63//! <h3>Limited</h3>
64//! <input type="file" name="file" />
65//! <input type="submit" value="upload" />
66//! </form>
67//! </body>
68//! </html>
69//! "#;
70//! ```
71
72use tokio::sync::{Semaphore, TryAcquireError};
73
74use salvo_core::http::StatusError;
75use salvo_core::http::{Request, Response};
76use salvo_core::{async_trait, Depot, FlowCtrl, Handler};
77
78/// MaxConcurrency
79#[derive(Debug)]
80pub struct MaxConcurrency {
81 semaphore: Semaphore,
82}
83#[async_trait]
84impl Handler for MaxConcurrency {
85 #[inline]
86 async fn handle(
87 &self,
88 req: &mut Request,
89 depot: &mut Depot,
90 res: &mut Response,
91 ctrl: &mut FlowCtrl,
92 ) {
93 match self.semaphore.try_acquire() {
94 Ok(_) => {
95 ctrl.call_next(req, depot, res).await;
96 }
97 Err(e) => match e {
98 TryAcquireError::Closed => {
99 tracing::error!(
100 "Max concurrency semaphore is never closed, acquire should never fail: {}",
101 e
102 );
103 res.render(StatusError::payload_too_large().brief("Max concurrency reached."));
104 }
105 TryAcquireError::NoPermits => {
106 tracing::error!("NoPermits : {}", e);
107 res.render(StatusError::too_many_requests().brief("Max concurrency reached."));
108 }
109 },
110 }
111 }
112}
113/// Create a new `MaxConcurrency`.
114#[inline]
115#[must_use] pub fn max_concurrency(size: usize) -> MaxConcurrency {
116 MaxConcurrency {
117 semaphore: Semaphore::new(size),
118 }
119}