vortex_array/compute/
cast.rs1use std::sync::LazyLock;
2
3use arcref::ArcRef;
4use vortex_dtype::Nullability::Nullable;
5use vortex_dtype::{DType, PType};
6use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
7
8use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output};
9use crate::vtable::VTable;
10use crate::{Array, ArrayRef};
11
12pub fn cast(array: &dyn Array, dtype: &DType) -> VortexResult<ArrayRef> {
16 CAST_FN
17 .invoke(&InvocationArgs {
18 inputs: &[array.into(), dtype.into()],
19 options: &(),
20 })?
21 .unwrap_array()
22}
23
24pub static CAST_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
25 let compute = ComputeFn::new("cast".into(), ArcRef::new_ref(&Cast));
26 for kernel in inventory::iter::<CastKernelRef> {
27 compute.register_kernel(kernel.0.clone());
28 }
29 compute
30});
31
32struct Cast;
33
34impl ComputeFnVTable for Cast {
35 fn invoke(
36 &self,
37 args: &InvocationArgs,
38 kernels: &[ArcRef<dyn Kernel>],
39 ) -> VortexResult<Output> {
40 let CastArgs { array, dtype } = CastArgs::try_from(args)?;
41
42 if array.dtype() == dtype {
43 return Ok(array.to_array().into());
44 }
45
46 for kernel in kernels {
49 if let Some(output) = kernel.invoke(args)? {
50 return Ok(output);
51 }
52 }
53 if let Some(output) = array.invoke(&CAST_FN, args)? {
54 return Ok(output);
55 }
56
57 log::debug!(
59 "Falling back to canonical cast for encoding {} and dtype {} to {}",
60 array.encoding_id(),
61 array.dtype(),
62 dtype
63 );
64 if array.is_canonical() {
65 vortex_bail!(
66 "No compute kernel to cast array {} to {}",
67 array.encoding_id(),
68 dtype
69 );
70 }
71
72 Ok(cast(array.to_canonical()?.as_ref(), dtype)?.into())
73 }
74
75 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
76 let CastArgs { dtype, .. } = CastArgs::try_from(args)?;
77 Ok(dtype.clone())
78 }
79
80 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
81 let CastArgs { array, .. } = CastArgs::try_from(args)?;
82 Ok(array.len())
83 }
84
85 fn is_elementwise(&self) -> bool {
86 true
87 }
88}
89
90struct CastArgs<'a> {
91 array: &'a dyn Array,
92 dtype: &'a DType,
93}
94
95impl<'a> TryFrom<&InvocationArgs<'a>> for CastArgs<'a> {
96 type Error = VortexError;
97
98 fn try_from(args: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
99 if args.inputs.len() != 2 {
100 vortex_bail!(
101 "Cast function requires 2 arguments, but got {}",
102 args.inputs.len()
103 );
104 }
105 let array = args.inputs[0]
106 .array()
107 .ok_or_else(|| vortex_err!("Missing array argument"))?;
108 let dtype = args.inputs[1]
109 .dtype()
110 .ok_or_else(|| vortex_err!("Missing dtype argument"))?;
111
112 Ok(CastArgs { array, dtype })
113 }
114}
115
116pub struct CastKernelRef(ArcRef<dyn Kernel>);
117inventory::collect!(CastKernelRef);
118
119pub trait CastKernel: VTable {
120 fn cast(&self, array: &Self::Array, dtype: &DType) -> VortexResult<ArrayRef>;
121}
122
123#[derive(Debug)]
124pub struct CastKernelAdapter<V: VTable>(pub V);
125
126impl<V: VTable + CastKernel> CastKernelAdapter<V> {
127 pub const fn lift(&'static self) -> CastKernelRef {
128 CastKernelRef(ArcRef::new_ref(self))
129 }
130}
131
132impl<V: VTable + CastKernel> Kernel for CastKernelAdapter<V> {
133 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
134 let CastArgs { array, dtype } = CastArgs::try_from(args)?;
135 let Some(array) = array.as_opt::<V>() else {
136 return Ok(None);
137 };
138 Ok(Some(V::cast(&self.0, array, dtype)?.into()))
139 }
140}
141
142#[derive(Debug, Clone, Copy, PartialEq, Eq)]
143pub enum CastOutcome {
144 Fallible,
145 Infallible,
146}
147
148pub fn allowed_casting(from: &DType, to: &DType) -> Option<CastOutcome> {
149 if &from.with_nullability(Nullable) == to {
151 return Some(CastOutcome::Infallible);
152 }
153 match (from, to) {
154 (DType::Primitive(from_ptype, _), DType::Primitive(to_ptype, _)) => {
155 allowed_casting_ptype(*from_ptype, *to_ptype)
156 }
157 _ => None,
158 }
159}
160
161pub fn allowed_casting_ptype(from: PType, to: PType) -> Option<CastOutcome> {
162 use CastOutcome::*;
163 use PType::*;
164
165 match (from, to) {
166 (a, b) if a == b => Some(Infallible),
168
169 (U8, U16 | U32 | U64)
171 | (U16, U32 | U64)
172 | (U32, U64)
173 | (I8, I16 | I32 | I64)
174 | (I16, I32 | I64)
175 | (I32, I64) => Some(Infallible),
176
177 (U16 | U32 | U64, U8)
179 | (U32 | U64, U16)
180 | (U64, U32)
181 | (I16 | I32 | I64, I8)
182 | (I32 | I64, I16)
183 | (I64, I32) => Some(Fallible),
184
185 (I8 | I16 | I32 | I64, U8 | U16 | U32 | U64)
187 | (U8 | U16 | U32 | U64, I8 | I16 | I32 | I64) => Some(Fallible),
188
189 (F16, F32 | F64) | (F32, F64) => Some(Infallible),
198
199 (F64, F32 | F16) | (F32, F16) => Some(Fallible),
201
202 _ => None,
203 }
204}