streamtools/
flatten_switch.rs

1use futures::task;
2use futures::Stream;
3use pin_project_lite::pin_project;
4use std::sync::Arc;
5use std::task::{Context, Poll};
6
7use crate::outer_waker::OuterWaker;
8
9pin_project! {
10    /// Stream for the [`flatten_switch`](crate::StreamTools::flatten_switch) method.
11    #[must_use = "streams do nothing unless polled"]
12    pub struct FlattenSwitch<St>
13    where
14        St: Stream,
15        St::Item: Stream
16    {
17        #[pin]
18        outer: St,
19
20        outer_waker: Arc<OuterWaker>,
21
22        #[pin]
23        inner: Option<<St as Stream>::Item>
24    }
25}
26
27impl<St> FlattenSwitch<St>
28where
29    St: Stream,
30    St::Item: Stream,
31{
32    pub(super) fn new(stream: St) -> Self {
33        Self {
34            outer: stream,
35            outer_waker: Arc::default(),
36            inner: None,
37        }
38    }
39}
40
41impl<St> Stream for FlattenSwitch<St>
42where
43    St: Stream,
44    St::Item: Stream,
45{
46    type Item = <St::Item as Stream>::Item;
47
48    fn poll_next(
49        self: std::pin::Pin<&mut Self>,
50        cx: &mut std::task::Context<'_>,
51    ) -> std::task::Poll<Option<Self::Item>> {
52        let mut this = self.project();
53
54        // We can avoid polling the outer stream if its waker has not been woken since
55        // we were last polled
56        let outer_ready = this.outer_waker.set_parent_waker(cx.waker().clone());
57        if outer_ready {
58            let waker = task::waker(Arc::clone(this.outer_waker));
59            let mut cx = Context::from_waker(&waker);
60            while let Poll::Ready(inner) = this.outer.as_mut().poll_next(&mut cx) {
61                match inner {
62                    Some(inner) => this.inner.set(Some(inner)),
63                    None => {
64                        // Terminate when the outer stream terminates
65                        return Poll::Ready(None);
66                    }
67                }
68            }
69        };
70
71        match this.inner.as_mut().as_pin_mut() {
72            Some(inner) => match inner.poll_next(cx) {
73                Poll::Ready(value) => match value {
74                    Some(value) => Poll::Ready(Some(value)),
75                    None => {
76                        // If the inner stream terminated, clear it so we don't poll it again.
77                        // This is important because some Streams don't support being polled again after
78                        // termination, e.g. stream::unfold.
79                        this.inner.set(None);
80
81                        // The inner stream can terminate but we don't terminate until the outer stream ends.
82                        Poll::Pending
83                    }
84                },
85
86                // Waiting on inner stream to emit next
87                Poll::Pending => Poll::Pending,
88            },
89
90            // We are still waiting for the first inner stream to be emitted by the outer
91            None => Poll::Pending,
92        }
93    }
94}
95
96impl<S> std::fmt::Debug for FlattenSwitch<S>
97where
98    S: Stream + std::fmt::Debug,
99    S::Item: Stream + std::fmt::Debug,
100{
101    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102        f.debug_struct("FlattenSwitch")
103            .field("stream", &self.outer)
104            .field("inner", &self.inner)
105            .finish()
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use std::future;
112
113    use futures::{stream, FutureExt, StreamExt};
114    use parking_lot::Mutex;
115    use tokio_test::{assert_pending, assert_ready_eq};
116
117    use super::*;
118
119    pin_project! {
120        struct MockStream<S: Stream> {
121            #[pin]
122            inner: S,
123            polled: Arc<Mutex<bool>>
124        }
125    }
126
127    impl<S: Stream> Stream for MockStream<S> {
128        type Item = S::Item;
129
130        fn poll_next(
131            self: std::pin::Pin<&mut Self>,
132            cx: &mut Context<'_>,
133        ) -> Poll<Option<Self::Item>> {
134            let this = self.project();
135            let result = this.inner.poll_next(cx);
136
137            *this.polled.lock() = true;
138
139            result
140        }
141    }
142
143    #[tokio::test]
144    async fn test_flatten_switch() {
145        use futures::{channel::mpsc, SinkExt, StreamExt};
146        use tokio::sync::broadcast::{self, error::SendError};
147        use tokio_stream::wrappers::BroadcastStream;
148
149        let waker = futures::task::noop_waker_ref();
150        let mut cx = std::task::Context::from_waker(waker);
151
152        let (tx_inner1, rx_inner1) = broadcast::channel(32);
153        let (tx_inner2, rx_inner2) = broadcast::channel(32);
154        let (tx_inner3, rx_inner3) = broadcast::channel(32);
155        let (mut tx, rx) = mpsc::unbounded();
156
157        let outer_polled = Arc::new(Mutex::new(false));
158
159        let take_outer_polled = || -> bool {
160            let mut guard = outer_polled.lock();
161            std::mem::replace(&mut guard, false)
162        };
163
164        let assert_outer_polled = || assert!(take_outer_polled());
165        let assert_outer_not_polled = || assert!(!take_outer_polled());
166
167        let outer_stream = MockStream {
168            inner: rx,
169            polled: Arc::clone(&outer_polled),
170        };
171
172        let mut switch_stream = FlattenSwitch::new(outer_stream);
173
174        assert_eq!(switch_stream.poll_next_unpin(&mut cx), Poll::Pending);
175        assert_outer_polled();
176
177        tx.send(
178            BroadcastStream::new(rx_inner1)
179                .map(|r: Result<_, _>| r.unwrap())
180                .boxed(),
181        )
182        .await
183        .unwrap();
184
185        assert_eq!(switch_stream.poll_next_unpin(&mut cx), Poll::Pending);
186        assert_outer_polled();
187
188        tx_inner1.send(10).unwrap();
189        assert_eq!(
190            switch_stream.poll_next_unpin(&mut cx),
191            Poll::Ready(Some(10))
192        );
193        assert_outer_not_polled(); // Outer stream didn't change so shouldn't be polled
194        assert_eq!(switch_stream.poll_next_unpin(&mut cx), Poll::Pending);
195        assert_outer_not_polled(); // Outer stream didn't change so shouldn't be polled
196
197        tx_inner1.send(20).unwrap();
198        assert_eq!(
199            switch_stream.poll_next_unpin(&mut cx),
200            Poll::Ready(Some(20))
201        );
202        assert_outer_not_polled();
203
204        tx.send(
205            BroadcastStream::new(rx_inner2)
206                .map(|r: Result<_, _>| r.unwrap())
207                .boxed(),
208        )
209        .await
210        .unwrap();
211
212        assert_eq!(switch_stream.poll_next_unpin(&mut cx), Poll::Pending);
213        assert_outer_polled();
214
215        // We expect trying to send to the first inner stream to fail because
216        // rx_inner1 should have been dropped by SwitchStream once we started
217        // listening to rx_inner2.
218        matches!(tx_inner1.send(30), Err(SendError(_)));
219        assert_eq!(switch_stream.poll_next_unpin(&mut cx), Poll::Pending);
220        assert_outer_not_polled(); // Outer stream didn't change so shouldn't be polled
221
222        // This should not cause the result stream to terminate.
223        // We only terminate on the outer stream terminating.
224        drop(tx_inner2);
225        assert_eq!(switch_stream.poll_next_unpin(&mut cx), Poll::Pending);
226        assert_outer_not_polled(); // Outer stream didn't change so shouldn't be polled
227
228        tx.send(
229            BroadcastStream::new(rx_inner3)
230                .map(|r: Result<_, _>| r.unwrap())
231                .boxed(),
232        )
233        .await
234        .unwrap();
235
236        tx_inner3.send(100).unwrap();
237        assert_eq!(
238            switch_stream.poll_next_unpin(&mut cx),
239            Poll::Ready(Some(100))
240        );
241        assert_outer_polled();
242
243        tx_inner3.send(110).unwrap();
244        assert_eq!(
245            switch_stream.poll_next_unpin(&mut cx),
246            Poll::Ready(Some(110))
247        );
248        assert_outer_not_polled(); // Outer stream didn't change so shouldn't be polled
249
250        drop(tx);
251        assert_eq!(switch_stream.poll_next_unpin(&mut cx), Poll::Ready(None));
252        assert_outer_polled();
253    }
254
255    #[tokio::test]
256    async fn test_inner_not_polled_twice_after_termination() {
257        let inner_polled = Arc::new(Mutex::new(false));
258
259        let take_inner_polled = || -> bool {
260            let mut guard = inner_polled.lock();
261            std::mem::replace(&mut guard, false)
262        };
263
264        let assert_inner_polled = || assert!(take_inner_polled());
265        let assert_inner_not_polled = || assert!(!take_inner_polled());
266
267        let first_inner = MockStream {
268            inner: stream::once(future::ready(1)),
269            polled: Arc::clone(&inner_polled),
270        };
271
272        // Outer stream consists of first_inner which emits one value and then completes, but never yields any further streams and is permanently
273        // pending for the 2nd stream.
274        let outer_stream =
275            stream::once(future::ready(first_inner)).chain(future::pending().into_stream());
276
277        let mut stream = FlattenSwitch::new(outer_stream);
278
279        let waker = futures::task::noop_waker_ref();
280        let mut cx = std::task::Context::from_waker(waker);
281
282        assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(1));
283        assert_inner_polled();
284        assert_pending!(stream.poll_next_unpin(&mut cx));
285        assert_inner_polled();
286        assert_pending!(stream.poll_next_unpin(&mut cx));
287        assert_inner_not_polled();
288    }
289}