vortex_compute/take/vector/
bool.rs1use std::ops::BitAnd;
10use std::ops::Not;
11
12use vortex_buffer::BitBuffer;
13use vortex_dtype::UnsignedPType;
14use vortex_mask::Mask;
15use vortex_vector::VectorOps;
16use vortex_vector::bool::BoolVector;
17use vortex_vector::primitive::PVector;
18
19use crate::take::Take;
20
21const OPTIMIZED_TAKE_MAX_VALUES_LEN: usize = 8;
25
26const OPTIMIZED_TAKE_MIN_RATIO: usize = 2;
28
29fn should_use_optimized_take(values_len: usize, indices_len: usize) -> bool {
31 values_len <= OPTIMIZED_TAKE_MAX_VALUES_LEN
32 || indices_len >= OPTIMIZED_TAKE_MIN_RATIO * values_len
33}
34
35impl<I: UnsignedPType> Take<PVector<I>> for &BoolVector {
36 type Output = BoolVector;
37
38 fn take(self, indices: &PVector<I>) -> BoolVector {
39 if indices.validity().all_true() {
40 self.take(indices.elements().as_slice())
42 } else {
43 take_with_nullable_indices(self, indices)
45 }
46 }
47}
48
49impl<I: UnsignedPType> Take<[I]> for &BoolVector {
50 type Output = BoolVector;
51
52 fn take(self, indices: &[I]) -> BoolVector {
53 if should_use_optimized_take(self.len(), indices.len()) {
54 optimized_take(self, indices, || self.validity().take(indices))
55 } else {
56 default_take(self, indices)
57 }
58 }
59}
60
61pub fn default_take<I: UnsignedPType>(values: &BoolVector, indices: &[I]) -> BoolVector {
63 let taken_bits = values.bits().take(indices);
64 let taken_validity = values.validity().take(indices);
65
66 debug_assert_eq!(taken_bits.len(), taken_validity.len());
67
68 unsafe { BoolVector::new_unchecked(taken_bits, taken_validity) }
70}
71
72fn take_with_nullable_indices<I: UnsignedPType>(
74 values: &BoolVector,
75 indices: &PVector<I>,
76) -> BoolVector {
77 let indices_slice = indices.elements().as_slice();
78 let indices_validity = indices.validity();
79
80 let compute_validity = || {
82 values
83 .validity()
84 .take(indices_slice)
85 .bitand(indices_validity)
86 };
87
88 if should_use_optimized_take(values.len(), indices.len()) {
89 optimized_take(values, indices_slice, compute_validity)
90 } else {
91 let taken_bits = values.bits().take(indices_slice);
93 let taken_validity = compute_validity();
94
95 debug_assert_eq!(taken_bits.len(), taken_validity.len());
96
97 unsafe { BoolVector::new_unchecked(taken_bits, taken_validity) }
99 }
100}
101
102fn broadcast_index_comparison<I: UnsignedPType>(indices: &[I], target: usize) -> BitBuffer {
106 BitBuffer::collect_bool(indices.len(), |i| {
107 let index: usize = unsafe { indices.get_unchecked(i).as_() };
109 index == target
110 })
111}
112
113pub fn optimized_take<I: UnsignedPType>(
127 values: &BoolVector,
128 indices: &[I],
129 compute_validity: impl FnOnce() -> Mask,
130) -> BoolVector {
131 let len = indices.len();
132 let (trues, falses) = count_true_and_false_positions(values);
133
134 let (taken_bits, taken_validity) = match (trues, falses) {
135 (Count::None, Count::None) => (BitBuffer::new_unset(len), Mask::new_false(len)),
137
138 (Count::None, _) => (BitBuffer::new_unset(len), compute_validity()),
140
141 (_, Count::None) => (BitBuffer::new_set(len), compute_validity()),
143
144 (Count::One(true_idx), _) => {
146 let bits = broadcast_index_comparison(indices, true_idx);
147 (bits, compute_validity())
148 }
149
150 (_, Count::One(false_idx)) => {
152 let bits = broadcast_index_comparison(indices, false_idx).not();
153 (bits, compute_validity())
154 }
155
156 (Count::More, Count::More) => {
158 let taken_bits = values.bits().take(indices);
159 (taken_bits, compute_validity())
160 }
161 };
162
163 debug_assert_eq!(taken_bits.len(), taken_validity.len());
164
165 unsafe { BoolVector::new_unchecked(taken_bits, taken_validity) }
167}
168
169enum Count {
171 None,
173 One(usize),
175 More,
177}
178
179fn count_true_and_false_positions(values: &BoolVector) -> (Count, Count) {
185 let bits = values.bits();
186 let validity = values.validity();
187
188 let mut first_true: Option<usize> = None;
189 let mut found_second_true = false;
190 let mut first_false: Option<usize> = None;
191 let mut found_second_false = false;
192
193 for idx in 0..values.len() {
194 if !validity.value(idx) {
195 continue;
196 }
197
198 if bits.value(idx) {
199 if first_true.is_none() {
200 first_true = Some(idx);
201 } else {
202 found_second_true = true;
203 }
204 } else if first_false.is_none() {
205 first_false = Some(idx);
206 } else {
207 found_second_false = true;
208 }
209
210 if found_second_true && found_second_false {
211 break;
212 }
213 }
214
215 let true_count = match (first_true, found_second_true) {
216 (None, _) => Count::None,
217 (Some(idx), false) => Count::One(idx),
218 (Some(_), true) => Count::More,
219 };
220
221 let false_count = match (first_false, found_second_false) {
222 (None, _) => Count::None,
223 (Some(idx), false) => Count::One(idx),
224 (Some(_), true) => Count::More,
225 };
226
227 (true_count, false_count)
228}