stream_broadcast/
lib.rs

1#![doc = include_str!("../README.md")]
2#![forbid(unsafe_code)]
3
4use futures::stream::{FusedStream, Stream};
5use pin_project::pin_project;
6use std::{
7    ops::DerefMut,
8    pin::{pin, Pin},
9    sync::{atomic::AtomicU64, Arc, Mutex},
10    task::Poll,
11};
12
13mod weak;
14
15pub use weak::*;
16
17pub trait StreamBroadcastExt: FusedStream + Sized {
18    fn broadcast(self, size: usize) -> StreamBroadcast<Self>;
19}
20
21impl<T: FusedStream + Sized> StreamBroadcastExt for T
22where
23    T::Item: Clone,
24{
25    fn broadcast(self, size: usize) -> StreamBroadcast<Self> {
26        StreamBroadcast::new(self, size)
27    }
28}
29
30#[pin_project]
31pub struct StreamBroadcast<T: FusedStream> {
32    pos: u64,
33    id: u64,
34    state: Arc<Mutex<Pin<Box<StreamBroadcastState<T>>>>>,
35}
36
37impl<T: FusedStream> std::fmt::Debug for StreamBroadcast<T> {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        let pending = self.state.lock().unwrap().global_pos - self.pos;
40        f.debug_struct("WeakStreamBroadcast")
41            .field("pending_messages", &pending)
42            .field("strong_count", &Arc::strong_count(&self.state))
43            .finish()
44    }
45}
46
47impl<T: FusedStream> Clone for StreamBroadcast<T> {
48    fn clone(&self) -> Self {
49        Self {
50            state: self.state.clone(),
51            id: create_id(),
52            pos: self.pos,
53        }
54    }
55}
56
57impl<T: FusedStream> StreamBroadcast<T>
58where
59    T::Item: Clone,
60{
61    pub fn new(outer: T, size: usize) -> Self {
62        Self {
63            state: Arc::new(Mutex::new(Box::pin(StreamBroadcastState::new(outer, size)))),
64            id: create_id(),
65            pos: 0,
66        }
67    }
68
69    /// Creates a weak broadcast which terminates its stream, if all 'strong' [StreamBroadcast] went out of scope
70    ///
71    /// ```
72    /// # #[tokio::main]
73    /// # async fn main() {
74    /// use futures::StreamExt;
75    /// use stream_broadcast::StreamBroadcastExt;
76    ///
77    /// let stream = futures::stream::iter(0..).fuse().broadcast(5);
78    /// let mut weak = std::pin::pin!(stream.downgrade());
79    /// assert_eq!(Some((0, 0)), weak.next().await);
80    /// drop(stream);
81    /// assert_eq!(None, weak.next().await);
82    /// # }
83    /// ```
84    pub fn downgrade(&self) -> WeakStreamBroadcast<T> {
85        WeakStreamBroadcast::new(Arc::downgrade(&self.state), self.pos)
86    }
87
88    /// In contrast to clone, this method only shows new messages provided by the source stream
89    pub fn re_subscribe(&self) -> Self {
90        Self {
91            state: self.state.clone(),
92            id: create_id(),
93            pos: self.state.lock().unwrap().global_pos,
94        }
95    }
96}
97
98impl<T: FusedStream> Stream for StreamBroadcast<T>
99where
100    T::Item: Clone,
101{
102    type Item = (u64, T::Item);
103
104    fn poll_next(
105        self: Pin<&mut Self>,
106        cx: &mut std::task::Context<'_>,
107    ) -> Poll<Option<Self::Item>> {
108        let this = self.project();
109        let mut lock = this.state.lock().unwrap();
110        broadast_next(lock.deref_mut().as_mut(), cx, this.pos, *this.id)
111    }
112}
113fn create_id() -> u64 {
114    static ID_COUNTER: AtomicU64 = AtomicU64::new(0);
115    ID_COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
116}
117fn broadast_next<T: FusedStream>(
118    pinned: Pin<&mut StreamBroadcastState<T>>,
119    cx: &mut std::task::Context<'_>,
120    pos: &mut u64,
121    id: u64,
122) -> Poll<Option<(u64, T::Item)>>
123where
124    T::Item: Clone,
125{
126    match pinned.poll(cx, *pos, id) {
127        Poll::Ready(Some((new_pos, x))) => {
128            debug_assert!(new_pos > *pos, "Must always grow {} > {}", new_pos, *pos);
129            let offset = new_pos - *pos - 1;
130            *pos = new_pos;
131            Poll::Ready(Some((offset, x)))
132        }
133        Poll::Ready(None) => {
134            *pos += 1;
135            Poll::Ready(None)
136        }
137        Poll::Pending => Poll::Pending,
138    }
139}
140
141impl<T: FusedStream> FusedStream for StreamBroadcast<T>
142where
143    T::Item: Clone,
144{
145    fn is_terminated(&self) -> bool {
146        self.state.lock().unwrap().stream.is_terminated()
147    }
148}
149
150#[pin_project]
151struct StreamBroadcastState<T: FusedStream> {
152    #[pin]
153    stream: T,
154    global_pos: u64,
155    cache: Vec<T::Item>,
156    wakable: Vec<(u64, std::task::Waker)>,
157}
158
159impl<T: FusedStream> StreamBroadcastState<T>
160where
161    T::Item: Clone,
162{
163    fn new(outer: T, size: usize) -> Self {
164        Self {
165            stream: outer,
166            cache: Vec::with_capacity(size), // Could be improved with  Box<[MaybeUninit<T::Item>]>
167            global_pos: Default::default(),
168            wakable: Default::default(),
169        }
170    }
171    fn poll(
172        self: Pin<&mut Self>,
173        cx: &mut std::task::Context<'_>,
174        request_pos: u64,
175        id: u64,
176    ) -> Poll<Option<(u64, T::Item)>> {
177        let this = self.project();
178        if *this.global_pos > request_pos {
179            let cap = this.cache.capacity();
180            let return_pos = if *this.global_pos - request_pos > cap as u64 {
181                *this.global_pos - cap as u64
182            } else {
183                request_pos
184            };
185
186            let result = this.cache[(return_pos % cap as u64) as usize].clone();
187            return Poll::Ready(Some((return_pos + 1, result)));
188        }
189
190        match this.stream.poll_next(cx) {
191            Poll::Ready(Some(x)) => {
192                this.wakable.drain(..).for_each(|(k, w)| {
193                    if k != id {
194                        w.wake();
195                    }
196                });
197
198                let cap = this.cache.capacity();
199                if this.cache.len() < cap {
200                    this.cache.push(x.clone());
201                } else {
202                    this.cache[(*this.global_pos % cap as u64) as usize] = x.clone();
203                }
204                *this.global_pos += 1;
205                let result = (*this.global_pos, x);
206                Poll::Ready(Some(result))
207            }
208            Poll::Ready(None) => Poll::Ready(None),
209            Poll::Pending => {
210                this.wakable.push((id, cx.waker().clone()));
211                Poll::Pending
212            }
213        }
214    }
215}