vortex_mask/
intersect_by_rank.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use crate::AllOr;
5use crate::Mask;
6
7impl Mask {
8    /// Take the intersection of the `mask` with the set of true values in `self`.
9    ///
10    /// We are more interested in low selectivity `self` (as indices) with a boolean buffer mask,
11    /// so we don't optimize for other cases, yet.
12    ///
13    /// Note: we might be able to accelerate this function on x86 with BMI, see:
14    /// <https://www.microsoft.com/en-us/research/uploads/prod/2023/06/parquet-select-sigmod23.pdf>
15    ///
16    /// # Examples
17    ///
18    /// Keep the third and fifth set values from mask `m1`:
19    /// ```
20    /// use vortex_mask::Mask;
21    ///
22    /// let m1 = Mask::from_iter([true, false, false, true, true, true, false, true]);
23    /// let m2 = Mask::from_iter([false, false, true, false, true]);
24    /// assert_eq!(
25    ///     m1.intersect_by_rank(&m2),
26    ///     Mask::from_iter([false, false, false, false, true, false, false, true])
27    /// );
28    /// ```
29    pub fn intersect_by_rank(&self, mask: &Mask) -> Mask {
30        assert_eq!(self.true_count(), mask.len());
31
32        match (self.indices(), mask.indices()) {
33            (AllOr::All, _) => mask.clone(),
34            (_, AllOr::All) => self.clone(),
35            (AllOr::None, _) | (_, AllOr::None) => Self::new_false(self.len()),
36
37            (AllOr::Some(self_indices), AllOr::Some(mask_indices)) => {
38                Self::from_indices(
39                    self.len(),
40                    mask_indices
41                        .iter()
42                        .map(|idx|
43                            // This is verified as safe because we know that the indices are less than the
44                            // mask.len() and we known mask.len() <= self.len(),
45                            // implied by `self.true_count() == mask.len()`.
46                            unsafe{*self_indices.get_unchecked(*idx)})
47                        .collect(),
48                )
49            }
50        }
51    }
52}
53
54#[cfg(test)]
55mod test {
56    use rstest::rstest;
57    use vortex_buffer::BitBuffer;
58
59    use crate::Mask;
60
61    #[test]
62    fn mask_bitand_all_as_bit_and() {
63        let this = Mask::from_buffer(BitBuffer::from_iter(vec![true, true, true, true, true]));
64        let mask = Mask::from_buffer(BitBuffer::from_iter(vec![false, true, false, true, true]));
65        assert_eq!(
66            this.intersect_by_rank(&mask),
67            Mask::from_indices(5, vec![1, 3, 4])
68        );
69    }
70
71    #[test]
72    fn mask_bitand_all_true() {
73        let this = Mask::from_buffer(BitBuffer::from_iter(vec![false, false, true, true, true]));
74        let mask = Mask::from_buffer(BitBuffer::from_iter(vec![true, true, true]));
75        assert_eq!(
76            this.intersect_by_rank(&mask),
77            Mask::from_indices(5, vec![2, 3, 4])
78        );
79    }
80
81    #[test]
82    fn mask_bitand_true() {
83        let this = Mask::from_buffer(BitBuffer::from_iter(vec![true, false, false, true, true]));
84        let mask = Mask::from_buffer(BitBuffer::from_iter(vec![true, false, true]));
85        assert_eq!(
86            this.intersect_by_rank(&mask),
87            Mask::from_indices(5, vec![0, 4])
88        );
89    }
90
91    #[test]
92    fn mask_bitand_false() {
93        let this = Mask::from_buffer(BitBuffer::from_iter(vec![true, false, false, true, true]));
94        let mask = Mask::from_buffer(BitBuffer::from_iter(vec![false, false, false]));
95        assert_eq!(this.intersect_by_rank(&mask), Mask::from_indices(5, vec![]));
96    }
97
98    #[test]
99    fn mask_intersect_by_rank_all_false() {
100        let this = Mask::AllFalse(10);
101        let mask = Mask::AllFalse(0);
102        assert_eq!(this.intersect_by_rank(&mask), Mask::AllFalse(10));
103    }
104
105    #[rstest]
106    #[case::all_true_with_all_true(
107        Mask::new_true(5),
108        Mask::new_true(5),
109        vec![0, 1, 2, 3, 4]
110    )]
111    #[case::all_true_with_all_false(
112        Mask::new_true(5),
113        Mask::new_false(5),
114        vec![]
115    )]
116    #[case::all_false_with_any(
117        Mask::new_false(10),
118        Mask::new_true(0),
119        vec![]
120    )]
121    #[case::indices_with_all_true(
122        Mask::from_indices(10, vec![2, 5, 7, 9]),
123        Mask::new_true(4),
124        vec![2, 5, 7, 9]
125    )]
126    #[case::indices_with_all_false(
127        Mask::from_indices(10, vec![2, 5, 7, 9]),
128        Mask::new_false(4),
129        vec![]
130    )]
131    fn test_intersect_by_rank_special_cases(
132        #[case] base_mask: Mask,
133        #[case] rank_mask: Mask,
134        #[case] expected_indices: Vec<usize>,
135    ) {
136        let result = base_mask.intersect_by_rank(&rank_mask);
137
138        match result.indices() {
139            crate::AllOr::All => assert_eq!(expected_indices.len(), result.len()),
140            crate::AllOr::None => assert!(expected_indices.is_empty()),
141            crate::AllOr::Some(indices) => assert_eq!(indices, &expected_indices[..]),
142        }
143    }
144
145    #[test]
146    fn test_intersect_by_rank_example() {
147        // Example from the documentation
148        let m1 = Mask::from_iter([true, false, false, true, true, true, false, true]);
149        let m2 = Mask::from_iter([false, false, true, false, true]);
150        let result = m1.intersect_by_rank(&m2);
151        let expected = Mask::from_iter([false, false, false, false, true, false, false, true]);
152        assert_eq!(result, expected);
153    }
154
155    #[test]
156    #[should_panic]
157    fn test_intersect_by_rank_wrong_length() {
158        let m1 = Mask::from_indices(10, vec![2, 5, 7]); // 3 true values
159        let m2 = Mask::new_true(5); // 5 true values - doesn't match
160        m1.intersect_by_rank(&m2);
161    }
162
163    #[rstest]
164    #[case::single_element(
165        vec![3],
166        vec![true],
167        vec![3]
168    )]
169    #[case::single_element_masked(
170        vec![3],
171        vec![false],
172        vec![]
173    )]
174    #[case::alternating(
175        vec![0, 2, 4, 6, 8],
176        vec![true, false, true, false, true],
177        vec![0, 4, 8]
178    )]
179    #[case::consecutive(
180        vec![5, 6, 7, 8, 9],
181        vec![false, true, true, true, false],
182        vec![6, 7, 8]
183    )]
184    fn test_intersect_by_rank_patterns(
185        #[case] base_indices: Vec<usize>,
186        #[case] rank_pattern: Vec<bool>,
187        #[case] expected_indices: Vec<usize>,
188    ) {
189        let base = Mask::from_indices(10, base_indices);
190        let rank = Mask::from_iter(rank_pattern);
191        let result = base.intersect_by_rank(&rank);
192
193        match result.indices() {
194            crate::AllOr::Some(indices) => assert_eq!(indices, &expected_indices[..]),
195            crate::AllOr::None => assert!(expected_indices.is_empty()),
196            _ => panic!("Unexpected result"),
197        }
198    }
199}