1#![doc = include_str!("../README.md")]
2#![forbid(unsafe_code)]
3
4use futures::stream::{FusedStream, Stream};
5use pin_project::pin_project;
6use std::{
7 ops::DerefMut,
8 pin::{pin, Pin},
9 sync::{atomic::AtomicU64, Arc, Mutex},
10 task::Poll,
11};
12
13mod weak;
14
15pub use weak::*;
16
17pub trait StreamBroadcastExt: FusedStream + Sized {
18 fn broadcast(self, size: usize) -> StreamBroadcast<Self>;
19}
20
21impl<T: FusedStream + Sized> StreamBroadcastExt for T
22where
23 T::Item: Clone,
24{
25 fn broadcast(self, size: usize) -> StreamBroadcast<Self> {
26 StreamBroadcast::new(self, size)
27 }
28}
29
30#[pin_project]
31pub struct StreamBroadcast<T: FusedStream> {
32 pos: u64,
33 id: u64,
34 state: Arc<Mutex<Pin<Box<StreamBroadcastState<T>>>>>,
35}
36
37impl<T: FusedStream> std::fmt::Debug for StreamBroadcast<T> {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 let pending = self.state.lock().unwrap().global_pos - self.pos;
40 f.debug_struct("WeakStreamBroadcast")
41 .field("pending_messages", &pending)
42 .field("strong_count", &Arc::strong_count(&self.state))
43 .finish()
44 }
45}
46
47impl<T: FusedStream> Clone for StreamBroadcast<T> {
48 fn clone(&self) -> Self {
49 Self {
50 state: self.state.clone(),
51 id: create_id(),
52 pos: self.pos,
53 }
54 }
55}
56
57impl<T: FusedStream> StreamBroadcast<T>
58where
59 T::Item: Clone,
60{
61 pub fn new(outer: T, size: usize) -> Self {
62 Self {
63 state: Arc::new(Mutex::new(Box::pin(StreamBroadcastState::new(outer, size)))),
64 id: create_id(),
65 pos: 0,
66 }
67 }
68
69 pub fn downgrade(&self) -> WeakStreamBroadcast<T> {
85 WeakStreamBroadcast::new(Arc::downgrade(&self.state), self.pos)
86 }
87
88 pub fn re_subscribe(&self) -> Self {
90 Self {
91 state: self.state.clone(),
92 id: create_id(),
93 pos: self.state.lock().unwrap().global_pos,
94 }
95 }
96}
97
98impl<T: FusedStream> Stream for StreamBroadcast<T>
99where
100 T::Item: Clone,
101{
102 type Item = (u64, T::Item);
103
104 fn poll_next(
105 self: Pin<&mut Self>,
106 cx: &mut std::task::Context<'_>,
107 ) -> Poll<Option<Self::Item>> {
108 let this = self.project();
109 let mut lock = this.state.lock().unwrap();
110 broadast_next(lock.deref_mut().as_mut(), cx, this.pos, *this.id)
111 }
112}
113fn create_id() -> u64 {
114 static ID_COUNTER: AtomicU64 = AtomicU64::new(0);
115 ID_COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
116}
117fn broadast_next<T: FusedStream>(
118 pinned: Pin<&mut StreamBroadcastState<T>>,
119 cx: &mut std::task::Context<'_>,
120 pos: &mut u64,
121 id: u64,
122) -> Poll<Option<(u64, T::Item)>>
123where
124 T::Item: Clone,
125{
126 match pinned.poll(cx, *pos, id) {
127 Poll::Ready(Some((new_pos, x))) => {
128 debug_assert!(new_pos > *pos, "Must always grow {} > {}", new_pos, *pos);
129 let offset = new_pos - *pos - 1;
130 *pos = new_pos;
131 Poll::Ready(Some((offset, x)))
132 }
133 Poll::Ready(None) => {
134 *pos += 1;
135 Poll::Ready(None)
136 }
137 Poll::Pending => Poll::Pending,
138 }
139}
140
141impl<T: FusedStream> FusedStream for StreamBroadcast<T>
142where
143 T::Item: Clone,
144{
145 fn is_terminated(&self) -> bool {
146 self.state.lock().unwrap().stream.is_terminated()
147 }
148}
149
150#[pin_project]
151struct StreamBroadcastState<T: FusedStream> {
152 #[pin]
153 stream: T,
154 global_pos: u64,
155 cache: Vec<T::Item>,
156 wakable: Vec<(u64, std::task::Waker)>,
157}
158
159impl<T: FusedStream> StreamBroadcastState<T>
160where
161 T::Item: Clone,
162{
163 fn new(outer: T, size: usize) -> Self {
164 Self {
165 stream: outer,
166 cache: Vec::with_capacity(size), global_pos: Default::default(),
168 wakable: Default::default(),
169 }
170 }
171 fn poll(
172 self: Pin<&mut Self>,
173 cx: &mut std::task::Context<'_>,
174 request_pos: u64,
175 id: u64,
176 ) -> Poll<Option<(u64, T::Item)>> {
177 let this = self.project();
178 if *this.global_pos > request_pos {
179 let cap = this.cache.capacity();
180 let return_pos = if *this.global_pos - request_pos > cap as u64 {
181 *this.global_pos - cap as u64
182 } else {
183 request_pos
184 };
185
186 let result = this.cache[(return_pos % cap as u64) as usize].clone();
187 return Poll::Ready(Some((return_pos + 1, result)));
188 }
189
190 match this.stream.poll_next(cx) {
191 Poll::Ready(Some(x)) => {
192 this.wakable.drain(..).for_each(|(k, w)| {
193 if k != id {
194 w.wake();
195 }
196 });
197
198 let cap = this.cache.capacity();
199 if this.cache.len() < cap {
200 this.cache.push(x.clone());
201 } else {
202 this.cache[(*this.global_pos % cap as u64) as usize] = x.clone();
203 }
204 *this.global_pos += 1;
205 let result = (*this.global_pos, x);
206 Poll::Ready(Some(result))
207 }
208 Poll::Ready(None) => Poll::Ready(None),
209 Poll::Pending => {
210 this.wakable.push((id, cx.waker().clone()));
211 Poll::Pending
212 }
213 }
214 }
215}