skp_ratelimit/algorithm/
concurrent.rs1use std::sync::Arc;
7use std::time::Duration;
8
9use dashmap::DashMap;
10use tokio::sync::Semaphore;
11
12pub struct ConcurrentLimiter {
31 max_concurrent: u32,
32 semaphores: Arc<DashMap<String, Arc<Semaphore>>>,
33 counts: Arc<DashMap<String, u32>>,
34}
35
36impl std::fmt::Debug for ConcurrentLimiter {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 f.debug_struct("ConcurrentLimiter")
39 .field("max_concurrent", &self.max_concurrent)
40 .field("active_keys", &self.semaphores.len())
41 .finish()
42 }
43}
44
45impl Clone for ConcurrentLimiter {
46 fn clone(&self) -> Self {
47 Self {
48 max_concurrent: self.max_concurrent,
49 semaphores: self.semaphores.clone(),
50 counts: self.counts.clone(),
51 }
52 }
53}
54
55impl ConcurrentLimiter {
56 pub fn new(max_concurrent: u32) -> Self {
58 Self {
59 max_concurrent,
60 semaphores: Arc::new(DashMap::new()),
61 counts: Arc::new(DashMap::new()),
62 }
63 }
64
65 pub fn try_acquire(&self, key: &str) -> Option<ConcurrentPermit> {
70 let semaphore = self
71 .semaphores
72 .entry(key.to_string())
73 .or_insert_with(|| Arc::new(Semaphore::new(self.max_concurrent as usize)))
74 .clone();
75
76 match semaphore.clone().try_acquire_owned() {
78 Ok(permit) => {
79 *self.counts.entry(key.to_string()).or_insert(0) += 1;
81
82 Some(ConcurrentPermit {
83 _permit: permit,
84 key: key.to_string(),
85 counts: self.counts.clone(),
86 })
87 }
88 Err(_) => None,
89 }
90 }
91
92 pub async fn acquire(&self, key: &str) -> ConcurrentPermit {
94 let semaphore = self
95 .semaphores
96 .entry(key.to_string())
97 .or_insert_with(|| Arc::new(Semaphore::new(self.max_concurrent as usize)))
98 .clone();
99
100 let permit = semaphore.acquire_owned().await.expect("Semaphore closed");
101
102 *self.counts.entry(key.to_string()).or_insert(0) += 1;
103
104 ConcurrentPermit {
105 _permit: permit,
106 key: key.to_string(),
107 counts: self.counts.clone(),
108 }
109 }
110
111 pub async fn acquire_timeout(
113 &self,
114 key: &str,
115 timeout: Duration,
116 ) -> Option<ConcurrentPermit> {
117 tokio::time::timeout(timeout, self.acquire(key))
118 .await
119 .ok()
120 }
121
122 pub fn current_count(&self, key: &str) -> u32 {
124 self.counts.get(key).map(|c| *c).unwrap_or(0)
125 }
126
127 pub fn max_concurrent(&self) -> u32 {
129 self.max_concurrent
130 }
131
132 pub fn remaining(&self, key: &str) -> u32 {
134 self.max_concurrent.saturating_sub(self.current_count(key))
135 }
136}
137
138pub struct ConcurrentPermit {
143 _permit: tokio::sync::OwnedSemaphorePermit,
144 key: String,
145 counts: Arc<DashMap<String, u32>>,
146}
147
148impl Drop for ConcurrentPermit {
149 fn drop(&mut self) {
150 if let Some(mut count) = self.counts.get_mut(&self.key) {
151 *count = count.saturating_sub(1);
152 }
153 }
154}
155
156impl std::fmt::Debug for ConcurrentPermit {
157 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158 f.debug_struct("ConcurrentPermit")
159 .field("key", &self.key)
160 .finish()
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167
168 #[tokio::test]
169 async fn test_concurrent_basic() {
170 let limiter = ConcurrentLimiter::new(2);
171
172 let permit1 = limiter.try_acquire("user:1");
173 assert!(permit1.is_some());
174 assert_eq!(limiter.current_count("user:1"), 1);
175
176 let permit2 = limiter.try_acquire("user:1");
177 assert!(permit2.is_some());
178 assert_eq!(limiter.current_count("user:1"), 2);
179
180 let permit3 = limiter.try_acquire("user:1");
182 assert!(permit3.is_none());
183
184 let permit_other = limiter.try_acquire("user:2");
186 assert!(permit_other.is_some());
187 }
188
189 #[tokio::test]
190 async fn test_concurrent_release() {
191 let limiter = ConcurrentLimiter::new(1);
192
193 {
194 let _permit = limiter.try_acquire("user:1");
195 assert!(limiter.try_acquire("user:1").is_none());
196 }
197
198 let permit = limiter.try_acquire("user:1");
200 assert!(permit.is_some());
201 }
202
203 #[tokio::test]
204 async fn test_concurrent_async_acquire() {
205 let limiter = Arc::new(ConcurrentLimiter::new(1));
206
207 let permit = limiter.try_acquire("user:1").unwrap();
208
209 let limiter_clone = limiter.clone();
210 let handle = tokio::spawn(async move {
211 limiter_clone.acquire("user:1").await
212 });
213
214 tokio::time::sleep(Duration::from_millis(10)).await;
216 drop(permit);
217
218 let _permit2 = handle.await.unwrap();
220 }
221}