vortex_array/expr/exprs/mask/kernel.rs
1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_error::vortex_err;
6
7use crate::ArrayRef;
8use crate::ExecutionCtx;
9use crate::arrays::BoolVTable;
10use crate::arrays::ExactScalarFn;
11use crate::arrays::ScalarFnArrayView;
12use crate::expr::Mask as MaskExpr;
13use crate::kernel::ExecuteParentKernel;
14use crate::optimizer::rules::ArrayParentReduceRule;
15use crate::vtable::VTable;
16
17/// Mask an array without reading buffers.
18///
19/// This trait is for mask implementations that can operate purely on array metadata and
20/// structure without needing to read or execute on the underlying buffers. Implementations
21/// should return `None` if masking requires buffer access.
22///
23/// The `mask` parameter is a boolean array where true=keep/valid, false=null-out.
24///
25/// # Preconditions
26///
27/// The mask is guaranteed to have the same length as the array. Trivial cases
28/// (`AllValid`, `AllInvalid`, `NonNullable`) are handled by the caller before dispatch.
29pub trait MaskReduce: VTable {
30 fn mask(array: &Self::Array, mask: &ArrayRef) -> VortexResult<Option<ArrayRef>>;
31}
32
33/// Mask an array, potentially reading buffers.
34///
35/// Unlike [`MaskReduce`], this trait is for mask implementations that may need to read
36/// and execute on the underlying buffers to produce the masked result.
37///
38/// The `mask` parameter is a boolean array where true=keep/valid, false=null-out.
39///
40/// # Preconditions
41///
42/// The mask is guaranteed to have the same length as the array. Trivial cases
43/// (`AllValid`, `AllInvalid`, `NonNullable`) are handled by the caller before dispatch.
44pub trait MaskKernel: VTable {
45 fn mask(
46 array: &Self::Array,
47 mask: &ArrayRef,
48 ctx: &mut ExecutionCtx,
49 ) -> VortexResult<Option<ArrayRef>>;
50}
51
52/// Adaptor that wraps a [`MaskReduce`] impl as an [`ArrayParentReduceRule`].
53#[derive(Default, Debug)]
54pub struct MaskReduceAdaptor<V>(pub V);
55
56impl<V> ArrayParentReduceRule<V> for MaskReduceAdaptor<V>
57where
58 V: MaskReduce,
59{
60 type Parent = ExactScalarFn<MaskExpr>;
61
62 fn reduce_parent(
63 &self,
64 array: &V::Array,
65 parent: ScalarFnArrayView<'_, MaskExpr>,
66 child_idx: usize,
67 ) -> VortexResult<Option<ArrayRef>> {
68 // Only reduce the input child (index 0), not the mask child (index 1).
69 if child_idx != 0 {
70 return Ok(None);
71 }
72 // The mask child (child 1) is a non-nullable BoolArray where true=keep.
73 // If it's not yet a BoolArray, we can't reduce without execution.
74 let mask_child = parent
75 .nth_child(1)
76 .ok_or_else(|| vortex_err!("Mask expression must have 2 children"))?;
77 if mask_child.as_opt::<BoolVTable>().is_none() {
78 return Ok(None);
79 };
80 <V as MaskReduce>::mask(array, &mask_child)
81 }
82}
83
84/// Adaptor that wraps a [`MaskKernel`] impl as an [`ExecuteParentKernel`].
85#[derive(Default, Debug)]
86pub struct MaskExecuteAdaptor<V>(pub V);
87
88impl<V> ExecuteParentKernel<V> for MaskExecuteAdaptor<V>
89where
90 V: MaskKernel,
91{
92 type Parent = ExactScalarFn<MaskExpr>;
93
94 fn execute_parent(
95 &self,
96 array: &V::Array,
97 parent: ScalarFnArrayView<'_, MaskExpr>,
98 child_idx: usize,
99 ctx: &mut ExecutionCtx,
100 ) -> VortexResult<Option<ArrayRef>> {
101 // Only execute the input child (index 0), not the mask child (index 1).
102 if child_idx != 0 {
103 return Ok(None);
104 }
105 let mask_child = parent
106 .nth_child(1)
107 .ok_or_else(|| vortex_err!("Mask expression must have 2 children"))?;
108 <V as MaskKernel>::mask(array, &mask_child, ctx)
109 }
110}