shared_stream/
lib.rs

1//! A crate for easily cloneable [`Stream`]s, similar to `FutureExt::shared`.
2//!
3//! # Examples
4//! ```
5//! # futures::executor::block_on(async {
6//! use futures::stream::{self, StreamExt};
7//! use shared_stream::Share;
8//!
9//! let shared = stream::iter(1..=3).shared();
10//! assert_eq!(shared.clone().take(1).collect::<Vec<_>>().await, [1]);
11//! assert_eq!(shared.clone().skip(1).take(1).collect::<Vec<_>>().await, [2]);
12//! assert_eq!(shared.collect::<Vec<_>>().await, [1, 2, 3]);
13//! # });
14//! ```
15#![warn(
16    clippy::pedantic,
17    clippy::nursery,
18    clippy::cargo,
19    future_incompatible,
20    nonstandard_style,
21    rust_2018_compatibility,
22    rust_2018_idioms,
23    rustdoc,
24    unused,
25    absolute_paths_not_starting_with_crate,
26    anonymous_parameters,
27    box_pointers,
28    elided_lifetimes_in_paths,
29    explicit_outlives_requirements,
30    invalid_html_tags,
31    keyword_idents,
32    macro_use_extern_crate,
33    meta_variable_misuse,
34    missing_copy_implementations,
35    missing_crate_level_docs,
36    missing_debug_implementations,
37    missing_doc_code_examples,
38    missing_docs,
39    non_ascii_idents,
40    pointer_structural_match,
41    private_doc_tests,
42    single_use_lifetimes,
43    trivial_casts,
44    trivial_numeric_casts,
45    unaligned_references,
46    unreachable_pub,
47    unstable_features,
48    unused_crate_dependencies,
49    unused_extern_crates,
50    unused_import_braces,
51    unused_lifetimes,
52    unused_qualifications,
53    unused_results,
54    variant_size_differences
55)]
56
57use core::pin::Pin;
58use core::task::Context;
59use core::task::Poll;
60use futures_core::ready;
61use futures_core::{FusedStream, Stream};
62use pin_project_lite::pin_project;
63use std::cell::RefCell;
64use std::fmt;
65use std::mem;
66use std::rc::Rc;
67
68pin_project! {
69    #[project = InnerStateProj]
70    #[derive(Debug)]
71    enum InnerState<S: Stream> {
72        Running { values: Vec<S::Item>, #[pin] stream: S },
73        Finished { values: Vec<S::Item> },
74    }
75}
76impl<S: Stream> InnerState<S>
77where
78    S::Item: Clone,
79{
80    fn get_item(
81        mut self: Pin<&mut Self>,
82        idx: usize,
83        cx: &mut Context<'_>,
84    ) -> Poll<Option<S::Item>> {
85        loop {
86            let this = self.as_mut().project();
87            return Poll::Ready(match this {
88                InnerStateProj::Running { stream, values } => {
89                    let value = values.get(idx).cloned();
90                    if value.is_none() {
91                        let result = ready!(stream.poll_next(cx));
92                        if let Some(v) = result {
93                            values.push(v);
94                            continue;
95                        } else {
96                            let values = mem::take(values);
97                            self.set(Self::Finished { values });
98                        }
99                    }
100                    value
101                }
102                InnerStateProj::Finished { values } => values.get(idx).cloned(),
103            });
104        }
105    }
106}
107
108/// Stream for the [`shared`](Share::shared) method.
109#[must_use = "streams do nothing unless polled"]
110pub struct Shared<S: Stream> {
111    inner: Rc<RefCell<InnerState<S>>>,
112    idx: usize,
113}
114
115impl<S: Stream> fmt::Debug for Shared<S>
116where
117    S: fmt::Debug,
118    S::Item: fmt::Debug,
119{
120    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
121        f.debug_struct("Shared")
122            .field("inner", &self.inner)
123            .field("idx", &self.idx)
124            .finish()
125    }
126}
127
128impl<S: Stream> Shared<S> {
129    pub(crate) fn new(stream: S) -> Self {
130        Self {
131            inner: Rc::new(RefCell::new(InnerState::Running {
132                stream,
133                values: vec![],
134            })),
135            idx: 0,
136        }
137    }
138}
139
140impl<S: Stream> Clone for Shared<S> {
141    fn clone(&self) -> Self {
142        Self {
143            inner: Rc::clone(&self.inner),
144            idx: self.idx,
145        }
146    }
147}
148
149impl<S: Stream> Stream for Shared<S>
150where
151    S::Item: Clone,
152{
153    type Item = S::Item;
154    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
155        // pin project Pin<&mut Self> -> Pin<&mut InnerState<I, S>>
156        // this is only safe because we don't do anything else with Self::inner except
157        // cloning (the Rc) which doesn't move its content or make it accessible.
158        let result = unsafe {
159            let inner: &RefCell<InnerState<S>> =
160                Pin::into_inner_unchecked(self.as_ref()).inner.as_ref();
161            Pin::new_unchecked(&mut *inner.borrow_mut()).get_item(self.idx, cx)
162        };
163        if let Poll::Ready(Some(_)) = result {
164            // trivial safe pin projection
165            unsafe { Pin::get_unchecked_mut(self).idx += 1 }
166        }
167        result
168    }
169
170    fn size_hint(&self) -> (usize, Option<usize>) {
171        match &*self.inner.borrow() {
172            InnerState::Running { values, stream } => {
173                let upstream_cached = values.len() - self.idx;
174                let upstream = stream.size_hint();
175                (
176                    upstream.0 + upstream_cached,
177                    upstream.1.map(|v| v + upstream_cached),
178                )
179            }
180            InnerState::Finished { values } => {
181                (values.len() - self.idx, Some(values.len() - self.idx))
182            }
183        }
184    }
185}
186
187impl<S: Stream> FusedStream for Shared<S>
188where
189    S::Item: Clone,
190{
191    fn is_terminated(&self) -> bool {
192        match &*self.inner.borrow() {
193            InnerState::Running { .. } => false,
194            InnerState::Finished { values } => values.len() <= self.idx,
195        }
196    }
197}
198
199/// An extension trait implemented for [`Stream`]s that provides the [`shared`](Share::shared) method.
200pub trait Share: Stream {
201    /// Turns this stream into a cloneable stream. Polled items are cached and cloned.
202    ///
203    /// Note that this function consumes the stream passed into it and returns a wrapped version of it.
204    fn shared(self) -> Shared<Self>
205    where
206        Self: Sized,
207        Self::Item: Clone;
208}
209
210impl<T: Stream> Share for T
211where
212    T::Item: Clone,
213{
214    fn shared(self) -> Shared<Self> {
215        Shared::new(self)
216    }
217}
218
219#[cfg(test)]
220mod test {
221    use super::Share;
222    use core::cell::RefCell;
223    use futures::executor::block_on;
224    use futures::future;
225    use futures::stream::{self, StreamExt};
226    use futures_core::stream::{FusedStream, Stream};
227
228    fn collect<V: Clone, S: Stream<Item = V>>(stream: S) -> Vec<V> {
229        block_on(stream.collect::<Vec<_>>())
230    }
231
232    #[test]
233    fn test_everything() {
234        let seen = RefCell::new(vec![]);
235        let orig_stream = stream::iter(["a", "b", "c"].iter().map(|v| v.to_string()))
236            .inspect(|v| {
237                seen.borrow_mut().push(v.clone());
238            })
239            .shared();
240        assert_eq!(*seen.borrow(), [] as [String; 0]);
241        assert_eq!(orig_stream.size_hint(), (3, Some(3)));
242        assert_eq!(orig_stream.is_terminated(), false);
243
244        let stream = orig_stream.clone().take(1);
245        assert_eq!(stream.size_hint(), (1, Some(1)));
246        assert_eq!(stream.is_terminated(), false);
247        let result = collect(stream);
248        assert_eq!(result, ["a"]);
249        assert_eq!(*seen.borrow(), ["a"]);
250        assert_eq!(orig_stream.size_hint(), (3, Some(3)));
251        assert_eq!(orig_stream.is_terminated(), false);
252
253        let stream = orig_stream.clone();
254        assert_eq!(stream.size_hint(), (3, Some(3)));
255        assert_eq!(stream.is_terminated(), false);
256        let result = collect(stream);
257        assert_eq!(result, ["a", "b", "c"]);
258        assert_eq!(*seen.borrow(), ["a", "b", "c"]);
259        assert_eq!(orig_stream.size_hint(), (3, Some(3)));
260        assert_eq!(orig_stream.is_terminated(), false);
261
262        let stream1 = orig_stream.clone().skip(1);
263        assert_eq!(stream1.size_hint(), (2, Some(2)));
264        assert_eq!(stream1.is_terminated(), false);
265        let stream2 = orig_stream.clone();
266        assert_eq!(stream2.size_hint(), (3, Some(3)));
267        assert_eq!(stream2.is_terminated(), false);
268        let (result1, result2): (Vec<_>, Vec<_>) =
269            block_on(future::join(stream1.collect(), stream2.collect()));
270        assert_eq!(result1, ["b", "c"]);
271        assert_eq!(result2, ["a", "b", "c"]);
272        assert_eq!(*seen.borrow(), ["a", "b", "c"]);
273        assert_eq!(orig_stream.size_hint(), (3, Some(3)));
274        assert_eq!(orig_stream.is_terminated(), false);
275
276        let mut stream1 = orig_stream.clone();
277        assert_eq!(Some("a".to_string()), block_on(stream1.next()));
278        assert_eq!(stream1.size_hint(), (2, Some(2)));
279        assert_eq!(stream1.is_terminated(), false);
280        assert_eq!(Some("b".to_string()), block_on(stream1.next()));
281        assert_eq!(stream1.size_hint(), (1, Some(1)));
282        assert_eq!(stream1.is_terminated(), false);
283        assert_eq!(Some("c".to_string()), block_on(stream1.next()));
284        assert_eq!(stream1.size_hint(), (0, Some(0)));
285        assert_eq!(stream1.is_terminated(), true);
286        assert_eq!(orig_stream.size_hint(), (3, Some(3)));
287        assert_eq!(orig_stream.is_terminated(), false);
288    }
289
290    #[test]
291    fn test_size_hint_for_unfinished() {
292        let mut stream = stream::iter(["a", "b", "c"].iter().map(|v| v.to_string())).shared();
293        assert_eq!(stream.size_hint(), (3, Some(3)));
294        assert_eq!(stream.is_terminated(), false);
295        block_on(stream.next());
296        assert_eq!(stream.size_hint(), (2, Some(2)));
297        assert_eq!(stream.is_terminated(), false);
298    }
299}