vortex_array/pipeline/
canonical.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_buffer::BooleanBuffer;
5use vortex_buffer::BufferMut;
6use vortex_dtype::{DType, NativePType, Nullability, match_each_native_ptype};
7use vortex_error::{VortexResult, vortex_bail};
8use vortex_mask::Mask;
9
10use crate::Canonical;
11use crate::arrays::{BoolArray, PrimitiveArray};
12use crate::pipeline::bits::{BitVector, BitView, BitViewMut};
13use crate::pipeline::operators::Operator;
14use crate::pipeline::query::QueryPlan;
15use crate::pipeline::types::Element;
16use crate::pipeline::vec::Vector;
17use crate::pipeline::view::ViewMut;
18use crate::pipeline::{Kernel, KernelContext, N, N_WORDS};
19use crate::validity::Validity;
20
21/// Export canonical data from a pipeline kernel with the given mask.
22pub fn export_canonical_pipeline(
23    dtype: &DType,
24    len: usize,
25    pipeline: &mut dyn Kernel,
26    mask: &Mask,
27) -> VortexResult<Canonical> {
28    match dtype {
29        DType::Bool(Nullability::NonNullable) => {
30            export_bool_nonnull_masked(mask, pipeline).map(Canonical::Bool)
31        }
32        DType::Primitive(ptype, Nullability::NonNullable) => {
33            if mask.all_true() {
34                match_each_native_ptype!(ptype, |T| {
35                    export_primitive_nonnull::<T>(len, pipeline).map(Canonical::Primitive)
36                })
37            } else {
38                match_each_native_ptype!(ptype, |T| {
39                    export_primitive_nonnull_masked::<T>(mask, pipeline).map(Canonical::Primitive)
40                })
41            }
42        }
43        _ => vortex_bail!("Expected a primitive array, got: {}", dtype),
44    }
45}
46
47/// Export canonical data from an operator expression with a starting offset and mask.
48pub fn export_canonical_pipeline_expr_offset(
49    dtype: &DType,
50    offset: usize,
51    len: usize,
52    expression: &dyn Operator,
53    mask: &Mask,
54) -> VortexResult<Canonical> {
55    let plan = QueryPlan::new(expression)?;
56    let mut pipeline = plan.executable_plan()?;
57    pipeline.seek(offset)?;
58    export_canonical_pipeline(dtype, len, &mut pipeline, mask)
59}
60
61/// Export canonical data from an operator expression with the given mask.
62pub fn export_canonical_pipeline_expr(
63    dtype: &DType,
64    len: usize,
65    expression: &dyn Operator,
66    mask: &Mask,
67) -> VortexResult<Canonical> {
68    let plan = QueryPlan::new(expression)?;
69    let mut pipeline = plan.executable_plan()?;
70    export_canonical_pipeline(dtype, len, &mut pipeline, mask)
71}
72
73fn export_primitive_nonnull<T: Element + NativePType>(
74    len: usize,
75    pipeline: &mut dyn Kernel,
76) -> VortexResult<PrimitiveArray> {
77    let capacity = len.next_multiple_of(N) + N;
78
79    let mut elements = BufferMut::<T>::with_capacity(capacity);
80    unsafe { elements.set_len(capacity) };
81
82    let mut remaining = len;
83    while remaining >= N {
84        let mut elements_view = ViewMut::new(&mut elements[len - remaining..][..N], None);
85        let dummy_ctx = KernelContext::default();
86        pipeline.step(&dummy_ctx, BitView::all_true(), &mut elements_view)?;
87        remaining -= N;
88    }
89
90    if remaining > 0 {
91        let mut elements_view = ViewMut::new(&mut elements[len - remaining..][..N], None);
92        let mask = BitVector::true_until(remaining);
93        let dummy_ctx = KernelContext::default();
94        pipeline.step(&dummy_ctx, mask.as_view(), &mut elements_view)?;
95    }
96
97    unsafe { elements.set_len(len) };
98
99    Ok(PrimitiveArray::new(
100        elements.freeze(),
101        Validity::NonNullable,
102    ))
103}
104
105fn export_primitive_nonnull_masked<T: Element + NativePType>(
106    mask: &Mask,
107    pipeline: &mut dyn Kernel,
108) -> VortexResult<PrimitiveArray> {
109    let len = mask.len();
110    let capacity = mask.true_count().next_multiple_of(N) + N;
111
112    let mut elements = BufferMut::<T>::with_capacity(capacity);
113    unsafe { elements.set_len(capacity) };
114
115    let mask_buffer = mask.to_boolean_buffer();
116    let mut mask_iter = mask_buffer.bit_chunks().iter_padded();
117
118    let mut mask = [0usize; N_WORDS];
119    let mut mask_view = BitViewMut::new(&mut mask);
120
121    let mut offset = 0;
122    let mut remaining = len;
123    while remaining > 0 {
124        let mut elements_view = ViewMut::new(&mut elements[offset..][..N], None);
125
126        mask_view.clear();
127        mask_view.fill_with_words(&mut mask_iter);
128
129        let dummy_ctx = KernelContext::default();
130        pipeline.step(&dummy_ctx, mask_view.as_view(), &mut elements_view)?;
131        offset += mask_view.true_count();
132
133        remaining = remaining.saturating_sub(N);
134    }
135
136    unsafe { elements.set_len(offset) };
137
138    Ok(PrimitiveArray::new(
139        elements.freeze(),
140        Validity::NonNullable,
141    ))
142}
143
144fn export_bool_nonnull_masked(mask: &Mask, pipeline: &mut dyn Kernel) -> VortexResult<BoolArray> {
145    let len = mask.len();
146    let true_count = mask.true_count();
147
148    let mut elements_buffer = Vector::new::<bool>();
149    let mut elements_buffer_mut = elements_buffer.as_view_mut();
150
151    let mask_buffer = mask.to_boolean_buffer();
152    let mut mask_iter = mask_buffer.bit_chunks().iter_padded();
153
154    let mut mask = [0usize; N_WORDS];
155    let mut mask_view = BitViewMut::new(&mut mask);
156
157    // Fast path: collect all bools first, then use collect_bool for optimal packing
158    let mut all_bools: Vec<bool> = Vec::with_capacity(true_count);
159    let mut remaining = len;
160
161    while remaining > 0 {
162        mask_view.clear();
163        mask_view.fill_with_words(&mut mask_iter);
164
165        // Handle partial iteration on the last chunk
166        let current_len = remaining.min(N);
167        if current_len < N {
168            mask_view.intersect_prefix(current_len);
169        }
170
171        let dummy_ctx = KernelContext::default();
172        pipeline.step(&dummy_ctx, mask_view.as_view(), &mut elements_buffer_mut)?;
173
174        // Collect bools efficiently with unsafe for better performance
175        let bool_slice = elements_buffer_mut.as_slice::<bool>();
176        let count = mask_view.true_count();
177
178        // Unsafe version to avoid bounds checking in hot path
179        let old_len = all_bools.len();
180        unsafe {
181            all_bools.set_len(old_len + count);
182            std::ptr::copy_nonoverlapping(
183                bool_slice.as_ptr(),
184                all_bools.as_mut_ptr().add(old_len),
185                count,
186            );
187        }
188
189        remaining = remaining.saturating_sub(N);
190    }
191
192    // Use collect_bool for optimal bit packing - avoid closure overhead
193    let values = BooleanBuffer::collect_bool(all_bools.len(), |idx| unsafe {
194        *all_bools.get_unchecked(idx)
195    });
196
197    Ok(BoolArray::new(values, Validity::NonNullable))
198}