rust_expect/util/
backpressure.rs

1//! Backpressure and flow control utilities.
2//!
3//! This module provides utilities for managing data flow between
4//! producers and consumers with backpressure.
5//!
6//! # Overview
7//!
8//! The [`Backpressure`] controller provides two complementary mechanisms:
9//!
10//! 1. **Buffer size tracking**: Track how much data is buffered and block when full.
11//!    Use [`try_acquire`](Backpressure::try_acquire), [`acquire`](Backpressure::acquire),
12//!    and [`release`](Backpressure::release) for this.
13//!
14//! 2. **Concurrent operation limiting**: Limit the number of concurrent operations
15//!    using semaphore-based permits. Use [`try_acquire_permit`](Backpressure::try_acquire_permit)
16//!    and [`acquire_permit`](Backpressure::acquire_permit) for this.
17
18use std::sync::Arc;
19use std::sync::atomic::{AtomicUsize, Ordering};
20
21use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
22
23/// A backpressure controller for limiting data flow and concurrent operations.
24///
25/// This controller provides two mechanisms:
26///
27/// - **Buffer size tracking**: For tracking how much data is in-flight or buffered.
28/// - **Permit-based concurrency**: For limiting the number of concurrent operations.
29///
30/// Both mechanisms use the same `max_size` limit but operate independently,
31/// allowing flexible backpressure strategies.
32#[derive(Debug)]
33pub struct Backpressure {
34    /// Maximum buffer size / concurrent operations.
35    max_size: usize,
36    /// Current buffer size (for size-based tracking).
37    current: AtomicUsize,
38    /// Notify when space becomes available.
39    space_available: Notify,
40    /// Semaphore for permit-based concurrency limiting.
41    /// Wrapped in Arc to allow owned permits.
42    semaphore: Arc<Semaphore>,
43}
44
45impl Backpressure {
46    /// Create a new backpressure controller.
47    ///
48    /// The `max_size` parameter controls both:
49    /// - The maximum buffer size for size-based tracking
50    /// - The number of available permits for concurrency limiting
51    #[must_use]
52    pub fn new(max_size: usize) -> Self {
53        Self {
54            max_size,
55            current: AtomicUsize::new(0),
56            space_available: Notify::new(),
57            semaphore: Arc::new(Semaphore::new(max_size)),
58        }
59    }
60
61    // ========================================================================
62    // Buffer Size Tracking Methods
63    // ========================================================================
64
65    /// Try to acquire space for the given amount.
66    ///
67    /// Returns true if space was acquired, false if the buffer is full.
68    ///
69    /// This is part of the buffer size tracking mechanism. Use [`release`](Self::release)
70    /// to return the space when done.
71    pub fn try_acquire(&self, amount: usize) -> bool {
72        let current = self.current.load(Ordering::Acquire);
73        if current + amount <= self.max_size {
74            self.current.fetch_add(amount, Ordering::Release);
75            true
76        } else {
77            false
78        }
79    }
80
81    /// Acquire space for the given amount, waiting if necessary.
82    ///
83    /// This is part of the buffer size tracking mechanism. Use [`release`](Self::release)
84    /// to return the space when done.
85    pub async fn acquire(&self, amount: usize) {
86        loop {
87            if self.try_acquire(amount) {
88                return;
89            }
90            self.space_available.notified().await;
91        }
92    }
93
94    /// Release the given amount of space.
95    ///
96    /// This is part of the buffer size tracking mechanism. Call this after
97    /// data has been processed/consumed to free space for new data.
98    pub fn release(&self, amount: usize) {
99        self.current.fetch_sub(amount, Ordering::Release);
100        self.space_available.notify_one();
101    }
102
103    /// Get the current buffer usage.
104    #[must_use]
105    pub fn current_size(&self) -> usize {
106        self.current.load(Ordering::Acquire)
107    }
108
109    /// Get the maximum buffer size.
110    #[must_use]
111    pub const fn max_size(&self) -> usize {
112        self.max_size
113    }
114
115    /// Check if the buffer is full.
116    #[must_use]
117    pub fn is_full(&self) -> bool {
118        self.current_size() >= self.max_size
119    }
120
121    /// Get the available space.
122    #[must_use]
123    pub fn available(&self) -> usize {
124        self.max_size.saturating_sub(self.current_size())
125    }
126
127    // ========================================================================
128    // Permit-Based Concurrency Limiting Methods
129    // ========================================================================
130
131    /// Try to acquire a permit for a concurrent operation.
132    ///
133    /// Returns `Ok(permit)` if a permit was acquired, or an error if no permits
134    /// are available. The permit is automatically released when dropped.
135    ///
136    /// This is part of the permit-based concurrency mechanism. Use this when
137    /// you want to limit the number of concurrent operations rather than
138    /// tracking buffer sizes.
139    ///
140    /// # Example
141    ///
142    /// ```
143    /// use rust_expect::util::backpressure::Backpressure;
144    ///
145    /// let bp = Backpressure::new(2); // Allow 2 concurrent operations
146    ///
147    /// let permit1 = bp.try_acquire_permit().unwrap();
148    /// let permit2 = bp.try_acquire_permit().unwrap();
149    ///
150    /// // Third attempt fails - at capacity
151    /// assert!(bp.try_acquire_permit().is_err());
152    ///
153    /// // Dropping a permit frees it
154    /// drop(permit1);
155    /// let permit3 = bp.try_acquire_permit().unwrap();
156    /// ```
157    pub fn try_acquire_permit(&self) -> Result<OwnedSemaphorePermit, TryAcquireError> {
158        self.semaphore.clone().try_acquire_owned()
159    }
160
161    /// Acquire a permit for a concurrent operation, waiting if necessary.
162    ///
163    /// Returns a permit that is automatically released when dropped.
164    ///
165    /// This is part of the permit-based concurrency mechanism. Use this when
166    /// you want to limit the number of concurrent operations.
167    ///
168    /// # Example
169    ///
170    /// ```no_run
171    /// use rust_expect::util::backpressure::Backpressure;
172    ///
173    /// # async fn example() {
174    /// let bp = Backpressure::new(10);
175    ///
176    /// // Acquire a permit - will wait if none available
177    /// let permit = bp.acquire_permit().await;
178    ///
179    /// // Do work while holding the permit
180    /// // ...
181    ///
182    /// // Permit is released when dropped
183    /// drop(permit);
184    /// # }
185    /// ```
186    pub async fn acquire_permit(&self) -> OwnedSemaphorePermit {
187        self.semaphore
188            .clone()
189            .acquire_owned()
190            .await
191            .expect("semaphore should not be closed")
192    }
193
194    /// Get the number of available permits.
195    #[must_use]
196    pub fn available_permits(&self) -> usize {
197        self.semaphore.available_permits()
198    }
199}
200
201impl Default for Backpressure {
202    fn default() -> Self {
203        Self::new(64 * 1024) // 64KB default
204    }
205}
206
207/// A rate limiter for controlling operation frequency.
208#[derive(Debug)]
209pub struct RateLimiter {
210    /// Maximum operations per interval.
211    max_ops: usize,
212    /// Interval duration in milliseconds.
213    interval_ms: u64,
214    /// Current operation count.
215    current: AtomicUsize,
216    /// Last reset time.
217    last_reset: std::sync::Mutex<std::time::Instant>,
218}
219
220impl RateLimiter {
221    /// Create a new rate limiter.
222    #[must_use]
223    pub fn new(max_ops: usize, interval: std::time::Duration) -> Self {
224        Self {
225            max_ops,
226            interval_ms: interval.as_millis() as u64,
227            current: AtomicUsize::new(0),
228            last_reset: std::sync::Mutex::new(std::time::Instant::now()),
229        }
230    }
231
232    /// Try to perform an operation.
233    ///
234    /// Returns true if the operation is allowed, false if rate limited.
235    pub fn try_acquire(&self) -> bool {
236        self.maybe_reset();
237
238        let current = self.current.fetch_add(1, Ordering::AcqRel);
239        if current < self.max_ops {
240            true
241        } else {
242            self.current.fetch_sub(1, Ordering::Release);
243            false
244        }
245    }
246
247    /// Perform an operation, waiting if necessary.
248    pub async fn acquire(&self) {
249        while !self.try_acquire() {
250            let sleep_time = self.time_until_reset();
251            tokio::time::sleep(sleep_time).await;
252        }
253    }
254
255    /// Reset the counter if the interval has elapsed.
256    fn maybe_reset(&self) {
257        let mut last_reset = self
258            .last_reset
259            .lock()
260            .unwrap_or_else(std::sync::PoisonError::into_inner);
261        let elapsed = last_reset.elapsed();
262
263        if elapsed.as_millis() as u64 >= self.interval_ms {
264            self.current.store(0, Ordering::Release);
265            *last_reset = std::time::Instant::now();
266        }
267    }
268
269    /// Get the time until the next reset.
270    #[allow(clippy::significant_drop_tightening)]
271    fn time_until_reset(&self) -> std::time::Duration {
272        let last_reset = self
273            .last_reset
274            .lock()
275            .unwrap_or_else(std::sync::PoisonError::into_inner);
276        let elapsed = last_reset.elapsed();
277        let interval = std::time::Duration::from_millis(self.interval_ms);
278
279        if elapsed >= interval {
280            std::time::Duration::ZERO
281        } else {
282            interval.checked_sub(elapsed).unwrap()
283        }
284    }
285}
286
287/// A token bucket for rate limiting with bursts.
288#[derive(Debug)]
289pub struct TokenBucket {
290    /// Maximum tokens in the bucket.
291    capacity: usize,
292    /// Current tokens.
293    tokens: AtomicUsize,
294    /// Token refill rate (per second).
295    refill_rate: f64,
296    /// Last refill time.
297    last_refill: std::sync::Mutex<std::time::Instant>,
298}
299
300impl TokenBucket {
301    /// Create a new token bucket.
302    #[must_use]
303    pub fn new(capacity: usize, refill_rate: f64) -> Self {
304        Self {
305            capacity,
306            tokens: AtomicUsize::new(capacity),
307            refill_rate,
308            last_refill: std::sync::Mutex::new(std::time::Instant::now()),
309        }
310    }
311
312    /// Try to consume tokens.
313    pub fn try_consume(&self, count: usize) -> bool {
314        self.refill();
315
316        loop {
317            let current = self.tokens.load(Ordering::Acquire);
318            if current < count {
319                return false;
320            }
321
322            if self
323                .tokens
324                .compare_exchange(
325                    current,
326                    current - count,
327                    Ordering::AcqRel,
328                    Ordering::Acquire,
329                )
330                .is_ok()
331            {
332                return true;
333            }
334        }
335    }
336
337    /// Consume tokens, waiting if necessary.
338    pub async fn consume(&self, count: usize) {
339        while !self.try_consume(count) {
340            let wait_time = self.time_for_tokens(count);
341            tokio::time::sleep(wait_time).await;
342        }
343    }
344
345    /// Refill tokens based on elapsed time.
346    fn refill(&self) {
347        let mut last_refill = self
348            .last_refill
349            .lock()
350            .unwrap_or_else(std::sync::PoisonError::into_inner);
351        let elapsed = last_refill.elapsed().as_secs_f64();
352        let new_tokens = (elapsed * self.refill_rate) as usize;
353
354        if new_tokens > 0 {
355            let current = self.tokens.load(Ordering::Acquire);
356            let new_value = (current + new_tokens).min(self.capacity);
357            self.tokens.store(new_value, Ordering::Release);
358            *last_refill = std::time::Instant::now();
359        }
360    }
361
362    /// Get the time needed to have the specified number of tokens.
363    fn time_for_tokens(&self, count: usize) -> std::time::Duration {
364        let current = self.tokens.load(Ordering::Acquire);
365        if current >= count {
366            return std::time::Duration::ZERO;
367        }
368
369        let needed = count - current;
370        let seconds = needed as f64 / self.refill_rate;
371        std::time::Duration::from_secs_f64(seconds)
372    }
373
374    /// Get the current token count.
375    #[must_use]
376    pub fn tokens(&self) -> usize {
377        self.refill();
378        self.tokens.load(Ordering::Acquire)
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385
386    #[test]
387    fn backpressure_acquire() {
388        let bp = Backpressure::new(100);
389        assert!(bp.try_acquire(50));
390        assert!(bp.try_acquire(50));
391        assert!(!bp.try_acquire(1)); // Full
392
393        bp.release(50);
394        assert!(bp.try_acquire(50));
395    }
396
397    #[test]
398    fn backpressure_permits() {
399        let bp = Backpressure::new(3);
400
401        // Acquire all permits
402        let p1 = bp.try_acquire_permit().unwrap();
403        let p2 = bp.try_acquire_permit().unwrap();
404        let p3 = bp.try_acquire_permit().unwrap();
405
406        // Fourth should fail
407        assert!(bp.try_acquire_permit().is_err());
408        assert_eq!(bp.available_permits(), 0);
409
410        // Dropping releases the permit
411        drop(p1);
412        assert_eq!(bp.available_permits(), 1);
413
414        // Now we can acquire again
415        let _p4 = bp.try_acquire_permit().unwrap();
416
417        // Clean up
418        drop(p2);
419        drop(p3);
420    }
421
422    #[tokio::test]
423    async fn backpressure_async_permit() {
424        let bp = Backpressure::new(2);
425
426        let permit1 = bp.acquire_permit().await;
427        let permit2 = bp.acquire_permit().await;
428        assert_eq!(bp.available_permits(), 0);
429
430        drop(permit1);
431        assert_eq!(bp.available_permits(), 1);
432
433        drop(permit2);
434        assert_eq!(bp.available_permits(), 2);
435    }
436
437    #[test]
438    fn rate_limiter_basic() {
439        let limiter = RateLimiter::new(5, std::time::Duration::from_secs(1));
440
441        for _ in 0..5 {
442            assert!(limiter.try_acquire());
443        }
444        assert!(!limiter.try_acquire()); // Rate limited
445    }
446
447    #[test]
448    fn token_bucket_basic() {
449        let bucket = TokenBucket::new(10, 5.0);
450
451        assert!(bucket.try_consume(5));
452        assert!(bucket.try_consume(5));
453        assert!(!bucket.try_consume(1)); // Empty
454    }
455}