Skip to main content

typhoon/utils/
sync.rs

1#[cfg(all(test, feature = "tokio"))]
2#[path = "../../tests/utils/sync.rs"]
3mod tests;
4
5use std::future::Future;
6use std::sync::Arc;
7use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
8use std::time::Duration;
9
10use cfg_if::cfg_if;
11#[cfg(all(feature = "server", feature = "tokio"))]
12use crossbeam::queue::ArrayQueue;
13#[cfg(feature = "server")]
14use log::debug;
15cfg_if! {
16    if #[cfg(feature = "client")] {
17        use std::pin::Pin;
18        use futures::stream::{FuturesUnordered, StreamExt};
19    }
20}
21cfg_if! {
22    if #[cfg(feature = "tokio")] {
23        use crossbeam::queue::SegQueue;
24        use tokio::sync::Notify;
25        use tokio::time::sleep as tokio_sleep;
26        use tokio::runtime::{Handle, RuntimeFlavor};
27        pub use tokio::sync::{RwLock, Mutex};
28    } else if #[cfg(feature = "async-std")] {
29        pub use async_lock::{RwLock, Mutex};
30        use async_channel::{Receiver, Sender, bounded, unbounded};
31        use async_io::Timer;
32    }
33}
34
35/// Reject runtimes that cannot host the protocol's multi-flow, multi-task design.
36#[cfg(feature = "tokio")]
37pub fn assert_runtime() -> Result<(), &'static str> {
38    let handle = Handle::try_current().map_err(|_| "no tokio runtime in scope")?;
39    if matches!(handle.runtime_flavor(), RuntimeFlavor::MultiThread) {
40        Ok(())
41    } else {
42        Err("TYPHOON requires a multi-threaded tokio runtime (use `#[tokio::main]` or `Builder::new_multi_thread()`)")
43    }
44}
45
46/// Reject runtimes that cannot host the protocol; the async-std backend is always accepted.
47#[cfg(feature = "async-std")]
48pub fn assert_runtime() -> Result<(), &'static str> {
49    Ok(())
50}
51
52/// Runtime-agnostic async task executor trait.
53pub trait AsyncExecutor: Clone + Send + Sync {
54    /// Create a new executor instance.
55    fn new() -> Self;
56    /// Spawn a fire-and-forget future onto the runtime.
57    fn spawn<F: Future<Output = ()> + Send + 'static>(&self, future: F);
58    /// Drive a future to completion on the current thread.
59    fn block_on<F: Future<Output = ()>>(&self, future: F);
60}
61
62// ── Watch channel (latest-value-wins, destructive read) ──────────────────────
63//
64// Semantics: the sender stores the latest value and wakes all current
65// receivers; each receiver *takes* (consumes) the value on recv().
66// No off-the-shelf watch channel provides destructive single-consumer reads
67// across both runtimes, so this stays hand-rolled.
68
69struct WatchState<T> {
70    value: std::sync::Mutex<Option<T>>,
71    closed: AtomicBool,
72    receiver_count: AtomicUsize,
73    #[cfg(feature = "tokio")]
74    notify: Notify,
75    #[cfg(feature = "async-std")]
76    notifiers: std::sync::Mutex<Vec<Sender<()>>>,
77}
78
79/// Watch channel sender: stores the latest value, wakes all receivers on change.
80pub struct WatchSender<T: Send> {
81    state: Arc<WatchState<T>>,
82}
83
84/// Watch channel receiver: waits for the next change and takes the latest value.
85pub struct WatchReceiver<T> {
86    state: Arc<WatchState<T>>,
87    #[cfg(feature = "async-std")]
88    notify: Receiver<()>,
89}
90
91impl<T: Send> WatchSender<T> {
92    /// Send a new value, overwriting the previous one.
93    /// Returns false if all receivers have been dropped.
94    pub fn send(&self, value: T) -> bool {
95        *self.state.value.lock().unwrap() = Some(value);
96        #[cfg(feature = "tokio")]
97        self.state.notify.notify_waiters();
98        #[cfg(feature = "async-std")]
99        {
100            let notifiers = self.state.notifiers.lock().unwrap();
101            for tx in notifiers.iter() {
102                let _ = tx.try_send(());
103            }
104        }
105        self.state.receiver_count.load(Ordering::Relaxed) > 0
106    }
107
108    /// Create a new receiver watching the same sender.
109    #[cfg(feature = "client")]
110    pub fn subscribe(&self) -> WatchReceiver<T> {
111        self.state.receiver_count.fetch_add(1, Ordering::Relaxed);
112        #[cfg(feature = "tokio")]
113        return WatchReceiver {
114            state: Arc::clone(&self.state),
115        };
116        #[cfg(feature = "async-std")]
117        {
118            let (tx, rx) = bounded(1);
119            self.state.notifiers.lock().unwrap().push(tx);
120            WatchReceiver {
121                state: Arc::clone(&self.state),
122                notify: rx,
123            }
124        }
125    }
126}
127
128impl<T: Send> Drop for WatchSender<T> {
129    fn drop(&mut self) {
130        self.state.closed.store(true, Ordering::Release);
131        #[cfg(feature = "tokio")]
132        self.state.notify.notify_waiters();
133        #[cfg(feature = "async-std")]
134        {
135            let mut notifiers = self.state.notifiers.lock().unwrap();
136            for tx in notifiers.drain(..) {
137                let _ = tx.try_send(());
138            }
139        }
140    }
141}
142
143impl<T> Drop for WatchReceiver<T> {
144    fn drop(&mut self) {
145        self.state.receiver_count.fetch_sub(1, Ordering::Release);
146    }
147}
148
149impl<T: Send> WatchReceiver<T> {
150    /// Wait for the next value change and take it (destructive read),
151    /// or `None` if the sender is dropped.
152    pub async fn recv(&mut self) -> Option<T> {
153        loop {
154            #[cfg(feature = "tokio")]
155            let mut notified = std::pin::pin!(self.state.notify.notified());
156            #[cfg(feature = "tokio")]
157            notified.as_mut().enable();
158
159            {
160                let mut guard = self.state.value.lock().unwrap();
161                if let Some(v) = guard.take() {
162                    return Some(v);
163                }
164                if self.state.closed.load(Ordering::Acquire) {
165                    return None;
166                }
167            }
168
169            #[cfg(feature = "tokio")]
170            notified.await;
171            #[cfg(feature = "async-std")]
172            {
173                self.notify.recv().await.ok();
174            }
175        }
176    }
177}
178
179/// Create a watch channel.
180#[cfg(feature = "tokio")]
181pub fn create_watch<T: Send>() -> (WatchSender<T>, WatchReceiver<T>) {
182    let state = Arc::new(WatchState {
183        value: std::sync::Mutex::new(None),
184        closed: AtomicBool::new(false),
185        receiver_count: AtomicUsize::new(1),
186        notify: Notify::new(),
187    });
188    (
189        WatchSender {
190            state: Arc::clone(&state),
191        },
192        WatchReceiver {
193            state,
194        },
195    )
196}
197
198/// Create a watch channel.
199#[cfg(feature = "async-std")]
200pub fn create_watch<T: Send>() -> (WatchSender<T>, WatchReceiver<T>) {
201    let (tx, rx) = bounded(1);
202    let state = Arc::new(WatchState {
203        value: std::sync::Mutex::new(None),
204        closed: AtomicBool::new(false),
205        receiver_count: AtomicUsize::new(1),
206        notifiers: std::sync::Mutex::new(vec![tx]),
207    });
208    (
209        WatchSender {
210            state: Arc::clone(&state),
211        },
212        WatchReceiver {
213            state,
214            notify: rx,
215        },
216    )
217}
218
219// ── Notifying queues ──────────────────────────────────────────────────────────
220//
221// Crossbeam SegQueue/ArrayQueue for O(1) lock-free, allocation-free storage;
222// runtime-native Notify/channel for async wakeup. The obvious alternative
223// (tokio mpsc) allocates a Box<Node> per push — avoid.
224
225cfg_if! {
226    if #[cfg(feature = "tokio")] {
227
228        // ── Shared state ─────────────────────────────────────────────────────
229
230        struct NotifyQueueState<T> {
231            queue: SegQueue<T>,
232            notify: Notify,
233            closed: AtomicBool,
234        }
235
236        // ── Unbounded ────────────────────────────────────────────────────────
237
238        /// Push side of an unbounded notifying queue.
239        pub struct NotifyQueueSender<T: Send>(Arc<NotifyQueueState<T>>);
240
241        /// Pop side of an unbounded notifying queue.
242        pub struct NotifyQueueReceiver<T: Send>(Arc<NotifyQueueState<T>>);
243
244        impl<T: Send> NotifyQueueSender<T> {
245            /// Push an item; never blocks.
246            pub fn push(&self, item: T) {
247                self.0.queue.push(item);
248                self.0.notify.notify_one();
249            }
250        }
251
252        impl<T: Send> Drop for NotifyQueueSender<T> {
253            fn drop(&mut self) {
254                self.0.closed.store(true, Ordering::Release);
255                self.0.notify.notify_waiters();
256            }
257        }
258
259        impl<T: Send> NotifyQueueReceiver<T> {
260            /// Pop the next item, waiting asynchronously until one is pushed.
261            /// Returns `None` if the sender has been dropped and the queue is empty.
262            pub async fn recv(&mut self) -> Option<T> {
263                loop {
264                    if let Some(item) = self.0.queue.pop() {
265                        return Some(item);
266                    }
267                    if self.0.closed.load(Ordering::Acquire) && self.0.queue.is_empty() {
268                        return None;
269                    }
270                    // Pre-register the waker before the second pop so we cannot
271                    // miss a push that arrives between the two checks.
272                    let mut notified = std::pin::pin!(self.0.notify.notified());
273                    notified.as_mut().enable();
274                    if let Some(item) = self.0.queue.pop() {
275                        return Some(item);
276                    }
277                    notified.await;
278                }
279            }
280        }
281
282        /// Create an unbounded notifying queue.
283        pub fn create_notify_queue<T: Send>() -> (NotifyQueueSender<T>, NotifyQueueReceiver<T>) {
284            let state = Arc::new(NotifyQueueState {
285                queue: SegQueue::new(),
286                notify: Notify::new(),
287                closed: AtomicBool::new(false),
288            });
289            (NotifyQueueSender(Arc::clone(&state)), NotifyQueueReceiver(state))
290        }
291
292        // ── Bounded ──────────────────────────────────────────────────────────
293
294        #[cfg(feature = "server")]
295        struct BoundedNotifyQueueState<T> {
296            queue: ArrayQueue<T>,
297            notify: Notify,
298            closed: AtomicBool,
299        }
300
301        /// Push side of a bounded notifying queue.
302        #[cfg(feature = "server")]
303        pub struct BoundedNotifyQueueSender<T: Send>(Arc<BoundedNotifyQueueState<T>>);
304
305        /// Pop side of a bounded notifying queue.
306        #[cfg(feature = "server")]
307        pub struct BoundedNotifyQueueReceiver<T: Send>(Arc<BoundedNotifyQueueState<T>>);
308
309        #[cfg(feature = "server")]
310        impl<T: Send> BoundedNotifyQueueSender<T> {
311            /// Push an item; silently drops it (with a debug log) if the queue is full.
312            pub fn push(&self, item: T) {
313                if self.0.queue.push(item).is_err() {
314                    debug!("BoundedNotifyQueue: queue full, dropping item");
315                    return;
316                }
317                self.0.notify.notify_one();
318            }
319        }
320
321        #[cfg(feature = "server")]
322        impl<T: Send> Drop for BoundedNotifyQueueSender<T> {
323            fn drop(&mut self) {
324                self.0.closed.store(true, Ordering::Release);
325                self.0.notify.notify_waiters();
326            }
327        }
328
329        #[cfg(feature = "server")]
330        impl<T: Send> BoundedNotifyQueueReceiver<T> {
331            /// Pop the next item, waiting asynchronously until one is pushed.
332            /// Returns `None` if the sender has been dropped and the queue is empty.
333            pub async fn recv(&mut self) -> Option<T> {
334                loop {
335                    if let Some(item) = self.0.queue.pop() {
336                        return Some(item);
337                    }
338                    if self.0.closed.load(Ordering::Acquire) && self.0.queue.is_empty() {
339                        return None;
340                    }
341                    let mut notified = std::pin::pin!(self.0.notify.notified());
342                    notified.as_mut().enable();
343                    if let Some(item) = self.0.queue.pop() {
344                        return Some(item);
345                    }
346                    notified.await;
347                }
348            }
349        }
350
351        /// Create a bounded notifying queue with the given capacity.
352        #[cfg(feature = "server")]
353        pub fn create_bounded_notify_queue<T: Send>(cap: usize) -> (BoundedNotifyQueueSender<T>, BoundedNotifyQueueReceiver<T>) {
354            let state = Arc::new(BoundedNotifyQueueState {
355                queue: ArrayQueue::new(cap),
356                notify: Notify::new(),
357                closed: AtomicBool::new(false),
358            });
359            (BoundedNotifyQueueSender(Arc::clone(&state)), BoundedNotifyQueueReceiver(state))
360        }
361
362    } else if #[cfg(feature = "async-std")] {
363
364        // Under async-std there is no standalone Notify equivalent, so we use
365        // async_channel which is already a dependency.
366
367        /// Push side of an unbounded notifying queue.
368        pub struct NotifyQueueSender<T: Send>(Sender<T>);
369
370        /// Pop side of an unbounded notifying queue.
371        pub struct NotifyQueueReceiver<T: Send>(Receiver<T>);
372
373        impl<T: Send> NotifyQueueSender<T> {
374            pub fn push(&self, item: T) {
375                let _ = self.0.try_send(item);
376            }
377        }
378
379        impl<T: Send> NotifyQueueReceiver<T> {
380            pub async fn recv(&mut self) -> Option<T> {
381                self.0.recv().await.ok()
382            }
383        }
384
385        /// Create an unbounded notifying queue.
386        pub fn create_notify_queue<T: Send>() -> (NotifyQueueSender<T>, NotifyQueueReceiver<T>) {
387            let (tx, rx) = unbounded();
388            (NotifyQueueSender(tx), NotifyQueueReceiver(rx))
389        }
390
391        /// Push side of a bounded notifying queue.
392        #[cfg(feature = "server")]
393        pub struct BoundedNotifyQueueSender<T: Send>(Sender<T>);
394
395        /// Pop side of a bounded notifying queue.
396        #[cfg(feature = "server")]
397        pub struct BoundedNotifyQueueReceiver<T: Send>(Receiver<T>);
398
399        #[cfg(feature = "server")]
400        impl<T: Send> BoundedNotifyQueueSender<T> {
401            pub fn push(&self, item: T) {
402                if self.0.try_send(item).is_err() {
403                    debug!("BoundedNotifyQueue: queue full, dropping item");
404                }
405            }
406        }
407
408        #[cfg(feature = "server")]
409        impl<T: Send> BoundedNotifyQueueReceiver<T> {
410            pub async fn recv(&mut self) -> Option<T> {
411                self.0.recv().await.ok()
412            }
413        }
414
415        /// Create a bounded notifying queue with the given capacity.
416        #[cfg(feature = "server")]
417        pub fn create_bounded_notify_queue<T: Send>(cap: usize) -> (BoundedNotifyQueueSender<T>, BoundedNotifyQueueReceiver<T>) {
418            let (tx, rx) = bounded(cap);
419            (BoundedNotifyQueueSender(tx), BoundedNotifyQueueReceiver(rx))
420        }
421    }
422}
423
424// ── Future pool ───────────────────────────────────────────────────────────────
425
426/// Pool of concurrent futures that resolves them as they complete.
427#[cfg(feature = "client")]
428pub struct FuturePool<'f, T> {
429    tasks: FuturesUnordered<Pin<Box<dyn Future<Output = T> + Send + 'f>>>,
430}
431
432#[cfg(feature = "client")]
433impl<'f, T> FuturePool<'f, T> {
434    pub fn new() -> Self {
435        Self {
436            tasks: FuturesUnordered::new(),
437        }
438    }
439
440    pub fn add<F: Future<Output = T> + Send + 'f>(&mut self, future: F) {
441        self.tasks.push(Box::pin(future));
442    }
443
444    pub async fn next(&mut self) -> Option<T> {
445        self.tasks.next().await
446    }
447}
448
449// ── Sleep ─────────────────────────────────────────────────────────────────────
450
451#[cfg(feature = "tokio")]
452pub async fn sleep(duration: Duration) {
453    tokio_sleep(duration).await;
454}
455
456#[cfg(feature = "async-std")]
457pub async fn sleep(duration: Duration) {
458    Timer::after(duration).await;
459}