Skip to main content

vortex_array/arrays/decimal/compute/
rules.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Range;
5
6use vortex_error::VortexResult;
7
8use crate::ArrayRef;
9use crate::IntoArray;
10use crate::array::ArrayView;
11use crate::arrays::Decimal;
12use crate::arrays::DecimalArray;
13use crate::arrays::Masked;
14use crate::arrays::slice::SliceReduce;
15use crate::arrays::slice::SliceReduceAdaptor;
16use crate::match_each_decimal_value_type;
17use crate::optimizer::rules::ArrayParentReduceRule;
18use crate::optimizer::rules::ParentRuleSet;
19use crate::scalar_fn::fns::mask::MaskReduceAdaptor;
20
21pub(crate) static RULES: ParentRuleSet<Decimal> = ParentRuleSet::new(&[
22    ParentRuleSet::lift(&DecimalMaskedValidityRule),
23    ParentRuleSet::lift(&MaskReduceAdaptor(Decimal)),
24    ParentRuleSet::lift(&SliceReduceAdaptor(Decimal)),
25]);
26
27/// Rule to push down validity masking from MaskedArray parent into DecimalArray child.
28///
29/// When a DecimalArray is wrapped by a MaskedArray, this rule merges the mask's validity
30/// with the DecimalArray's existing validity, eliminating the need for the MaskedArray wrapper.
31#[derive(Default, Debug)]
32pub struct DecimalMaskedValidityRule;
33
34impl ArrayParentReduceRule<Decimal> for DecimalMaskedValidityRule {
35    type Parent = Masked;
36
37    fn reduce_parent(
38        &self,
39        array: ArrayView<'_, Decimal>,
40        parent: ArrayView<'_, Masked>,
41        _child_idx: usize,
42    ) -> VortexResult<Option<ArrayRef>> {
43        // Merge the parent's validity mask into the child's validity
44        // TODO(joe): make this lazy
45        let masked_array = match_each_decimal_value_type!(array.values_type(), |D| {
46            // SAFETY: Since we are only flipping some bits in the validity, all invariants that
47            // were upheld are still upheld.
48            unsafe {
49                DecimalArray::new_unchecked(
50                    array.buffer::<D>(),
51                    array.decimal_dtype(),
52                    array.validity()?.and(parent.validity()?)?,
53                )
54            }
55            .into_array()
56        });
57
58        Ok(Some(masked_array))
59    }
60}
61
62impl SliceReduce for Decimal {
63    fn slice(array: ArrayView<'_, Self>, range: Range<usize>) -> VortexResult<Option<ArrayRef>> {
64        let result = match_each_decimal_value_type!(array.values_type(), |D| {
65            let sliced = array.buffer::<D>().slice(range.clone());
66            let validity = array.validity()?.slice(range)?;
67            // SAFETY: Slicing preserves all DecimalArray invariants
68            unsafe { DecimalArray::new_unchecked(sliced, array.decimal_dtype(), validity) }
69                .into_array()
70        });
71        Ok(Some(result))
72    }
73}