semrs/
lib.rs

1use std::sync::{Condvar, Mutex};
2
3pub struct Semaphore {
4    barrier: Condvar,
5    lock: Mutex<u32>,
6}
7
8impl Semaphore {
9    /// Create a new semaphore representing `count` resources.
10    pub fn new(count: u32) -> Semaphore {
11        Semaphore {
12            barrier: Condvar::new(),
13            lock: Mutex::new(count),
14        }
15    }
16
17    /// Consumes one resource from the semaphore.
18    /// If no resources are available, blocks the current thread until
19    /// one becomes present.
20    pub fn down(&self) {
21        let mut has_resource = false;
22        while !has_resource {
23            let mut guard = self.lock.lock().unwrap();
24            has_resource = *guard > 0;
25            if has_resource {
26                *guard -= 1;
27            } else {
28                let _unused = self.barrier.wait(guard);
29            }
30        }
31    }
32
33    /// Provides one resource back to the semaphore.
34    pub fn up(&self) {
35        *self.lock.lock().unwrap() += 1;
36        self.barrier.notify_one();
37    }
38}
39
40#[cfg(test)]
41mod tests {
42    use std::{cell::UnsafeCell, sync::Arc};
43
44    use super::*;
45
46    #[test]
47    fn test_single_thread_bin_semaphore() {
48        let s = Semaphore::new(1);
49        s.down();
50        s.up();
51    }
52
53    #[test]
54    fn test_two_thread_bin_semaphore() {
55        let s1 = Arc::new(Semaphore::new(1));
56        let s2 = s1.clone();
57        let mut r = UnsafeCell::new(0);
58        let p: usize = r.get() as usize;
59        let t1 = std::thread::spawn(move || {
60            s1.down();
61            let p: *mut i32 = p as *mut i32;
62            unsafe {
63                *p += 1;
64            }
65            s1.up();
66        });
67        let t2 = std::thread::spawn(move || {
68            s2.down();
69            let p: *mut i32 = p as *mut i32;
70            unsafe {
71                *p += 1;
72            }
73            s2.up();
74        });
75        let _ = t1.join();
76        let _ = t2.join();
77        assert!(*r.get_mut() == 2);
78    }
79}