utils_atomics/flag/
mpmc.rs

1use crate::{
2    locks::{lock, Lock},
3    FillQueue,
4};
5use alloc::sync::{Arc, Weak};
6use core::mem::ManuallyDrop;
7use docfg::docfg;
8
9/// A flag type that will be completed when all its references have been dropped or marked.
10///
11/// This flag drops loudly by default (a.k.a will complete when dropped),
12/// but can be droped silently with [`silent_drop`](Flag::silent_drop)
13#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
14#[derive(Debug, Clone)]
15pub struct Flag {
16    inner: Arc<FlagQueue>,
17}
18
19/// Subscriber of a [`Flag`]
20#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
21#[derive(Debug, Clone)]
22pub struct Subscribe {
23    inner: Weak<FlagQueue>,
24}
25
26impl Flag {
27    /// See [`Arc::into_raw`]
28    #[inline]
29    pub unsafe fn into_raw(self) -> *const FillQueue<Lock> {
30        Arc::into_raw(self.inner).cast()
31    }
32
33    /// See [`Arc::from_raw`]
34    #[inline]
35    pub unsafe fn from_raw(ptr: *const FillQueue<Lock>) -> Self {
36        Self {
37            inner: Arc::from_raw(ptr.cast()),
38        }
39    }
40
41    #[inline]
42    pub fn has_subscriber(&self) -> bool {
43        return Arc::weak_count(&self.inner) > 0;
44    }
45
46    /// Mark this flag as completed, consuming it
47    #[inline]
48    pub fn mark(self) {}
49
50    /// Drops the flag without **notifying** it as completed.
51    /// This method may leak memory.
52    #[inline]
53    pub fn silent_drop(self) {
54        if let Ok(inner) = Arc::try_unwrap(self.inner) {
55            inner.silent_drop()
56        }
57    }
58}
59
60impl Subscribe {
61    #[inline]
62    pub fn is_marked(&self) -> bool {
63        return self.inner.strong_count() == 0;
64    }
65
66    /// Blocks the current thread until the flag gets marked.
67    #[inline]
68    pub fn wait(self) {
69        if let Some(queue) = self.inner.upgrade() {
70            let (waker, sub) = lock();
71            queue.0.push(waker);
72            drop(queue);
73            sub.wait()
74        }
75    }
76
77    /// Blocks the current thread until the flag gets marked or the timeout expires.
78    ///
79    /// # Errors
80    /// This method returns an error if the wait didn't conclude before the specified duration
81    #[docfg(feature = "std")]
82    #[inline]
83    pub fn wait_timeout(self, dur: core::time::Duration) -> Result<(), crate::Timeout> {
84        if let Some(queue) = self.inner.upgrade() {
85            let (waker, sub) = lock();
86            queue.0.push(waker);
87            drop(queue);
88            sub.wait_timeout(dur);
89            return match self.is_marked() {
90                true => Ok(()),
91                false => Err(crate::Timeout),
92            };
93        }
94        return Ok(());
95    }
96}
97
98/// Creates a new pair of [`Flag`] and [`Subscribe`].
99///
100/// The flag will be completed when all references to [`Flag`] have been dropped or marked.
101#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
102pub fn flag() -> (Flag, Subscribe) {
103    let flag = Arc::new(FlagQueue(FillQueue::new()));
104    let sub = Arc::downgrade(&flag);
105    (Flag { inner: flag }, Subscribe { inner: sub })
106}
107
108#[repr(transparent)]
109#[derive(Debug)]
110struct FlagQueue(pub FillQueue<Lock>);
111
112impl FlagQueue {
113    #[inline]
114    pub fn silent_drop(self) {
115        let mut this = ManuallyDrop::new(self);
116        this.0.chop_mut().for_each(Lock::silent_drop);
117        unsafe { core::ptr::drop_in_place(&mut this) };
118    }
119}
120
121impl Drop for FlagQueue {
122    #[inline]
123    fn drop(&mut self) {
124        self.0.chop_mut().for_each(Lock::wake);
125    }
126}
127
128cfg_if::cfg_if! {
129    if #[cfg(feature = "futures")] {
130        use core::{future::Future, task::{Waker, Poll}};
131        use futures::future::FusedFuture;
132
133        /// Creates a new pair of [`AsyncFlag`] and [`AsyncSubscribe`]
134        #[cfg_attr(docsrs, doc(cfg(all(feature = "alloc", feature = "futures"))))]
135        #[inline]
136        pub fn async_flag () -> (AsyncFlag, AsyncSubscribe) {
137            let flag = Arc::new(AsyncFlagQueue(FillQueue::new()));
138            let sub = Arc::downgrade(&flag);
139            return (AsyncFlag { inner: flag }, AsyncSubscribe { inner: Some(sub) })
140        }
141
142        /// Async flag that will be completed when all references to [`Flag`] have been dropped or marked.
143        #[cfg_attr(docsrs, doc(cfg(all(feature = "alloc", feature = "futures"))))]
144        #[derive(Debug, Clone)]
145        pub struct AsyncFlag {
146            inner: Arc<AsyncFlagQueue>
147        }
148
149        impl AsyncFlag {
150            /// See [`Arc::into_raw`]
151            #[inline]
152            pub unsafe fn into_raw (self) -> *const FillQueue<Waker> {
153                Arc::into_raw(self.inner).cast()
154            }
155
156            /// See [`Arc::from_raw`]
157            #[inline]
158            pub unsafe fn from_raw (ptr: *const FillQueue<Waker>) -> Self {
159                Self { inner: Arc::from_raw(ptr.cast()) }
160            }
161
162            #[inline]
163            pub fn has_subscriber(&self) -> bool {
164                return Arc::weak_count(&self.inner) > 0
165            }
166
167            /// Marks this flag as complete, consuming it
168            #[inline]
169            pub fn mark (self) {}
170
171            /// Creates a new subscriber to this flag.
172            #[inline]
173            pub fn subscribe (&self) -> AsyncSubscribe {
174                AsyncSubscribe {
175                    inner: Some(Arc::downgrade(&self.inner))
176                }
177            }
178
179            /// Drops the flag without **notifying** it as completed.
180            /// This method may leak memory.
181            #[inline]
182            pub fn silent_drop (self) {
183                if let Ok(inner) = Arc::try_unwrap(self.inner) {
184                    inner.silent_drop()
185                }
186            }
187        }
188
189        #[cfg_attr(docsrs, doc(cfg(all(feature = "alloc", feature = "futures"))))]
190        /// Subscriber of an [`AsyncFlag`]
191        #[derive(Debug, Clone)]
192        pub struct AsyncSubscribe {
193            inner: Option<Weak<AsyncFlagQueue>>
194        }
195
196        impl AsyncSubscribe {
197            /// Creates a new subscriber that has already completed
198            #[inline]
199            pub fn marked () -> AsyncSubscribe {
200                return Self { inner: None }
201            }
202
203            /// Returns `true` if the flag has been marked, and `false` otherwise
204            #[inline]
205            pub fn is_marked (&self) -> bool {
206                return !crate::is_some_and(self.inner.as_ref(), |x| x.strong_count() > 0)
207            }
208        }
209
210        impl Future for AsyncSubscribe {
211            type Output = ();
212
213            #[inline]
214            fn poll(mut self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> {
215                if let Some(ref queue) = self.inner {
216                    if let Some(queue) = queue.upgrade() {
217                        queue.0.push(cx.waker().clone());
218                        return Poll::Pending;
219                    }
220
221                    self.inner = None;
222                    return Poll::Ready(())
223                }
224
225                return Poll::Ready(())
226            }
227        }
228
229        impl FusedFuture for AsyncSubscribe {
230            #[inline]
231            fn is_terminated(&self) -> bool {
232                self.inner.is_none()
233            }
234        }
235
236        #[derive(Debug)]
237        struct AsyncFlagQueue (pub FillQueue<Waker>);
238
239        impl AsyncFlagQueue {
240            #[inline]
241            pub fn silent_drop (self) {
242                let mut this = ManuallyDrop::new(self);
243                let _: crate::prelude::ChopIter<Waker> = this.0.chop_mut();
244                unsafe { core::ptr::drop_in_place(&mut this.0) }
245            }
246        }
247
248        impl Drop for AsyncFlagQueue {
249            #[inline]
250            fn drop(&mut self) {
251                self.0.chop_mut().for_each(Waker::wake);
252            }
253        }
254    }
255}
256
257#[cfg(all(feature = "std", test))]
258mod tests {
259    use super::flag;
260    use super::Flag;
261    use core::time::Duration;
262    use std::thread;
263
264    #[test]
265    fn test_normal_conditions() {
266        let (f, _) = flag();
267        // Test marking the flag.
268        f.mark();
269
270        // Test waiting for the flag.
271        let (f, s) = flag();
272        let f = unsafe { Flag::from_raw(Flag::into_raw(f)) };
273
274        thread::spawn(move || {
275            thread::sleep(Duration::from_millis(100));
276            f.mark();
277        });
278
279        s.wait();
280    }
281
282    #[test]
283    fn test_silent_drop() {
284        let (f, s) = flag();
285
286        let handle = thread::spawn(move || s.wait_timeout(Duration::from_millis(100)));
287
288        std::thread::sleep(Duration::from_millis(200));
289        f.silent_drop();
290
291        let time = handle.join().unwrap();
292        assert!(time.is_err());
293    }
294
295    #[test]
296    fn test_stressed_conditions() {
297        let mut handles = Vec::new();
298        let (f, s) = flag();
299
300        for _ in 0..10 {
301            let cloned_s = s.clone();
302            let handle = thread::spawn(move || {
303                for _ in 0..10 {
304                    let cloned_s = cloned_s.clone();
305                    cloned_s.wait();
306                }
307            });
308            handles.push(handle);
309        }
310
311        thread::sleep(Duration::from_millis(100));
312
313        for _ in 0..9 {
314            f.clone().mark();
315        }
316        f.mark();
317
318        for handle in handles {
319            handle.join().unwrap();
320        }
321    }
322}
323
324#[cfg(all(feature = "futures", test))]
325mod async_tests {
326    use super::{async_flag, AsyncFlag};
327    use core::time::Duration;
328    use std::time::Instant;
329
330    #[tokio::test]
331    async fn test_async_normal_conditions() {
332        let (f, s) = async_flag();
333        assert_eq!(s.is_marked(), false);
334
335        // Test marking the flag.
336        f.mark();
337        assert_eq!(s.is_marked(), true);
338
339        // Test waiting for the flag.
340        let (f, mut s) = async_flag();
341        let f = unsafe { AsyncFlag::from_raw(AsyncFlag::into_raw(f)) };
342
343        tokio::spawn(async move {
344            tokio::time::sleep(Duration::from_millis(100)).await;
345            f.mark();
346        });
347
348        (&mut s).await;
349        assert_eq!(s.is_marked(), true);
350    }
351
352    #[tokio::test]
353    async fn test_silent_drop() {
354        let (f, s) = async_flag();
355
356        let handle = tokio::spawn(tokio::time::timeout(
357            Duration::from_millis(200),
358            async move {
359                let now = Instant::now();
360                s.await;
361                now.elapsed()
362            },
363        ));
364        tokio::time::sleep(Duration::from_millis(100)).await;
365        f.silent_drop();
366
367        match handle.await.unwrap() {
368            Ok(t) if t < Duration::from_millis(200) => panic!("{t:?}"),
369            _ => {}
370        }
371    }
372
373    #[tokio::test]
374    async fn test_async_stressed_conditions() {
375        let (f, s) = async_flag();
376        let mut handles = Vec::new();
377
378        for _ in 0..100 {
379            let mut cloned_s = s.clone();
380            let handle = tokio::spawn(async move {
381                (&mut cloned_s).await;
382                assert_eq!(cloned_s.is_marked(), true);
383            });
384            handles.push(handle);
385        }
386
387        tokio::time::sleep(Duration::from_millis(100)).await;
388        f.mark();
389
390        for handle in handles {
391            handle.await.unwrap();
392        }
393    }
394}