utils_atomics/flag/
mpsc.rs1use crate::locks::{lock, Lock};
2use alloc::sync::{Arc, Weak};
3use core::{cell::UnsafeCell, fmt::Debug};
4use docfg::docfg;
5
6#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
8pub fn flag() -> (Flag, Subscribe) {
9 let waker = FlagWaker {
10 waker: UnsafeCell::new(None),
11 };
12
13 let flag = Arc::new(waker);
14 let sub = Arc::downgrade(&flag);
15 (Flag { inner: flag }, Subscribe { inner: sub })
16}
17
18#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
23#[derive(Debug, Clone)]
24pub struct Flag {
25 #[allow(unused)]
26 inner: Arc<FlagWaker>,
27}
28
29#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
31#[derive(Debug)]
32pub struct Subscribe {
33 inner: Weak<FlagWaker>,
34}
35
36impl Flag {
37 #[inline]
39 pub unsafe fn into_raw(self) -> *const () {
40 Arc::into_raw(self.inner).cast()
41 }
42
43 #[inline]
45 pub unsafe fn from_raw(ptr: *const ()) -> Self {
46 Self {
47 inner: Arc::from_raw(ptr.cast()),
48 }
49 }
50
51 #[inline]
52 pub fn has_subscriber(&self) -> bool {
53 return Arc::weak_count(&self.inner) > 0;
54 }
55
56 #[inline]
58 pub fn mark(self) {}
59
60 #[inline]
63 pub fn silent_drop(self) {
64 if let Ok(inner) = Arc::try_unwrap(self.inner) {
65 if let Some(inner) = inner.waker.into_inner() {
66 inner.silent_drop();
67 }
68 }
69 }
70}
71
72impl Subscribe {
73 #[inline]
75 pub fn is_marked(&self) -> bool {
76 return self.inner.strong_count() == 0;
77 }
78
79 #[inline]
81 pub fn wait(self) {
82 if let Some(queue) = self.inner.upgrade() {
83 let (lock, sub) = lock();
84 unsafe { *queue.waker.get() = Some(lock) }
85 drop(queue);
86 sub.wait();
87 }
88 }
89
90 #[docfg(feature = "std")]
95 #[inline]
96 pub fn wait_timeout(&self, dur: core::time::Duration) -> Result<(), crate::Timeout> {
97 if let Some(queue) = self.inner.upgrade() {
98 let (lock, sub) = lock();
99 unsafe { *queue.waker.get() = Some(lock) }
100 drop(queue);
101 sub.wait_timeout(dur);
102 return match self.is_marked() {
103 true => Ok(()),
104 false => Err(crate::Timeout),
105 };
106 }
107 return Ok(());
108 }
109}
110
111struct FlagWaker {
112 waker: UnsafeCell<Option<Lock>>,
113}
114
115impl Debug for FlagWaker {
116 #[inline]
117 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
118 f.debug_struct("FlagWaker").finish_non_exhaustive()
119 }
120}
121
122unsafe impl Send for FlagWaker where Lock: Send {}
123unsafe impl Sync for FlagWaker where Lock: Sync {}
124
125cfg_if::cfg_if! {
126 if #[cfg(feature = "futures")] {
127 use core::{future::Future, task::{Waker, Poll}};
128 use futures::future::FusedFuture;
129
130 #[cfg_attr(docsrs, doc(cfg(all(feature = "alloc", feature = "futures"))))]
132 #[inline]
133 pub fn async_flag () -> (AsyncFlag, AsyncSubscribe) {
134 let waker = AsyncFlagWaker {
135 waker: UnsafeCell::new(None)
136 };
137
138 let flag = Arc::new(waker);
139 let sub = Arc::downgrade(&flag);
140 (AsyncFlag { inner: flag }, AsyncSubscribe { inner: Some(sub) })
141 }
142
143 #[cfg_attr(docsrs, doc(cfg(all(feature = "alloc", feature = "futures"))))]
148 #[derive(Debug, Clone)]
149 pub struct AsyncFlag {
150 inner: Arc<AsyncFlagWaker>
151 }
152
153 impl AsyncFlag {
154 #[inline]
156 pub unsafe fn into_raw (self) -> *const Option<Waker> {
157 Arc::into_raw(self.inner).cast()
158 }
159
160 #[inline]
162 pub unsafe fn from_raw (ptr: *const Option<Waker>) -> Self {
163 Self { inner: Arc::from_raw(ptr.cast()) }
164 }
165
166 #[inline]
167 pub fn has_subscriber(&self) -> bool {
168 return Arc::weak_count(&self.inner) > 0
169 }
170
171 #[inline]
173 pub fn mark (self) {}
174
175 #[inline]
178 pub fn silent_drop (self) {
179 if let Ok(inner) = Arc::try_unwrap(self.inner) {
180 inner.silent_drop();
181 }
182 }
183 }
184
185 #[cfg_attr(docsrs, doc(cfg(all(feature = "alloc", feature = "futures"))))]
186 #[derive(Debug)]
188 pub struct AsyncSubscribe {
189 inner: Option<Weak<AsyncFlagWaker>>
190 }
191
192 impl AsyncSubscribe {
193 #[inline]
195 pub fn is_marked (&self) -> bool {
196 return !crate::is_some_and(self.inner.as_ref(), |x| x.strong_count() > 0)
197 }
198 }
199
200 impl Future for AsyncSubscribe {
201 type Output = ();
202
203 #[inline]
204 fn poll(mut self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> {
205 if let Some(ref queue) = self.inner {
206 if let Some(queue) = queue.upgrade() {
207 unsafe { *queue.waker.get() = Some(cx.waker().clone()) };
210 return Poll::Pending;
211 }
212
213 self.inner = None;
214 return Poll::Ready(())
215 }
216 return Poll::Ready(())
217 }
218 }
219
220 impl FusedFuture for AsyncSubscribe {
221 #[inline]
222 fn is_terminated(&self) -> bool {
223 self.inner.is_none()
224 }
225 }
226
227 struct AsyncFlagWaker {
228 waker: UnsafeCell<Option<Waker>>
229 }
230
231 impl AsyncFlagWaker {
232 #[inline]
233 pub fn silent_drop (self) {
234 let mut this = core::mem::ManuallyDrop::new(self);
235 unsafe { core::ptr::drop_in_place(&mut this.waker) }
236 }
237 }
238
239 impl Drop for AsyncFlagWaker {
240 #[inline]
241 fn drop(&mut self) {
242 if let Some(waker) = self.waker.get_mut().take() {
243 waker.wake()
244 }
245 }
246 }
247
248 impl Debug for AsyncFlagWaker {
249 #[inline]
250 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
251 f.debug_struct("AsyncFlagWaker").finish_non_exhaustive()
252 }
253 }
254
255 unsafe impl Send for AsyncFlagWaker where Option<Waker>: Send {}
256 unsafe impl Sync for AsyncFlagWaker where Option<Waker>: Sync {}
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 #[cfg(feature = "std")]
264 use std::thread;
265
266 #[test]
267 fn test_flag_creation() {
268 let (flag, subscribe) = flag();
269 assert!(!subscribe.is_marked());
270 drop(flag);
271 }
272
273 #[test]
274 fn test_flag_mark() {
275 let (flag, subscribe) = flag();
276 flag.mark();
277 assert!(subscribe.is_marked());
278 }
279
280 #[cfg(feature = "std")]
281 #[test]
282 fn test_flag_silent_drop() {
283 use core::time::Duration;
284 use std::time::Instant;
285
286 let (flag, subscribe) = flag();
287
288 let handle = thread::spawn(move || {
289 thread::sleep(std::time::Duration::from_millis(100));
290 flag.silent_drop();
291 });
292
293 let now = Instant::now();
294 let _ = subscribe.wait_timeout(std::time::Duration::from_millis(200));
295 let elapsed = now.elapsed();
296
297 handle.join().unwrap();
298 assert!(elapsed >= Duration::from_millis(200), "{elapsed:?}");
299 }
300
301 #[cfg(feature = "std")]
302 #[test]
303 fn test_subscribe_wait() {
304 let (flag, subscribe) = flag();
305
306 let handle = thread::spawn(move || {
307 thread::sleep(std::time::Duration::from_millis(100));
308 flag.mark();
309 });
310
311 subscribe.wait();
312 handle.join().unwrap();
313 }
314
315 #[cfg(feature = "std")]
316 #[test]
317 fn test_flag_stress() {
318 const THREADS: usize = 10;
319 const ITERATIONS: usize = 100;
320
321 for _ in 0..ITERATIONS {
322 let (flag, subscribe) = flag();
323 let mut handles = Vec::with_capacity(THREADS);
324
325 for _ in 0..THREADS {
326 let flag_clone = flag.clone();
327 let handle = std::thread::spawn(move || {
328 flag_clone.mark();
329 });
330 handles.push(handle);
331 }
332
333 drop(flag);
334 subscribe.wait();
335
336 for handle in handles {
337 handle.join().unwrap();
338 }
339 }
340 }
341
342 #[cfg(feature = "futures")]
343 mod async_tests {
344 use super::*;
345
346 #[test]
347 fn test_async_flag_creation() {
348 let (async_flag, async_subscribe) = async_flag();
349 assert!(!async_subscribe.is_marked());
350 drop(async_flag);
351 }
352
353 #[test]
354 fn test_async_flag_mark() {
355 let (async_flag, async_subscribe) = async_flag();
356 async_flag.mark();
357 assert!(async_subscribe.is_marked());
358 }
359
360 #[tokio::test]
361 async fn test_flag_silent_drop() {
362 use core::time::Duration;
363 use std::time::Instant;
364
365 let (flag, subscribe) = async_flag();
366
367 let handle = tokio::spawn(async move {
368 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
369 flag.silent_drop();
370 });
371
372 let elapsed = tokio::time::timeout(std::time::Duration::from_millis(200), async move {
373 let now = Instant::now();
374 subscribe.await;
375 now.elapsed()
376 })
377 .await;
378
379 handle.await.unwrap();
380 match elapsed {
381 Ok(t) if t < Duration::from_millis(200) => panic!("{t:?}"),
382 _ => {}
383 }
384 }
385
386 #[tokio::test]
387 async fn test_async_subscribe_wait() {
388 let (async_flag, async_subscribe) = async_flag();
389
390 let handle = tokio::spawn(async move {
391 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
392 async_flag.mark();
393 });
394
395 handle.await.unwrap();
397
398 async_subscribe.await;
400 }
401
402 #[tokio::test]
403 async fn test_async_flag_stress() {
404 const TASKS: usize = 10;
405 const ITERATIONS: usize = 10;
406
407 for _ in 0..ITERATIONS {
408 let (async_flag, async_subscribe) = async_flag();
409 let mut handles = Vec::with_capacity(TASKS);
410
411 for _ in 0..TASKS {
412 let async_flag_clone = async_flag.clone();
413 let handle = tokio::spawn(async move {
414 async_flag_clone.mark();
415 });
416 handles.push(handle);
417 }
418
419 drop(async_flag);
420 async_subscribe.await;
421
422 for handle in handles {
423 handle.await.unwrap();
424 }
425 }
426 }
427 }
428}