stream_utils/
copied_multi_stream.rs1use std::{
2 sync::{Arc, Mutex},
3 task::Waker,
4};
5
6use futures_util::{Stream, StreamExt};
7
8#[derive(Clone)]
9struct CopiedMultiStreamState<S>
10where
11 S: Stream,
12{
13 cache: Box<[Option<S::Item>]>,
14 wakers: Box<[Option<Waker>]>,
15 stream: Option<S>,
16}
17
18#[must_use = "streams do nothing unless polled"]
20#[derive(Clone)]
21pub struct CopiedMultiStream<S>
22where
23 S: Stream,
24{
25 state: Arc<Mutex<CopiedMultiStreamState<S>>>,
26 pos: usize,
27}
28
29pub fn copied_multi_stream<S>(stream: S, i: usize) -> Vec<CopiedMultiStream<S>>
38where
39 S: Stream,
40{
41 let state = Arc::new(Mutex::new(CopiedMultiStreamState {
42 stream: Some(stream),
43 cache: (0..i).map(|_| None).collect(),
44 wakers: (0..i).map(|_| None).collect(),
45 }));
46 (0..i)
47 .map(|pos| CopiedMultiStream {
48 pos,
49 state: state.clone(),
50 })
51 .collect()
52}
53
54impl<S> Stream for CopiedMultiStream<S>
55where
56 S: Stream + Unpin,
57 S::Item: Clone,
58{
59 type Item = S::Item;
60
61 fn poll_next(
62 self: std::pin::Pin<&mut Self>,
63 cx: &mut std::task::Context<'_>,
64 ) -> std::task::Poll<Option<Self::Item>> {
65 let mut state = self.state.lock().unwrap();
66 if let Some(v) = state.cache[self.pos].take() {
67 std::task::Poll::Ready(Some(v))
68 } else if state.cache.iter().any(Option::is_some) {
69 state.wakers[self.pos] = Some(cx.waker().clone());
70 std::task::Poll::Pending
71 } else if let Some(ref mut stream) = state.stream {
72 match stream.poll_next_unpin(cx) {
73 std::task::Poll::Ready(Some(v)) => {
74 state.cache.iter_mut().for_each(|c| *c = Some(v.clone()));
75 state.wakers.iter_mut().for_each(|waker| {
76 if let Some(waker) = waker.take() {
77 waker.wake_by_ref()
78 }
79 });
80 std::task::Poll::Ready(state.cache[self.pos].take())
81 }
82 std::task::Poll::Ready(None) => {
83 state.stream = None;
84 state.wakers.iter_mut().for_each(|waker| {
85 if let Some(waker) = waker.take() {
86 waker.wake_by_ref()
87 }
88 });
89 std::task::Poll::Ready(None)
90 }
91 std::task::Poll::Pending => {
92 state.wakers[self.pos] = Some(cx.waker().clone());
93 std::task::Poll::Pending
94 }
95 }
96 } else {
97 std::task::Poll::Ready(None)
98 }
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use std::pin::pin;
105
106 use futures_util::stream::{self, BoxStream};
107 use ntest_timeout::timeout;
108
109 use crate::StreamUtils;
110
111 use super::*;
112
113 #[tokio::test]
114 async fn test_stream() {
115 let size = 3;
116 let stream = stream::iter(0..3);
117 let res = stream.copied_multi_stream(size);
118
119 assert_eq!(res.len(), size);
120 let res = stream::select_all(res);
121 let res: Vec<usize> = res.collect().await;
122 assert_eq!(res, vec![0, 0, 0, 1, 1, 1, 2, 2, 2]);
123 }
124
125 #[tokio::test]
126 async fn test_box_stream() {
127 let size = 3;
128 let stream: BoxStream<usize> = Box::pin(stream::iter(0..3));
129 let res = stream.copied_multi_stream(size);
130 assert_eq!(res.len(), size);
131 let res = stream::select_all(res);
132 let res: Vec<usize> = res.collect().await;
133 assert_eq!(res, vec![0, 0, 0, 1, 1, 1, 2, 2, 2]);
134 }
135
136 #[tokio::test]
137 async fn test_empty_stream() {
138 let size = 3;
139 let stream = Box::pin(stream::iter(0..0));
140 let res = stream.copied_multi_stream(size);
141 assert_eq!(res.len(), size);
142 let res = stream::select_all(res);
143 let res: Vec<usize> = res.collect().await;
144 let exp: Vec<usize> = Vec::new();
145 assert_eq!(res, exp);
146 }
147
148 #[tokio::test]
149 async fn test_zero_streams() {
150 let size = 0;
151 let stream = stream::iter(0..3);
152 let res = stream.copied_multi_stream(size);
153 assert_eq!(res.len(), size);
154 let res = stream::select_all(res);
155 let res: Vec<usize> = res.collect().await;
156 let exp: Vec<usize> = Vec::new();
157 assert_eq!(res, exp);
158 }
159
160 #[tokio::test]
161 async fn test_future_stream() {
162 let size = 3;
163 let stream = stream::unfold(0, |state| async move {
164 if state <= 2 {
165 let next_state = state + 1;
166 let yielded = state * 2;
167 Some((yielded, next_state))
168 } else {
169 None
170 }
171 });
172 let stream = pin!(stream);
173 let res = stream.copied_multi_stream(size);
174 assert_eq!(res.len(), size);
175 let res = stream::select_all(res);
176 let res: Vec<usize> = res.collect().await;
177 assert_eq!(res, vec![0, 0, 0, 2, 2, 2, 4, 4, 4]);
178 }
179
180 #[tokio::test]
181 #[timeout(200)]
182 async fn test_async_pull() {
183 let size = 5;
184 let stream = stream::iter(0..3);
185 let res = stream.copied_multi_stream(size);
186
187 let res: Vec<_> = res
188 .into_iter()
189 .map(|stream| tokio::task::spawn(async move { stream.collect::<Vec<usize>>().await }))
190 .collect();
191 for r in res {
192 let r = r.await.unwrap();
193 assert_eq!(r, vec![0, 1, 2]);
194 }
195 }
196}