sync_utils/
notifier.rs

1use std::{
2    collections::LinkedList,
3    future::Future,
4    pin::Pin,
5    sync::{
6        Arc,
7        atomic::{AtomicBool, Ordering},
8    },
9    task::{Context, Poll, Waker},
10};
11
12use parking_lot::Mutex;
13
14struct NotifyOnceInner {
15    loaded: AtomicBool,
16    wakers: Mutex<LinkedList<Waker>>,
17}
18
19/// NotifyOnce Assumes:
20///
21/// One coroutine issue some job, multiple coroutines wait for it to complete.
22///
23#[derive(Clone)]
24pub struct NotifyOnce(Arc<NotifyOnceInner>);
25
26impl NotifyOnce {
27    pub fn new() -> Self {
28        Self(Arc::new(NotifyOnceInner {
29            loaded: AtomicBool::new(false),
30            wakers: Mutex::new(LinkedList::new()),
31        }))
32    }
33
34    #[inline]
35    pub fn done(&self) {
36        let _self = self.0.as_ref();
37        _self.loaded.store(true, Ordering::Release);
38        {
39            let mut guard = _self.wakers.lock();
40            while let Some(waker) = guard.pop_front() {
41                waker.wake();
42            }
43        }
44    }
45
46    #[inline]
47    pub async fn wait(&self) {
48        NotifyOnceWaitFuture {
49            inner: self.0.as_ref(),
50            is_new: true,
51        }
52        .await;
53    }
54}
55
56struct NotifyOnceWaitFuture<'a> {
57    inner: &'a NotifyOnceInner,
58    is_new: bool,
59}
60
61impl<'a> Future for NotifyOnceWaitFuture<'a> {
62    type Output = ();
63
64    fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
65        let _self = self.get_mut();
66        if _self.inner.loaded.load(Ordering::Acquire) {
67            return Poll::Ready(());
68        }
69        if _self.is_new {
70            {
71                let mut guard = _self.inner.wakers.lock();
72                guard.push_back(ctx.waker().clone());
73            }
74            _self.is_new = false;
75            if _self.inner.loaded.load(Ordering::Acquire) {
76                return Poll::Ready(());
77            }
78        }
79        Poll::Pending
80    }
81}
82
83#[cfg(test)]
84mod tests {
85
86    use std::sync::{
87        Arc,
88        atomic::{AtomicUsize, Ordering},
89    };
90
91    use tokio::time::{Duration, sleep};
92
93    use super::*;
94
95    #[test]
96    fn test_notify_once() {
97        let rt = tokio::runtime::Builder::new_multi_thread()
98            .enable_all()
99            .worker_threads(2)
100            .build()
101            .unwrap();
102
103        rt.block_on(async move {
104            let noti = NotifyOnce::new();
105            let done = Arc::new(AtomicUsize::new(0));
106            let mut ths = Vec::new();
107            for _ in 0..10 {
108                let _noti = noti.clone();
109                let _done = done.clone();
110                ths.push(tokio::spawn(async move {
111                    _noti.wait().await;
112                    _done.fetch_add(1, Ordering::SeqCst);
113                }));
114            }
115            sleep(Duration::from_secs(1)).await;
116            assert_eq!(done.load(Ordering::Acquire), 0);
117            noti.done();
118            for th in ths {
119                let _ = th.await.expect("");
120            }
121            assert_eq!(done.load(Ordering::Acquire), 10);
122        });
123    }
124}