sync_utils/
notifier.rs

1use std::{
2    collections::LinkedList,
3    future::Future,
4    pin::Pin,
5    sync::{
6        Arc,
7        atomic::{AtomicBool, Ordering},
8    },
9    task::{Context, Poll, Waker},
10};
11
12use parking_lot::Mutex;
13
14struct NotifyOnceInner {
15    loaded: AtomicBool,
16    wakers: Mutex<LinkedList<Waker>>,
17}
18
19/// NotifyOnce Assumes:
20///
21/// One coroutine issue some loading job, multiple coroutines wait for it to complete.
22///
23/// ## example:
24///
25/// ``` rust
26///
27/// async fn foo() {
28///     use sync_utils::notifier::NotifyOnce;
29///     use tokio::time::*;
30///     use std::sync::{Arc, atomic::{AtomicBool, AtomicUsize, Ordering}};
31///     let noti = NotifyOnce::new();
32///     let done = Arc::new(AtomicBool::new(false));
33///     for _ in 0..10 {
34///         let _noti = noti.clone();
35///         let _done = done.clone();
36///         tokio::spawn(async move {
37///             assert_eq!(_done.load(Ordering::Acquire), false);
38///             _noti.wait().await;
39///             assert_eq!(_done.load(Ordering::Acquire), true);
40///         });
41///     }
42///     sleep(Duration::from_secs(1)).await;
43///     done.store(true, Ordering::Release);
44///     noti.done();
45/// }
46/// ```
47
48#[derive(Clone)]
49pub struct NotifyOnce(Arc<NotifyOnceInner>);
50
51impl NotifyOnce {
52    pub fn new() -> Self {
53        Self(Arc::new(NotifyOnceInner {
54            loaded: AtomicBool::new(false),
55            wakers: Mutex::new(LinkedList::new()),
56        }))
57    }
58
59    #[inline]
60    pub fn done(&self) {
61        let _self = self.0.as_ref();
62        _self.loaded.store(true, Ordering::Release);
63        {
64            let mut guard = _self.wakers.lock();
65            while let Some(waker) = guard.pop_front() {
66                waker.wake();
67            }
68        }
69    }
70
71    #[inline]
72    pub async fn wait(&self) {
73        NotifyOnceWaitFuture { inner: self.0.as_ref(), is_new: true }.await;
74    }
75}
76
77struct NotifyOnceWaitFuture<'a> {
78    inner: &'a NotifyOnceInner,
79    is_new: bool,
80}
81
82impl<'a> Future for NotifyOnceWaitFuture<'a> {
83    type Output = ();
84
85    fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
86        let _self = self.get_mut();
87        if _self.inner.loaded.load(Ordering::Acquire) {
88            return Poll::Ready(());
89        }
90        if _self.is_new {
91            {
92                let mut guard = _self.inner.wakers.lock();
93                guard.push_back(ctx.waker().clone());
94            }
95            _self.is_new = false;
96            if _self.inner.loaded.load(Ordering::Acquire) {
97                return Poll::Ready(());
98            }
99        }
100        Poll::Pending
101    }
102}
103
104#[cfg(test)]
105mod tests {
106
107    use std::sync::{
108        Arc,
109        atomic::{AtomicUsize, Ordering},
110    };
111
112    use tokio::time::{Duration, sleep};
113
114    use super::*;
115
116    #[test]
117    fn test_notify_once() {
118        let rt = tokio::runtime::Builder::new_multi_thread()
119            .enable_all()
120            .worker_threads(2)
121            .build()
122            .unwrap();
123
124        rt.block_on(async move {
125            let noti = NotifyOnce::new();
126            let done = Arc::new(AtomicBool::new(false));
127            let wait_count = Arc::new(AtomicUsize::new(0));
128            let mut th_s = Vec::new();
129            for _ in 0..10 {
130                let _noti = noti.clone();
131                let _done = done.clone();
132                let _wait_count = wait_count.clone();
133                th_s.push(tokio::spawn(async move {
134                    assert_eq!(_done.load(Ordering::Acquire), false);
135                    _noti.wait().await;
136                    _wait_count.fetch_add(1, Ordering::SeqCst);
137                    assert_eq!(_done.load(Ordering::Acquire), true);
138                }));
139            }
140            sleep(Duration::from_secs(1)).await;
141            assert_eq!(wait_count.load(Ordering::Acquire), 0);
142            done.store(true, Ordering::Release);
143            noti.done();
144            for th in th_s {
145                let _ = th.await.expect("");
146            }
147            assert_eq!(wait_count.load(Ordering::Acquire), 10);
148        });
149    }
150}