Skip to main content

vortex_mask/
intersect_by_rank.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::iter::Chain;
5use std::iter::Once;
6use std::iter::once;
7use std::sync::Arc;
8
9use vortex_buffer::BitBuffer;
10use vortex_buffer::BitChunkIterator;
11use vortex_buffer::BufferMut;
12use vortex_error::VortexExpect;
13
14use crate::Mask;
15use crate::MaskValues;
16
17trait DepositBits {
18    /// Whether the implementation benefits from short-circuiting on `rank_bits == 0`
19    /// and `self_chunk == u64::MAX`. The portable path loops `popcount(mask)` times,
20    /// so an all-ones mask is genuinely expensive; BMI2 PDEP is constant-time and
21    /// the branches just add mispredict cost.
22    const PREFER_BRANCHES: bool;
23
24    fn deposit_bits(source: u64, mask: u64, mask_count: usize) -> u64;
25}
26
27trait SelectBit {
28    /// Position (0..63) of the `rank`-th set bit in `word`. Caller ensures
29    /// `rank < word.count_ones()`.
30    fn select_bit_position(word: u64, rank: usize) -> usize;
31}
32
33struct Portable;
34
35impl DepositBits for Portable {
36    const PREFER_BRANCHES: bool = true;
37
38    #[inline]
39    fn deposit_bits(source: u64, mask: u64, mask_count: usize) -> u64 {
40        if mask_count >= 16 && source.count_ones() as usize * 8 < mask_count {
41            return deposit_sparse_source(source, mask);
42        }
43
44        deposit_by_mask(source, mask)
45    }
46}
47
48impl SelectBit for Portable {
49    #[inline]
50    fn select_bit_position(word: u64, rank: usize) -> usize {
51        select_bit_position_portable(word, rank)
52    }
53}
54
55#[inline]
56fn deposit_by_mask(mut source: u64, mut mask: u64) -> u64 {
57    let mut result = 0u64;
58    while mask != 0 {
59        let bit = mask & mask.wrapping_neg();
60        if source & 1 != 0 {
61            result |= bit;
62        }
63        source >>= 1;
64        mask &= mask - 1;
65    }
66    result
67}
68
69#[inline]
70fn deposit_sparse_source(mut source: u64, mask: u64) -> u64 {
71    let mut result = 0u64;
72    while source != 0 {
73        result |= select_set_bit(mask, source.trailing_zeros() as usize);
74        source &= source - 1;
75    }
76    result
77}
78
79#[inline]
80fn select_set_bit(word: u64, rank: usize) -> u64 {
81    1u64 << select_bit_position_portable(word, rank)
82}
83
84#[inline]
85fn select_bit_position_portable(word: u64, mut rank: usize) -> usize {
86    debug_assert!(rank < word.count_ones() as usize);
87    let mut bit_offset = 0usize;
88    for byte in word.to_le_bytes() {
89        let count = byte.count_ones() as usize;
90        if rank < count {
91            let mut bits = byte;
92            for _ in 0..rank {
93                bits &= bits - 1;
94            }
95
96            return bit_offset + bits.trailing_zeros() as usize;
97        }
98
99        rank -= count;
100        bit_offset += 8;
101    }
102
103    debug_assert!(false, "rank out of bounds");
104    0
105}
106
107#[cfg(target_arch = "x86_64")]
108struct Bmi2;
109
110#[cfg(target_arch = "x86_64")]
111impl DepositBits for Bmi2 {
112    const PREFER_BRANCHES: bool = false;
113
114    #[inline]
115    fn deposit_bits(source: u64, mask: u64, _mask_count: usize) -> u64 {
116        // SAFETY: callers only instantiate this implementation after checking BMI2 support.
117        unsafe { pdep_bmi2(source, mask) }
118    }
119}
120
121#[cfg(target_arch = "x86_64")]
122impl SelectBit for Bmi2 {
123    #[inline]
124    fn select_bit_position(word: u64, rank: usize) -> usize {
125        // SAFETY: callers only instantiate this implementation after checking BMI2 support.
126        unsafe { select_bit_position_bmi2(word, rank) }
127    }
128}
129
130#[cfg(target_arch = "x86_64")]
131#[target_feature(enable = "bmi2")]
132unsafe fn pdep_bmi2(source: u64, mask: u64) -> u64 {
133    use std::arch::x86_64;
134    x86_64::_pdep_u64(source, mask)
135}
136
137#[cfg(target_arch = "x86_64")]
138#[target_feature(enable = "bmi2")]
139unsafe fn select_bit_position_bmi2(word: u64, rank: usize) -> usize {
140    use std::arch::x86_64;
141    debug_assert!(rank < word.count_ones() as usize);
142    // PDEP places the rank-th bit of source into the rank-th set bit of mask, returning a single
143    // bit at the desired position.
144    let bit = x86_64::_pdep_u64(1u64 << rank, word);
145    bit.trailing_zeros() as usize
146}
147
148/// Reader that pulls variable-length (0..=64 bit) groups from a [`BitBuffer`] sequentially.
149///
150/// Maintains a 128-bit window over two consecutive chunks (`current`, `next`) and uses a
151/// funnel shift via `u128` to extract bits at any offset without branching. The shift
152/// pattern compiles to a single funnel-shift / SHRD-style sequence on x86_64.
153struct RankBitReader<'a> {
154    chunk_iter: Chain<BitChunkIterator<'a>, Once<u64>>,
155    current: u64,
156    next: u64,
157    bit_offset: usize,
158}
159
160impl<'a> RankBitReader<'a> {
161    fn new(buffer: &'a BitBuffer) -> Self {
162        let chunks = buffer.chunks();
163        let mut chunk_iter = chunks.iter().chain(once(chunks.remainder_bits()));
164
165        let current = chunk_iter.next().unwrap_or(0);
166        let next = chunk_iter.next().unwrap_or(0);
167
168        Self {
169            chunk_iter,
170            current,
171            next,
172            bit_offset: 0,
173        }
174    }
175
176    #[inline]
177    fn fetch_next(&mut self) -> u64 {
178        self.chunk_iter.next().unwrap_or(0)
179    }
180
181    #[inline]
182    fn read(&mut self, bit_count: usize) -> u64 {
183        debug_assert!(bit_count <= 64);
184
185        // Funnel shift: extract `bit_count` bits at `bit_offset` from the (next:current)
186        // 128-bit window. For bit_offset in 0..=63 this is a single SHRD-style instruction
187        // on x86_64; the u128 cast keeps it well-defined when bit_offset == 0.
188        let combined = ((self.next as u128) << 64) | (self.current as u128);
189        // The truncation is intentional: we want the low 64 bits of the funnel-shifted
190        // window, which is exactly what `as u64` produces.
191        #[expect(clippy::cast_possible_truncation)]
192        let bits = (combined >> self.bit_offset) as u64 & low_bits(bit_count);
193
194        let new_offset = self.bit_offset + bit_count;
195        if new_offset >= 64 {
196            self.current = self.next;
197            self.next = self.fetch_next();
198            self.bit_offset = new_offset - 64;
199        } else {
200            self.bit_offset = new_offset;
201        }
202
203        bits
204    }
205}
206
207#[inline]
208fn low_bits(bit_count: usize) -> u64 {
209    debug_assert!(bit_count <= 64);
210    if bit_count == 64 {
211        u64::MAX
212    } else {
213        (1u64 << bit_count) - 1
214    }
215}
216
217#[inline]
218fn mask_from_buffer(buffer: BitBuffer, true_count: usize) -> Mask {
219    let len = buffer.len();
220    if true_count == 0 {
221        return Mask::new_false(len);
222    }
223    if true_count == len {
224        return Mask::new_true(len);
225    }
226
227    Mask::Values(Arc::new(MaskValues {
228        buffer,
229        indices: Default::default(),
230        slices: Default::default(),
231        true_count,
232        density: true_count as f64 / len as f64,
233    }))
234}
235
236#[inline]
237fn push_result_chunk<D: DepositBits>(
238    result: &mut BufferMut<u64>,
239    self_chunk: u64,
240    self_count: usize,
241    rank_bits: u64,
242) {
243    let chunk = if D::PREFER_BRANCHES {
244        if rank_bits == 0 {
245            0
246        } else if self_chunk == u64::MAX {
247            rank_bits
248        } else {
249            D::deposit_bits(rank_bits, self_chunk, self_count)
250        }
251    } else {
252        D::deposit_bits(rank_bits, self_chunk, self_count)
253    };
254
255    // SAFETY: callers allocate enough capacity for every output chunk.
256    unsafe { result.push_unchecked(chunk) };
257}
258
259fn intersect_bit_buffers<D: DepositBits>(
260    self_buffer: &BitBuffer,
261    mask_buffer: &BitBuffer,
262    true_count: usize,
263) -> Mask {
264    let len = self_buffer.len();
265    let mut result = BufferMut::with_capacity(len.div_ceil(64));
266    let mut reader = RankBitReader::new(mask_buffer);
267    let self_chunks = self_buffer.chunks();
268
269    for self_chunk in self_chunks.iter() {
270        let self_count = self_chunk.count_ones() as usize;
271        let rank_bits = reader.read(self_count);
272        push_result_chunk::<D>(&mut result, self_chunk, self_count, rank_bits);
273    }
274
275    if self_chunks.remainder_len() != 0 {
276        let self_chunk = self_chunks.remainder_bits();
277        let self_count = self_chunk.count_ones() as usize;
278        let rank_bits = reader.read(self_count);
279        push_result_chunk::<D>(&mut result, self_chunk, self_count, rank_bits);
280    }
281
282    mask_from_buffer(
283        BitBuffer::new(result.freeze().into_byte_buffer(), len),
284        true_count,
285    )
286}
287
288fn intersect_bit_buffer_by_rank_indices<D: DepositBits>(
289    self_buffer: &BitBuffer,
290    mask_indices: &[usize],
291) -> Mask {
292    let len = self_buffer.len();
293    let mut result = BufferMut::with_capacity(len.div_ceil(64));
294    let self_chunks = self_buffer.chunks();
295    let mut rank_base = 0usize;
296    let mut rank_idx = 0usize;
297
298    for self_chunk in self_chunks.iter() {
299        let self_count = self_chunk.count_ones() as usize;
300        let next_rank_base = rank_base + self_count;
301        let rank_bits = rank_bits_for_chunk(mask_indices, &mut rank_idx, rank_base, next_rank_base);
302        push_result_chunk::<D>(&mut result, self_chunk, self_count, rank_bits);
303        rank_base = next_rank_base;
304    }
305
306    if self_chunks.remainder_len() != 0 {
307        let self_chunk = self_chunks.remainder_bits();
308        let self_count = self_chunk.count_ones() as usize;
309        let next_rank_base = rank_base + self_count;
310        let rank_bits = rank_bits_for_chunk(mask_indices, &mut rank_idx, rank_base, next_rank_base);
311        push_result_chunk::<D>(&mut result, self_chunk, self_count, rank_bits);
312    }
313
314    debug_assert_eq!(rank_idx, mask_indices.len());
315
316    mask_from_buffer(
317        BitBuffer::new(result.freeze().into_byte_buffer(), len),
318        mask_indices.len(),
319    )
320}
321
322/// Walks `mask_indices` (global ranks into `self_buffer.set_bits`) and emits the corresponding
323/// positions in `self_buffer`. For each rank, advances `self_buffer`'s chunks via popcount
324/// skip-while, then locates the bit inside the current chunk with rank-select.
325///
326/// This dominates the chunk-scan paths when the mask is very sparse: cost is
327/// `O(mask.true_count() + self.len() / 64)` rather than `O(self.len() / 64)` per chunk.
328fn intersect_mask_driven<S, I>(self_buffer: &BitBuffer, mask_indices: I, true_count: usize) -> Mask
329where
330    S: SelectBit,
331    I: Iterator<Item = usize>,
332{
333    let len = self_buffer.len();
334    if true_count == 0 {
335        return Mask::new_false(len);
336    }
337
338    let mut chunk_iter = self_buffer.chunks().iter_padded();
339
340    let mut current_chunk = chunk_iter.next().unwrap_or(0);
341    let mut current_count = current_chunk.count_ones() as usize;
342    let mut current_chunk_idx = 0usize;
343    let mut rank_before = 0usize;
344
345    let mut output = Vec::with_capacity(true_count);
346
347    for global_rank in mask_indices {
348        while rank_before + current_count <= global_rank {
349            rank_before += current_count;
350            current_chunk_idx += 1;
351            current_chunk = chunk_iter.next().vortex_expect("mask index out of bounds");
352            current_count = current_chunk.count_ones() as usize;
353        }
354
355        let local_rank = global_rank - rank_before;
356        let bit_pos = S::select_bit_position(current_chunk, local_rank);
357        output.push(current_chunk_idx * 64 + bit_pos);
358    }
359
360    debug_assert_eq!(output.len(), true_count);
361    Mask::from_indices(len, output)
362}
363
364#[inline]
365fn rank_bits_for_chunk(
366    mask_indices: &[usize],
367    rank_idx: &mut usize,
368    rank_base: usize,
369    next_rank_base: usize,
370) -> u64 {
371    let mut rank_bits = 0u64;
372    while let Some(&rank) = mask_indices.get(*rank_idx) {
373        if rank >= next_rank_base {
374            break;
375        }
376        rank_bits |= 1u64 << (rank - rank_base);
377        *rank_idx += 1;
378    }
379    rank_bits
380}
381
382fn intersect_by_rank_indices(len: usize, self_indices: &[usize], mask_indices: &[usize]) -> Mask {
383    Mask::from_indices(
384        len,
385        mask_indices.iter().map(|idx| {
386            // SAFETY: mask indices are ranks into self_indices, because
387            // mask.len() == self.true_count() == self_indices.len().
388            unsafe { *self_indices.get_unchecked(*idx) }
389        }),
390    )
391}
392
393#[inline]
394fn intersect_bit_buffers_dispatch(
395    self_buffer: &BitBuffer,
396    mask_buffer: &BitBuffer,
397    true_count: usize,
398) -> Mask {
399    #[cfg(target_arch = "x86_64")]
400    if std::arch::is_x86_feature_detected!("bmi2") {
401        return intersect_bit_buffers::<Bmi2>(self_buffer, mask_buffer, true_count);
402    }
403
404    intersect_bit_buffers::<Portable>(self_buffer, mask_buffer, true_count)
405}
406
407#[inline]
408fn intersect_rank_indices_dispatch(self_buffer: &BitBuffer, mask_indices: &[usize]) -> Mask {
409    #[cfg(target_arch = "x86_64")]
410    if std::arch::is_x86_feature_detected!("bmi2") {
411        return intersect_bit_buffer_by_rank_indices::<Bmi2>(self_buffer, mask_indices);
412    }
413
414    intersect_bit_buffer_by_rank_indices::<Portable>(self_buffer, mask_indices)
415}
416
417#[inline]
418fn intersect_mask_driven_dispatch<I>(
419    self_buffer: &BitBuffer,
420    mask_indices: I,
421    true_count: usize,
422) -> Mask
423where
424    I: Iterator<Item = usize>,
425{
426    #[cfg(target_arch = "x86_64")]
427    if std::arch::is_x86_feature_detected!("bmi2") {
428        return intersect_mask_driven::<Bmi2, _>(self_buffer, mask_indices, true_count);
429    }
430
431    intersect_mask_driven::<Portable, _>(self_buffer, mask_indices, true_count)
432}
433
434/// Check if a mask is sparse.
435///
436/// BitBuffer traversal uses u64, hence we conclude that one or fewer values per u64 is sparse
437fn mask_is_sparse(values: &Arc<MaskValues>) -> bool {
438    values.true_count().saturating_mul(64) < values.len()
439}
440
441/// Check if a rank mask is sparse
442///
443/// The mask-driven path becomes worthwhile around ~3% mask density: each set
444/// bit costs a select and push, but we save a per-self-chunk popcount + deposit.
445fn rank_mask_is_sparse(values: &Arc<MaskValues>) -> bool {
446    values.true_count().saturating_mul(32) < values.len()
447}
448
449impl Mask {
450    /// Take the intersection of the `mask` with the set of true values in `self`.
451    ///
452    /// The hot path keeps bit-buffer-backed masks as bit buffers. It scans the set bits of `self`
453    /// by rank and deposits selected rank bits into their original positions.
454    ///
455    /// # Examples
456    ///
457    /// Keep the third and fifth set values from mask `m1`:
458    /// ```
459    /// use vortex_mask::Mask;
460    ///
461    /// let m1 = Mask::from_iter([true, false, false, true, true, true, false, true]);
462    /// let m2 = Mask::from_iter([false, false, true, false, true]);
463    /// assert_eq!(
464    ///     m1.intersect_by_rank(&m2),
465    ///     Mask::from_iter([false, false, false, false, true, false, false, true])
466    /// );
467    /// ```
468    pub fn intersect_by_rank(&self, mask: &Mask) -> Mask {
469        assert_eq!(self.true_count(), mask.len());
470
471        match (self, mask) {
472            (Self::AllTrue(_), _) => mask.clone(),
473            (_, Self::AllTrue(_)) => self.clone(),
474            (Self::AllFalse(_), _) | (_, Self::AllFalse(_)) => Self::new_false(self.len()),
475            (Self::Values(self_values), Self::Values(mask_values)) => {
476                // Four dispatch cases keyed by (self density, mask density):
477                //
478                //              | mask sparse | mask dense
479                // -------------+-------------+------------
480                // self sparse  | indices     | indices
481                // self dense   | mask-driven | bit-buffer
482                if let Some(mask_indices) = mask_values.indices.get() {
483                    if let Some(self_indices) = self_values.indices.get()
484                        && mask_indices.len() < self.len().div_ceil(64)
485                    {
486                        return intersect_by_rank_indices(self.len(), self_indices, mask_indices);
487                    }
488
489                    let self_is_very_sparse = mask_is_sparse(self_values);
490                    let mask_is_very_sparse = rank_mask_is_sparse(mask_values);
491
492                    if self_is_very_sparse {
493                        return intersect_by_rank_indices(
494                            self.len(),
495                            self_values.indices(),
496                            mask_indices,
497                        );
498                    }
499
500                    if mask_is_very_sparse {
501                        return intersect_mask_driven_dispatch(
502                            self_values.bit_buffer(),
503                            mask_indices.iter().copied(),
504                            mask_values.true_count(),
505                        );
506                    }
507
508                    if mask_indices.len().saturating_mul(4) > mask.len() {
509                        return intersect_bit_buffers_dispatch(
510                            self_values.bit_buffer(),
511                            mask_values.bit_buffer(),
512                            mask_values.true_count(),
513                        );
514                    }
515
516                    return intersect_rank_indices_dispatch(self_values.bit_buffer(), mask_indices);
517                }
518
519                let self_is_very_sparse = mask_is_sparse(self_values);
520                let mask_is_very_sparse = rank_mask_is_sparse(mask_values);
521
522                if self_is_very_sparse {
523                    return intersect_by_rank_indices(
524                        self.len(),
525                        self_values.indices(),
526                        mask_values.indices(),
527                    );
528                }
529
530                if mask_is_very_sparse {
531                    return intersect_mask_driven_dispatch(
532                        self_values.bit_buffer(),
533                        mask_values.bit_buffer().set_indices(),
534                        mask_values.true_count(),
535                    );
536                }
537
538                intersect_bit_buffers_dispatch(
539                    self_values.bit_buffer(),
540                    mask_values.bit_buffer(),
541                    mask_values.true_count(),
542                )
543            }
544        }
545    }
546}
547
548#[cfg(test)]
549mod test {
550    use rstest::rstest;
551    use vortex_buffer::BitBuffer;
552
553    use crate::Mask;
554
555    #[test]
556    fn mask_bitand_all_as_bit_and() {
557        let this = Mask::from_buffer(BitBuffer::from_iter(vec![true, true, true, true, true]));
558        let mask = Mask::from_buffer(BitBuffer::from_iter(vec![false, true, false, true, true]));
559        assert_eq!(
560            this.intersect_by_rank(&mask),
561            Mask::from_indices(5, vec![1, 3, 4])
562        );
563    }
564
565    #[test]
566    fn mask_bitand_all_true() {
567        let this = Mask::from_buffer(BitBuffer::from_iter(vec![false, false, true, true, true]));
568        let mask = Mask::from_buffer(BitBuffer::from_iter(vec![true, true, true]));
569        assert_eq!(
570            this.intersect_by_rank(&mask),
571            Mask::from_indices(5, vec![2, 3, 4])
572        );
573    }
574
575    #[test]
576    fn mask_bitand_true() {
577        let this = Mask::from_buffer(BitBuffer::from_iter(vec![true, false, false, true, true]));
578        let mask = Mask::from_buffer(BitBuffer::from_iter(vec![true, false, true]));
579        assert_eq!(
580            this.intersect_by_rank(&mask),
581            Mask::from_indices(5, vec![0, 4])
582        );
583    }
584
585    #[test]
586    fn mask_bitand_false() {
587        let this = Mask::from_buffer(BitBuffer::from_iter(vec![true, false, false, true, true]));
588        let mask = Mask::from_buffer(BitBuffer::from_iter(vec![false, false, false]));
589        assert_eq!(this.intersect_by_rank(&mask), Mask::from_indices(5, vec![]));
590    }
591
592    #[test]
593    fn mask_intersect_by_rank_all_false() {
594        let this = Mask::AllFalse(10);
595        let mask = Mask::AllFalse(0);
596        assert_eq!(this.intersect_by_rank(&mask), Mask::AllFalse(10));
597    }
598
599    #[rstest]
600    #[case::all_true_with_all_true(
601        Mask::new_true(5),
602        Mask::new_true(5),
603        vec![0, 1, 2, 3, 4]
604    )]
605    #[case::all_true_with_all_false(
606        Mask::new_true(5),
607        Mask::new_false(5),
608        vec![]
609    )]
610    #[case::all_false_with_any(
611        Mask::new_false(10),
612        Mask::new_true(0),
613        vec![]
614    )]
615    #[case::indices_with_all_true(
616        Mask::from_indices(10, vec![2, 5, 7, 9]),
617        Mask::new_true(4),
618        vec![2, 5, 7, 9]
619    )]
620    #[case::indices_with_all_false(
621        Mask::from_indices(10, vec![2, 5, 7, 9]),
622        Mask::new_false(4),
623        vec![]
624    )]
625    fn test_intersect_by_rank_special_cases(
626        #[case] base_mask: Mask,
627        #[case] rank_mask: Mask,
628        #[case] expected_indices: Vec<usize>,
629    ) {
630        let result = base_mask.intersect_by_rank(&rank_mask);
631
632        match result.indices() {
633            crate::AllOr::All => assert_eq!(expected_indices.len(), result.len()),
634            crate::AllOr::None => assert!(expected_indices.is_empty()),
635            crate::AllOr::Some(indices) => assert_eq!(indices, &expected_indices[..]),
636        }
637    }
638
639    #[test]
640    fn test_intersect_by_rank_example() {
641        // Example from the documentation
642        let m1 = Mask::from_iter([true, false, false, true, true, true, false, true]);
643        let m2 = Mask::from_iter([false, false, true, false, true]);
644        let result = m1.intersect_by_rank(&m2);
645        let expected = Mask::from_iter([false, false, false, false, true, false, false, true]);
646        assert_eq!(result, expected);
647    }
648
649    #[test]
650    #[should_panic]
651    fn test_intersect_by_rank_wrong_length() {
652        let m1 = Mask::from_indices(10, vec![2, 5, 7]); // 3 true values
653        let m2 = Mask::new_true(5); // 5 true values - doesn't match
654        m1.intersect_by_rank(&m2);
655    }
656
657    #[rstest]
658    #[case::single_element(
659        vec![3],
660        vec![true],
661        vec![3]
662    )]
663    #[case::single_element_masked(
664        vec![3],
665        vec![false],
666        vec![]
667    )]
668    #[case::alternating(
669        vec![0, 2, 4, 6, 8],
670        vec![true, false, true, false, true],
671        vec![0, 4, 8]
672    )]
673    #[case::consecutive(
674        vec![5, 6, 7, 8, 9],
675        vec![false, true, true, true, false],
676        vec![6, 7, 8]
677    )]
678    fn test_intersect_by_rank_patterns(
679        #[case] base_indices: Vec<usize>,
680        #[case] rank_pattern: Vec<bool>,
681        #[case] expected_indices: Vec<usize>,
682    ) {
683        let base = Mask::from_indices(10, base_indices);
684        let rank = Mask::from_iter(rank_pattern);
685        let result = base.intersect_by_rank(&rank);
686
687        match result.indices() {
688            crate::AllOr::Some(indices) => assert_eq!(indices, &expected_indices[..]),
689            crate::AllOr::None => assert!(expected_indices.is_empty()),
690            _ => panic!("Unexpected result"),
691        }
692    }
693
694    #[rstest]
695    // Larger sizes to push the bench-shaped buffer paths through the unit tests too.
696    #[case::dense_len_1024(1024, 31, 0.5, 0.5)]
697    // Very-sparse mask exercises the mask-driven dispatch path. Both densities live in
698    // the half-open interval where `mask_is_very_sparse` is true.
699    #[case::sparse_mask_1pct(1024, 17, 0.5, 0.01)]
700    #[case::sparse_mask_2pct(2048, 0, 0.5, 0.02)]
701    #[case::very_sparse_mask_with_offsets(513, 5, 0.5, 0.005)]
702    fn test_intersect_by_rank_density_matrix(
703        #[case] base_len: usize,
704        #[case] base_offset: usize,
705        #[case] base_density: f64,
706        #[case] rank_density: f64,
707    ) {
708        #[expect(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
709        let base_threshold = (base_density * 1024.0) as usize;
710        #[expect(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
711        let rank_threshold = (rank_density * 1024.0) as usize;
712
713        let base_source: Vec<bool> = (0..base_len + base_offset + 16)
714            .map(|i| (i * 7 + 13) % 1024 < base_threshold)
715            .collect();
716        let base_bits = base_source[base_offset..base_offset + base_len].to_vec();
717        let base = Mask::from_buffer(
718            BitBuffer::from(base_source).slice(base_offset..base_offset + base_len),
719        );
720
721        let rank_len = base.true_count();
722        let rank_bits: Vec<bool> = (0..rank_len)
723            .map(|i| (i * 11 + 7) % 1024 < rank_threshold)
724            .collect();
725        let rank_from_buffer = Mask::from_buffer(BitBuffer::from(rank_bits.clone()));
726        let rank_indices_vec = rank_bits
727            .iter()
728            .enumerate()
729            .filter_map(|(idx, &v)| v.then_some(idx))
730            .collect::<Vec<_>>();
731        let rank_from_indices = Mask::from_indices(rank_len, rank_indices_vec);
732
733        let expected = expected_intersect_by_rank(&base_bits, &rank_bits);
734
735        assert_eq!(
736            base.intersect_by_rank(&rank_from_buffer),
737            expected,
738            "uncached rank"
739        );
740        assert_eq!(
741            base.intersect_by_rank(&rank_from_indices),
742            expected,
743            "cached rank"
744        );
745    }
746
747    #[rstest]
748    #[case::short(37, 0, 0)]
749    #[case::base_offset(257, 5, 0)]
750    #[case::rank_offset(257, 0, 3)]
751    #[case::both_offsets(513, 6, 5)]
752    fn test_intersect_by_rank_bitbuffer_paths_with_offsets(
753        #[case] base_len: usize,
754        #[case] base_offset: usize,
755        #[case] rank_offset: usize,
756    ) {
757        let base_source: Vec<bool> = (0..base_len + base_offset + 16)
758            .map(|i| (i % 3 == 0) ^ (i % 11 == 0) ^ (i % 17 == 0))
759            .collect();
760        let base_bits = base_source[base_offset..base_offset + base_len].to_vec();
761        let base = Mask::from_buffer(
762            BitBuffer::from(base_source).slice(base_offset..base_offset + base_len),
763        );
764
765        let rank_len = base.true_count();
766        let rank_bits: Vec<bool> = (0..rank_len)
767            .map(|i| (i % 5 == 0) || (i % 13 == 3))
768            .collect();
769        let mut rank_source = vec![false; rank_offset];
770        rank_source.extend(rank_bits.iter().copied());
771        rank_source.extend([true, false, true, false, true, false, true, false]);
772
773        let rank_from_buffer = Mask::from_buffer(
774            BitBuffer::from(rank_source).slice(rank_offset..rank_offset + rank_len),
775        );
776        let rank_indices = rank_bits
777            .iter()
778            .enumerate()
779            .filter_map(|(idx, &value)| value.then_some(idx))
780            .collect::<Vec<_>>();
781        let rank_from_indices = Mask::from_indices(rank_len, rank_indices);
782
783        let expected = expected_intersect_by_rank(&base_bits, &rank_bits);
784
785        assert_eq!(base.intersect_by_rank(&rank_from_buffer), expected);
786        assert_eq!(base.intersect_by_rank(&rank_from_indices), expected);
787    }
788
789    fn expected_intersect_by_rank(base_bits: &[bool], rank_bits: &[bool]) -> Mask {
790        let mut rank = 0usize;
791        Mask::from_iter(base_bits.iter().map(|&is_set| {
792            if is_set {
793                let keep = rank_bits[rank];
794                rank += 1;
795                keep
796            } else {
797                false
798            }
799        }))
800    }
801}