use crate::locks::{lock, Lock};
use alloc::sync::{Arc, Weak};
use core::{cell::UnsafeCell, fmt::Debug};
use docfg::docfg;
#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
pub fn flag() -> (Flag, Subscribe) {
let waker = FlagWaker {
waker: UnsafeCell::new(None),
};
let flag = Arc::new(waker);
let sub = Arc::downgrade(&flag);
(Flag { inner: flag }, Subscribe { inner: sub })
}
#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
#[derive(Debug, Clone)]
pub struct Flag {
#[allow(unused)]
inner: Arc<FlagWaker>,
}
#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
#[derive(Debug)]
pub struct Subscribe {
inner: Weak<FlagWaker>,
}
impl Flag {
#[inline]
pub unsafe fn into_raw(self) -> *const () {
Arc::into_raw(self.inner).cast()
}
#[inline]
pub unsafe fn from_raw(ptr: *const ()) -> Self {
Self {
inner: Arc::from_raw(ptr.cast()),
}
}
#[inline]
pub fn has_subscriber(&self) -> bool {
return Arc::weak_count(&self.inner) > 0;
}
#[inline]
pub fn mark(self) {}
#[inline]
pub fn silent_drop(self) {
if let Ok(inner) = Arc::try_unwrap(self.inner) {
if let Some(inner) = inner.waker.into_inner() {
inner.silent_drop();
}
}
}
}
impl Subscribe {
#[inline]
pub fn is_marked(&self) -> bool {
return self.inner.strong_count() == 0;
}
#[inline]
pub fn wait(self) {
if let Some(queue) = self.inner.upgrade() {
let (lock, sub) = lock();
unsafe { *queue.waker.get() = Some(lock) }
drop(queue);
sub.wait();
}
}
#[docfg(feature = "std")]
#[inline]
pub fn wait_timeout(&self, dur: core::time::Duration) -> Result<(), crate::Timeout> {
if let Some(queue) = self.inner.upgrade() {
let (lock, sub) = lock();
unsafe { *queue.waker.get() = Some(lock) }
drop(queue);
sub.wait_timeout(dur);
return match self.is_marked() {
true => Ok(()),
false => Err(crate::Timeout),
};
}
return Ok(());
}
}
struct FlagWaker {
waker: UnsafeCell<Option<Lock>>,
}
impl Debug for FlagWaker {
#[inline]
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("FlagWaker").finish_non_exhaustive()
}
}
unsafe impl Send for FlagWaker where Lock: Send {}
unsafe impl Sync for FlagWaker where Lock: Sync {}
cfg_if::cfg_if! {
if #[cfg(feature = "futures")] {
use core::{future::Future, task::{Waker, Poll}};
use futures::future::FusedFuture;
#[cfg_attr(docsrs, doc(cfg(all(feature = "alloc", feature = "futures"))))]
#[inline]
pub fn async_flag () -> (AsyncFlag, AsyncSubscribe) {
let waker = AsyncFlagWaker {
waker: UnsafeCell::new(None)
};
let flag = Arc::new(waker);
let sub = Arc::downgrade(&flag);
(AsyncFlag { inner: flag }, AsyncSubscribe { inner: Some(sub) })
}
#[cfg_attr(docsrs, doc(cfg(all(feature = "alloc", feature = "futures"))))]
#[derive(Debug, Clone)]
pub struct AsyncFlag {
inner: Arc<AsyncFlagWaker>
}
impl AsyncFlag {
#[inline]
pub unsafe fn into_raw (self) -> *const Option<Waker> {
Arc::into_raw(self.inner).cast()
}
#[inline]
pub unsafe fn from_raw (ptr: *const Option<Waker>) -> Self {
Self { inner: Arc::from_raw(ptr.cast()) }
}
#[inline]
pub fn has_subscriber(&self) -> bool {
return Arc::weak_count(&self.inner) > 0
}
#[inline]
pub fn mark (self) {}
#[inline]
pub fn silent_drop (self) {
if let Ok(inner) = Arc::try_unwrap(self.inner) {
inner.silent_drop();
}
}
}
#[cfg_attr(docsrs, doc(cfg(all(feature = "alloc", feature = "futures"))))]
#[derive(Debug)]
pub struct AsyncSubscribe {
inner: Option<Weak<AsyncFlagWaker>>
}
impl AsyncSubscribe {
#[inline]
pub fn is_marked (&self) -> bool {
return !crate::is_some_and(self.inner.as_ref(), |x| x.strong_count() > 0)
}
}
impl Future for AsyncSubscribe {
type Output = ();
#[inline]
fn poll(mut self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> {
if let Some(ref queue) = self.inner {
if let Some(queue) = queue.upgrade() {
unsafe { *queue.waker.get() = Some(cx.waker().clone()) };
return Poll::Pending;
}
self.inner = None;
return Poll::Ready(())
}
return Poll::Ready(())
}
}
impl FusedFuture for AsyncSubscribe {
#[inline]
fn is_terminated(&self) -> bool {
self.inner.is_none()
}
}
struct AsyncFlagWaker {
waker: UnsafeCell<Option<Waker>>
}
impl AsyncFlagWaker {
#[inline]
pub fn silent_drop (self) {
let mut this = core::mem::ManuallyDrop::new(self);
unsafe { core::ptr::drop_in_place(&mut this.waker) }
}
}
impl Drop for AsyncFlagWaker {
#[inline]
fn drop(&mut self) {
if let Some(waker) = self.waker.get_mut().take() {
waker.wake()
}
}
}
impl Debug for AsyncFlagWaker {
#[inline]
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("AsyncFlagWaker").finish_non_exhaustive()
}
}
unsafe impl Send for AsyncFlagWaker where Option<Waker>: Send {}
unsafe impl Sync for AsyncFlagWaker where Option<Waker>: Sync {}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "std")]
use std::thread;
#[test]
fn test_flag_creation() {
let (flag, subscribe) = flag();
assert!(!subscribe.is_marked());
drop(flag);
}
#[test]
fn test_flag_mark() {
let (flag, subscribe) = flag();
flag.mark();
assert!(subscribe.is_marked());
}
#[cfg(feature = "std")]
#[test]
fn test_flag_silent_drop() {
use core::time::Duration;
use std::time::Instant;
let (flag, subscribe) = flag();
let handle = thread::spawn(move || {
thread::sleep(std::time::Duration::from_millis(100));
flag.silent_drop();
});
let now = Instant::now();
let _ = subscribe.wait_timeout(std::time::Duration::from_millis(200));
let elapsed = now.elapsed();
handle.join().unwrap();
assert!(elapsed >= Duration::from_millis(200), "{elapsed:?}");
}
#[cfg(feature = "std")]
#[test]
fn test_subscribe_wait() {
let (flag, subscribe) = flag();
let handle = thread::spawn(move || {
thread::sleep(std::time::Duration::from_millis(100));
flag.mark();
});
subscribe.wait();
handle.join().unwrap();
}
#[cfg(feature = "std")]
#[test]
fn test_flag_stress() {
const THREADS: usize = 10;
const ITERATIONS: usize = 100;
for _ in 0..ITERATIONS {
let (flag, subscribe) = flag();
let mut handles = Vec::with_capacity(THREADS);
for _ in 0..THREADS {
let flag_clone = flag.clone();
let handle = std::thread::spawn(move || {
flag_clone.mark();
});
handles.push(handle);
}
drop(flag);
subscribe.wait();
for handle in handles {
handle.join().unwrap();
}
}
}
#[cfg(feature = "futures")]
mod async_tests {
use super::*;
#[test]
fn test_async_flag_creation() {
let (async_flag, async_subscribe) = async_flag();
assert!(!async_subscribe.is_marked());
drop(async_flag);
}
#[test]
fn test_async_flag_mark() {
let (async_flag, async_subscribe) = async_flag();
async_flag.mark();
assert!(async_subscribe.is_marked());
}
#[tokio::test]
async fn test_flag_silent_drop() {
use core::time::Duration;
use std::time::Instant;
let (flag, subscribe) = async_flag();
let handle = tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
flag.silent_drop();
});
let elapsed = tokio::time::timeout(std::time::Duration::from_millis(200), async move {
let now = Instant::now();
subscribe.await;
now.elapsed()
})
.await;
handle.await.unwrap();
match elapsed {
Ok(t) if t < Duration::from_millis(200) => panic!("{t:?}"),
_ => {}
}
}
#[tokio::test]
async fn test_async_subscribe_wait() {
let (async_flag, async_subscribe) = async_flag();
let handle = tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
async_flag.mark();
});
handle.await.unwrap();
async_subscribe.await;
}
#[tokio::test]
async fn test_async_flag_stress() {
const TASKS: usize = 10;
const ITERATIONS: usize = 10;
for _ in 0..ITERATIONS {
let (async_flag, async_subscribe) = async_flag();
let mut handles = Vec::with_capacity(TASKS);
for _ in 0..TASKS {
let async_flag_clone = async_flag.clone();
let handle = tokio::spawn(async move {
async_flag_clone.mark();
});
handles.push(handle);
}
drop(async_flag);
async_subscribe.await;
for handle in handles {
handle.await.unwrap();
}
}
}
}
}