vortex_array/scalar_fn/fns/cast/
kernel.rs1use vortex_error::VortexResult;
5
6use crate::ArrayRef;
7use crate::ExecutionCtx;
8use crate::IntoArray;
9use crate::arrays::scalar_fn::ExactScalarFn;
10use crate::arrays::scalar_fn::ScalarFnArrayView;
11use crate::dtype::DType;
12use crate::kernel::ExecuteParentKernel;
13use crate::matcher::Matcher;
14use crate::optimizer::rules::ArrayParentReduceRule;
15use crate::scalar_fn::fns::cast::Cast;
16use crate::vtable::VTable;
17
18pub trait CastReduce: VTable {
26 fn cast(array: &Self::Array, dtype: &DType) -> VortexResult<Option<ArrayRef>>;
27}
28
29pub trait CastKernel: VTable {
36 fn cast(
37 array: &Self::Array,
38 dtype: &DType,
39 ctx: &mut ExecutionCtx,
40 ) -> VortexResult<Option<ArrayRef>>;
41}
42
43#[derive(Default, Debug)]
45pub struct CastReduceAdaptor<V>(pub V);
46
47impl<V> ArrayParentReduceRule<V> for CastReduceAdaptor<V>
48where
49 V: CastReduce,
50{
51 type Parent = ExactScalarFn<Cast>;
52
53 fn reduce_parent(
54 &self,
55 array: &V::Array,
56 parent: ScalarFnArrayView<'_, Cast>,
57 _child_idx: usize,
58 ) -> VortexResult<Option<ArrayRef>> {
59 let dtype = parent.options;
60 if array.dtype() == dtype {
61 return Ok(Some(array.clone().into_array()));
62 }
63 <V as CastReduce>::cast(array, dtype)
64 }
65}
66
67#[derive(Default, Debug)]
69pub struct CastExecuteAdaptor<V>(pub V);
70
71impl<V> ExecuteParentKernel<V> for CastExecuteAdaptor<V>
72where
73 V: CastKernel,
74{
75 type Parent = ExactScalarFn<Cast>;
76
77 fn execute_parent(
78 &self,
79 array: &V::Array,
80 parent: <Self::Parent as Matcher>::Match<'_>,
81 _child_idx: usize,
82 ctx: &mut ExecutionCtx,
83 ) -> VortexResult<Option<ArrayRef>> {
84 let dtype = parent.options;
85 if array.dtype() == dtype {
86 return Ok(Some(array.clone().into_array()));
87 }
88 <V as CastKernel>::cast(array, dtype, ctx)
89 }
90}