vectorless/concurrency/
controller.rs1use std::sync::Arc;
7use tokio::sync::{Semaphore, SemaphorePermit};
8use tracing::{debug, trace};
9
10use super::config::ConcurrencyConfig;
11use super::rate_limiter::RateLimiter;
12
13#[derive(Clone)]
37pub struct ConcurrencyController {
38 semaphore: Arc<Semaphore>,
40 rate_limiter: Option<Arc<RateLimiter>>,
42 config: ConcurrencyConfig,
44}
45
46impl ConcurrencyController {
47 pub fn new(config: ConcurrencyConfig) -> Self {
49 let semaphore = Arc::new(Semaphore::new(config.max_concurrent_requests));
50 let rate_limiter = if config.enabled {
51 Some(Arc::new(RateLimiter::new(config.requests_per_minute)))
52 } else {
53 None
54 };
55
56 Self {
57 semaphore,
58 rate_limiter,
59 config,
60 }
61 }
62
63 pub fn with_defaults() -> Self {
65 Self::new(ConcurrencyConfig::default())
66 }
67
68 pub fn high_throughput() -> Self {
70 Self::new(ConcurrencyConfig::high_throughput())
71 }
72
73 pub fn conservative() -> Self {
75 Self::new(ConcurrencyConfig::conservative())
76 }
77
78 pub fn unlimited() -> Self {
80 Self::new(ConcurrencyConfig::unlimited())
81 }
82
83 pub async fn acquire(&self) -> Option<SemaphorePermit<'_>> {
91 if let Some(ref limiter) = self.rate_limiter {
93 trace!("Waiting for rate limiter");
94 limiter.acquire().await;
95 debug!("Rate limiter: token acquired");
96 }
97
98 if self.config.semaphore_enabled {
100 trace!("Waiting for semaphore permit");
101 let permit = self.semaphore.acquire().await.unwrap();
102 debug!("Semaphore: permit acquired (available: {})", self.semaphore.available_permits());
103 Some(permit)
104 } else {
105 None
106 }
107 }
108
109 pub fn try_acquire(&self) -> Option<SemaphorePermit<'_>> {
113 if let Some(ref limiter) = self.rate_limiter {
115 if !limiter.try_acquire() {
116 return None;
117 }
118 }
119
120 if self.config.semaphore_enabled {
122 self.semaphore.try_acquire().ok()
123 } else {
124 None
125 }
126 }
127
128 pub fn available_permits(&self) -> usize {
130 self.semaphore.available_permits()
131 }
132
133 pub fn config(&self) -> &ConcurrencyConfig {
135 &self.config
136 }
137
138 pub fn rate_limiter(&self) -> Option<&RateLimiter> {
140 self.rate_limiter.as_deref()
141 }
142}
143
144impl std::fmt::Debug for ConcurrencyController {
145 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146 f.debug_struct("ConcurrencyController")
147 .field("max_concurrent_requests", &self.config.max_concurrent_requests)
148 .field("requests_per_minute", &self.config.requests_per_minute)
149 .field("rate_limiting_enabled", &self.config.enabled)
150 .field("semaphore_enabled", &self.config.semaphore_enabled)
151 .field("available_permits", &self.semaphore.available_permits())
152 .finish()
153 }
154}
155
156impl Default for ConcurrencyController {
157 fn default() -> Self {
158 Self::with_defaults()
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 #[tokio::test]
167 async fn test_controller_acquire() {
168 let controller = ConcurrencyController::new(ConcurrencyConfig {
169 max_concurrent_requests: 2,
170 requests_per_minute: 100,
171 enabled: false, semaphore_enabled: true,
173 });
174
175 let permit1 = controller.acquire().await;
176 assert!(permit1.is_some());
177 assert_eq!(controller.available_permits(), 1);
178
179 let permit2 = controller.acquire().await;
180 assert!(permit2.is_some());
181 assert_eq!(controller.available_permits(), 0);
182
183 drop(permit1);
184 assert_eq!(controller.available_permits(), 1);
185 }
186
187 #[test]
188 fn test_controller_creation() {
189 let controller = ConcurrencyController::with_defaults();
190 assert!(controller.available_permits() > 0);
191 }
192}