Skip to main content

salvo_extra/
concurrency_limiter.rs

1//! Middleware for limiting concurrency.
2//! 
3//! This middleware limits the maximum number of requests being processed concurrently,
4//! which helps prevent server overload during traffic spikes.
5//!
6//! # Example
7//! 
8//! ```no_run
9//! use std::fs::create_dir_all;
10//! use std::path::Path;
11//! 
12//! use salvo_core::prelude::*;
13//! use salvo_extra::concurrency_limiter::*;
14//! 
15//! #[handler]
16//! async fn index(res: &mut Response) {
17//!     res.render(Text::Html(INDEX_HTML));
18//! }
19//! #[handler]
20//! async fn upload(req: &mut Request, res: &mut Response) {
21//!     let file = req.file("file").await;
22//!     tokio::time::sleep(tokio::time::Duration::from_secs(10)).await;
23//!     if let Some(file) = file {
24//!         let dest = format!("temp/{}", file.name().unwrap_or("file"));
25//!         tracing::debug!(dest = %dest, "upload file");
26//!         if let Err(e) = std::fs::copy(file.path(), Path::new(&dest)) {
27//!             res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
28//!             res.render(Text::Plain(format!("file not found in request: {e}")));
29//!         } else {
30//!             res.render(Text::Plain(format!("File uploaded to {dest}")));
31//!         }
32//!     } else {
33//!         res.status_code(StatusCode::BAD_REQUEST);
34//!         res.render(Text::Plain("file not found in request"));
35//!     }
36//! }
37//! 
38//! #[tokio::main]
39//! async fn main() {
40//!     create_dir_all("temp").unwrap();
41//!     let router = Router::new()
42//!         .get(index)
43//!         .push(Router::new().hoop(max_concurrency(1)).path("limited").post(upload))
44//!         .push(Router::with_path("unlimit").post(upload));
45//! 
46//!     let acceptor = TcpListener::new("0.0.0.0:8698").bind().await;
47//!     Server::new(acceptor).serve(router).await;
48//! }
49//! 
50//! static INDEX_HTML: &str = r#"<!DOCTYPE html>
51//! <html>
52//!     <head>
53//!         <title>Upload file</title>
54//!     </head>
55//!     <body>
56//!         <h1>Upload file</h1>
57//!         <form action="/unlimit" method="post" enctype="multipart/form-data">
58//!             <h3>Unlimit</h3>
59//!             <input type="file" name="file" />
60//!             <input type="submit" value="upload" />
61//!         </form>
62//!         <form action="/limited" method="post" enctype="multipart/form-data">
63//!             <h3>Limited</h3>
64//!             <input type="file" name="file" />
65//!             <input type="submit" value="upload" />
66//!         </form>
67//!     </body>
68//! </html>
69//! "#;
70//! ```
71
72use tokio::sync::{Semaphore, TryAcquireError};
73
74use salvo_core::http::StatusError;
75use salvo_core::http::{Request, Response};
76use salvo_core::{async_trait, Depot, FlowCtrl, Handler};
77
78/// MaxConcurrency
79#[derive(Debug)]
80pub struct MaxConcurrency {
81    semaphore: Semaphore,
82}
83#[async_trait]
84impl Handler for MaxConcurrency {
85    #[inline]
86    async fn handle(
87        &self,
88        req: &mut Request,
89        depot: &mut Depot,
90        res: &mut Response,
91        ctrl: &mut FlowCtrl,
92    ) {
93        match self.semaphore.try_acquire() {
94            Ok(_) => {
95                ctrl.call_next(req, depot, res).await;
96            }
97            Err(e) => match e {
98                TryAcquireError::Closed => {
99                    tracing::error!(
100                        "Max concurrency semaphore is never closed, acquire should never fail: {}",
101                        e
102                    );
103                    res.render(StatusError::payload_too_large().brief("Max concurrency reached."));
104                }
105                TryAcquireError::NoPermits => {
106                    tracing::error!("NoPermits : {}", e);
107                    res.render(StatusError::too_many_requests().brief("Max concurrency reached."));
108                }
109            },
110        }
111    }
112}
113/// Create a new `MaxConcurrency`.
114#[inline]
115#[must_use] pub fn max_concurrency(size: usize) -> MaxConcurrency {
116    MaxConcurrency {
117        semaphore: Semaphore::new(size),
118    }
119}