1use std::iter::Iterator;
2
3use crate::{merge_entry::MergeEntry, merge_tree::MergeTree};
4
5pub struct KWayMergeIterator<'a, T, I: Iterator<Item = T>, M: MergeTree<T>> {
6 heap: M,
7 list: &'a mut [I],
8}
9
10impl<'a, T: std::cmp::Ord, I: Iterator<Item = T>, M: MergeTree<T>> KWayMergeIterator<'a, T, I, M> {
11 pub fn new(list: &'a mut [I], mut heap: M) -> Self {
12 for (i, iterator) in list.iter_mut().enumerate() {
13 if let Some(first) = iterator.next() {
14 heap.push(MergeEntry {
15 item: first,
16 index: i,
17 });
18 }
19 }
20 Self { heap, list }
21 }
22}
23
24impl<T: std::cmp::Ord, I: Iterator<Item = T>, M: MergeTree<T>> Iterator
25 for KWayMergeIterator<'_, T, I, M>
26{
27 type Item = T;
28
29 fn next(&mut self) -> Option<Self::Item> {
30 let MergeEntry {
31 item: value,
32 index: list,
33 } = self.heap.pop()?;
34 if let Some(next) = self.list[list].next() {
35 self.heap.push(MergeEntry {
36 item: next,
37 index: list,
38 });
39 }
40 Some(value)
41 }
42}
43
44#[cfg(test)]
45mod test {
46 use std::collections::BinaryHeap;
47
48 use crate::k_way_merge_iterator::MergeEntry;
49
50 #[test]
51 fn four_way_merge() {
52 let mut list = vec![
53 vec![1, 3, 5, 7, 9].into_iter(),
54 vec![2, 4, 6, 8, 10].into_iter(),
55 vec![11, 13, 15, 17, 19].into_iter(),
56 vec![12, 14, 16, 18, 20].into_iter(),
57 ];
58 let heap = BinaryHeap::new();
59 let k_way_merge = super::KWayMergeIterator::new(&mut list, heap);
60 let result: Vec<_> = k_way_merge.collect();
61 assert_eq!(
62 result,
63 vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
64 );
65 }
66
67 #[test]
68 fn three_way_merge_of_differently_sized_sequences() {
69 let mut list = vec![
70 vec![1, 3, 5, 6, 7, 12].into_iter(),
71 vec![2, 4, 8, 11].into_iter(),
72 vec![9, 10].into_iter(),
73 ];
74 let heap = BinaryHeap::new();
75 let k_way_merge = super::KWayMergeIterator::new(&mut list, heap);
76 let result: Vec<_> = k_way_merge.collect();
77 assert_eq!(result, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
78 }
79
80 #[test]
81 fn merge_of_empty_sequences() {
82 let mut list: Vec<std::vec::IntoIter<u64>> = vec![vec![].into_iter(), vec![].into_iter()];
83 let heap = BinaryHeap::new();
84 let k_way_merge = super::KWayMergeIterator::new(&mut list, heap);
85 let result: Vec<_> = k_way_merge.collect();
86 assert_eq!(result, Vec::<u64>::new());
87 }
88
89 #[test]
90 fn test_merge_entry_ordering() {
91 let entry1 = MergeEntry { item: 2, index: 0 };
92 let entry2 = MergeEntry { item: 1, index: 1 };
93 let entry3 = MergeEntry { item: 3, index: 1 };
94
95 assert!(entry1 < entry2); assert!(entry1 > entry3); assert!(entry2 > entry3); let mut heap = BinaryHeap::new();
101 heap.push(entry1);
102 heap.push(entry2);
103 heap.push(entry3);
104
105 assert_eq!(heap.pop().unwrap().item, 1);
107 assert_eq!(heap.pop().unwrap().item, 2);
108 assert_eq!(heap.pop().unwrap().item, 3);
109 }
110}