Skip to main content

worktrunk/
sync.rs

1//! Synchronization primitives for worktrunk.
2
3use std::sync::{Arc, Condvar, Mutex};
4
5/// A counting semaphore for limiting concurrency.
6///
7/// Used to prevent resource exhaustion when many parallel operations need
8/// to run. Provides RAII-based permit management through [`SemaphoreGuard`].
9#[derive(Clone)]
10pub struct Semaphore {
11    state: Arc<(Mutex<usize>, Condvar)>,
12}
13
14/// RAII guard that releases a semaphore permit on drop.
15///
16/// Created by [`Semaphore::acquire`]. The permit is automatically released
17/// when this guard is dropped, even if the code panics.
18pub struct SemaphoreGuard {
19    state: Arc<(Mutex<usize>, Condvar)>,
20}
21
22impl Semaphore {
23    /// Create a new semaphore with the given number of permits.
24    pub fn new(permits: usize) -> Self {
25        Self {
26            state: Arc::new((Mutex::new(permits), Condvar::new())),
27        }
28    }
29
30    /// Acquire a permit, blocking until one is available.
31    ///
32    /// Returns a guard that releases the permit when dropped.
33    pub fn acquire(&self) -> SemaphoreGuard {
34        let (lock, cvar) = &*self.state;
35        let mut available = lock.lock().unwrap();
36
37        // Wait until a permit is available
38        while *available == 0 {
39            available = cvar.wait(available).unwrap();
40        }
41
42        // Take a permit
43        *available -= 1;
44
45        SemaphoreGuard {
46            state: Arc::clone(&self.state),
47        }
48    }
49}
50
51impl Drop for SemaphoreGuard {
52    fn drop(&mut self) {
53        let (lock, cvar) = &*self.state;
54        let mut available = lock.lock().unwrap();
55        *available += 1;
56        cvar.notify_one();
57    }
58}
59
60#[cfg(test)]
61mod tests {
62    use super::*;
63    use std::sync::atomic::{AtomicUsize, Ordering};
64    use std::thread;
65    use std::time::Duration;
66
67    #[test]
68    fn test_semaphore_limits_concurrency() {
69        let sem = Semaphore::new(2);
70        let counter = Arc::new(AtomicUsize::new(0));
71        let max_concurrent = Arc::new(AtomicUsize::new(0));
72
73        let mut handles = vec![];
74
75        for _ in 0..10 {
76            let sem = sem.clone();
77            let counter = Arc::clone(&counter);
78            let max_concurrent = Arc::clone(&max_concurrent);
79
80            let handle = thread::spawn(move || {
81                let _guard = sem.acquire();
82
83                // Increment counter
84                let current = counter.fetch_add(1, Ordering::SeqCst) + 1;
85
86                // Track max concurrent
87                max_concurrent.fetch_max(current, Ordering::SeqCst);
88
89                // Simulate work
90                thread::sleep(Duration::from_millis(10));
91
92                // Decrement counter
93                counter.fetch_sub(1, Ordering::SeqCst);
94            });
95
96            handles.push(handle);
97        }
98
99        for handle in handles {
100            handle.join().unwrap();
101        }
102
103        // Should never have more than 2 threads running concurrently
104        assert!(max_concurrent.load(Ordering::SeqCst) <= 2);
105    }
106}