streamtools/
flatten_switch.rs1use futures::task;
2use futures::Stream;
3use pin_project_lite::pin_project;
4use std::sync::Arc;
5use std::task::{Context, Poll};
6
7use crate::outer_waker::OuterWaker;
8
9pin_project! {
10 #[must_use = "streams do nothing unless polled"]
12 pub struct FlattenSwitch<St>
13 where
14 St: Stream,
15 St::Item: Stream
16 {
17 #[pin]
18 outer: St,
19
20 outer_waker: Arc<OuterWaker>,
21
22 #[pin]
23 inner: Option<<St as Stream>::Item>
24 }
25}
26
27impl<St> FlattenSwitch<St>
28where
29 St: Stream,
30 St::Item: Stream,
31{
32 pub(super) fn new(stream: St) -> Self {
33 Self {
34 outer: stream,
35 outer_waker: Arc::default(),
36 inner: None,
37 }
38 }
39}
40
41impl<St> Stream for FlattenSwitch<St>
42where
43 St: Stream,
44 St::Item: Stream,
45{
46 type Item = <St::Item as Stream>::Item;
47
48 fn poll_next(
49 self: std::pin::Pin<&mut Self>,
50 cx: &mut std::task::Context<'_>,
51 ) -> std::task::Poll<Option<Self::Item>> {
52 let mut this = self.project();
53
54 let outer_ready = this.outer_waker.set_parent_waker(cx.waker().clone());
57 if outer_ready {
58 let waker = task::waker(Arc::clone(this.outer_waker));
59 let mut cx = Context::from_waker(&waker);
60 while let Poll::Ready(inner) = this.outer.as_mut().poll_next(&mut cx) {
61 match inner {
62 Some(inner) => this.inner.set(Some(inner)),
63 None => {
64 return Poll::Ready(None);
66 }
67 }
68 }
69 };
70
71 match this.inner.as_mut().as_pin_mut() {
72 Some(inner) => match inner.poll_next(cx) {
73 Poll::Ready(value) => match value {
74 Some(value) => Poll::Ready(Some(value)),
75 None => {
76 this.inner.set(None);
80
81 Poll::Pending
83 }
84 },
85
86 Poll::Pending => Poll::Pending,
88 },
89
90 None => Poll::Pending,
92 }
93 }
94}
95
96impl<S> std::fmt::Debug for FlattenSwitch<S>
97where
98 S: Stream + std::fmt::Debug,
99 S::Item: Stream + std::fmt::Debug,
100{
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102 f.debug_struct("FlattenSwitch")
103 .field("stream", &self.outer)
104 .field("inner", &self.inner)
105 .finish()
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use std::future;
112
113 use futures::{stream, FutureExt, StreamExt};
114 use parking_lot::Mutex;
115 use tokio_test::{assert_pending, assert_ready_eq};
116
117 use super::*;
118
119 pin_project! {
120 struct MockStream<S: Stream> {
121 #[pin]
122 inner: S,
123 polled: Arc<Mutex<bool>>
124 }
125 }
126
127 impl<S: Stream> Stream for MockStream<S> {
128 type Item = S::Item;
129
130 fn poll_next(
131 self: std::pin::Pin<&mut Self>,
132 cx: &mut Context<'_>,
133 ) -> Poll<Option<Self::Item>> {
134 let this = self.project();
135 let result = this.inner.poll_next(cx);
136
137 *this.polled.lock() = true;
138
139 result
140 }
141 }
142
143 #[tokio::test]
144 async fn test_flatten_switch() {
145 use futures::{channel::mpsc, SinkExt, StreamExt};
146 use tokio::sync::broadcast::{self, error::SendError};
147 use tokio_stream::wrappers::BroadcastStream;
148
149 let waker = futures::task::noop_waker_ref();
150 let mut cx = std::task::Context::from_waker(waker);
151
152 let (tx_inner1, rx_inner1) = broadcast::channel(32);
153 let (tx_inner2, rx_inner2) = broadcast::channel(32);
154 let (tx_inner3, rx_inner3) = broadcast::channel(32);
155 let (mut tx, rx) = mpsc::unbounded();
156
157 let outer_polled = Arc::new(Mutex::new(false));
158
159 let take_outer_polled = || -> bool {
160 let mut guard = outer_polled.lock();
161 std::mem::replace(&mut guard, false)
162 };
163
164 let assert_outer_polled = || assert!(take_outer_polled());
165 let assert_outer_not_polled = || assert!(!take_outer_polled());
166
167 let outer_stream = MockStream {
168 inner: rx,
169 polled: Arc::clone(&outer_polled),
170 };
171
172 let mut switch_stream = FlattenSwitch::new(outer_stream);
173
174 assert_eq!(switch_stream.poll_next_unpin(&mut cx), Poll::Pending);
175 assert_outer_polled();
176
177 tx.send(
178 BroadcastStream::new(rx_inner1)
179 .map(|r: Result<_, _>| r.unwrap())
180 .boxed(),
181 )
182 .await
183 .unwrap();
184
185 assert_eq!(switch_stream.poll_next_unpin(&mut cx), Poll::Pending);
186 assert_outer_polled();
187
188 tx_inner1.send(10).unwrap();
189 assert_eq!(
190 switch_stream.poll_next_unpin(&mut cx),
191 Poll::Ready(Some(10))
192 );
193 assert_outer_not_polled(); assert_eq!(switch_stream.poll_next_unpin(&mut cx), Poll::Pending);
195 assert_outer_not_polled(); tx_inner1.send(20).unwrap();
198 assert_eq!(
199 switch_stream.poll_next_unpin(&mut cx),
200 Poll::Ready(Some(20))
201 );
202 assert_outer_not_polled();
203
204 tx.send(
205 BroadcastStream::new(rx_inner2)
206 .map(|r: Result<_, _>| r.unwrap())
207 .boxed(),
208 )
209 .await
210 .unwrap();
211
212 assert_eq!(switch_stream.poll_next_unpin(&mut cx), Poll::Pending);
213 assert_outer_polled();
214
215 matches!(tx_inner1.send(30), Err(SendError(_)));
219 assert_eq!(switch_stream.poll_next_unpin(&mut cx), Poll::Pending);
220 assert_outer_not_polled(); drop(tx_inner2);
225 assert_eq!(switch_stream.poll_next_unpin(&mut cx), Poll::Pending);
226 assert_outer_not_polled(); tx.send(
229 BroadcastStream::new(rx_inner3)
230 .map(|r: Result<_, _>| r.unwrap())
231 .boxed(),
232 )
233 .await
234 .unwrap();
235
236 tx_inner3.send(100).unwrap();
237 assert_eq!(
238 switch_stream.poll_next_unpin(&mut cx),
239 Poll::Ready(Some(100))
240 );
241 assert_outer_polled();
242
243 tx_inner3.send(110).unwrap();
244 assert_eq!(
245 switch_stream.poll_next_unpin(&mut cx),
246 Poll::Ready(Some(110))
247 );
248 assert_outer_not_polled(); drop(tx);
251 assert_eq!(switch_stream.poll_next_unpin(&mut cx), Poll::Ready(None));
252 assert_outer_polled();
253 }
254
255 #[tokio::test]
256 async fn test_inner_not_polled_twice_after_termination() {
257 let inner_polled = Arc::new(Mutex::new(false));
258
259 let take_inner_polled = || -> bool {
260 let mut guard = inner_polled.lock();
261 std::mem::replace(&mut guard, false)
262 };
263
264 let assert_inner_polled = || assert!(take_inner_polled());
265 let assert_inner_not_polled = || assert!(!take_inner_polled());
266
267 let first_inner = MockStream {
268 inner: stream::once(future::ready(1)),
269 polled: Arc::clone(&inner_polled),
270 };
271
272 let outer_stream =
275 stream::once(future::ready(first_inner)).chain(future::pending().into_stream());
276
277 let mut stream = FlattenSwitch::new(outer_stream);
278
279 let waker = futures::task::noop_waker_ref();
280 let mut cx = std::task::Context::from_waker(waker);
281
282 assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(1));
283 assert_inner_polled();
284 assert_pending!(stream.poll_next_unpin(&mut cx));
285 assert_inner_polled();
286 assert_pending!(stream.poll_next_unpin(&mut cx));
287 assert_inner_not_polled();
288 }
289}