1use crate::{
2 locks::{lock, Lock},
3 FillQueue,
4};
5use alloc::sync::{Arc, Weak};
6
7pub fn notify() -> (Notify, Listener) {
9 let inner = Arc::new(Inner {
10 wakers: FillQueue::new(),
11 });
12
13 let listener = Listener {
14 inner: Arc::downgrade(&inner),
15 };
16 return (Notify { inner }, listener);
17}
18
19#[derive(Debug)]
20struct Inner {
21 wakers: FillQueue<Lock>,
22}
23
24#[derive(Debug, Clone)]
30pub struct Notify {
31 inner: Arc<Inner>,
32}
33
34#[derive(Debug, Clone)]
35pub struct Listener {
36 inner: Weak<Inner>,
37}
38
39impl Notify {
40 pub unsafe fn into_raw(self) -> *const () {
41 Arc::into_raw(self.inner).cast()
42 }
43
44 pub unsafe fn from_raw(ptr: *const ()) -> Self {
45 Self {
46 inner: Arc::from_raw(ptr.cast()),
47 }
48 }
49
50 #[inline]
51 pub fn listeners(&self) -> usize {
52 return Arc::weak_count(&self.inner);
53 }
54
55 #[inline]
56 pub fn notify_all(&self) {
57 self.inner.wakers.chop().for_each(Lock::wake)
58 }
59
60 #[inline]
61 pub fn listen(&self) -> Listener {
62 return Listener {
63 inner: Arc::downgrade(&self.inner),
64 };
65 }
66
67 #[inline]
70 pub fn silent_drop(self) {
71 if let Ok(mut inner) = Arc::try_unwrap(self.inner) {
72 inner.wakers.chop_mut().for_each(Lock::silent_drop);
73 }
74 }
75}
76
77impl Listener {
78 pub unsafe fn into_raw(self) -> *const () {
79 Weak::into_raw(self.inner).cast()
80 }
81
82 pub unsafe fn from_raw(ptr: *const ()) -> Self {
83 Self {
84 inner: Weak::from_raw(ptr.cast()),
85 }
86 }
87
88 #[inline]
89 pub fn listeners(&self) -> usize {
90 return Weak::weak_count(&self.inner);
91 }
92
93 #[inline]
94 pub fn recv(&self) {
95 let _: bool = self.try_recv();
96 }
97
98 #[inline]
99 pub fn try_recv(&self) -> bool {
100 if let Some(inner) = self.inner.upgrade() {
101 let (lock, sub) = lock();
102 inner.wakers.push(lock);
103 sub.wait();
104 return true;
105 }
106 return false;
107 }
108}
109
110cfg_if::cfg_if! {
111 if #[cfg(feature = "futures")] {
112 use futures::{FutureExt, Stream};
113 use crate::flag::mpsc::{AsyncFlag, AsyncSubscribe, async_flag};
114 use core::task::Poll;
115 use futures::stream::FusedStream;
116
117 pub fn async_notify() -> (AsyncNotify, AsyncListener) {
119 let inner = Arc::new(AsyncInner {
120 wakers: FillQueue::new(),
121 });
122
123 let listener = AsyncListener {
124 inner: Some(Arc::downgrade(&inner)),
125 sub: None
126 };
127
128 return (AsyncNotify { inner }, listener);
129 }
130
131 #[derive(Debug)]
132 struct AsyncInner {
133 wakers: FillQueue<AsyncFlag>,
134 }
135
136 #[derive(Debug, Clone)]
142 pub struct AsyncNotify {
143 inner: Arc<AsyncInner>,
144 }
145
146 #[derive(Debug)]
147 pub struct AsyncListener {
148 inner: Option<Weak<AsyncInner>>,
149 sub: Option<AsyncSubscribe>
150 }
151
152 impl AsyncNotify {
153 pub unsafe fn into_raw(self) -> *const () {
154 Arc::into_raw(self.inner).cast()
155 }
156
157 pub unsafe fn from_raw(ptr: *const ()) -> Self {
158 Self {
159 inner: Arc::from_raw(ptr.cast()),
160 }
161 }
162
163 #[inline]
164 pub fn listeners(&self) -> usize {
165 return Arc::weak_count(&self.inner);
166 }
167
168 #[inline]
169 pub fn notify_all(&self) {
170 self.inner.wakers.chop().for_each(AsyncFlag::mark)
171 }
172
173 #[inline]
174 pub fn listen(&self) -> AsyncListener {
175 return AsyncListener {
176 inner: Some(Arc::downgrade(&self.inner)),
177 sub: None
178 };
179 }
180
181 #[inline]
184 pub fn silent_drop (self) {
185 if let Ok(mut inner) = Arc::try_unwrap(self.inner) {
186 inner.wakers.chop_mut().for_each(AsyncFlag::silent_drop);
187 }
188 }
189 }
190
191 impl AsyncListener {
192 #[inline]
193 pub fn listeners(&self) -> usize {
194 return match self.inner {
195 Some(ref inner) => Weak::weak_count(inner),
196 None => 0
197 }
198 }
199 }
200
201 impl Stream for AsyncListener {
202 type Item = ();
203
204 fn poll_next(mut self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Option<Self::Item>> {
205 if let Some(ref mut sub) = self.sub {
206 return match sub.poll_unpin(cx) {
207 Poll::Ready(_) => {
208 self.sub = None;
209 Poll::Ready(Some(()))
210 },
211 Poll::Pending => Poll::Pending
212 }
213 } else if let Some(inner) = self.inner.as_ref().and_then(Weak::upgrade) {
214 let (flag, sub) = async_flag();
215 inner.wakers.push(flag);
216 self.sub = Some(sub);
217 return self.poll_next(cx)
218 }
219
220 self.inner = None;
221 return core::task::Poll::Ready(None)
222 }
223
224 #[inline]
225 fn size_hint(&self) -> (usize, Option<usize>) {
226 match (&self.inner, &self.sub) {
227 (None, None) => (0, Some(0)),
228 (Some(inner), None) if inner.upgrade().is_none() => (0, Some(0)),
229 (None, Some(_)) => (1, Some(1)),
230 (Some(inner), Some(_)) if inner.upgrade().is_none() => (1, Some(1)),
231 (Some(_), Some(_)) => (1, None),
232 _ => (0, None)
233 }
234 }
235 }
236
237 impl FusedStream for AsyncListener {
238 #[inline]
239 fn is_terminated(&self) -> bool {
240 match (&self.inner, &self.sub) {
241 (_, Some(_)) => false,
242 (None, None) => true,
243 (Some(inner), None) => inner.upgrade().is_none(),
244 }
245 }
246 }
247
248 impl Clone for AsyncListener {
249 #[inline]
250 fn clone(&self) -> Self {
251 return Self {
252 inner: self.inner.clone(),
253 sub: None
254 }
255 }
256 }
257 }
258}
259
260#[cfg(all(feature = "std", test))]
262mod tests {
263 use super::notify;
264 use std::{
265 thread::{self},
266 time::Duration,
267 };
268
269 #[test]
270 fn test_basic_functionality() {
271 let (notify, listener) = notify();
272 assert_eq!(notify.listeners(), 1);
273
274 let listener2 = notify.listen();
275 assert_eq!(notify.listeners(), 2);
276
277 let handle = thread::spawn(move || {
278 listener2.recv();
279 });
280
281 thread::sleep(Duration::from_millis(100));
282 notify.notify_all();
283 handle.join().unwrap();
284
285 assert_eq!(notify.listeners(), 1);
286 drop(listener);
287 }
288
289 #[test]
290 fn test_multi_threaded() {
291 use std::sync::{Arc, Barrier};
292 use std::thread::JoinHandle;
293
294 let (notify, listener) = notify();
295 let barrier = Arc::new(Barrier::new(11));
296 let mut handles = vec![];
297
298 for _ in 0..10 {
299 let barrier_clone = Arc::clone(&barrier);
300 let listener_clone = listener.clone();
301 handles.push(thread::spawn(move || {
302 barrier_clone.wait();
303 listener_clone.recv();
304 }));
305 }
306
307 barrier.wait();
308 thread::sleep(Duration::from_millis(100));
309 notify.notify_all();
310
311 handles
312 .into_iter()
313 .map(JoinHandle::join)
314 .for_each(Result::unwrap);
315
316 assert_eq!(listener.listeners(), 1);
317 }
318}
319
320#[cfg(all(feature = "futures", test))]
321mod async_tests {
322 use crate::notify::async_notify;
323 use core::time::Duration;
324 use futures::stream::StreamExt;
325
326 #[tokio::test]
327 async fn test_basic_functionality_async_tokio() {
328 let (notify, listener) = async_notify();
329 assert_eq!(notify.listeners(), 1);
330
331 let mut listener2 = notify.listen();
332 let handle = tokio::spawn(async move {
333 assert_eq!(listener2.next().await, Some(()));
334 });
335
336 tokio::time::sleep(Duration::from_millis(100)).await;
337 notify.notify_all();
338
339 drop(listener);
340 handle.await.unwrap();
341 assert_eq!(notify.listeners(), 0);
342 }
343
344 #[tokio::test]
345 async fn test_multi_task_async_tokio() {
346 let (notify, listener) = async_notify();
347 let mut handles = vec![];
348
349 for _ in 0..10 {
350 let mut listener_clone = listener.clone();
351 let handle = tokio::spawn(async move {
352 assert_eq!(listener_clone.next().await, Some(()));
353 });
354
355 handles.push(handle);
356 }
357
358 drop(listener);
359 tokio::time::sleep(Duration::from_millis(100)).await;
360 notify.notify_all();
361
362 let _ = futures::future::try_join_all(handles).await.unwrap();
363 assert_eq!(notify.listeners(), 0);
364 }
365}