1use std::{
4 pin::Pin,
5 sync::{
6 atomic::{AtomicU32, Ordering},
7 Arc,
8 },
9 task::{Context, Poll},
10};
11
12use futures::{channel::oneshot, stream::Fuse, Future, FutureExt, Stream, StreamExt};
13use parking_lot::Mutex;
14use tokio::sync::Notify;
15
16#[derive(Clone)]
17pub struct StopBroadcaster {
18 txs: Arc<Mutex<Vec<oneshot::Sender<()>>>>,
19 pub(crate) num: Arc<AtomicU32>,
20 notify: Arc<Notify>,
21}
22
23impl StopBroadcaster {
24 pub fn new() -> Self {
25 Self {
26 txs: Arc::new(Mutex::new(vec![])),
27 num: Arc::new(AtomicU32::new(0)),
28 notify: Arc::new(Notify::new()),
29 }
30 }
31
32 pub fn listener(&self) -> StopListener {
33 self.num.fetch_add(1, Ordering::SeqCst);
34 let notify = self.notify.clone();
35 let (tx, rx) = oneshot::channel();
36
37 self.txs.lock().push(tx);
38
39 StopListener {
40 receiver: rx,
41 num: self.num.clone(),
42 notify,
43 }
44 }
45
46 pub fn emit(&mut self) {
47 for tx in self.txs.lock().drain(..) {
49 tx.send(()).ok();
50 }
51 }
52
53 pub fn len(&self) -> u32 {
54 self.num.load(Ordering::SeqCst)
55 }
56
57 pub async fn until_empty(&self) {
58 while self.len() > 0 {
59 self.notify.notified().await
60 }
61 }
62}
63
64pub struct StopListener {
71 receiver: oneshot::Receiver<()>,
72 num: Arc<AtomicU32>,
73 notify: Arc<Notify>,
74}
75
76impl StopListener {
77 pub fn fuse_with<T, S: Unpin + Stream<Item = T>>(
80 self,
81 stream: S,
82 ) -> Fuse<Pin<Box<StopListenerFuse<T, S>>>> {
83 StreamExt::fuse(Box::pin(StopListenerFuse { stream, stop: self }))
84 }
85
86 pub fn receiver(&mut self) -> &mut oneshot::Receiver<()> {
88 &mut self.receiver
89 }
90
91 pub fn ready(&mut self) -> bool {
94 !matches!(self.receiver.try_recv(), Ok(None))
95 }
96}
97
98impl Drop for StopListener {
99 fn drop(&mut self) {
100 self.num.fetch_sub(1, Ordering::SeqCst);
101 self.notify.notify_one();
102 }
103}
104
105impl Future for StopListener {
106 type Output = ();
107
108 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
109 match Box::pin(&mut self.receiver).poll_unpin(cx) {
110 Poll::Ready(_) => Poll::Ready(()),
111 Poll::Pending => Poll::Pending,
112 }
113 }
114}
115
116pub struct StopListenerFuse<T, S: Stream<Item = T>> {
117 stream: S,
118 stop: StopListener,
119}
120
121impl<T, S: Unpin + Stream<Item = T>> Stream for StopListenerFuse<T, S> {
122 type Item = T;
123
124 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
125 if let Poll::Ready(_) = Box::pin(&mut self.stop).poll_unpin(cx) {
126 return Poll::Ready(None);
127 }
128
129 Stream::poll_next(Pin::new(&mut self.stream), cx)
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136 use crate::test_util::*;
137
138 #[tokio::test]
139 async fn test_stop_empty() {
140 let x = StopBroadcaster::new();
141 assert_eq!(x.len(), 0);
142 assert!(ready(x.until_empty()).await);
143 }
144
145 #[tokio::test]
146 async fn test_stop() {
147 let mut x = StopBroadcaster::new();
148 let a = x.listener();
149 let mut b = x.listener();
150 let c = x.listener();
151 assert_eq!(x.len(), 3);
152 assert!(not_ready(x.until_empty()).await);
153
154 assert!(not_ready(a).await);
155 assert_eq!(x.len(), 2);
156 assert!(!b.ready());
157
158 x.emit();
159 assert!(b.ready());
160 assert!(ready(b).await);
161 assert_eq!(x.len(), 1);
162 assert!(not_ready(x.until_empty()).await);
163
164 assert!(ready(c).await);
165 assert_eq!(x.len(), 0);
166 assert!(ready(x.until_empty()).await);
167
168 let y = StopBroadcaster::new();
169 let mut d = y.listener();
170 drop(y);
171 assert!(d.ready());
172 assert!(ready(d).await);
173 }
174
175 #[tokio::test]
176 async fn test_fuse_with() {
177 {
178 let mut tx = StopBroadcaster::new();
179 let rx = tx.listener();
180 let mut fused = rx.fuse_with(futures::stream::repeat(0));
181 assert_eq!(fused.next().await, Some(0));
182 assert_eq!(fused.next().await, Some(0));
183 tx.emit();
184 assert_eq!(fused.next().await, None);
185 assert_eq!(fused.next().await, None);
186 drop(fused);
187 tx.until_empty().await;
188 assert_eq!(tx.len(), 0);
189 }
190 {
191 let mut tx = StopBroadcaster::new();
192 let rx = tx.listener();
193 let mut fused = rx.fuse_with(futures::stream::repeat(0).take(1));
194 assert_eq!(fused.next().await, Some(0));
195 assert_eq!(fused.next().await, None);
196 assert_eq!(fused.next().await, None);
197 tx.emit();
198 assert_eq!(fused.next().await, None);
199 assert_eq!(fused.next().await, None);
200 drop(fused);
201 tx.until_empty().await;
202 assert_eq!(tx.len(), 0);
203 }
204 }
205}