sea_streamer_fuse/
lib.rs

1use pin_project::pin_project;
2use sea_streamer_types::{export::futures::Stream, Message, StreamKey};
3use std::{
4    collections::{BTreeMap, VecDeque},
5    pin::Pin,
6    task::Poll,
7};
8
9type Keys<M> = BTreeMap<StreamKey, VecDeque<M>>;
10
11/// Join multiple streams, but reorder messages by timestamp.
12/// Since a stream can potentially infinite and the keys in the stream cannot be known priori,
13/// the internal buffer can potentially grow infinite.
14///
15/// `align()` must be called manually to specify which streams to be aligned. Otherwise messages will be out
16/// of order until the first message of each key arrives. Imagine a really stuck stream sending the first message
17/// one day later, it will invalidate everything before it. But itself is the problem, not the others.
18///
19/// Messages within each stream key are assumed to be causal.
20///
21/// A typical use would be to join two streams from different sources, each with a different update frequency.
22/// Messages from the fast stream will be buffered, until a message from the slow stream arrives.
23///
24/// ```ignore
25/// fast | (1) (2) (3) (4) (5)
26/// slow |         (2)         (6)
27/// ```
28///
29/// In the example above, messages 1, 2 from fast will be buffered, until 2 from the slow stream arrives.
30/// Likewise, messages 3, 4, 5 will be buffered until 6 arrives.
31///
32/// If two messages have the same timestamp, the order will be determined by the alphabetic order of the stream keys.
33#[pin_project]
34pub struct StreamJoin<S, M, E>
35where
36    S: Stream<Item = Result<M, E>>,
37    M: Message,
38    E: std::error::Error,
39{
40    #[pin]
41    muxed: S,
42    keys: Keys<M>,
43    key_keys: Vec<StreamKey>,
44    ended: bool,
45    err: Option<E>,
46}
47
48impl<S, M, E> StreamJoin<S, M, E>
49where
50    S: Stream<Item = Result<M, E>>,
51    M: Message,
52    E: std::error::Error,
53{
54    /// Takes an already multiplexed stream. This can typically be achieved by `futures_concurrency::stream::Merge`.
55    pub fn muxed(muxed: S) -> Self {
56        Self {
57            muxed,
58            keys: Default::default(),
59            key_keys: Default::default(),
60            ended: false,
61            err: None,
62        }
63    }
64
65    /// Add a stream key that needs to be joined. You can call this multiple times.
66    pub fn align(&mut self, stream_key: StreamKey) {
67        self.keys.insert(stream_key.clone(), Default::default());
68        self.key_keys.push(stream_key);
69    }
70
71    fn next(keys: &mut Keys<M>) -> Option<M> {
72        let mut min_key = None;
73        let mut min_ts = None;
74        for (k, ms) in keys.iter() {
75            if let Some(m) = ms.front() {
76                let m_ts = m.timestamp();
77                if min_ts.is_none() || m_ts < min_ts.unwrap() {
78                    min_ts = Some(m_ts);
79                    min_key = Some(k.clone());
80                }
81            }
82        }
83        if let Some(min_key) = min_key {
84            Some(
85                keys.get_mut(&min_key)
86                    .unwrap()
87                    .pop_front()
88                    .expect("Checked above"),
89            )
90        } else {
91            // all streams ended
92            None
93        }
94    }
95
96    fn check(keys: &Keys<M>, key_keys: &[StreamKey]) -> bool {
97        // if none of the key streams are empty
98        for kk in key_keys {
99            if keys.get(kk).expect("Already inserted").is_empty() {
100                return false;
101            }
102        }
103        // if anyone got anything
104        keys.values().any(|ms| !ms.is_empty())
105    }
106}
107
108impl<S, M, E> Stream for StreamJoin<S, M, E>
109where
110    S: Stream<Item = Result<M, E>>,
111    M: Message,
112    E: std::error::Error,
113{
114    type Item = Result<M, E>;
115
116    fn poll_next(
117        self: Pin<&mut Self>,
118        cx: &mut std::task::Context<'_>,
119    ) -> Poll<Option<Self::Item>> {
120        let mut this = self.project();
121        while !*this.ended {
122            match this.muxed.as_mut().poll_next(cx) {
123                Poll::Ready(Some(Ok(mes))) => {
124                    let key = mes.stream_key();
125                    this.keys.entry(key).or_default().push_back(mes);
126                    if Self::check(&this.keys, &this.key_keys) {
127                        // if we can yield
128                        break;
129                    }
130                    // keep polling
131                }
132                Poll::Ready(Some(Err(err))) => {
133                    *this.ended = true;
134                    *this.err = Some(err);
135                    break;
136                }
137                Poll::Ready(None) => {
138                    *this.ended = true;
139                    break;
140                }
141                Poll::Pending => {
142                    // take a break
143                    break;
144                }
145            }
146        }
147        if *this.ended || Self::check(&this.keys, &this.key_keys) {
148            Poll::Ready(match Self::next(this.keys) {
149                Some(item) => Some(Ok(item)),
150                None => this.err.take().map(Err),
151            })
152        } else {
153            Poll::Pending
154        }
155    }
156}
157
158#[cfg(test)]
159mod test {
160    use super::*;
161    use sea_streamer_socket::{BackendErr, SeaMessage, SeaMessageStream};
162    use sea_streamer_types::{
163        export::futures::{self, TryStreamExt},
164        MessageHeader, OwnedMessage, StreamErr, Timestamp,
165    };
166
167    // just to see if this compiles
168    #[allow(dead_code)]
169    fn wrap<'a>(
170        s: SeaMessageStream<'a>,
171    ) -> StreamJoin<SeaMessageStream<'a>, SeaMessage<'a>, StreamErr<BackendErr>> {
172        StreamJoin::muxed(s)
173    }
174
175    fn make_seq(key: StreamKey, items: &[u64]) -> Vec<Result<OwnedMessage, BackendErr>> {
176        items
177            .iter()
178            .copied()
179            .map(|i| {
180                Ok(OwnedMessage::new(
181                    MessageHeader::new(
182                        key.clone(),
183                        Default::default(),
184                        i,
185                        Timestamp::from_unix_timestamp(i as i64).unwrap(),
186                    ),
187                    Vec::new(),
188                ))
189            })
190            .collect()
191    }
192
193    fn compare(messages: Vec<OwnedMessage>, expected: &[(&str, u64)]) {
194        assert_eq!(messages.len(), expected.len());
195        for (i, m) in messages.iter().enumerate() {
196            assert_eq!(m.stream_key().name(), expected[i].0);
197            assert_eq!(m.sequence(), expected[i].1);
198        }
199    }
200
201    #[tokio::test]
202    async fn test_mux_streams_2() {
203        let a = StreamKey::new("a").unwrap();
204        let b = StreamKey::new("b").unwrap();
205        let stream = futures::stream::iter(
206            make_seq(a.clone(), &[1, 3, 5, 7, 9])
207                .into_iter()
208                .chain(make_seq(b.clone(), &[2, 4, 6, 8, 10]).into_iter()),
209        );
210        let mut join = StreamJoin::muxed(stream);
211        join.align(a);
212        join.align(b);
213        let messages: Vec<_> = join.try_collect().await.unwrap();
214        compare(
215            messages,
216            &[
217                ("a", 1),
218                ("b", 2),
219                ("a", 3),
220                ("b", 4),
221                ("a", 5),
222                ("b", 6),
223                ("a", 7),
224                ("b", 8),
225                ("a", 9),
226                ("b", 10),
227            ],
228        );
229    }
230
231    #[tokio::test]
232    async fn test_mux_streams_2_2() {
233        let a = StreamKey::new("a").unwrap();
234        let b = StreamKey::new("b").unwrap();
235        let stream = futures::stream::iter(
236            make_seq(a.clone(), &[1, 2, 5, 8, 9])
237                .into_iter()
238                .chain(make_seq(b.clone(), &[3, 4, 6, 7, 10]).into_iter()),
239        );
240        let mut join = StreamJoin::muxed(stream);
241        join.align(a);
242        join.align(b);
243        let messages: Vec<_> = join.try_collect().await.unwrap();
244        compare(
245            messages,
246            &[
247                ("a", 1),
248                ("a", 2),
249                ("b", 3),
250                ("b", 4),
251                ("a", 5),
252                ("b", 6),
253                ("b", 7),
254                ("a", 8),
255                ("a", 9),
256                ("b", 10),
257            ],
258        );
259    }
260
261    #[tokio::test]
262    async fn test_mux_streams_3() {
263        let a = StreamKey::new("a").unwrap();
264        let b = StreamKey::new("b").unwrap();
265        let c = StreamKey::new("c").unwrap();
266        let stream = futures::stream::iter(
267            make_seq(a.clone(), &[1, 3, 5, 7, 9])
268                .into_iter()
269                .chain(make_seq(c.clone(), &[5]).into_iter())
270                .chain(make_seq(b.clone(), &[2, 4, 6, 8, 10]).into_iter()),
271        );
272        let mut join = StreamJoin::muxed(stream);
273        join.align(a);
274        join.align(b);
275        join.align(c);
276        let messages: Vec<_> = join.try_collect().await.unwrap();
277        compare(
278            messages,
279            &[
280                ("a", 1),
281                ("b", 2),
282                ("a", 3),
283                ("b", 4),
284                ("a", 5),
285                ("c", 5),
286                ("b", 6),
287                ("a", 7),
288                ("b", 8),
289                ("a", 9),
290                ("b", 10),
291            ],
292        );
293    }
294
295    #[tokio::test]
296    async fn test_mux_streams_4() {
297        let a = StreamKey::new("a").unwrap();
298        let b = StreamKey::new("b").unwrap();
299        let c = StreamKey::new("c").unwrap();
300        let d = StreamKey::new("d").unwrap();
301        let stream = futures::stream::iter(
302            make_seq(a.clone(), &[1, 3])
303                .into_iter()
304                .chain(make_seq(d.clone(), &[5]).into_iter())
305                .chain(make_seq(b.clone(), &[2, 4]).into_iter())
306                .chain(make_seq(c.clone(), &[3]).into_iter()),
307        );
308        let mut join = StreamJoin::muxed(stream);
309        join.align(a);
310        join.align(b);
311        join.align(c);
312        join.align(d);
313        let messages: Vec<_> = join.try_collect().await.unwrap();
314        compare(
315            messages,
316            &[("a", 1), ("b", 2), ("a", 3), ("c", 3), ("b", 4), ("d", 5)],
317        );
318    }
319}