vortex_mask/
intersect_by_rank.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use crate::{AllOr, Mask};
5
6impl Mask {
7    /// Take the intersection of the `mask` with the set of true values in `self`.
8    ///
9    /// We are more interested in low selectivity `self` (as indices) with a boolean buffer mask,
10    /// so we don't optimize for other cases, yet.
11    ///
12    /// Note: we might be able to accelerate this function on x86 with BMI, see:
13    /// <https://www.microsoft.com/en-us/research/uploads/prod/2023/06/parquet-select-sigmod23.pdf>
14    ///
15    /// # Examples
16    ///
17    /// Keep the third and fifth set values from mask `m1`:
18    /// ```
19    /// use vortex_mask::Mask;
20    ///
21    /// let m1 = Mask::from_iter([true, false, false, true, true, true, false, true]);
22    /// let m2 = Mask::from_iter([false, false, true, false, true]);
23    /// assert_eq!(
24    ///     m1.intersect_by_rank(&m2),
25    ///     Mask::from_iter([false, false, false, false, true, false, false, true])
26    /// );
27    /// ```
28    pub fn intersect_by_rank(&self, mask: &Mask) -> Mask {
29        assert_eq!(self.true_count(), mask.len());
30
31        match (self.indices(), mask.indices()) {
32            (AllOr::All, _) => mask.clone(),
33            (_, AllOr::All) => self.clone(),
34            (AllOr::None, _) => Self::new_false(0),
35            (_, AllOr::None) => Self::new_false(self.len()),
36            (AllOr::Some(self_indices), AllOr::Some(mask_indices)) => {
37                Self::from_indices(
38                    self.len(),
39                    mask_indices
40                        .iter()
41                        .map(|idx|
42                            // This is verified as safe because we know that the indices are less than the
43                            // mask.len() and we known mask.len() <= self.len(),
44                            // implied by `self.true_count() == mask.len()`.
45                            unsafe{*self_indices.get_unchecked(*idx)})
46                        .collect(),
47                )
48            }
49        }
50    }
51}
52
53#[cfg(test)]
54#[allow(clippy::panic)]
55mod test {
56    use arrow_buffer::BooleanBuffer;
57    use rstest::rstest;
58
59    use crate::Mask;
60
61    #[test]
62    fn mask_bitand_all_as_bit_and() {
63        let this = Mask::from_buffer(BooleanBuffer::from_iter(vec![true, true, true, true, true]));
64        let mask = Mask::from_buffer(BooleanBuffer::from_iter(vec![
65            false, true, false, true, true,
66        ]));
67        assert_eq!(
68            this.intersect_by_rank(&mask),
69            Mask::from_indices(5, vec![1, 3, 4])
70        );
71    }
72
73    #[test]
74    fn mask_bitand_all_true() {
75        let this = Mask::from_buffer(BooleanBuffer::from_iter(vec![
76            false, false, true, true, true,
77        ]));
78        let mask = Mask::from_buffer(BooleanBuffer::from_iter(vec![true, true, true]));
79        assert_eq!(
80            this.intersect_by_rank(&mask),
81            Mask::from_indices(5, vec![2, 3, 4])
82        );
83    }
84
85    #[test]
86    fn mask_bitand_true() {
87        let this = Mask::from_buffer(BooleanBuffer::from_iter(vec![
88            true, false, false, true, true,
89        ]));
90        let mask = Mask::from_buffer(BooleanBuffer::from_iter(vec![true, false, true]));
91        assert_eq!(
92            this.intersect_by_rank(&mask),
93            Mask::from_indices(5, vec![0, 4])
94        );
95    }
96
97    #[test]
98    fn mask_bitand_false() {
99        let this = Mask::from_buffer(BooleanBuffer::from_iter(vec![
100            true, false, false, true, true,
101        ]));
102        let mask = Mask::from_buffer(BooleanBuffer::from_iter(vec![false, false, false]));
103        assert_eq!(this.intersect_by_rank(&mask), Mask::from_indices(5, vec![]));
104    }
105
106    #[rstest]
107    #[case::all_true_with_all_true(
108        Mask::new_true(5),
109        Mask::new_true(5),
110        vec![0, 1, 2, 3, 4]
111    )]
112    #[case::all_true_with_all_false(
113        Mask::new_true(5),
114        Mask::new_false(5),
115        vec![]
116    )]
117    #[case::all_false_with_any(
118        Mask::new_false(10),
119        Mask::new_true(0),
120        vec![]
121    )]
122    #[case::indices_with_all_true(
123        Mask::from_indices(10, vec![2, 5, 7, 9]),
124        Mask::new_true(4),
125        vec![2, 5, 7, 9]
126    )]
127    #[case::indices_with_all_false(
128        Mask::from_indices(10, vec![2, 5, 7, 9]),
129        Mask::new_false(4),
130        vec![]
131    )]
132    fn test_intersect_by_rank_special_cases(
133        #[case] base_mask: Mask,
134        #[case] rank_mask: Mask,
135        #[case] expected_indices: Vec<usize>,
136    ) {
137        let result = base_mask.intersect_by_rank(&rank_mask);
138
139        match result.indices() {
140            crate::AllOr::All => assert_eq!(expected_indices.len(), result.len()),
141            crate::AllOr::None => assert!(expected_indices.is_empty()),
142            crate::AllOr::Some(indices) => assert_eq!(indices, &expected_indices[..]),
143        }
144    }
145
146    #[test]
147    fn test_intersect_by_rank_example() {
148        // Example from the documentation
149        let m1 = Mask::from_iter([true, false, false, true, true, true, false, true]);
150        let m2 = Mask::from_iter([false, false, true, false, true]);
151        let result = m1.intersect_by_rank(&m2);
152        let expected = Mask::from_iter([false, false, false, false, true, false, false, true]);
153        assert_eq!(result, expected);
154    }
155
156    #[test]
157    #[should_panic]
158    fn test_intersect_by_rank_wrong_length() {
159        let m1 = Mask::from_indices(10, vec![2, 5, 7]); // 3 true values
160        let m2 = Mask::new_true(5); // 5 true values - doesn't match
161        let _ = m1.intersect_by_rank(&m2);
162    }
163
164    #[rstest]
165    #[case::single_element(
166        vec![3],
167        vec![true],
168        vec![3]
169    )]
170    #[case::single_element_masked(
171        vec![3],
172        vec![false],
173        vec![]
174    )]
175    #[case::alternating(
176        vec![0, 2, 4, 6, 8],
177        vec![true, false, true, false, true],
178        vec![0, 4, 8]
179    )]
180    #[case::consecutive(
181        vec![5, 6, 7, 8, 9],
182        vec![false, true, true, true, false],
183        vec![6, 7, 8]
184    )]
185    fn test_intersect_by_rank_patterns(
186        #[case] base_indices: Vec<usize>,
187        #[case] rank_pattern: Vec<bool>,
188        #[case] expected_indices: Vec<usize>,
189    ) {
190        let base = Mask::from_indices(10, base_indices);
191        let rank = Mask::from_iter(rank_pattern);
192        let result = base.intersect_by_rank(&rank);
193
194        match result.indices() {
195            crate::AllOr::Some(indices) => assert_eq!(indices, &expected_indices[..]),
196            crate::AllOr::None => assert!(expected_indices.is_empty()),
197            _ => panic!("Unexpected result"),
198        }
199    }
200}