vortex_array/compute/
cast.rs1use std::sync::LazyLock;
2
3use vortex_dtype::DType;
4use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
5
6use crate::arcref::ArcRef;
7use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output};
8use crate::encoding::Encoding;
9use crate::{Array, ArrayRef};
10
11pub fn cast(array: &dyn Array, dtype: &DType) -> VortexResult<ArrayRef> {
15 CAST_FN
16 .invoke(&InvocationArgs {
17 inputs: &[array.into(), dtype.into()],
18 options: &(),
19 })?
20 .unwrap_array()
21}
22
23pub static CAST_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
24 let compute = ComputeFn::new("cast".into(), ArcRef::new_ref(&Cast));
25 for kernel in inventory::iter::<CastKernelRef> {
26 compute.register_kernel(kernel.0.clone());
27 }
28 compute
29});
30
31struct Cast;
32
33impl ComputeFnVTable for Cast {
34 fn invoke(
35 &self,
36 args: &InvocationArgs,
37 kernels: &[ArcRef<dyn Kernel>],
38 ) -> VortexResult<Output> {
39 let CastArgs { array, dtype } = CastArgs::try_from(args)?;
40
41 if array.dtype() == dtype {
42 return Ok(array.to_array().into());
43 }
44
45 for kernel in kernels {
48 if let Some(output) = kernel.invoke(args)? {
49 return Ok(output);
50 }
51 }
52 if let Some(output) = array.invoke(&CAST_FN, args)? {
53 return Ok(output);
54 }
55
56 log::debug!(
58 "Falling back to canonical cast for encoding {} and dtype {} to {}",
59 array.encoding(),
60 array.dtype(),
61 dtype
62 );
63 if array.is_canonical() {
64 vortex_bail!(
65 "No compute kernel to cast array {} to {}",
66 array.encoding(),
67 dtype
68 );
69 }
70
71 Ok(cast(array.to_canonical()?.as_ref(), dtype)?.into())
72 }
73
74 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
75 let CastArgs { dtype, .. } = CastArgs::try_from(args)?;
76 Ok(dtype.clone())
77 }
78
79 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
80 let CastArgs { array, .. } = CastArgs::try_from(args)?;
81 Ok(array.len())
82 }
83
84 fn is_elementwise(&self) -> bool {
85 true
86 }
87}
88
89struct CastArgs<'a> {
90 array: &'a dyn Array,
91 dtype: &'a DType,
92}
93
94impl<'a> TryFrom<&InvocationArgs<'a>> for CastArgs<'a> {
95 type Error = VortexError;
96
97 fn try_from(args: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
98 if args.inputs.len() != 2 {
99 vortex_bail!(
100 "Cast function requires 2 arguments, but got {}",
101 args.inputs.len()
102 );
103 }
104 let array = args.inputs[0]
105 .array()
106 .ok_or_else(|| vortex_err!("Missing array argument"))?;
107 let dtype = args.inputs[1]
108 .dtype()
109 .ok_or_else(|| vortex_err!("Missing dtype argument"))?;
110
111 Ok(CastArgs { array, dtype })
112 }
113}
114
115pub struct CastKernelRef(ArcRef<dyn Kernel>);
116inventory::collect!(CastKernelRef);
117
118pub trait CastKernel: Encoding {
119 fn cast(&self, array: &Self::Array, dtype: &DType) -> VortexResult<ArrayRef>;
120}
121
122#[derive(Debug)]
123pub struct CastKernelAdapter<E: Encoding>(pub E);
124
125impl<E: Encoding + CastKernel> CastKernelAdapter<E> {
126 pub const fn lift(&'static self) -> CastKernelRef {
127 CastKernelRef(ArcRef::new_ref(self))
128 }
129}
130
131impl<E: Encoding + CastKernel> Kernel for CastKernelAdapter<E> {
132 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
133 let CastArgs { array, dtype } = CastArgs::try_from(args)?;
134 let Some(array) = array.as_any().downcast_ref::<E::Array>() else {
135 return Ok(None);
136 };
137 Ok(Some(E::cast(&self.0, array, dtype)?.into()))
138 }
139}