Skip to main content

vortex_array/expr/exprs/mask/
kernel.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_error::vortex_err;
6
7use crate::ArrayRef;
8use crate::ExecutionCtx;
9use crate::arrays::BoolVTable;
10use crate::arrays::ExactScalarFn;
11use crate::arrays::ScalarFnArrayView;
12use crate::expr::Mask as MaskExpr;
13use crate::kernel::ExecuteParentKernel;
14use crate::optimizer::rules::ArrayParentReduceRule;
15use crate::vtable::VTable;
16
17/// Mask an array without reading buffers.
18///
19/// This trait is for mask implementations that can operate purely on array metadata and
20/// structure without needing to read or execute on the underlying buffers. Implementations
21/// should return `None` if masking requires buffer access.
22///
23/// The `mask` parameter is a boolean array where true=keep/valid, false=null-out.
24///
25/// # Preconditions
26///
27/// The mask is guaranteed to have the same length as the array. Trivial cases
28/// (`AllValid`, `AllInvalid`, `NonNullable`) are handled by the caller before dispatch.
29pub trait MaskReduce: VTable {
30    fn mask(array: &Self::Array, mask: &ArrayRef) -> VortexResult<Option<ArrayRef>>;
31}
32
33/// Mask an array, potentially reading buffers.
34///
35/// Unlike [`MaskReduce`], this trait is for mask implementations that may need to read
36/// and execute on the underlying buffers to produce the masked result.
37///
38/// The `mask` parameter is a boolean array where true=keep/valid, false=null-out.
39///
40/// # Preconditions
41///
42/// The mask is guaranteed to have the same length as the array. Trivial cases
43/// (`AllValid`, `AllInvalid`, `NonNullable`) are handled by the caller before dispatch.
44pub trait MaskKernel: VTable {
45    fn mask(
46        array: &Self::Array,
47        mask: &ArrayRef,
48        ctx: &mut ExecutionCtx,
49    ) -> VortexResult<Option<ArrayRef>>;
50}
51
52/// Adaptor that wraps a [`MaskReduce`] impl as an [`ArrayParentReduceRule`].
53#[derive(Default, Debug)]
54pub struct MaskReduceAdaptor<V>(pub V);
55
56impl<V> ArrayParentReduceRule<V> for MaskReduceAdaptor<V>
57where
58    V: MaskReduce,
59{
60    type Parent = ExactScalarFn<MaskExpr>;
61
62    fn reduce_parent(
63        &self,
64        array: &V::Array,
65        parent: ScalarFnArrayView<'_, MaskExpr>,
66        child_idx: usize,
67    ) -> VortexResult<Option<ArrayRef>> {
68        // Only reduce the input child (index 0), not the mask child (index 1).
69        if child_idx != 0 {
70            return Ok(None);
71        }
72        // The mask child (child 1) is a non-nullable BoolArray where true=keep.
73        // If it's not yet a BoolArray, we can't reduce without execution.
74        let mask_child = parent
75            .nth_child(1)
76            .ok_or_else(|| vortex_err!("Mask expression must have 2 children"))?;
77        if mask_child.as_opt::<BoolVTable>().is_none() {
78            return Ok(None);
79        };
80        <V as MaskReduce>::mask(array, &mask_child)
81    }
82}
83
84/// Adaptor that wraps a [`MaskKernel`] impl as an [`ExecuteParentKernel`].
85#[derive(Default, Debug)]
86pub struct MaskExecuteAdaptor<V>(pub V);
87
88impl<V> ExecuteParentKernel<V> for MaskExecuteAdaptor<V>
89where
90    V: MaskKernel,
91{
92    type Parent = ExactScalarFn<MaskExpr>;
93
94    fn execute_parent(
95        &self,
96        array: &V::Array,
97        parent: ScalarFnArrayView<'_, MaskExpr>,
98        child_idx: usize,
99        ctx: &mut ExecutionCtx,
100    ) -> VortexResult<Option<ArrayRef>> {
101        // Only execute the input child (index 0), not the mask child (index 1).
102        if child_idx != 0 {
103            return Ok(None);
104        }
105        let mask_child = parent
106            .nth_child(1)
107            .ok_or_else(|| vortex_err!("Mask expression must have 2 children"))?;
108        <V as MaskKernel>::mask(array, &mask_child, ctx)
109    }
110}