tokio_stream_ext/
combine_latest.rs

1use core::task::Context;
2use futures::StreamExt;
3use pin_project_lite::pin_project;
4use std::pin::Pin;
5use std::task::Poll;
6
7pin_project! {
8    pub struct CombineLatest<S, I>
9    {
10        #[pin]
11        streams: Vec<S>,
12        #[pin]
13        last_state: Vec<Option<I>>,
14        #[pin]
15        live_mode: bool
16    }
17}
18
19#[allow(dead_code)]
20pub fn combine_latest<S, I>(streams: Vec<S>) -> CombineLatest<S, I>
21where
22    S: tokio_stream::Stream<Item = I>,
23    I: Clone,
24{
25    CombineLatest {
26        last_state: vec![None; streams.len()],
27        streams,
28        live_mode: false,
29    }
30}
31
32impl<S, I> tokio_stream::Stream for CombineLatest<S, I>
33where
34    S: tokio_stream::Stream<Item = I> + std::marker::Unpin,
35    I: Clone,
36{
37    type Item = Vec<I>;
38
39    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
40        let mut me = self.project();
41
42        if me.streams.len() == 0 {
43            return Poll::Ready(None);
44        }
45
46        let mut at_least_one_updated = false;
47
48        for (idx, stream) in me.streams.iter_mut().enumerate() {
49            'stateCollectLoop: while let Poll::Ready(p) = stream.poll_next_unpin(cx) {
50                if let Some(state) = p {
51                    let mut l = me.last_state.clone();
52                    l[idx] = Some(state);
53                    me.last_state.set(l);
54
55                    at_least_one_updated = true;
56                    if *me.live_mode == false {
57                        let all_defined = me.last_state.iter().all(|s| s.is_some());
58                        me.live_mode.set(all_defined);
59                    }
60                } else {
61                    break 'stateCollectLoop;
62                }
63            }
64        }
65
66        if *me.live_mode == true && at_least_one_updated {
67            Poll::Ready(Some(
68                me.last_state
69                    .iter()
70                    .filter(|s| s.is_some())
71                    .map(|s| s.clone().unwrap())
72                    .collect(),
73            ))
74        } else {
75            Poll::Pending
76        }
77    }
78}