splinter_rs/splinter/
intersection.rs

1use crate::{cow::CowSplinter, ops::Intersection, relational::Relation};
2
3use super::{Splinter, SplinterRef};
4
5// Splinter <> Splinter
6impl Intersection for Splinter {
7    type Output = Splinter;
8
9    fn intersection(&self, rhs: &Self) -> Self::Output {
10        let mut out = Splinter::default();
11        for (a, left, right) in self.partitions.inner_join(&rhs.partitions) {
12            for (b, left, right) in left.inner_join(&right) {
13                for (c, left, right) in left.inner_join(&right) {
14                    out.insert_block(a, b, c, left.intersection(right));
15                }
16            }
17        }
18        out
19    }
20}
21
22// Splinter <> SplinterRef
23impl<T: AsRef<[u8]>> Intersection<SplinterRef<T>> for Splinter {
24    type Output = Splinter;
25
26    fn intersection(&self, rhs: &SplinterRef<T>) -> Self::Output {
27        let mut out = Splinter::default();
28        let rhs = rhs.load_partitions();
29        for (a, left, right) in self.partitions.inner_join(&rhs) {
30            for (b, left, right) in left.inner_join(&right) {
31                for (c, left, right) in left.inner_join(&right) {
32                    out.insert_block(a, b, c, left.intersection(&right));
33                }
34            }
35        }
36        out
37    }
38}
39
40// SplinterRef <> Splinter
41impl<T: AsRef<[u8]>> Intersection<Splinter> for SplinterRef<T> {
42    type Output = Splinter;
43
44    fn intersection(&self, rhs: &Splinter) -> Self::Output {
45        rhs.intersection(self)
46    }
47}
48
49// SplinterRef <> SplinterRef
50impl<T1: AsRef<[u8]>, T2: AsRef<[u8]>> Intersection<SplinterRef<T2>> for SplinterRef<T1> {
51    type Output = Splinter;
52
53    fn intersection(&self, rhs: &SplinterRef<T2>) -> Self::Output {
54        let mut out = Splinter::default();
55        let rhs = rhs.load_partitions();
56        for (a, left, right) in self.load_partitions().inner_join(&rhs) {
57            for (b, left, right) in left.inner_join(&right) {
58                for (c, left, right) in left.inner_join(&right) {
59                    out.insert_block(a, b, c, left.intersection(&right));
60                }
61            }
62        }
63        out
64    }
65}
66
67// CowSplinter <> Splinter
68impl<T: AsRef<[u8]>> Intersection<Splinter> for CowSplinter<T> {
69    type Output = Splinter;
70
71    fn intersection(&self, rhs: &Splinter) -> Self::Output {
72        match self {
73            CowSplinter::Owned(splinter) => splinter.intersection(rhs),
74            CowSplinter::Ref(splinter_ref) => rhs.intersection(splinter_ref),
75        }
76    }
77}
78
79// CowSplinter <> SplinterRef
80impl<T1: AsRef<[u8]>, T2: AsRef<[u8]>> Intersection<SplinterRef<T2>> for CowSplinter<T1> {
81    type Output = Splinter;
82
83    fn intersection(&self, rhs: &SplinterRef<T2>) -> Self::Output {
84        match self {
85            CowSplinter::Owned(splinter) => splinter.intersection(rhs),
86            CowSplinter::Ref(splinter_ref) => splinter_ref.intersection(rhs),
87        }
88    }
89}
90
91#[cfg(test)]
92mod tests {
93    use crate::{
94        Splinter,
95        ops::Intersection,
96        testutil::{TestSplinter, check_combinations},
97    };
98
99    impl Intersection for TestSplinter {
100        type Output = Splinter;
101
102        fn intersection(&self, rhs: &Self) -> Self::Output {
103            use TestSplinter::*;
104            match (self, rhs) {
105                (Splinter(lhs), Splinter(rhs)) => lhs.intersection(rhs),
106                (Splinter(lhs), SplinterRef(rhs)) => lhs.intersection(rhs),
107                (SplinterRef(lhs), Splinter(rhs)) => lhs.intersection(rhs),
108                (SplinterRef(lhs), SplinterRef(rhs)) => lhs.intersection(rhs),
109            }
110        }
111    }
112
113    #[test]
114    fn test_sanity() {
115        check_combinations(0..0, 0..0, 0..0, |lhs, rhs| lhs.intersection(&rhs));
116        check_combinations(0..100, 30..150, 30..100, |lhs, rhs| lhs.intersection(&rhs));
117
118        // 8 sparse blocks
119        let set = (0..=1024).step_by(128).collect::<Vec<_>>();
120        check_combinations(set.clone(), vec![0, 128], vec![0, 128], |lhs, rhs| {
121            lhs.intersection(&rhs)
122        });
123    }
124}