vortex_array/scalar_fn/fns/mask/
kernel.rs1use vortex_error::VortexResult;
5use vortex_error::vortex_err;
6
7use crate::ArrayRef;
8use crate::ExecutionCtx;
9use crate::arrays::BoolVTable;
10use crate::arrays::ExactScalarFn;
11use crate::arrays::ScalarFnArrayView;
12use crate::kernel::ExecuteParentKernel;
13use crate::optimizer::rules::ArrayParentReduceRule;
14use crate::scalar_fn::fns::mask::Mask as MaskExpr;
15use crate::vtable::VTable;
16
17pub trait MaskReduce: VTable {
30 fn mask(array: &Self::Array, mask: &ArrayRef) -> VortexResult<Option<ArrayRef>>;
31}
32
33pub trait MaskKernel: VTable {
45 fn mask(
46 array: &Self::Array,
47 mask: &ArrayRef,
48 ctx: &mut ExecutionCtx,
49 ) -> VortexResult<Option<ArrayRef>>;
50}
51
52#[derive(Default, Debug)]
54pub struct MaskReduceAdaptor<V>(pub V);
55
56impl<V> ArrayParentReduceRule<V> for MaskReduceAdaptor<V>
57where
58 V: MaskReduce,
59{
60 type Parent = ExactScalarFn<MaskExpr>;
61
62 fn reduce_parent(
63 &self,
64 array: &V::Array,
65 parent: ScalarFnArrayView<'_, MaskExpr>,
66 child_idx: usize,
67 ) -> VortexResult<Option<ArrayRef>> {
68 if child_idx != 0 {
70 return Ok(None);
71 }
72 let mask_child = parent
75 .nth_child(1)
76 .ok_or_else(|| vortex_err!("Mask expression must have 2 children"))?;
77 if mask_child.as_opt::<BoolVTable>().is_none() {
78 return Ok(None);
79 };
80 <V as MaskReduce>::mask(array, &mask_child)
81 }
82}
83
84#[derive(Default, Debug)]
86pub struct MaskExecuteAdaptor<V>(pub V);
87
88impl<V> ExecuteParentKernel<V> for MaskExecuteAdaptor<V>
89where
90 V: MaskKernel,
91{
92 type Parent = ExactScalarFn<MaskExpr>;
93
94 fn execute_parent(
95 &self,
96 array: &V::Array,
97 parent: ScalarFnArrayView<'_, MaskExpr>,
98 child_idx: usize,
99 ctx: &mut ExecutionCtx,
100 ) -> VortexResult<Option<ArrayRef>> {
101 if child_idx != 0 {
103 return Ok(None);
104 }
105 let mask_child = parent
106 .nth_child(1)
107 .ok_or_else(|| vortex_err!("Mask expression must have 2 children"))?;
108 <V as MaskKernel>::mask(array, &mask_child, ctx)
109 }
110}