vortex_compute/take/vector/
bool.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Take implementations for [`BoolVector`].
5//!
6//! This module includes an optimization for small boolean value arrays (typical of dictionary
7//! encoding) that avoids element-wise indexing when possible.
8
9use std::ops::BitAnd;
10use std::ops::Not;
11
12use vortex_buffer::BitBuffer;
13use vortex_dtype::UnsignedPType;
14use vortex_mask::Mask;
15use vortex_vector::VectorOps;
16use vortex_vector::bool::BoolVector;
17use vortex_vector::primitive::PVector;
18
19use crate::take::Take;
20
21// TODO(connor): Figure out good numbers for these heuristics.
22
23/// The maximum length of a values array for which we unconditionally apply the optimized take.
24const OPTIMIZED_TAKE_MAX_VALUES_LEN: usize = 8;
25
26/// The minimum ratio of `indices.len() / values.len()` for which we apply the optimized take.
27const OPTIMIZED_TAKE_MIN_RATIO: usize = 2;
28
29/// Returns whether to use the optimized take path based on heuristics.
30fn should_use_optimized_take(values_len: usize, indices_len: usize) -> bool {
31    values_len <= OPTIMIZED_TAKE_MAX_VALUES_LEN
32        || indices_len >= OPTIMIZED_TAKE_MIN_RATIO * values_len
33}
34
35impl<I: UnsignedPType> Take<PVector<I>> for &BoolVector {
36    type Output = BoolVector;
37
38    fn take(self, indices: &PVector<I>) -> BoolVector {
39        if indices.validity().all_true() {
40            // No null indices, delegate to slice implementation.
41            self.take(indices.elements().as_slice())
42        } else {
43            // Has null indices, need to propagate nulls.
44            take_with_nullable_indices(self, indices)
45        }
46    }
47}
48
49impl<I: UnsignedPType> Take<[I]> for &BoolVector {
50    type Output = BoolVector;
51
52    fn take(self, indices: &[I]) -> BoolVector {
53        if should_use_optimized_take(self.len(), indices.len()) {
54            optimized_take(self, indices, || self.validity().take(indices))
55        } else {
56            default_take(self, indices)
57        }
58    }
59}
60
61/// Default element-wise take from a slice of indices.
62pub fn default_take<I: UnsignedPType>(values: &BoolVector, indices: &[I]) -> BoolVector {
63    let taken_bits = values.bits().take(indices);
64    let taken_validity = values.validity().take(indices);
65
66    debug_assert_eq!(taken_bits.len(), taken_validity.len());
67
68    // SAFETY: Both components were taken with the same indices, so they have the same length.
69    unsafe { BoolVector::new_unchecked(taken_bits, taken_validity) }
70}
71
72/// Take with nullable indices, propagating nulls from both values and indices.
73fn take_with_nullable_indices<I: UnsignedPType>(
74    values: &BoolVector,
75    indices: &PVector<I>,
76) -> BoolVector {
77    let indices_slice = indices.elements().as_slice();
78    let indices_validity = indices.validity();
79
80    // Validity must combine value validity with index validity.
81    let compute_validity = || {
82        values
83            .validity()
84            .take(indices_slice)
85            .bitand(indices_validity)
86    };
87
88    if should_use_optimized_take(values.len(), indices.len()) {
89        optimized_take(values, indices_slice, compute_validity)
90    } else {
91        // We ignore index nullability when taking the bits since the validity mask handles nulls.
92        let taken_bits = values.bits().take(indices_slice);
93        let taken_validity = compute_validity();
94
95        debug_assert_eq!(taken_bits.len(), taken_validity.len());
96
97        // SAFETY: Both components were taken with the same indices, so they have the same length.
98        unsafe { BoolVector::new_unchecked(taken_bits, taken_validity) }
99    }
100}
101
102// TODO(connor): Use the generic `compare` implementation when that gets implemented.
103
104/// Creates a [`BitBuffer`] where each bit is set iff the corresponding index equals `target`.
105fn broadcast_index_comparison<I: UnsignedPType>(indices: &[I], target: usize) -> BitBuffer {
106    BitBuffer::collect_bool(indices.len(), |i| {
107        // SAFETY: `i` is in bounds since `collect_bool` iterates from 0..len.
108        let index: usize = unsafe { indices.get_unchecked(i).as_() };
109        index == target
110    })
111}
112
113/// Optimized take for boolean vectors with small value arrays.
114///
115/// Since booleans can only be `true` or `false`, we can optimize these specific cases:
116///
117/// - All of the values are `true`, so create a [`BoolVector`] with `n` `true`s.
118/// - All of the values are `false`, so create a [`BoolVector`] with `n` `false`s.
119/// - There is a single `true` value, so compare indices against that index.
120/// - There is a single `false` value, so compare indices against that index and negate.
121/// - Otherwise, there are multiple `true`s and `false`s in the `values` vector and we must do a
122///   normal `take` on it.
123///
124/// The `compute_validity` closure computes the output validity mask, allowing callers to handle
125/// nullable vs non-nullable indices differently.
126pub fn optimized_take<I: UnsignedPType>(
127    values: &BoolVector,
128    indices: &[I],
129    compute_validity: impl FnOnce() -> Mask,
130) -> BoolVector {
131    let len = indices.len();
132    let (trues, falses) = count_true_and_false_positions(values);
133
134    let (taken_bits, taken_validity) = match (trues, falses) {
135        // All values are null.
136        (Count::None, Count::None) => (BitBuffer::new_unset(len), Mask::new_false(len)),
137
138        // No true values exist, so all output bits are false.
139        (Count::None, _) => (BitBuffer::new_unset(len), compute_validity()),
140
141        // No false values exist, so all output bits are true.
142        (_, Count::None) => (BitBuffer::new_set(len), compute_validity()),
143
144        // Single true value: output bit is set iff index equals the true position.
145        (Count::One(true_idx), _) => {
146            let bits = broadcast_index_comparison(indices, true_idx);
147            (bits, compute_validity())
148        }
149
150        // Single false value: output bit is set iff index does NOT equal the false position.
151        (_, Count::One(false_idx)) => {
152            let bits = broadcast_index_comparison(indices, false_idx).not();
153            (bits, compute_validity())
154        }
155
156        // Multiple true and false values, so fall back to the default `take`.
157        (Count::More, Count::More) => {
158            let taken_bits = values.bits().take(indices);
159            (taken_bits, compute_validity())
160        }
161    };
162
163    debug_assert_eq!(taken_bits.len(), taken_validity.len());
164
165    // SAFETY: Both components have length `len` (the length of `indices`).
166    unsafe { BoolVector::new_unchecked(taken_bits, taken_validity) }
167}
168
169/// Represents the count of true or false values found in a boolean vector.
170enum Count {
171    /// No values of this kind were found.
172    None,
173    /// Exactly one value was found at the given index.
174    One(usize),
175    /// Two or more values were found.
176    More,
177}
178
179/// Scans a boolean vector to determine how many true and false values exist.
180///
181/// Returns `(true_count, false_count)` where each is a [`Count`] indicating none, one (with
182/// position), or more than one. Null values are skipped. The scan exits early once both counts
183/// reach "more than one".
184fn count_true_and_false_positions(values: &BoolVector) -> (Count, Count) {
185    let bits = values.bits();
186    let validity = values.validity();
187
188    let mut first_true: Option<usize> = None;
189    let mut found_second_true = false;
190    let mut first_false: Option<usize> = None;
191    let mut found_second_false = false;
192
193    for idx in 0..values.len() {
194        if !validity.value(idx) {
195            continue;
196        }
197
198        if bits.value(idx) {
199            if first_true.is_none() {
200                first_true = Some(idx);
201            } else {
202                found_second_true = true;
203            }
204        } else if first_false.is_none() {
205            first_false = Some(idx);
206        } else {
207            found_second_false = true;
208        }
209
210        if found_second_true && found_second_false {
211            break;
212        }
213    }
214
215    let true_count = match (first_true, found_second_true) {
216        (None, _) => Count::None,
217        (Some(idx), false) => Count::One(idx),
218        (Some(_), true) => Count::More,
219    };
220
221    let false_count = match (first_false, found_second_false) {
222        (None, _) => Count::None,
223        (Some(idx), false) => Count::One(idx),
224        (Some(_), true) => Count::More,
225    };
226
227    (true_count, false_count)
228}