1use std::{
2 collections::LinkedList,
3 future::Future,
4 pin::Pin,
5 sync::{
6 Arc,
7 atomic::{AtomicBool, Ordering},
8 },
9 task::{Context, Poll, Waker},
10};
11
12use parking_lot::Mutex;
13
14struct NotifyOnceInner {
15 loaded: AtomicBool,
16 wakers: Mutex<LinkedList<Waker>>,
17}
18
19#[derive(Clone)]
49pub struct NotifyOnce(Arc<NotifyOnceInner>);
50
51impl NotifyOnce {
52 pub fn new() -> Self {
53 Self(Arc::new(NotifyOnceInner {
54 loaded: AtomicBool::new(false),
55 wakers: Mutex::new(LinkedList::new()),
56 }))
57 }
58
59 #[inline]
60 pub fn done(&self) {
61 let _self = self.0.as_ref();
62 _self.loaded.store(true, Ordering::Release);
63 {
64 let mut guard = _self.wakers.lock();
65 while let Some(waker) = guard.pop_front() {
66 waker.wake();
67 }
68 }
69 }
70
71 #[inline]
72 pub async fn wait(&self) {
73 NotifyOnceWaitFuture { inner: self.0.as_ref(), is_new: true }.await;
74 }
75}
76
77struct NotifyOnceWaitFuture<'a> {
78 inner: &'a NotifyOnceInner,
79 is_new: bool,
80}
81
82impl<'a> Future for NotifyOnceWaitFuture<'a> {
83 type Output = ();
84
85 fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
86 let _self = self.get_mut();
87 if _self.inner.loaded.load(Ordering::Acquire) {
88 return Poll::Ready(());
89 }
90 if _self.is_new {
91 {
92 let mut guard = _self.inner.wakers.lock();
93 guard.push_back(ctx.waker().clone());
94 }
95 _self.is_new = false;
96 if _self.inner.loaded.load(Ordering::Acquire) {
97 return Poll::Ready(());
98 }
99 }
100 Poll::Pending
101 }
102}
103
104#[cfg(test)]
105mod tests {
106
107 use std::sync::{
108 Arc,
109 atomic::{AtomicUsize, Ordering},
110 };
111
112 use tokio::time::{Duration, sleep};
113
114 use super::*;
115
116 #[test]
117 fn test_notify_once() {
118 let rt = tokio::runtime::Builder::new_multi_thread()
119 .enable_all()
120 .worker_threads(2)
121 .build()
122 .unwrap();
123
124 rt.block_on(async move {
125 let noti = NotifyOnce::new();
126 let done = Arc::new(AtomicBool::new(false));
127 let wait_count = Arc::new(AtomicUsize::new(0));
128 let mut th_s = Vec::new();
129 for _ in 0..10 {
130 let _noti = noti.clone();
131 let _done = done.clone();
132 let _wait_count = wait_count.clone();
133 th_s.push(tokio::spawn(async move {
134 assert_eq!(_done.load(Ordering::Acquire), false);
135 _noti.wait().await;
136 _wait_count.fetch_add(1, Ordering::SeqCst);
137 assert_eq!(_done.load(Ordering::Acquire), true);
138 }));
139 }
140 sleep(Duration::from_secs(1)).await;
141 assert_eq!(wait_count.load(Ordering::Acquire), 0);
142 done.store(true, Ordering::Release);
143 noti.done();
144 for th in th_s {
145 let _ = th.await.expect("");
146 }
147 assert_eq!(wait_count.load(Ordering::Acquire), 10);
148 });
149 }
150}