vortex_array/expr/exprs/cast/
kernel.rs1use vortex_dtype::DType;
5use vortex_error::VortexResult;
6
7use crate::ArrayRef;
8use crate::ExecutionCtx;
9use crate::arrays::ExactScalarFn;
10use crate::arrays::ScalarFnArrayView;
11use crate::expr::Cast;
12use crate::kernel::ExecuteParentKernel;
13use crate::matcher::Matcher;
14use crate::optimizer::rules::ArrayParentReduceRule;
15use crate::vtable::VTable;
16
17pub trait CastReduce: VTable {
25 fn cast(array: &Self::Array, dtype: &DType) -> VortexResult<Option<ArrayRef>>;
26}
27
28pub trait CastKernel: VTable {
35 fn cast(
36 array: &Self::Array,
37 dtype: &DType,
38 ctx: &mut ExecutionCtx,
39 ) -> VortexResult<Option<ArrayRef>>;
40}
41
42#[derive(Default, Debug)]
44pub struct CastReduceAdaptor<V>(pub V);
45
46impl<V> ArrayParentReduceRule<V> for CastReduceAdaptor<V>
47where
48 V: CastReduce,
49{
50 type Parent = ExactScalarFn<Cast>;
51
52 fn reduce_parent(
53 &self,
54 array: &V::Array,
55 parent: ScalarFnArrayView<'_, Cast>,
56 _child_idx: usize,
57 ) -> VortexResult<Option<ArrayRef>> {
58 let dtype = parent.options;
59 if array.dtype() == dtype {
60 return Ok(Some(array.to_array()));
61 }
62 <V as CastReduce>::cast(array, dtype)
63 }
64}
65
66#[derive(Default, Debug)]
68pub struct CastExecuteAdaptor<V>(pub V);
69
70impl<V> ExecuteParentKernel<V> for CastExecuteAdaptor<V>
71where
72 V: CastKernel,
73{
74 type Parent = ExactScalarFn<Cast>;
75
76 fn execute_parent(
77 &self,
78 array: &V::Array,
79 parent: <Self::Parent as Matcher>::Match<'_>,
80 _child_idx: usize,
81 ctx: &mut ExecutionCtx,
82 ) -> VortexResult<Option<ArrayRef>> {
83 let dtype = parent.options;
84 if array.dtype() == dtype {
85 return Ok(Some(array.to_array()));
86 }
87 <V as CastKernel>::cast(array, dtype, ctx)
88 }
89}