Skip to main content

vortex_array/arrays/decimal/compute/
rules.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Range;
5
6use vortex_error::VortexResult;
7
8use crate::ArrayRef;
9use crate::IntoArray;
10use crate::array::ArrayView;
11use crate::arrays::Decimal;
12use crate::arrays::DecimalArray;
13use crate::arrays::Masked;
14use crate::arrays::slice::SliceReduce;
15use crate::arrays::slice::SliceReduceAdaptor;
16use crate::match_each_decimal_value_type;
17use crate::optimizer::rules::ArrayParentReduceRule;
18use crate::optimizer::rules::ParentRuleSet;
19use crate::scalar_fn::fns::cast::CastReduceAdaptor;
20use crate::scalar_fn::fns::mask::MaskReduceAdaptor;
21
22pub(crate) static RULES: ParentRuleSet<Decimal> = ParentRuleSet::new(&[
23    ParentRuleSet::lift(&DecimalMaskedValidityRule),
24    ParentRuleSet::lift(&CastReduceAdaptor(Decimal)),
25    ParentRuleSet::lift(&MaskReduceAdaptor(Decimal)),
26    ParentRuleSet::lift(&SliceReduceAdaptor(Decimal)),
27]);
28
29/// Rule to push down validity masking from MaskedArray parent into DecimalArray child.
30///
31/// When a DecimalArray is wrapped by a MaskedArray, this rule merges the mask's validity
32/// with the DecimalArray's existing validity, eliminating the need for the MaskedArray wrapper.
33#[derive(Default, Debug)]
34pub struct DecimalMaskedValidityRule;
35
36impl ArrayParentReduceRule<Decimal> for DecimalMaskedValidityRule {
37    type Parent = Masked;
38
39    fn reduce_parent(
40        &self,
41        array: ArrayView<'_, Decimal>,
42        parent: ArrayView<'_, Masked>,
43        _child_idx: usize,
44    ) -> VortexResult<Option<ArrayRef>> {
45        // Merge the parent's validity mask into the child's validity
46        // TODO(joe): make this lazy
47        let masked_array = match_each_decimal_value_type!(array.values_type(), |D| {
48            // SAFETY: Since we are only flipping some bits in the validity, all invariants that
49            // were upheld are still upheld.
50            unsafe {
51                DecimalArray::new_unchecked(
52                    array.buffer::<D>(),
53                    array.decimal_dtype(),
54                    array.validity()?.and(parent.validity()?)?,
55                )
56            }
57            .into_array()
58        });
59
60        Ok(Some(masked_array))
61    }
62}
63
64impl SliceReduce for Decimal {
65    fn slice(array: ArrayView<'_, Self>, range: Range<usize>) -> VortexResult<Option<ArrayRef>> {
66        let result = match_each_decimal_value_type!(array.values_type(), |D| {
67            let sliced = array.buffer::<D>().slice(range.clone());
68            let validity = array.validity()?.slice(range)?;
69            // SAFETY: Slicing preserves all DecimalArray invariants
70            unsafe { DecimalArray::new_unchecked(sliced, array.decimal_dtype(), validity) }
71                .into_array()
72        });
73        Ok(Some(result))
74    }
75}