stream_utils/
copied_multi_stream.rs

1use std::{
2    sync::{Arc, Mutex},
3    task::Waker,
4};
5
6use futures_util::{Stream, StreamExt};
7
8#[derive(Clone)]
9struct CopiedMultiStreamState<S>
10where
11    S: Stream,
12{
13    cache: Box<[Option<S::Item>]>,
14    wakers: Box<[Option<Waker>]>,
15    stream: Option<S>,
16}
17
18/// Stream for the [`copied_multi_stream`](crate::StreamUtils::copied_multi_stream) method.
19#[must_use = "streams do nothing unless polled"]
20#[derive(Clone)]
21pub struct CopiedMultiStream<S>
22where
23    S: Stream,
24{
25    state: Arc<Mutex<CopiedMultiStreamState<S>>>,
26    pos: usize,
27}
28
29/// Copies values from the inner stream into multiple new streams. Polls from inner stream one
30/// value and waits till all new streams have pulled a copied value.
31/// Note that not pulling from all new streams will result in a blocking state.
32///
33/// When the underlying stream terminates, all new streams which have allready pulled the last value will be [`Pending`].
34/// When all new streams have pulled the last value, all streams will terminate on next pull.
35///
36/// [`Pending`]: std::task::Poll#variant.Pending
37pub fn copied_multi_stream<S>(stream: S, i: usize) -> Vec<CopiedMultiStream<S>>
38where
39    S: Stream,
40{
41    let state = Arc::new(Mutex::new(CopiedMultiStreamState {
42        stream: Some(stream),
43        cache: (0..i).map(|_| None).collect(),
44        wakers: (0..i).map(|_| None).collect(),
45    }));
46    (0..i)
47        .map(|pos| CopiedMultiStream {
48            pos,
49            state: state.clone(),
50        })
51        .collect()
52}
53
54impl<S> Stream for CopiedMultiStream<S>
55where
56    S: Stream + Unpin,
57    S::Item: Clone,
58{
59    type Item = S::Item;
60
61    fn poll_next(
62        self: std::pin::Pin<&mut Self>,
63        cx: &mut std::task::Context<'_>,
64    ) -> std::task::Poll<Option<Self::Item>> {
65        let mut state = self.state.lock().unwrap();
66        if let Some(v) = state.cache[self.pos].take() {
67            std::task::Poll::Ready(Some(v))
68        } else if state.cache.iter().any(Option::is_some) {
69            state.wakers[self.pos] = Some(cx.waker().clone());
70            std::task::Poll::Pending
71        } else if let Some(ref mut stream) = state.stream {
72            match stream.poll_next_unpin(cx) {
73                std::task::Poll::Ready(Some(v)) => {
74                    state.cache.iter_mut().for_each(|c| *c = Some(v.clone()));
75                    state.wakers.iter_mut().for_each(|waker| {
76                        if let Some(waker) = waker.take() {
77                            waker.wake_by_ref()
78                        }
79                    });
80                    std::task::Poll::Ready(state.cache[self.pos].take())
81                }
82                std::task::Poll::Ready(None) => {
83                    state.stream = None;
84                    state.wakers.iter_mut().for_each(|waker| {
85                        if let Some(waker) = waker.take() {
86                            waker.wake_by_ref()
87                        }
88                    });
89                    std::task::Poll::Ready(None)
90                }
91                std::task::Poll::Pending => {
92                    state.wakers[self.pos] = Some(cx.waker().clone());
93                    std::task::Poll::Pending
94                }
95            }
96        } else {
97            std::task::Poll::Ready(None)
98        }
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use std::pin::pin;
105
106    use futures_util::stream::{self, BoxStream};
107    use ntest_timeout::timeout;
108
109    use crate::StreamUtils;
110
111    use super::*;
112
113    #[tokio::test]
114    async fn test_stream() {
115        let size = 3;
116        let stream = stream::iter(0..3);
117        let res = stream.copied_multi_stream(size);
118
119        assert_eq!(res.len(), size);
120        let res = stream::select_all(res);
121        let res: Vec<usize> = res.collect().await;
122        assert_eq!(res, vec![0, 0, 0, 1, 1, 1, 2, 2, 2]);
123    }
124
125    #[tokio::test]
126    async fn test_box_stream() {
127        let size = 3;
128        let stream: BoxStream<usize> = Box::pin(stream::iter(0..3));
129        let res = stream.copied_multi_stream(size);
130        assert_eq!(res.len(), size);
131        let res = stream::select_all(res);
132        let res: Vec<usize> = res.collect().await;
133        assert_eq!(res, vec![0, 0, 0, 1, 1, 1, 2, 2, 2]);
134    }
135
136    #[tokio::test]
137    async fn test_empty_stream() {
138        let size = 3;
139        let stream = Box::pin(stream::iter(0..0));
140        let res = stream.copied_multi_stream(size);
141        assert_eq!(res.len(), size);
142        let res = stream::select_all(res);
143        let res: Vec<usize> = res.collect().await;
144        let exp: Vec<usize> = Vec::new();
145        assert_eq!(res, exp);
146    }
147
148    #[tokio::test]
149    async fn test_zero_streams() {
150        let size = 0;
151        let stream = stream::iter(0..3);
152        let res = stream.copied_multi_stream(size);
153        assert_eq!(res.len(), size);
154        let res = stream::select_all(res);
155        let res: Vec<usize> = res.collect().await;
156        let exp: Vec<usize> = Vec::new();
157        assert_eq!(res, exp);
158    }
159
160    #[tokio::test]
161    async fn test_future_stream() {
162        let size = 3;
163        let stream = stream::unfold(0, |state| async move {
164            if state <= 2 {
165                let next_state = state + 1;
166                let yielded = state * 2;
167                Some((yielded, next_state))
168            } else {
169                None
170            }
171        });
172        let stream = pin!(stream);
173        let res = stream.copied_multi_stream(size);
174        assert_eq!(res.len(), size);
175        let res = stream::select_all(res);
176        let res: Vec<usize> = res.collect().await;
177        assert_eq!(res, vec![0, 0, 0, 2, 2, 2, 4, 4, 4]);
178    }
179
180    #[tokio::test]
181    #[timeout(200)]
182    async fn test_async_pull() {
183        let size = 5;
184        let stream = stream::iter(0..3);
185        let res = stream.copied_multi_stream(size);
186
187        let res: Vec<_> = res
188            .into_iter()
189            .map(|stream| tokio::task::spawn(async move { stream.collect::<Vec<usize>>().await }))
190            .collect();
191        for r in res {
192            let r = r.await.unwrap();
193            assert_eq!(r, vec![0, 1, 2]);
194        }
195    }
196}