1#![warn(rust_2018_idioms)]
2use std::cmp::Ordering;
3use std::task::Poll;
4
5use binary_heap_plus::BinaryHeap;
6use compare::Compare;
7use futures::future::{join_all, JoinAll};
8use futures::{ready, stream::StreamFuture, FutureExt, Stream, StreamExt};
9use pin_project_lite::pin_project;
10
11#[derive(Debug)]
13pub struct HeadTail<S>
14where
15 S: Stream,
16{
17 head: S::Item,
18 tail: S,
19}
20
21pin_project! {
22 #[must_use = "stream adaptors are lazy and do nothing unless consumed"]
30 pub struct KWayMergeBy<S, C>
31 where
32 S: Stream,
33 S: Unpin,
34 C: Compare<HeadTail<S>>
35 {
36 initial: Option<JoinAll<StreamFuture<S>>>,
37 next: Option<S>,
38 heap: BinaryHeap<HeadTail<S>, C>,
39 }
40}
41
42#[must_use = "stream adaptors are lazy and do nothing unless consumed"]
49pub type KWayMerge<I> = KWayMergeBy<I, OrdComparator>;
50
51pub struct OrdComparator;
52
53impl<S> Compare<HeadTail<S>> for OrdComparator
54where
55 S: Stream,
56 S::Item: Ord,
57{
58 fn compare(&self, l: &HeadTail<S>, r: &HeadTail<S>) -> std::cmp::Ordering {
59 l.head.cmp(&r.head)
60 }
61}
62
63pub struct FnComparator<F> {
64 f: F,
65}
66
67impl<S, F> Compare<HeadTail<S>> for FnComparator<F>
68where
69 S: Stream,
70 F: Fn(&S::Item, &S::Item) -> Ordering,
71{
72 fn compare(&self, l: &HeadTail<S>, r: &HeadTail<S>) -> std::cmp::Ordering {
73 (self.f)(&l.head, &r.head)
74 }
75}
76
77pub struct KeyComparator<F> {
78 f: F,
79}
80
81impl<S, F, O> Compare<HeadTail<S>> for KeyComparator<F>
82where
83 S: Stream,
84 F: Fn(&S::Item) -> O,
85 O: Ord,
86{
87 fn compare(&self, l: &HeadTail<S>, r: &HeadTail<S>) -> std::cmp::Ordering {
88 (self.f)(&l.head).cmp(&(self.f)(&r.head))
89 }
90}
91
92pub fn kmerge<S>(xs: impl IntoIterator<Item = S>) -> KWayMerge<S>
109where
110 S: Stream + Unpin,
111 S::Item: Ord,
112{
113 assert_stream::<S::Item, _>(kmerge_generic(xs, OrdComparator))
114}
115
116pub fn kmerge_by<S, F>(xs: impl IntoIterator<Item = S>, f: F) -> KWayMergeBy<S, FnComparator<F>>
132where
133 S: Stream + Unpin,
134 F: Fn(&S::Item, &S::Item) -> Ordering,
135{
136 kmerge_generic(xs, FnComparator { f })
137}
138
139pub fn kmerge_by_key<S, F, O>(
155 xs: impl IntoIterator<Item = S>,
156 f: F,
157) -> KWayMergeBy<S, KeyComparator<F>>
158where
159 S: Stream + Unpin,
160 F: Fn(&S::Item) -> O,
161 O: Ord,
162{
163 kmerge_generic(xs, KeyComparator { f })
164}
165
166fn kmerge_generic<S, C>(xs: impl IntoIterator<Item = S>, cmp: C) -> KWayMergeBy<S, C>
216where
217 S: Stream + Unpin,
218 C: Compare<HeadTail<S>>,
219{
220 let iter = xs.into_iter();
221 let (min_size, _) = iter.size_hint();
222 assert_stream::<S::Item, _>(KWayMergeBy {
223 initial: Some(join_all(iter.map(|x| x.into_future()))),
224 next: None,
225 heap: BinaryHeap::from_vec_cmp(Vec::with_capacity(min_size), cmp),
226 })
227}
228
229impl<S, C> Stream for KWayMergeBy<S, C>
230where
231 S: Stream + Unpin,
232 C: Compare<HeadTail<S>>,
233{
234 type Item = S::Item;
235
236 fn poll_next(
237 self: std::pin::Pin<&mut Self>,
238 cx: &mut std::task::Context<'_>,
239 ) -> std::task::Poll<Option<Self::Item>> {
240 let this = self.project();
241 if let Some(init_fut) = this.initial.as_mut() {
242 let xs = ready!(init_fut.poll_unpin(cx));
243 *this.initial = None;
244 this.heap.extend(
245 xs.into_iter().filter_map(|(head_option, tail)| {
246 head_option.map(|head| HeadTail { head, tail })
247 }),
248 );
249 }
250
251 if let Some(ref mut next_stream) = this.next {
252 if let Some(item) = ready!(next_stream.next().poll_unpin(cx)) {
253 this.heap.push(HeadTail {
254 head: item,
255 tail: this.next.take().unwrap(),
256 });
257 }
258 }
259
260 match this.heap.pop() {
261 None => Poll::Ready(None),
262 Some(HeadTail { head, tail }) => {
263 this.next.replace(tail);
264
265 Poll::Ready(Some(head))
266 }
267 }
268 }
269}
270
271fn assert_stream<T, S>(stream: S) -> S
274where
275 S: Stream<Item = T>,
276{
277 stream
278}
279
280#[cfg(test)]
281mod test {
282 use std::pin::Pin;
283 use std::time::Duration;
284
285 use futures::stream;
286 use futures::FutureExt;
287 use futures::Stream;
288 use futures::StreamExt;
289 use tokio::sync::oneshot;
290 use tokio::time;
291 use tokio_stream::wrappers::IntervalStream;
292
293 use super::*;
294
295 #[tokio::test]
296 async fn sync() {
297 let streams = vec![stream::iter(vec![5, 3, 1]), stream::iter(vec![4, 3, 2])];
298
299 assert_eq!(
300 kmerge(streams).collect::<Vec<usize>>().await,
301 vec![5, 4, 3, 3, 2, 1],
302 );
303 }
304
305 #[tokio::test]
306 async fn by() {
307 let streams = vec![stream::iter(vec![5, 3, 1]), stream::iter(vec![4, 3, 2])];
308 let stream = kmerge_by(streams, |x: &usize, y: &usize| x.cmp(&y));
309
310 assert_eq!(stream.collect::<Vec<usize>>().await, vec![5, 4, 3, 3, 2, 1],);
311 }
312
313 #[tokio::test]
314 async fn by_key() {
315 let streams = vec![
316 stream::iter(vec![("a", 5), ("a", 3)]),
317 stream::iter(vec![("b", 4), ("b", 4)]),
318 ];
319 let stream = kmerge_by_key(streams, |x: &(&'static str, usize)| x.1);
320
321 assert_eq!(
322 stream.collect::<Vec<_>>().await,
323 vec![("a", 5), ("b", 4), ("b", 4), ("a", 3)]
324 );
325 }
326
327 #[tokio::test]
328 async fn kmerge_async() {
329 let streams = vec![
330 IntervalStream::new(time::interval(Duration::from_nanos(1))),
331 IntervalStream::new(time::interval(Duration::from_nanos(2))),
332 ];
333
334 let result = kmerge(streams).take(10).collect::<Vec<_>>().await;
335
336 assert_eq!(result.len(), 10);
337 }
338
339 #[tokio::test]
340 async fn concurrent_initialization() {
341 let (tx1, rx1) = oneshot::channel();
342 let (tx2, rx2) = oneshot::channel();
343
344 let s1 = async move {
345 tx1.send(1).unwrap();
346 rx2.await.unwrap()
347 }
348 .into_stream();
349 let s2 = async move {
350 tx2.send(2).unwrap();
351 rx1.await.unwrap()
352 }
353 .into_stream();
354
355 let streams: Vec<Pin<Box<dyn Stream<Item = i32>>>> = vec![Box::pin(s1), Box::pin(s2)];
356
357 let result = kmerge(streams).collect::<Vec<_>>().await;
358 assert_eq!(result, vec![2, 1]);
359 }
360}
361
362#[cfg(doctest)]
363doc_comment::doctest!("../README.md");