sync_utils/
notifier.rs

1use std::{
2    collections::LinkedList,
3    future::Future,
4    pin::Pin,
5    sync::{
6        Arc,
7        atomic::{AtomicBool, Ordering},
8    },
9    task::{Context, Poll, Waker},
10};
11
12use parking_lot::Mutex;
13
14struct NotifyOnceInner {
15    loaded: AtomicBool,
16    wakers: Mutex<LinkedList<Waker>>,
17}
18
19/// NotifyOnce Assumes:
20///
21/// One coroutine issue some job, multiple coroutines wait for it to complete.
22///
23#[derive(Clone)]
24pub struct NotifyOnce(Arc<NotifyOnceInner>);
25
26impl NotifyOnce {
27    pub fn new() -> Self {
28        Self(Arc::new(NotifyOnceInner {
29            loaded: AtomicBool::new(false),
30            wakers: Mutex::new(LinkedList::new()),
31        }))
32    }
33
34    #[inline]
35    pub fn done(&self) {
36        let _self = self.0.as_ref();
37        _self.loaded.store(true, Ordering::Release);
38        {
39            let mut guard = _self.wakers.lock();
40            while let Some(waker) = guard.pop_front() {
41                waker.wake();
42            }
43        }
44    }
45
46    #[inline]
47    pub async fn wait(&self) {
48        NotifyOnceWaitFuture { inner: self.0.as_ref(), is_new: true }.await;
49    }
50}
51
52struct NotifyOnceWaitFuture<'a> {
53    inner: &'a NotifyOnceInner,
54    is_new: bool,
55}
56
57impl<'a> Future for NotifyOnceWaitFuture<'a> {
58    type Output = ();
59
60    fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
61        let _self = self.get_mut();
62        if _self.inner.loaded.load(Ordering::Acquire) {
63            return Poll::Ready(());
64        }
65        if _self.is_new {
66            {
67                let mut guard = _self.inner.wakers.lock();
68                guard.push_back(ctx.waker().clone());
69            }
70            _self.is_new = false;
71            if _self.inner.loaded.load(Ordering::Acquire) {
72                return Poll::Ready(());
73            }
74        }
75        Poll::Pending
76    }
77}
78
79#[cfg(test)]
80mod tests {
81
82    use std::sync::{
83        Arc,
84        atomic::{AtomicUsize, Ordering},
85    };
86
87    use tokio::time::{Duration, sleep};
88
89    use super::*;
90
91    #[test]
92    fn test_notify_once() {
93        let rt = tokio::runtime::Builder::new_multi_thread()
94            .enable_all()
95            .worker_threads(2)
96            .build()
97            .unwrap();
98
99        rt.block_on(async move {
100            let noti = NotifyOnce::new();
101            let done = Arc::new(AtomicUsize::new(0));
102            let mut ths = Vec::new();
103            for _ in 0..10 {
104                let _noti = noti.clone();
105                let _done = done.clone();
106                ths.push(tokio::spawn(async move {
107                    _noti.wait().await;
108                    _done.fetch_add(1, Ordering::SeqCst);
109                }));
110            }
111            sleep(Duration::from_secs(1)).await;
112            assert_eq!(done.load(Ordering::Acquire), 0);
113            noti.done();
114            for th in ths {
115                let _ = th.await.expect("");
116            }
117            assert_eq!(done.load(Ordering::Acquire), 10);
118        });
119    }
120}