Skip to main content

vortex_array/scalar_fn/fns/mask/
kernel.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_error::vortex_err;
6
7use crate::ArrayRef;
8use crate::ExecutionCtx;
9use crate::array::ArrayView;
10use crate::array::VTable;
11use crate::arrays::Bool;
12use crate::arrays::scalar_fn::ExactScalarFn;
13use crate::arrays::scalar_fn::ScalarFnArrayView;
14use crate::kernel::ExecuteParentKernel;
15use crate::optimizer::rules::ArrayParentReduceRule;
16use crate::scalar_fn::fns::mask::Mask as MaskExpr;
17
18/// Mask an array without reading buffers.
19///
20/// This trait is for mask implementations that can operate purely on array metadata and
21/// structure without needing to read or execute on the underlying buffers. Implementations
22/// should return `None` if masking requires buffer access.
23///
24/// The `mask` parameter is a boolean array where true=keep/valid, false=null-out.
25///
26/// # Preconditions
27///
28/// The mask is guaranteed to have the same length as the array. Trivial cases
29/// (`AllValid`, `AllInvalid`, `NonNullable`) are handled by the caller before dispatch.
30pub trait MaskReduce: VTable {
31    fn mask(array: ArrayView<'_, Self>, mask: &ArrayRef) -> VortexResult<Option<ArrayRef>>;
32}
33
34/// Mask an array, potentially reading buffers.
35///
36/// Unlike [`MaskReduce`], this trait is for mask implementations that may need to read
37/// and execute on the underlying buffers to produce the masked result.
38///
39/// The `mask` parameter is a boolean array where true=keep/valid, false=null-out.
40///
41/// # Preconditions
42///
43/// The mask is guaranteed to have the same length as the array. Trivial cases
44/// (`AllValid`, `AllInvalid`, `NonNullable`) are handled by the caller before dispatch.
45pub trait MaskKernel: VTable {
46    fn mask(
47        array: ArrayView<'_, Self>,
48        mask: &ArrayRef,
49        ctx: &mut ExecutionCtx,
50    ) -> VortexResult<Option<ArrayRef>>;
51}
52
53/// Adaptor that wraps a [`MaskReduce`] impl as an [`ArrayParentReduceRule`].
54#[derive(Default, Debug)]
55pub struct MaskReduceAdaptor<V>(pub V);
56
57impl<V> ArrayParentReduceRule<V> for MaskReduceAdaptor<V>
58where
59    V: MaskReduce,
60{
61    type Parent = ExactScalarFn<MaskExpr>;
62
63    fn reduce_parent(
64        &self,
65        array: ArrayView<'_, V>,
66        parent: ScalarFnArrayView<'_, MaskExpr>,
67        child_idx: usize,
68    ) -> VortexResult<Option<ArrayRef>> {
69        // Only reduce the input child (index 0), not the mask child (index 1).
70        if child_idx != 0 {
71            return Ok(None);
72        }
73        // The mask child (child 1) is a non-nullable BoolArray where true=keep.
74        // If it's not yet a BoolArray, we can't reduce without execution.
75        let parent_ref: ArrayRef = (*parent).clone();
76        let mask_child = parent_ref
77            .nth_child(1)
78            .ok_or_else(|| vortex_err!("Mask expression must have 2 children"))?;
79        if mask_child.as_opt::<Bool>().is_none() {
80            return Ok(None);
81        };
82        <V as MaskReduce>::mask(array, &mask_child)
83    }
84}
85
86/// Adaptor that wraps a [`MaskKernel`] impl as an [`ExecuteParentKernel`].
87#[derive(Default, Debug)]
88pub struct MaskExecuteAdaptor<V>(pub V);
89
90impl<V> ExecuteParentKernel<V> for MaskExecuteAdaptor<V>
91where
92    V: MaskKernel,
93{
94    type Parent = ExactScalarFn<MaskExpr>;
95
96    fn execute_parent(
97        &self,
98        array: ArrayView<'_, V>,
99        parent: ScalarFnArrayView<'_, MaskExpr>,
100        child_idx: usize,
101        ctx: &mut ExecutionCtx,
102    ) -> VortexResult<Option<ArrayRef>> {
103        // Only execute the input child (index 0), not the mask child (index 1).
104        if child_idx != 0 {
105            return Ok(None);
106        }
107        let mask_child = parent
108            .nth_child(1)
109            .ok_or_else(|| vortex_err!("Mask expression must have 2 children"))?;
110        <V as MaskKernel>::mask(array, &mask_child, ctx)
111    }
112}