Skip to main content

vortex_array/arrays/primitive/compute/
rules.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::ArrayRef;
7use crate::IntoArray;
8use crate::array::ArrayView;
9use crate::arrays::Masked;
10use crate::arrays::Primitive;
11use crate::arrays::PrimitiveArray;
12use crate::arrays::slice::SliceReduceAdaptor;
13use crate::optimizer::rules::ArrayParentReduceRule;
14use crate::optimizer::rules::ParentRuleSet;
15use crate::scalar_fn::fns::cast::CastReduceAdaptor;
16use crate::scalar_fn::fns::mask::MaskReduceAdaptor;
17
18pub(crate) const RULES: ParentRuleSet<Primitive> = ParentRuleSet::new(&[
19    ParentRuleSet::lift(&PrimitiveMaskedValidityRule),
20    ParentRuleSet::lift(&CastReduceAdaptor(Primitive)),
21    ParentRuleSet::lift(&MaskReduceAdaptor(Primitive)),
22    ParentRuleSet::lift(&SliceReduceAdaptor(Primitive)),
23]);
24
25/// Rule to push down validity masking from MaskedArray parent into PrimitiveArray child.
26///
27/// When a PrimitiveArray is wrapped by a MaskedArray, this rule merges the mask's validity
28/// with the PrimitiveArray's existing validity, eliminating the need for the MaskedArray wrapper.
29#[derive(Default, Debug)]
30pub struct PrimitiveMaskedValidityRule;
31
32impl ArrayParentReduceRule<Primitive> for PrimitiveMaskedValidityRule {
33    type Parent = Masked;
34
35    fn reduce_parent(
36        &self,
37        array: ArrayView<'_, Primitive>,
38        parent: ArrayView<'_, Masked>,
39        _child_idx: usize,
40    ) -> VortexResult<Option<ArrayRef>> {
41        // TODO(joe): make this lazy
42        // Merge the parent's validity mask into the child's validity
43        let new_validity = array.validity()?.and(parent.validity()?)?;
44
45        // SAFETY: masking validity does not change PrimitiveArray invariants
46        let masked_array = unsafe {
47            PrimitiveArray::new_unchecked_from_handle(
48                array.buffer_handle().clone(),
49                array.ptype(),
50                new_validity,
51            )
52        };
53
54        Ok(Some(masked_array.into_array()))
55    }
56}