1use std::sync::atomic::{AtomicUsize, Ordering};
2use std::sync::Arc;
3use tokio::sync::Notify;
4
5struct Inner {
6 count: AtomicUsize,
7 notify: Notify,
8}
9
10pub struct WaitCounter {
11 inner: Arc<Inner>,
12}
13
14impl WaitCounter {
15 pub fn new() -> Self {
16 let inner = Inner {
17 count: AtomicUsize::new(1),
18 notify: Notify::new(),
19 };
20 Self {
21 inner: Arc::new(inner),
22 }
23 }
24
25 pub fn wake_clone(&self) -> Self {
26 Self {
27 inner: self.inner.clone(),
28 }
29 }
30
31 pub async fn wait(&self) {
32 loop {
33 let current = self.inner.count.load(Ordering::Acquire);
35 if current == 1 {
36 break;
37 }
38 self.inner.notify.notified().await;
40 }
41 }
42}
43
44impl Clone for WaitCounter {
45 fn clone(&self) -> Self {
46 self.inner.count.fetch_add(1, Ordering::Relaxed);
48 Self {
49 inner: self.inner.clone(),
50 }
51 }
52}
53
54impl Drop for WaitCounter {
55 fn drop(&mut self) {
56 let prev = self.inner.count.fetch_sub(1, Ordering::Release);
58 if prev == 2 {
60 self.inner.notify.notify_waiters();
61 }
62 }
63}
64#[cfg(test)]
65mod test {
66 use crate::WaitCounter;
67 use std::time::Duration;
68
69 #[tokio::test]
70 async fn test_wait_counter() {
71 let counter = WaitCounter::new();
72 let cloned = counter.clone();
73
74 tokio::spawn(async move {
75 tokio::time::sleep(Duration::from_millis(1000)).await;
76 drop(cloned);
77 });
78
79 counter.wait().await;
80 println!("Counter reached 1");
81 }
82}