warpdrive_proxy/middleware/
concurrency.rs1use async_trait::async_trait;
13use pingora::prelude::*;
14use std::sync::Arc;
15use tokio::sync::Semaphore;
16use tracing::{debug, warn};
17
18use super::{Middleware, MiddlewareContext};
19
20pub struct ConcurrencyMiddleware {
25 semaphore: Arc<Semaphore>,
27 enabled: bool,
29 max_concurrent: usize,
31}
32
33impl ConcurrencyMiddleware {
34 pub fn new(enabled: bool, max_concurrent: usize) -> Self {
41 let semaphore = if max_concurrent > 0 {
42 Arc::new(Semaphore::new(max_concurrent))
43 } else {
44 Arc::new(Semaphore::new(1_000_000))
47 };
48
49 debug!(
50 "Concurrency limiter initialized: enabled={}, max={}",
51 enabled, max_concurrent
52 );
53
54 Self {
55 semaphore,
56 enabled,
57 max_concurrent,
58 }
59 }
60}
61
62#[async_trait]
63impl Middleware for ConcurrencyMiddleware {
64 async fn request_filter(
66 &self,
67 session: &mut Session,
68 ctx: &mut MiddlewareContext,
69 ) -> Result<()> {
70 if !self.enabled || self.max_concurrent == 0 {
71 return Ok(());
72 }
73
74 if ctx.trusted_source {
76 debug!("Skipping concurrency limit for trusted source");
77 return Ok(());
78 }
79
80 match self.semaphore.clone().try_acquire_owned() {
82 Ok(permit) => {
83 debug!("Concurrency permit acquired");
84
85 ctx.concurrency_permit = Some(permit);
88 Ok(())
89 }
90 Err(_) => {
91 warn!(
92 "Concurrency limit reached (max: {}), rejecting request",
93 self.max_concurrent
94 );
95
96 session.respond_error(503).await?;
98
99 Err(Error::explain(
100 ErrorType::HTTPStatus(503),
101 format!("Concurrency limit reached: {}", self.max_concurrent),
102 ))
103 }
104 }
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111
112 #[test]
113 fn test_concurrency_middleware_creation() {
114 let mw = ConcurrencyMiddleware::new(true, 100);
115 assert!(mw.enabled);
116 assert_eq!(mw.max_concurrent, 100);
117 }
118
119 #[test]
120 fn test_concurrency_middleware_disabled() {
121 let mw = ConcurrencyMiddleware::new(false, 100);
122 assert!(!mw.enabled);
123 }
124
125 #[test]
126 fn test_concurrency_middleware_unlimited() {
127 let mw = ConcurrencyMiddleware::new(true, 0);
128 assert!(mw.enabled);
129 assert_eq!(mw.max_concurrent, 0);
130 }
131
132 #[tokio::test]
133 async fn test_concurrency_permit_acquire() {
134 let mw = ConcurrencyMiddleware::new(true, 2);
135
136 let _permit1 = mw.semaphore.clone().try_acquire_owned();
138 assert!(_permit1.is_ok());
139
140 let _permit2 = mw.semaphore.clone().try_acquire_owned();
142 assert!(_permit2.is_ok());
143
144 let permit3 = mw.semaphore.clone().try_acquire_owned();
146 assert!(permit3.is_err());
147 }
148}