vortex_array/scalar_fn/fns/mask/
kernel.rs1use vortex_error::VortexResult;
5use vortex_error::vortex_err;
6
7use crate::ArrayRef;
8use crate::ExecutionCtx;
9use crate::array::ArrayView;
10use crate::array::VTable;
11use crate::arrays::Bool;
12use crate::arrays::scalar_fn::ExactScalarFn;
13use crate::arrays::scalar_fn::ScalarFnArrayView;
14use crate::kernel::ExecuteParentKernel;
15use crate::optimizer::rules::ArrayParentReduceRule;
16use crate::scalar_fn::fns::mask::Mask as MaskExpr;
17
18pub trait MaskReduce: VTable {
31 fn mask(array: ArrayView<'_, Self>, mask: &ArrayRef) -> VortexResult<Option<ArrayRef>>;
32}
33
34pub trait MaskKernel: VTable {
46 fn mask(
47 array: ArrayView<'_, Self>,
48 mask: &ArrayRef,
49 ctx: &mut ExecutionCtx,
50 ) -> VortexResult<Option<ArrayRef>>;
51}
52
53#[derive(Default, Debug)]
55pub struct MaskReduceAdaptor<V>(pub V);
56
57impl<V> ArrayParentReduceRule<V> for MaskReduceAdaptor<V>
58where
59 V: MaskReduce,
60{
61 type Parent = ExactScalarFn<MaskExpr>;
62
63 fn reduce_parent(
64 &self,
65 array: ArrayView<'_, V>,
66 parent: ScalarFnArrayView<'_, MaskExpr>,
67 child_idx: usize,
68 ) -> VortexResult<Option<ArrayRef>> {
69 if child_idx != 0 {
71 return Ok(None);
72 }
73 let parent_ref: ArrayRef = (*parent).clone();
76 let mask_child = parent_ref
77 .nth_child(1)
78 .ok_or_else(|| vortex_err!("Mask expression must have 2 children"))?;
79 if mask_child.as_opt::<Bool>().is_none() {
80 return Ok(None);
81 };
82 <V as MaskReduce>::mask(array, &mask_child)
83 }
84}
85
86#[derive(Default, Debug)]
88pub struct MaskExecuteAdaptor<V>(pub V);
89
90impl<V> ExecuteParentKernel<V> for MaskExecuteAdaptor<V>
91where
92 V: MaskKernel,
93{
94 type Parent = ExactScalarFn<MaskExpr>;
95
96 fn execute_parent(
97 &self,
98 array: ArrayView<'_, V>,
99 parent: ScalarFnArrayView<'_, MaskExpr>,
100 child_idx: usize,
101 ctx: &mut ExecutionCtx,
102 ) -> VortexResult<Option<ArrayRef>> {
103 if child_idx != 0 {
105 return Ok(None);
106 }
107 let mask_child = parent
108 .nth_child(1)
109 .ok_or_else(|| vortex_err!("Mask expression must have 2 children"))?;
110 <V as MaskKernel>::mask(array, &mask_child, ctx)
111 }
112}