vortex_array/compute/
cast.rs1use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::VortexError;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_err;
12
13use crate::Array;
14use crate::ArrayRef;
15use crate::compute::ComputeFn;
16use crate::compute::ComputeFnVTable;
17use crate::compute::InvocationArgs;
18use crate::compute::Kernel;
19use crate::compute::Output;
20use crate::vtable::VTable;
21
22static CAST_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
23 let compute = ComputeFn::new("cast".into(), ArcRef::new_ref(&Cast));
24 for kernel in inventory::iter::<CastKernelRef> {
25 compute.register_kernel(kernel.0.clone());
26 }
27 compute
28});
29
30pub(crate) fn warm_up_vtable() -> usize {
31 CAST_FN.kernels().len()
32}
33
34pub fn cast(array: &dyn Array, dtype: &DType) -> VortexResult<ArrayRef> {
38 CAST_FN
39 .invoke(&InvocationArgs {
40 inputs: &[array.into(), dtype.into()],
41 options: &(),
42 })?
43 .unwrap_array()
44}
45
46struct Cast;
47
48impl ComputeFnVTable for Cast {
49 fn invoke(
50 &self,
51 args: &InvocationArgs,
52 kernels: &[ArcRef<dyn Kernel>],
53 ) -> VortexResult<Output> {
54 let CastArgs { array, dtype } = CastArgs::try_from(args)?;
55
56 if array.dtype() == dtype {
57 return Ok(array.to_array().into());
58 }
59
60 for kernel in kernels {
63 if let Some(output) = kernel.invoke(args)? {
64 return Ok(output);
65 }
66 }
67 if let Some(output) = array.invoke(&CAST_FN, args)? {
68 return Ok(output);
69 }
70
71 log::debug!(
73 "Falling back to canonical cast for encoding {} and dtype {} to {}",
74 array.encoding_id(),
75 array.dtype(),
76 dtype
77 );
78
79 if array.is_canonical() {
80 vortex_bail!(
81 "No compute kernel to cast array {} with dtype {} to {}",
82 array.encoding_id(),
83 array.dtype(),
84 dtype
85 );
86 }
87
88 Ok(cast(array.to_canonical().as_ref(), dtype)?.into())
89 }
90
91 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
92 let CastArgs { dtype, .. } = CastArgs::try_from(args)?;
93 Ok(dtype.clone())
94 }
95
96 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
97 let CastArgs { array, .. } = CastArgs::try_from(args)?;
98 Ok(array.len())
99 }
100
101 fn is_elementwise(&self) -> bool {
102 true
103 }
104}
105
106struct CastArgs<'a> {
107 array: &'a dyn Array,
108 dtype: &'a DType,
109}
110
111impl<'a> TryFrom<&InvocationArgs<'a>> for CastArgs<'a> {
112 type Error = VortexError;
113
114 fn try_from(args: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
115 if args.inputs.len() != 2 {
116 vortex_bail!(
117 "Cast function requires 2 arguments, but got {}",
118 args.inputs.len()
119 );
120 }
121 let array = args.inputs[0]
122 .array()
123 .ok_or_else(|| vortex_err!("Missing array argument"))?;
124 let dtype = args.inputs[1]
125 .dtype()
126 .ok_or_else(|| vortex_err!("Missing dtype argument"))?;
127
128 Ok(CastArgs { array, dtype })
129 }
130}
131
132pub struct CastKernelRef(ArcRef<dyn Kernel>);
133inventory::collect!(CastKernelRef);
134
135pub trait CastKernel: VTable {
136 fn cast(&self, array: &Self::Array, dtype: &DType) -> VortexResult<Option<ArrayRef>>;
137}
138
139#[derive(Debug)]
140pub struct CastKernelAdapter<V: VTable>(pub V);
141
142impl<V: VTable + CastKernel> CastKernelAdapter<V> {
143 pub const fn lift(&'static self) -> CastKernelRef {
144 CastKernelRef(ArcRef::new_ref(self))
145 }
146}
147
148impl<V: VTable + CastKernel> Kernel for CastKernelAdapter<V> {
149 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
150 let CastArgs { array, dtype } = CastArgs::try_from(args)?;
151 let Some(array) = array.as_opt::<V>() else {
152 return Ok(None);
153 };
154
155 Ok(V::cast(&self.0, array, dtype)?.map(|o| o.into()))
156 }
157}