vortex_array/scalar_fn/fns/zip/
kernel.rs1use vortex_error::VortexExpect;
5use vortex_error::VortexResult;
6
7use crate::ArrayRef;
8use crate::ExecutionCtx;
9use crate::array::ArrayView;
10use crate::array::VTable;
11use crate::arrays::ScalarFnVTable;
12use crate::arrays::scalar_fn::ExactScalarFn;
13use crate::arrays::scalar_fn::ScalarFnArrayExt;
14use crate::arrays::scalar_fn::ScalarFnArrayView;
15use crate::kernel::ExecuteParentKernel;
16use crate::optimizer::rules::ArrayParentReduceRule;
17use crate::scalar_fn::fns::zip::Zip as ZipExpr;
18
19pub trait ZipReduce: VTable {
28 fn zip(
29 array: ArrayView<'_, Self>,
30 if_false: &ArrayRef,
31 mask: &ArrayRef,
32 ) -> VortexResult<Option<ArrayRef>>;
33}
34
35pub trait ZipKernel: VTable {
43 fn zip(
44 array: ArrayView<'_, Self>,
45 if_false: &ArrayRef,
46 mask: &ArrayRef,
47 ctx: &mut ExecutionCtx,
48 ) -> VortexResult<Option<ArrayRef>>;
49}
50
51#[derive(Default, Debug)]
53pub struct ZipReduceAdaptor<V>(pub V);
54
55impl<V> ArrayParentReduceRule<V> for ZipReduceAdaptor<V>
56where
57 V: ZipReduce,
58{
59 type Parent = ExactScalarFn<ZipExpr>;
60
61 fn reduce_parent(
62 &self,
63 array: ArrayView<'_, V>,
64 parent: ScalarFnArrayView<'_, ZipExpr>,
65 child_idx: usize,
66 ) -> VortexResult<Option<ArrayRef>> {
67 if child_idx != 0 {
68 return Ok(None);
69 }
70 let scalar_fn_array = parent
71 .as_opt::<ScalarFnVTable>()
72 .vortex_expect("ExactScalarFn matcher confirmed ScalarFnArray");
73 let if_false = scalar_fn_array.get_child(1);
74 let mask_array = scalar_fn_array.get_child(2);
75 <V as ZipReduce>::zip(array, if_false, mask_array)
76 }
77}
78
79#[derive(Default, Debug)]
81pub struct ZipExecuteAdaptor<V>(pub V);
82
83impl<V> ExecuteParentKernel<V> for ZipExecuteAdaptor<V>
84where
85 V: ZipKernel,
86{
87 type Parent = ExactScalarFn<ZipExpr>;
88
89 fn execute_parent(
90 &self,
91 array: ArrayView<'_, V>,
92 parent: ScalarFnArrayView<'_, ZipExpr>,
93 child_idx: usize,
94 ctx: &mut ExecutionCtx,
95 ) -> VortexResult<Option<ArrayRef>> {
96 if child_idx != 0 {
97 return Ok(None);
98 }
99 let scalar_fn_array = parent
100 .as_opt::<ScalarFnVTable>()
101 .vortex_expect("ExactScalarFn matcher confirmed ScalarFnArray");
102 let if_false = scalar_fn_array.get_child(1);
103 let mask_array = scalar_fn_array.get_child(2);
104 <V as ZipKernel>::zip(array, if_false, mask_array, ctx)
105 }
106}