vortex_array/compute/
cast.rs1use vortex_dtype::DType;
2use vortex_error::{VortexExpect, VortexResult, vortex_bail};
3
4use crate::encoding::Encoding;
5use crate::{Array, ArrayRef, IntoArray};
6
7pub trait CastFn<A> {
8 fn cast(&self, array: A, dtype: &DType) -> VortexResult<ArrayRef>;
9}
10
11impl<E: Encoding> CastFn<&dyn Array> for E
12where
13 E: for<'a> CastFn<&'a E::Array>,
14{
15 fn cast(&self, array: &dyn Array, dtype: &DType) -> VortexResult<ArrayRef> {
16 let array_ref = array
17 .as_any()
18 .downcast_ref::<E::Array>()
19 .vortex_expect("Failed to downcast array");
20 CastFn::cast(self, array_ref, dtype)
21 }
22}
23
24pub fn try_cast(array: &dyn Array, dtype: &DType) -> VortexResult<ArrayRef> {
28 if array.dtype() == dtype {
29 return Ok(array.to_array());
30 }
31
32 let casted = try_cast_impl(array, dtype)?;
33
34 debug_assert_eq!(
35 casted.len(),
36 array.len(),
37 "Cast length mismatch {}",
38 array.encoding()
39 );
40 debug_assert_eq!(
41 casted.dtype(),
42 dtype,
43 "Cast dtype mismatch {}",
44 array.encoding()
45 );
46
47 Ok(casted)
48}
49
50fn try_cast_impl(array: &dyn Array, dtype: &DType) -> VortexResult<ArrayRef> {
51 if let Some(f) = array.vtable().cast_fn() {
53 return f.cast(array, dtype);
54 }
55
56 log::debug!(
58 "Falling back to canonical cast for encoding {} and dtype {} to {}",
59 array.encoding(),
60 array.dtype(),
61 dtype
62 );
63 let canonicalized = array.to_canonical()?.into_array();
64 if let Some(f) = canonicalized.vtable().cast_fn() {
65 return f.cast(&canonicalized, dtype);
66 }
67
68 vortex_bail!(
69 "No compute kernel to cast array from {} to {}",
70 array.dtype(),
71 dtype
72 )
73}