skp_ratelimit/algorithm/
concurrent.rs

1//! Concurrent request limiter.
2//!
3//! Unlike rate limiters that limit requests over time, this limits
4//! the number of simultaneous in-flight requests.
5
6use std::sync::Arc;
7use std::time::Duration;
8
9use dashmap::DashMap;
10use tokio::sync::Semaphore;
11
12/// Concurrent request limiter.
13///
14/// Limits the number of simultaneous in-flight requests per key.
15/// Unlike rate limiting, this tracks active requests that haven't completed yet.
16///
17/// # Example
18///
19/// ```ignore
20/// use oc_ratelimit_advanced::ConcurrentLimiter;
21///
22/// let limiter = ConcurrentLimiter::new(10); // Max 10 concurrent requests
23///
24/// // Acquire a permit
25/// if let Some(permit) = limiter.try_acquire("user:123") {
26///     // Process request...
27///     // Permit is automatically released when dropped
28/// }
29/// ```
30pub struct ConcurrentLimiter {
31    max_concurrent: u32,
32    semaphores: Arc<DashMap<String, Arc<Semaphore>>>,
33    counts: Arc<DashMap<String, u32>>,
34}
35
36impl std::fmt::Debug for ConcurrentLimiter {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        f.debug_struct("ConcurrentLimiter")
39            .field("max_concurrent", &self.max_concurrent)
40            .field("active_keys", &self.semaphores.len())
41            .finish()
42    }
43}
44
45impl Clone for ConcurrentLimiter {
46    fn clone(&self) -> Self {
47        Self {
48            max_concurrent: self.max_concurrent,
49            semaphores: self.semaphores.clone(),
50            counts: self.counts.clone(),
51        }
52    }
53}
54
55impl ConcurrentLimiter {
56    /// Create a new concurrent limiter.
57    pub fn new(max_concurrent: u32) -> Self {
58        Self {
59            max_concurrent,
60            semaphores: Arc::new(DashMap::new()),
61            counts: Arc::new(DashMap::new()),
62        }
63    }
64
65    /// Try to acquire a permit for the given key.
66    ///
67    /// Returns `Some(ConcurrentPermit)` if successful, `None` if at limit.
68    /// The permit automatically releases when dropped.
69    pub fn try_acquire(&self, key: &str) -> Option<ConcurrentPermit> {
70        let semaphore = self
71            .semaphores
72            .entry(key.to_string())
73            .or_insert_with(|| Arc::new(Semaphore::new(self.max_concurrent as usize)))
74            .clone();
75
76        // Try to acquire without blocking
77        match semaphore.clone().try_acquire_owned() {
78            Ok(permit) => {
79                // Increment count
80                *self.counts.entry(key.to_string()).or_insert(0) += 1;
81
82                Some(ConcurrentPermit {
83                    _permit: permit,
84                    key: key.to_string(),
85                    counts: self.counts.clone(),
86                })
87            }
88            Err(_) => None,
89        }
90    }
91
92    /// Acquire a permit, waiting if necessary.
93    pub async fn acquire(&self, key: &str) -> ConcurrentPermit {
94        let semaphore = self
95            .semaphores
96            .entry(key.to_string())
97            .or_insert_with(|| Arc::new(Semaphore::new(self.max_concurrent as usize)))
98            .clone();
99
100        let permit = semaphore.acquire_owned().await.expect("Semaphore closed");
101
102        *self.counts.entry(key.to_string()).or_insert(0) += 1;
103
104        ConcurrentPermit {
105            _permit: permit,
106            key: key.to_string(),
107            counts: self.counts.clone(),
108        }
109    }
110
111    /// Acquire a permit with a timeout.
112    pub async fn acquire_timeout(
113        &self,
114        key: &str,
115        timeout: Duration,
116    ) -> Option<ConcurrentPermit> {
117        tokio::time::timeout(timeout, self.acquire(key))
118            .await
119            .ok()
120    }
121
122    /// Get the current count of active requests for a key.
123    pub fn current_count(&self, key: &str) -> u32 {
124        self.counts.get(key).map(|c| *c).unwrap_or(0)
125    }
126
127    /// Get the maximum concurrent requests allowed.
128    pub fn max_concurrent(&self) -> u32 {
129        self.max_concurrent
130    }
131
132    /// Get remaining slots for a key.
133    pub fn remaining(&self, key: &str) -> u32 {
134        self.max_concurrent.saturating_sub(self.current_count(key))
135    }
136}
137
138/// A permit for a concurrent request.
139///
140/// While held, this counts against the concurrent limit.
141/// Automatically releases when dropped.
142pub struct ConcurrentPermit {
143    _permit: tokio::sync::OwnedSemaphorePermit,
144    key: String,
145    counts: Arc<DashMap<String, u32>>,
146}
147
148impl Drop for ConcurrentPermit {
149    fn drop(&mut self) {
150        if let Some(mut count) = self.counts.get_mut(&self.key) {
151            *count = count.saturating_sub(1);
152        }
153    }
154}
155
156impl std::fmt::Debug for ConcurrentPermit {
157    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158        f.debug_struct("ConcurrentPermit")
159            .field("key", &self.key)
160            .finish()
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167
168    #[tokio::test]
169    async fn test_concurrent_basic() {
170        let limiter = ConcurrentLimiter::new(2);
171
172        let permit1 = limiter.try_acquire("user:1");
173        assert!(permit1.is_some());
174        assert_eq!(limiter.current_count("user:1"), 1);
175
176        let permit2 = limiter.try_acquire("user:1");
177        assert!(permit2.is_some());
178        assert_eq!(limiter.current_count("user:1"), 2);
179
180        // Third should fail
181        let permit3 = limiter.try_acquire("user:1");
182        assert!(permit3.is_none());
183
184        // Different key should work
185        let permit_other = limiter.try_acquire("user:2");
186        assert!(permit_other.is_some());
187    }
188
189    #[tokio::test]
190    async fn test_concurrent_release() {
191        let limiter = ConcurrentLimiter::new(1);
192
193        {
194            let _permit = limiter.try_acquire("user:1");
195            assert!(limiter.try_acquire("user:1").is_none());
196        }
197
198        // After drop, should be able to acquire again
199        let permit = limiter.try_acquire("user:1");
200        assert!(permit.is_some());
201    }
202
203    #[tokio::test]
204    async fn test_concurrent_async_acquire() {
205        let limiter = Arc::new(ConcurrentLimiter::new(1));
206
207        let permit = limiter.try_acquire("user:1").unwrap();
208
209        let limiter_clone = limiter.clone();
210        let handle = tokio::spawn(async move {
211            limiter_clone.acquire("user:1").await
212        });
213
214        // Short delay then release
215        tokio::time::sleep(Duration::from_millis(10)).await;
216        drop(permit);
217
218        // Waiting acquire should complete
219        let _permit2 = handle.await.unwrap();
220    }
221}