sorted_iter/
sorted_iterator.rs

1//! implementation of the sorted_iterator set operations
2use super::*;
3use core::cmp::Ordering::*;
4use core::cmp::{max, min, Ordering, Reverse};
5use core::iter::Peekable;
6use core::{iter, ops, option, result};
7use std::collections;
8use std::collections::BinaryHeap;
9use std::iter::FusedIterator;
10
11/// marker trait for iterators that are sorted by their Item
12pub trait SortedByItem {}
13
14pub struct Union<I: Iterator, J: Iterator> {
15    pub(crate) a: Peekable<I>,
16    pub(crate) b: Peekable<J>,
17}
18
19impl<I: Iterator + Clone, J: Iterator + Clone> Clone for Union<I, J>
20where
21    I::Item: Clone,
22    J::Item: Clone,
23{
24    fn clone(&self) -> Self {
25        Self {
26            a: self.a.clone(),
27            b: self.b.clone(),
28        }
29    }
30}
31
32impl<K: Ord, I: Iterator<Item = K>, J: Iterator<Item = K>> Iterator for Union<I, J> {
33    type Item = K;
34
35    fn next(&mut self) -> Option<Self::Item> {
36        if let (Some(ak), Some(bk)) = (self.a.peek(), self.b.peek()) {
37            match ak.cmp(&bk) {
38                Less => self.a.next(),
39                Greater => self.b.next(),
40                Equal => {
41                    self.b.next();
42                    self.a.next()
43                }
44            }
45        } else {
46            self.a.next().or_else(|| self.b.next())
47        }
48    }
49
50    fn size_hint(&self) -> (usize, Option<usize>) {
51        let (amin, amax) = self.a.size_hint();
52        let (bmin, bmax) = self.b.size_hint();
53        // full overlap
54        let rmin = max(amin, bmin);
55        // no overlap
56        let rmax = amax.and_then(|amax| bmax.and_then(|bmax| amax.checked_add(bmax)));
57        (rmin, rmax)
58    }
59}
60
61// An iterator with the first item pulled out.
62pub(crate) struct Peeked<I: Iterator> {
63    h: Reverse<I::Item>,
64    t: I,
65}
66
67impl<I: Iterator> Peeked<I> {
68    fn new(mut i: I) -> Option<Peeked<I>> {
69        i.next().map(|x| Peeked {
70            h: Reverse(x),
71            t: i,
72        })
73    }
74
75    fn size_hint(&self) -> (usize, Option<usize>) {
76        let (lo, hi) = self.t.size_hint();
77        (lo + 1, hi.map(|hi| hi + 1))
78    }
79}
80
81// Delegate comparisons to the head element.
82impl<I: Iterator> PartialEq for Peeked<I>
83where
84    I::Item: PartialEq,
85{
86    fn eq(&self, that: &Self) -> bool {
87        self.h.eq(&that.h)
88    }
89}
90
91impl<I: Iterator> Eq for Peeked<I> where I::Item: Eq {}
92
93impl<I: Iterator> PartialOrd for Peeked<I>
94where
95    I::Item: PartialOrd,
96{
97    fn partial_cmp(&self, that: &Self) -> Option<Ordering> {
98        self.h.partial_cmp(&that.h)
99    }
100}
101
102impl<I: Iterator> Ord for Peeked<I>
103where
104    I::Item: Ord,
105{
106    fn cmp(&self, that: &Self) -> core::cmp::Ordering {
107        self.h.cmp(&that.h)
108    }
109}
110
111impl<I: Iterator + Clone> Clone for Peeked<I>
112where
113    I::Item: Clone,
114{
115    fn clone(&self) -> Self {
116        Self {
117            h: self.h.clone(),
118            t: self.t.clone(),
119        }
120    }
121}
122
123pub struct MultiwayUnion<I: Iterator> {
124    pub(crate) bh: BinaryHeap<Peeked<I>>,
125}
126
127impl<I: Iterator> MultiwayUnion<I>
128where
129    I::Item: Ord,
130{
131    pub(crate) fn from_iter<T: IntoIterator<Item = I>>(x: T) -> MultiwayUnion<I> {
132        MultiwayUnion {
133            bh: x.into_iter().filter_map(Peeked::new).collect(),
134        }
135    }
136}
137
138impl<I: Iterator + Clone> Clone for MultiwayUnion<I>
139where
140    I::Item: Clone,
141{
142    fn clone(&self) -> Self {
143        Self {
144            bh: self.bh.clone(),
145        }
146    }
147}
148
149impl<I: Iterator> Iterator for MultiwayUnion<I>
150where
151    I::Item: Ord,
152{
153    type Item = I::Item;
154
155    fn next(&mut self) -> Option<Self::Item> {
156        // Extract the current minimum element.
157        self.bh.pop().map(
158            |Peeked {
159                 h: Reverse(item),
160                 t: top_tail,
161             }| {
162                // Advance the iterator and re-insert it into the heap.
163                Peeked::new(top_tail).map(|i| self.bh.push(i));
164                // Remove equivalent elements and advance corresponding iterators.
165                while self.bh.peek().filter(|x| x.h.0 == item).is_some() {
166                    let tail = self.bh.pop().unwrap().t;
167                    Peeked::new(tail).map(|i| self.bh.push(i));
168                }
169                item
170            },
171        )
172    }
173
174    fn size_hint(&self) -> (usize, Option<usize>) {
175        self.bh.iter().fold((0, Some(0)), |(lo, hi), it| {
176            let (ilo, ihi) = it.size_hint();
177            (
178                max(lo, ilo),
179                hi.and_then(|hi| ihi.and_then(|ihi| hi.checked_add(ihi))),
180            )
181        })
182    }
183}
184
185pub struct Intersection<I: Iterator, J: Iterator> {
186    pub(crate) a: I,
187    pub(crate) b: Peekable<J>,
188}
189
190impl<I: Iterator + Clone, J: Iterator + Clone> Clone for Intersection<I, J>
191where
192    I::Item: Clone,
193    J::Item: Clone,
194{
195    fn clone(&self) -> Self {
196        Self {
197            a: self.a.clone(),
198            b: self.b.clone(),
199        }
200    }
201}
202
203impl<K: Ord, I: Iterator<Item = K>, J: Iterator<Item = K>> Iterator for Intersection<I, J> {
204    type Item = K;
205
206    fn next(&mut self) -> Option<Self::Item> {
207        while let Some(a) = self.a.next() {
208            while let Some(b) = self.b.peek() {
209                let order = a.cmp(b);
210                if order == Less {
211                    break;
212                }
213                self.b.next();
214                if order == Equal {
215                    return Some(a);
216                }
217            }
218        }
219        None
220    }
221
222    fn size_hint(&self) -> (usize, Option<usize>) {
223        let (_, amax) = self.a.size_hint();
224        let (_, bmax) = self.b.size_hint();
225        // no overlap
226        let rmin = 0;
227        // full overlap
228        let rmax = amax.and_then(|amax| bmax.map(|bmax| min(amax, bmax)));
229        (rmin, rmax)
230    }
231}
232
233pub struct Difference<I: Iterator, J: Iterator> {
234    pub(crate) a: I,
235    pub(crate) b: Peekable<J>,
236}
237
238impl<I: Iterator + Clone, J: Iterator + Clone> Clone for Difference<I, J>
239where
240    I::Item: Clone,
241    J::Item: Clone,
242{
243    fn clone(&self) -> Self {
244        Self {
245            a: self.a.clone(),
246            b: self.b.clone(),
247        }
248    }
249}
250
251impl<K: Ord, I: Iterator<Item = K>, J: Iterator<Item = K>> Iterator for Difference<I, J> {
252    type Item = K;
253
254    fn next(&mut self) -> Option<Self::Item> {
255        'next_a: while let Some(a) = self.a.next() {
256            while let Some(b) = self.b.peek() {
257                let order = a.cmp(b);
258                if order == Less {
259                    break;
260                }
261                self.b.next();
262                if order == Equal {
263                    continue 'next_a;
264                }
265            }
266            return Some(a);
267        }
268        None
269    }
270
271    fn size_hint(&self) -> (usize, Option<usize>) {
272        let (amin, amax) = self.a.size_hint();
273        let (_, bmax) = self.b.size_hint();
274        // no overlap
275        let rmax = amax;
276        // if the other has at most bmax elements, and we have at least amin elements
277        let rmin = bmax.map_or(0, |bmax| amin.saturating_sub(bmax));
278        (rmin, rmax)
279    }
280}
281
282pub struct SymmetricDifference<I: Iterator, J: Iterator> {
283    pub(crate) a: Peekable<I>,
284    pub(crate) b: Peekable<J>,
285}
286
287impl<I: Iterator + Clone, J: Iterator + Clone> Clone for SymmetricDifference<I, J>
288where
289    I::Item: Clone,
290    J::Item: Clone,
291{
292    fn clone(&self) -> Self {
293        Self {
294            a: self.a.clone(),
295            b: self.b.clone(),
296        }
297    }
298}
299
300impl<K: Ord, I: Iterator<Item = K>, J: Iterator<Item = K>> Iterator for SymmetricDifference<I, J> {
301    type Item = K;
302
303    fn next(&mut self) -> Option<Self::Item> {
304        while let (Some(ak), Some(bk)) = (self.a.peek(), self.b.peek()) {
305            match ak.cmp(&bk) {
306                Less => return self.a.next(),
307                Greater => return self.b.next(),
308                Equal => {
309                    self.b.next();
310                    self.a.next();
311                }
312            }
313        }
314        self.a.next().or_else(|| self.b.next())
315    }
316
317    fn size_hint(&self) -> (usize, Option<usize>) {
318        let (amin, amax) = self.a.size_hint();
319        let (bmin, bmax) = self.b.size_hint();
320        // full overlap
321        let rmin = match (amax, bmax) {
322            (Some(amax), _) if bmin >= amax => bmin - amax,
323            (_, Some(bmax)) if amin >= bmax => amin - bmax,
324            _ => 0,
325        };
326        // no overlap
327        let rmax = amax.and_then(|amax| bmax.and_then(|bmax| amax.checked_add(bmax)));
328        (rmin, rmax)
329    }
330}
331
332#[derive(Clone, Debug)]
333pub struct Pairs<I: Iterator> {
334    pub(crate) i: I,
335}
336
337impl<I: Iterator> Iterator for Pairs<I> {
338    type Item = (I::Item, ());
339
340    fn next(&mut self) -> Option<Self::Item> {
341        self.i.next().map(|k| (k, ()))
342    }
343
344    fn size_hint(&self) -> (usize, Option<usize>) {
345        self.i.size_hint()
346    }
347}
348
349#[derive(Clone, Debug)]
350pub struct AssumeSortedByItem<I: Iterator> {
351    pub(crate) i: I,
352}
353
354impl<I: Iterator> Iterator for AssumeSortedByItem<I> {
355    type Item = I::Item;
356
357    fn next(&mut self) -> Option<Self::Item> {
358        self.i.next()
359    }
360
361    fn size_hint(&self) -> (usize, Option<usize>) {
362        self.i.size_hint()
363    }
364}
365
366impl<I: Iterator> ExactSizeIterator for AssumeSortedByItem<I> where I: ExactSizeIterator {}
367
368impl<I: Iterator> FusedIterator for AssumeSortedByItem<I> where I: FusedIterator {}
369
370impl<I: Iterator> DoubleEndedIterator for AssumeSortedByItem<I> where I: DoubleEndedIterator {
371    fn next_back(&mut self) -> Option<Self::Item> {
372        self.i.next_back()
373    }
374}
375
376// mark common std traits
377impl<I> SortedByItem for iter::Empty<I> {}
378impl<I> SortedByItem for iter::Once<I> {}
379impl<'a, T> SortedByItem for option::Iter<'a, T> {}
380impl<'a, T> SortedByItem for result::Iter<'a, T> {}
381impl<T> SortedByItem for option::IntoIter<T> {}
382impl<T> SortedByItem for result::IntoIter<T> {}
383
384impl<I: SortedByItem> SortedByItem for iter::Take<I> {}
385impl<I: SortedByItem> SortedByItem for iter::Skip<I> {}
386impl<I: SortedByItem> SortedByItem for iter::StepBy<I> {}
387impl<I: SortedByItem> SortedByItem for iter::Cloned<I> {}
388impl<I: SortedByItem> SortedByItem for iter::Copied<I> {}
389impl<I: SortedByItem> SortedByItem for iter::Fuse<I> {}
390impl<I: SortedByItem, F> SortedByItem for iter::Inspect<I, F> {}
391impl<I: SortedByItem, P> SortedByItem for iter::TakeWhile<I, P> {}
392impl<I: SortedByItem, P> SortedByItem for iter::SkipWhile<I, P> {}
393impl<I: SortedByItem, P> SortedByItem for iter::Filter<I, P> {}
394impl<I: SortedByItem + Iterator> SortedByItem for iter::Peekable<I> {}
395
396impl<T> SortedByItem for collections::btree_set::IntoIter<T> {}
397impl<'a, T> SortedByItem for collections::btree_set::Iter<'a, T> {}
398impl<'a, T> SortedByItem for collections::btree_set::Intersection<'a, T> {}
399impl<'a, T> SortedByItem for collections::btree_set::Union<'a, T> {}
400impl<'a, T> SortedByItem for collections::btree_set::Difference<'a, T> {}
401impl<'a, T> SortedByItem for collections::btree_set::SymmetricDifference<'a, T> {}
402impl<'a, T> SortedByItem for collections::btree_set::Range<'a, T> {}
403
404impl<'a, K, V> SortedByItem for collections::btree_map::Keys<'a, K, V> {}
405
406impl<T> SortedByItem for ops::Range<T> {}
407impl<T> SortedByItem for ops::RangeInclusive<T> {}
408impl<T> SortedByItem for ops::RangeFrom<T> {}
409
410impl<I: Iterator> SortedByItem for Keys<I> {}
411impl<I: Iterator> SortedByItem for AssumeSortedByItem<I> {}
412impl<I: Iterator, J: Iterator> SortedByItem for Union<I, J> {}
413impl<I: Iterator, J: Iterator> SortedByItem for Intersection<I, J> {}
414impl<I: Iterator, J: Iterator> SortedByItem for Difference<I, J> {}
415impl<I: Iterator, J: Iterator> SortedByItem for SymmetricDifference<I, J> {}
416impl<I: Iterator> SortedByItem for MultiwayUnion<I> {}
417
418impl<I: SortedByItem> SortedByItem for Box<I> {}
419
420impl<I: OneOrLess, F> SortedByItem for iter::Map<I, F> {}
421impl<Iin, J, Iout, F> SortedByItem for iter::FlatMap<Iin, J, F>
422where
423    Iin: OneOrLess,
424    J: IntoIterator<IntoIter = Iout>,
425    Iout: SortedByItem,
426{
427}
428impl<Iin, J, Iout> SortedByItem for iter::Flatten<Iin>
429where
430    Iin: OneOrLess + Iterator<Item = J>,
431    J: IntoIterator<IntoIter = Iout>,
432    Iout: SortedByItem,
433{
434}
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439    use core::fmt::Debug;
440    use std::collections::BTreeMap;
441
442    /// just a helper to get good output when a check fails
443    fn binary_op<E: Debug, R: Eq + Debug>(a: E, b: E, expected: R, actual: R) -> bool {
444        let res = expected == actual;
445        if !res {
446            println!(
447                "a:{:?} b:{:?} expected:{:?} actual:{:?}",
448                a, b, expected, actual
449            );
450        }
451        res
452    }
453
454    type Element = i64;
455    type Reference = collections::BTreeSet<Element>;
456
457    #[quickcheck]
458    fn intersection(a: Reference, b: Reference) -> bool {
459        let expected: Reference = a.intersection(&b).cloned().collect();
460        let actual: Reference = a
461            .clone()
462            .into_iter()
463            .intersection(b.clone().into_iter())
464            .collect();
465        binary_op(a, b, expected, actual)
466    }
467
468    #[quickcheck]
469    fn union(a: Reference, b: Reference) -> bool {
470        let expected: Reference = a.union(&b).cloned().collect();
471        let actual: Reference = a.clone().into_iter().union(b.clone().into_iter()).collect();
472        binary_op(a, b, expected, actual)
473    }
474
475    #[quickcheck]
476    fn multi_union(inputs: Vec<Reference>) -> bool {
477        let expected: Reference = inputs.iter().flatten().copied().collect();
478        let actual = MultiwayUnion::from_iter(inputs.iter().map(|i| i.iter()));
479        let res = actual.clone().eq(expected.iter());
480        if !res {
481            let actual: Reference = actual.copied().collect();
482            println!("in:{:?} expected:{:?} out:{:?}", inputs, expected, actual);
483        }
484        res
485    }
486
487    #[quickcheck]
488    fn difference(a: Reference, b: Reference) -> bool {
489        let expected: Reference = a.difference(&b).cloned().collect();
490        let actual: Reference = a
491            .clone()
492            .into_iter()
493            .difference(b.clone().into_iter())
494            .collect();
495        binary_op(a, b, expected, actual)
496    }
497
498    #[quickcheck]
499    fn symmetric_difference(a: Reference, b: Reference) -> bool {
500        let expected: Reference = a.symmetric_difference(&b).cloned().collect();
501        let actual: Reference = a
502            .clone()
503            .into_iter()
504            .symmetric_difference(b.clone().into_iter())
505            .collect();
506        binary_op(a, b, expected, actual)
507    }
508
509    /// just a helper to get good output when a check fails
510    fn check_size_hint<E: Debug>(
511        input: E,
512        expected: usize,
513        (min, max): (usize, Option<usize>),
514    ) -> bool {
515        let res = min <= expected && max.map_or(true, |max| expected <= max && min <= max);
516        if !res {
517            println!(
518                "input:{:?} expected:{:?} min:{:?} max:{:?}",
519                input, expected, min, max
520            );
521        }
522        res
523    }
524
525    #[quickcheck]
526    fn intersection_size_hint(a: Reference, b: Reference) -> bool {
527        let expected = a.intersection(&b).count();
528        let actual = a.iter().intersection(b.iter()).size_hint();
529        check_size_hint((a, b), expected, actual)
530    }
531
532    #[quickcheck]
533    fn union_size_hint(a: Reference, b: Reference) -> bool {
534        let expected = a.union(&b).count();
535        let actual = a.iter().union(b.iter()).size_hint();
536        check_size_hint((a, b), expected, actual)
537    }
538
539    #[quickcheck]
540    fn multi_union_size_hint(inputs: Vec<Reference>) -> bool {
541        let expected: Reference = inputs.iter().flatten().copied().collect();
542        let actual = MultiwayUnion::from_iter(inputs.iter().map(|i| i.iter())).size_hint();
543        check_size_hint(inputs, expected.len(), actual)
544    }
545
546    #[quickcheck]
547    fn difference_size_hint(a: Reference, b: Reference) -> bool {
548        let expected = a.difference(&b).count();
549        let actual = a.iter().difference(b.iter()).size_hint();
550        check_size_hint((a, b), expected, actual)
551    }
552
553    #[quickcheck]
554    fn symmetric_difference_size_hint(a: Reference, b: Reference) -> bool {
555        let expected = a.symmetric_difference(&b).count();
556        let actual = a.iter().symmetric_difference(b.iter()).size_hint();
557        check_size_hint((a, b), expected, actual)
558    }
559
560    fn s() -> impl Iterator<Item = i64> + SortedByItem {
561        0i64..10
562    }
563    fn r<'a>() -> impl Iterator<Item = &'a i64> + SortedByItem {
564        iter::empty()
565    }
566    fn is_s<K, I: Iterator<Item = K> + SortedByItem>(_v: I) {}
567
568    #[test]
569    fn instances() {
570        is_s(iter::empty::<i64>());
571        is_s(iter::once(0u64));
572        // ranges
573        is_s(0i64..10);
574        is_s(0i64..=10);
575        is_s(0i64..);
576        // wrappers
577        is_s(Box::new(0i64..10));
578        // identity
579        is_s(s().fuse());
580        is_s(r().cloned());
581        is_s(r().copied());
582        is_s(r().peekable());
583        is_s(s().inspect(|_| {}));
584        // removing items
585        is_s(s().step_by(2));
586        is_s(s().take(1));
587        is_s(s().take_while(|_| true));
588        is_s(s().skip(1));
589        is_s(s().skip_while(|_| true));
590        is_s(s().filter(|_| true));
591        // set ops
592        is_s(s().union(s()));
593        is_s(s().intersection(s()));
594        is_s(s().difference(s()));
595        is_s(s().symmetric_difference(s()));
596        is_s(multiway_union(vec![s(), s(), s()]));
597        is_s(multiway_union(iter::once(s())));
598        // one_or_less
599        let a_btree = BTreeMap::<i64, f32>::new();
600        is_s(
601            Some(())
602                .iter()
603                .map(|_| &a_btree)
604                .filter(|b| !b.is_empty())
605                .flat_map(|m| m.keys()),
606        );
607        is_s(
608            iter::once(Some(()))
609                .flatten()
610                .map(|_| &a_btree)
611                .filter(|b| !b.is_empty())
612                .flat_map(|m| m.keys()),
613        );
614    }
615}