Skip to main content

vortex_array/arrays/decimal/compute/
rules.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Range;
5
6use vortex_dtype::match_each_decimal_value_type;
7use vortex_error::VortexResult;
8
9use crate::ArrayRef;
10use crate::IntoArray;
11use crate::arrays::DecimalArray;
12use crate::arrays::DecimalVTable;
13use crate::arrays::MaskedArray;
14use crate::arrays::MaskedVTable;
15use crate::arrays::SliceReduce;
16use crate::arrays::SliceReduceAdaptor;
17use crate::expr::MaskReduceAdaptor;
18use crate::optimizer::rules::ArrayParentReduceRule;
19use crate::optimizer::rules::ParentRuleSet;
20use crate::vtable::ValidityHelper;
21
22pub(crate) static RULES: ParentRuleSet<DecimalVTable> = ParentRuleSet::new(&[
23    ParentRuleSet::lift(&DecimalMaskedValidityRule),
24    ParentRuleSet::lift(&MaskReduceAdaptor(DecimalVTable)),
25    ParentRuleSet::lift(&SliceReduceAdaptor(DecimalVTable)),
26]);
27
28/// Rule to push down validity masking from MaskedArray parent into DecimalArray child.
29///
30/// When a DecimalArray is wrapped by a MaskedArray, this rule merges the mask's validity
31/// with the DecimalArray's existing validity, eliminating the need for the MaskedArray wrapper.
32#[derive(Default, Debug)]
33pub struct DecimalMaskedValidityRule;
34
35impl ArrayParentReduceRule<DecimalVTable> for DecimalMaskedValidityRule {
36    type Parent = MaskedVTable;
37
38    fn reduce_parent(
39        &self,
40        array: &DecimalArray,
41        parent: &MaskedArray,
42        _child_idx: usize,
43    ) -> VortexResult<Option<ArrayRef>> {
44        // Merge the parent's validity mask into the child's validity
45        // TODO(joe): make this lazy
46        let masked_array = match_each_decimal_value_type!(array.values_type(), |D| {
47            // SAFETY: Since we are only flipping some bits in the validity, all invariants that
48            // were upheld are still upheld.
49            unsafe {
50                DecimalArray::new_unchecked(
51                    array.buffer::<D>(),
52                    array.decimal_dtype(),
53                    array.validity().clone().and(parent.validity().clone())?,
54                )
55            }
56            .into_array()
57        });
58
59        Ok(Some(masked_array))
60    }
61}
62
63impl SliceReduce for DecimalVTable {
64    fn slice(array: &Self::Array, range: Range<usize>) -> VortexResult<Option<ArrayRef>> {
65        let result = match_each_decimal_value_type!(array.values_type(), |D| {
66            let sliced = array.buffer::<D>().slice(range.clone());
67            let validity = array.validity().clone().slice(range)?;
68            // SAFETY: Slicing preserves all DecimalArray invariants
69            unsafe { DecimalArray::new_unchecked(sliced, array.decimal_dtype(), validity) }
70                .into_array()
71        });
72        Ok(Some(result))
73    }
74}