wait_counter/
lib.rs

1use std::sync::atomic::{AtomicUsize, Ordering};
2use std::sync::Arc;
3use tokio::sync::Notify;
4
5struct Inner {
6    count: AtomicUsize,
7    notify: Notify,
8}
9
10pub struct WaitCounter {
11    inner: Arc<Inner>,
12}
13
14impl WaitCounter {
15    pub fn new() -> Self {
16        let inner = Inner {
17            count: AtomicUsize::new(1),
18            notify: Notify::new(),
19        };
20        Self {
21            inner: Arc::new(inner),
22        }
23    }
24
25    pub fn wake_clone(&self) -> Self {
26        Self {
27            inner: self.inner.clone(),
28        }
29    }
30
31    pub async fn wait(&self) {
32        loop {
33            // Use Acquire ordering to ensure the latest write is seen
34            let current = self.inner.count.load(Ordering::Acquire);
35            if current == 1 {
36                break;
37            }
38            // Waiting for notification, may cause false wakeup, so loop checking is required
39            self.inner.notify.notified().await;
40        }
41    }
42}
43
44impl Clone for WaitCounter {
45    fn clone(&self) -> Self {
46        // Use Relaxed ordering because no other memory synchronization is required here
47        self.inner.count.fetch_add(1, Ordering::Relaxed);
48        Self {
49            inner: self.inner.clone(),
50        }
51    }
52}
53
54impl Drop for WaitCounter {
55    fn drop(&mut self) {
56        // Use Release ordering to ensure previous writes complete before decrementing count
57        let prev = self.inner.count.fetch_sub(1, Ordering::Release);
58        // When the previous value is 2, it decreases to 1, triggering a notification
59        if prev == 2 {
60            self.inner.notify.notify_waiters();
61        }
62    }
63}
64#[cfg(test)]
65mod test {
66    use crate::WaitCounter;
67    use std::time::Duration;
68
69    #[tokio::test]
70    async fn test_wait_counter() {
71        let counter = WaitCounter::new();
72        let cloned = counter.clone();
73
74        tokio::spawn(async move {
75            tokio::time::sleep(Duration::from_millis(1000)).await;
76            drop(cloned);
77        });
78
79        counter.wait().await;
80        println!("Counter reached 1");
81    }
82}