1use std::{
2 collections::LinkedList,
3 future::Future,
4 pin::Pin,
5 sync::{
6 Arc,
7 atomic::{AtomicBool, Ordering},
8 },
9 task::{Context, Poll, Waker},
10};
11
12use parking_lot::Mutex;
13
14struct NotifyOnceInner {
15 loaded: AtomicBool,
16 wakers: Mutex<LinkedList<Waker>>,
17}
18
19#[derive(Clone)]
24pub struct NotifyOnce(Arc<NotifyOnceInner>);
25
26impl NotifyOnce {
27 pub fn new() -> Self {
28 Self(Arc::new(NotifyOnceInner {
29 loaded: AtomicBool::new(false),
30 wakers: Mutex::new(LinkedList::new()),
31 }))
32 }
33
34 #[inline]
35 pub fn done(&self) {
36 let _self = self.0.as_ref();
37 _self.loaded.store(true, Ordering::Release);
38 {
39 let mut guard = _self.wakers.lock();
40 while let Some(waker) = guard.pop_front() {
41 waker.wake();
42 }
43 }
44 }
45
46 #[inline]
47 pub async fn wait(&self) {
48 NotifyOnceWaitFuture {
49 inner: self.0.as_ref(),
50 is_new: true,
51 }
52 .await;
53 }
54}
55
56struct NotifyOnceWaitFuture<'a> {
57 inner: &'a NotifyOnceInner,
58 is_new: bool,
59}
60
61impl<'a> Future for NotifyOnceWaitFuture<'a> {
62 type Output = ();
63
64 fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
65 let _self = self.get_mut();
66 if _self.inner.loaded.load(Ordering::Acquire) {
67 return Poll::Ready(());
68 }
69 if _self.is_new {
70 {
71 let mut guard = _self.inner.wakers.lock();
72 guard.push_back(ctx.waker().clone());
73 }
74 _self.is_new = false;
75 if _self.inner.loaded.load(Ordering::Acquire) {
76 return Poll::Ready(());
77 }
78 }
79 Poll::Pending
80 }
81}
82
83#[cfg(test)]
84mod tests {
85
86 use std::sync::{
87 Arc,
88 atomic::{AtomicUsize, Ordering},
89 };
90
91 use tokio::time::{Duration, sleep};
92
93 use super::*;
94
95 #[test]
96 fn test_notify_once() {
97 let rt = tokio::runtime::Builder::new_multi_thread()
98 .enable_all()
99 .worker_threads(2)
100 .build()
101 .unwrap();
102
103 rt.block_on(async move {
104 let noti = NotifyOnce::new();
105 let done = Arc::new(AtomicUsize::new(0));
106 let mut ths = Vec::new();
107 for _ in 0..10 {
108 let _noti = noti.clone();
109 let _done = done.clone();
110 ths.push(tokio::spawn(async move {
111 _noti.wait().await;
112 _done.fetch_add(1, Ordering::SeqCst);
113 }));
114 }
115 sleep(Duration::from_secs(1)).await;
116 assert_eq!(done.load(Ordering::Acquire), 0);
117 noti.done();
118 for th in ths {
119 let _ = th.await.expect("");
120 }
121 assert_eq!(done.load(Ordering::Acquire), 10);
122 });
123 }
124}