Skip to main content

vortex_array/arrays/primitive/compute/
rules.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::ArrayRef;
7use crate::IntoArray;
8use crate::arrays::MaskedArray;
9use crate::arrays::MaskedVTable;
10use crate::arrays::PrimitiveArray;
11use crate::arrays::PrimitiveVTable;
12use crate::arrays::SliceReduceAdaptor;
13use crate::optimizer::rules::ArrayParentReduceRule;
14use crate::optimizer::rules::ParentRuleSet;
15use crate::scalar_fn::fns::mask::MaskReduceAdaptor;
16use crate::vtable::ValidityHelper;
17
18pub(crate) const RULES: ParentRuleSet<PrimitiveVTable> = ParentRuleSet::new(&[
19    ParentRuleSet::lift(&PrimitiveMaskedValidityRule),
20    ParentRuleSet::lift(&MaskReduceAdaptor(PrimitiveVTable)),
21    ParentRuleSet::lift(&SliceReduceAdaptor(PrimitiveVTable)),
22]);
23
24/// Rule to push down validity masking from MaskedArray parent into PrimitiveArray child.
25///
26/// When a PrimitiveArray is wrapped by a MaskedArray, this rule merges the mask's validity
27/// with the PrimitiveArray's existing validity, eliminating the need for the MaskedArray wrapper.
28#[derive(Default, Debug)]
29pub struct PrimitiveMaskedValidityRule;
30
31impl ArrayParentReduceRule<PrimitiveVTable> for PrimitiveMaskedValidityRule {
32    type Parent = MaskedVTable;
33
34    fn reduce_parent(
35        &self,
36        array: &PrimitiveArray,
37        parent: &MaskedArray,
38        _child_idx: usize,
39    ) -> VortexResult<Option<ArrayRef>> {
40        // TODO(joe): make this lazy
41        // Merge the parent's validity mask into the child's validity
42        let new_validity = array.validity().clone().and(parent.validity().clone())?;
43
44        // SAFETY: masking validity does not change PrimitiveArray invariants
45        let masked_array = unsafe {
46            PrimitiveArray::new_unchecked_from_handle(
47                array.buffer_handle().clone(),
48                array.ptype(),
49                new_validity,
50            )
51        };
52
53        Ok(Some(masked_array.into_array()))
54    }
55}