utils_atomics/flag/
mpsc.rs

1use crate::locks::{lock, Lock};
2use alloc::sync::{Arc, Weak};
3use core::{cell::UnsafeCell, fmt::Debug};
4use docfg::docfg;
5
6/// Creates a new pair of [`Flag`] and [`Subscribe`]
7#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
8pub fn flag() -> (Flag, Subscribe) {
9    let waker = FlagWaker {
10        waker: UnsafeCell::new(None),
11    };
12
13    let flag = Arc::new(waker);
14    let sub = Arc::downgrade(&flag);
15    (Flag { inner: flag }, Subscribe { inner: sub })
16}
17
18/// A flag type that completes when all it's references are marked or dropped.
19///
20/// This flag drops loudly by default (a.k.a will complete when dropped),
21/// but can be droped silently with [`silent_drop`](Flag::silent_drop)
22#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
23#[derive(Debug, Clone)]
24pub struct Flag {
25    #[allow(unused)]
26    inner: Arc<FlagWaker>,
27}
28
29/// Subscriber of a [`Flag`]
30#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
31#[derive(Debug)]
32pub struct Subscribe {
33    inner: Weak<FlagWaker>,
34}
35
36impl Flag {
37    /// See [`Arc::into_raw`]
38    #[inline]
39    pub unsafe fn into_raw(self) -> *const () {
40        Arc::into_raw(self.inner).cast()
41    }
42
43    /// See [`Arc::from_raw`]
44    #[inline]
45    pub unsafe fn from_raw(ptr: *const ()) -> Self {
46        Self {
47            inner: Arc::from_raw(ptr.cast()),
48        }
49    }
50
51    #[inline]
52    pub fn has_subscriber(&self) -> bool {
53        return Arc::weak_count(&self.inner) > 0;
54    }
55
56    /// Mark this flag reference as completed, consuming it
57    #[inline]
58    pub fn mark(self) {}
59
60    /// Drops the flag without **notifying** it as completed.
61    /// This method may leak memory.
62    #[inline]
63    pub fn silent_drop(self) {
64        if let Ok(inner) = Arc::try_unwrap(self.inner) {
65            if let Some(inner) = inner.waker.into_inner() {
66                inner.silent_drop();
67            }
68        }
69    }
70}
71
72impl Subscribe {
73    /// Returns `true` if the flag has been fully marked, and `false` otherwise
74    #[inline]
75    pub fn is_marked(&self) -> bool {
76        return self.inner.strong_count() == 0;
77    }
78
79    /// Blocks the current thread until the flag gets fully marked.
80    #[inline]
81    pub fn wait(self) {
82        if let Some(queue) = self.inner.upgrade() {
83            let (lock, sub) = lock();
84            unsafe { *queue.waker.get() = Some(lock) }
85            drop(queue);
86            sub.wait();
87        }
88    }
89
90    /// Blocks the current thread until the flag gets fully marked or the timeout expires.
91    ///
92    /// # Errors
93    /// This method returns an error if the wait didn't conclude before the specified duration
94    #[docfg(feature = "std")]
95    #[inline]
96    pub fn wait_timeout(&self, dur: core::time::Duration) -> Result<(), crate::Timeout> {
97        if let Some(queue) = self.inner.upgrade() {
98            let (lock, sub) = lock();
99            unsafe { *queue.waker.get() = Some(lock) }
100            drop(queue);
101            sub.wait_timeout(dur);
102            return match self.is_marked() {
103                true => Ok(()),
104                false => Err(crate::Timeout),
105            };
106        }
107        return Ok(());
108    }
109}
110
111struct FlagWaker {
112    waker: UnsafeCell<Option<Lock>>,
113}
114
115impl Debug for FlagWaker {
116    #[inline]
117    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
118        f.debug_struct("FlagWaker").finish_non_exhaustive()
119    }
120}
121
122unsafe impl Send for FlagWaker where Lock: Send {}
123unsafe impl Sync for FlagWaker where Lock: Sync {}
124
125cfg_if::cfg_if! {
126    if #[cfg(feature = "futures")] {
127        use core::{future::Future, task::{Waker, Poll}};
128        use futures::future::FusedFuture;
129
130        /// Creates a new pair of [`AsyncFlag`] and [`AsyncSubscribe`]
131        #[cfg_attr(docsrs, doc(cfg(all(feature = "alloc", feature = "futures"))))]
132        #[inline]
133        pub fn async_flag () -> (AsyncFlag, AsyncSubscribe) {
134            let waker = AsyncFlagWaker {
135                waker: UnsafeCell::new(None)
136            };
137
138            let flag = Arc::new(waker);
139            let sub = Arc::downgrade(&flag);
140            (AsyncFlag { inner: flag }, AsyncSubscribe { inner: Some(sub) })
141        }
142
143        /// Async flag that completes when all it's references are marked or droped.
144        ///
145        /// This flag drops loudly by default (a.k.a will complete when dropped),
146        /// but can be droped silently with [`silent_drop`](Flag::silent_drop)
147        #[cfg_attr(docsrs, doc(cfg(all(feature = "alloc", feature = "futures"))))]
148        #[derive(Debug, Clone)]
149        pub struct AsyncFlag {
150            inner: Arc<AsyncFlagWaker>
151        }
152
153        impl AsyncFlag {
154            /// See [`Arc::into_raw`]
155            #[inline]
156            pub unsafe fn into_raw (self) -> *const Option<Waker> {
157                Arc::into_raw(self.inner).cast()
158            }
159
160            /// See [`Arc::from_raw`]
161            #[inline]
162            pub unsafe fn from_raw (ptr: *const Option<Waker>) -> Self {
163                Self { inner: Arc::from_raw(ptr.cast()) }
164            }
165
166            #[inline]
167            pub fn has_subscriber(&self) -> bool {
168                return Arc::weak_count(&self.inner) > 0
169            }
170
171            /// Marks this flag as complete, consuming it
172            #[inline]
173            pub fn mark (self) {}
174
175            /// Drops the flag without marking it as completed.
176            /// This method may leak memory.
177            #[inline]
178            pub fn silent_drop (self) {
179                if let Ok(inner) = Arc::try_unwrap(self.inner) {
180                    inner.silent_drop();
181                }
182            }
183        }
184
185        #[cfg_attr(docsrs, doc(cfg(all(feature = "alloc", feature = "futures"))))]
186        /// Subscriber of an [`AsyncFlag`]
187        #[derive(Debug)]
188        pub struct AsyncSubscribe {
189            inner: Option<Weak<AsyncFlagWaker>>
190        }
191
192        impl AsyncSubscribe {
193            /// Returns `true` if the flag has been marked, and `false` otherwise
194            #[inline]
195            pub fn is_marked (&self) -> bool {
196                return !crate::is_some_and(self.inner.as_ref(), |x| x.strong_count() > 0)
197            }
198        }
199
200        impl Future for AsyncSubscribe {
201            type Output = ();
202
203            #[inline]
204            fn poll(mut self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> {
205                if let Some(ref queue) = self.inner {
206                    if let Some(queue) = queue.upgrade() {
207                        // SAFETY: If we upgraded, we are the only thread with access to the value,
208                        //         since the only other owner of the waker is it's destructor.
209                        unsafe { *queue.waker.get() = Some(cx.waker().clone()) };
210                        return Poll::Pending;
211                    }
212
213                    self.inner = None;
214                    return Poll::Ready(())
215                }
216                return Poll::Ready(())
217            }
218        }
219
220        impl FusedFuture for AsyncSubscribe {
221            #[inline]
222            fn is_terminated(&self) -> bool {
223                self.inner.is_none()
224            }
225        }
226
227        struct AsyncFlagWaker {
228            waker: UnsafeCell<Option<Waker>>
229        }
230
231        impl AsyncFlagWaker {
232            #[inline]
233            pub fn silent_drop (self) {
234                let mut this = core::mem::ManuallyDrop::new(self);
235                unsafe { core::ptr::drop_in_place(&mut this.waker) }
236            }
237        }
238
239        impl Drop for AsyncFlagWaker {
240            #[inline]
241            fn drop(&mut self) {
242                if let Some(waker) = self.waker.get_mut().take() {
243                    waker.wake()
244                }
245            }
246        }
247
248        impl Debug for AsyncFlagWaker {
249            #[inline]
250            fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
251                f.debug_struct("AsyncFlagWaker").finish_non_exhaustive()
252            }
253        }
254
255        unsafe impl Send for AsyncFlagWaker where Option<Waker>: Send {}
256        unsafe impl Sync for AsyncFlagWaker where Option<Waker>: Sync {}
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    #[cfg(feature = "std")]
264    use std::thread;
265
266    #[test]
267    fn test_flag_creation() {
268        let (flag, subscribe) = flag();
269        assert!(!subscribe.is_marked());
270        drop(flag);
271    }
272
273    #[test]
274    fn test_flag_mark() {
275        let (flag, subscribe) = flag();
276        flag.mark();
277        assert!(subscribe.is_marked());
278    }
279
280    #[cfg(feature = "std")]
281    #[test]
282    fn test_flag_silent_drop() {
283        use core::time::Duration;
284        use std::time::Instant;
285
286        let (flag, subscribe) = flag();
287
288        let handle = thread::spawn(move || {
289            thread::sleep(std::time::Duration::from_millis(100));
290            flag.silent_drop();
291        });
292
293        let now = Instant::now();
294        let _ = subscribe.wait_timeout(std::time::Duration::from_millis(200));
295        let elapsed = now.elapsed();
296
297        handle.join().unwrap();
298        assert!(elapsed >= Duration::from_millis(200), "{elapsed:?}");
299    }
300
301    #[cfg(feature = "std")]
302    #[test]
303    fn test_subscribe_wait() {
304        let (flag, subscribe) = flag();
305
306        let handle = thread::spawn(move || {
307            thread::sleep(std::time::Duration::from_millis(100));
308            flag.mark();
309        });
310
311        subscribe.wait();
312        handle.join().unwrap();
313    }
314
315    #[cfg(feature = "std")]
316    #[test]
317    fn test_flag_stress() {
318        const THREADS: usize = 10;
319        const ITERATIONS: usize = 100;
320
321        for _ in 0..ITERATIONS {
322            let (flag, subscribe) = flag();
323            let mut handles = Vec::with_capacity(THREADS);
324
325            for _ in 0..THREADS {
326                let flag_clone = flag.clone();
327                let handle = std::thread::spawn(move || {
328                    flag_clone.mark();
329                });
330                handles.push(handle);
331            }
332
333            drop(flag);
334            subscribe.wait();
335
336            for handle in handles {
337                handle.join().unwrap();
338            }
339        }
340    }
341
342    #[cfg(feature = "futures")]
343    mod async_tests {
344        use super::*;
345
346        #[test]
347        fn test_async_flag_creation() {
348            let (async_flag, async_subscribe) = async_flag();
349            assert!(!async_subscribe.is_marked());
350            drop(async_flag);
351        }
352
353        #[test]
354        fn test_async_flag_mark() {
355            let (async_flag, async_subscribe) = async_flag();
356            async_flag.mark();
357            assert!(async_subscribe.is_marked());
358        }
359
360        #[tokio::test]
361        async fn test_flag_silent_drop() {
362            use core::time::Duration;
363            use std::time::Instant;
364
365            let (flag, subscribe) = async_flag();
366
367            let handle = tokio::spawn(async move {
368                tokio::time::sleep(std::time::Duration::from_millis(100)).await;
369                flag.silent_drop();
370            });
371
372            let elapsed = tokio::time::timeout(std::time::Duration::from_millis(200), async move {
373                let now = Instant::now();
374                subscribe.await;
375                now.elapsed()
376            })
377            .await;
378
379            handle.await.unwrap();
380            match elapsed {
381                Ok(t) if t < Duration::from_millis(200) => panic!("{t:?}"),
382                _ => {}
383            }
384        }
385
386        #[tokio::test]
387        async fn test_async_subscribe_wait() {
388            let (async_flag, async_subscribe) = async_flag();
389
390            let handle = tokio::spawn(async move {
391                tokio::time::sleep(std::time::Duration::from_millis(100)).await;
392                async_flag.mark();
393            });
394
395            // Wait for the async_flag_clone to be marked
396            handle.await.unwrap();
397
398            // Wait for the async_subscribe to complete
399            async_subscribe.await;
400        }
401
402        #[tokio::test]
403        async fn test_async_flag_stress() {
404            const TASKS: usize = 10;
405            const ITERATIONS: usize = 10;
406
407            for _ in 0..ITERATIONS {
408                let (async_flag, async_subscribe) = async_flag();
409                let mut handles = Vec::with_capacity(TASKS);
410
411                for _ in 0..TASKS {
412                    let async_flag_clone = async_flag.clone();
413                    let handle = tokio::spawn(async move {
414                        async_flag_clone.mark();
415                    });
416                    handles.push(handle);
417                }
418
419                drop(async_flag);
420                async_subscribe.await;
421
422                for handle in handles {
423                    handle.await.unwrap();
424                }
425            }
426        }
427    }
428}