vortex_array/arrays/decimal/vtable/
operator.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_compute::filter::Filter;
5use vortex_dtype::{PrecisionScale, match_each_decimal_value_type};
6use vortex_error::VortexResult;
7use vortex_vector::decimal::DVector;
8
9use crate::arrays::{DecimalArray, DecimalVTable, MaskedVTable};
10use crate::execution::{BatchKernelRef, BindCtx, kernel};
11use crate::vtable::{OperatorVTable, ValidityHelper};
12use crate::{ArrayRef, IntoArray};
13
14impl OperatorVTable<DecimalVTable> for DecimalVTable {
15    fn bind(
16        array: &DecimalArray,
17        selection: Option<&ArrayRef>,
18        ctx: &mut dyn BindCtx,
19    ) -> VortexResult<BatchKernelRef> {
20        let mask = ctx.bind_selection(array.len(), selection)?;
21        let validity = ctx.bind_validity(array.validity(), array.len(), selection)?;
22
23        match_each_decimal_value_type!(array.values_type(), |D| {
24            let elements = array.buffer::<D>();
25            let ps = PrecisionScale::<D>::try_from(&array.decimal_dtype())?;
26
27            Ok(kernel(move || {
28                let mask = mask.execute()?;
29                let validity = validity.execute()?;
30
31                // Note that validity already has the mask applied so we only need to apply it to
32                // the elements.
33                let elements = elements.filter(&mask);
34
35                Ok(DVector::<D>::try_new(ps, elements, validity)?.into())
36            }))
37        })
38    }
39
40    fn reduce_parent(
41        array: &DecimalArray,
42        parent: &ArrayRef,
43        _child_idx: usize,
44    ) -> VortexResult<Option<ArrayRef>> {
45        // Push-down masking of `validity` from the parent `MaskedArray`.
46        if let Some(masked) = parent.as_opt::<MaskedVTable>() {
47            let masked_array = match_each_decimal_value_type!(array.values_type(), |D| {
48                // SAFETY: Since we are only flipping some bits in the validity, all invariants that
49                // were upheld are still upheld.
50                unsafe {
51                    DecimalArray::new_unchecked(
52                        array.buffer::<D>(),
53                        array.decimal_dtype(),
54                        array.validity().clone().and(masked.validity().clone()),
55                    )
56                }
57                .into_array()
58            });
59
60            return Ok(Some(masked_array));
61        }
62
63        Ok(None)
64    }
65}