vortex_array/compute/
cast.rs

1use vortex_dtype::DType;
2use vortex_error::{VortexExpect, VortexResult, vortex_bail};
3
4use crate::encoding::Encoding;
5use crate::{Array, ArrayRef, IntoArray};
6
7pub trait CastFn<A> {
8    fn cast(&self, array: A, dtype: &DType) -> VortexResult<ArrayRef>;
9}
10
11impl<E: Encoding> CastFn<&dyn Array> for E
12where
13    E: for<'a> CastFn<&'a E::Array>,
14{
15    fn cast(&self, array: &dyn Array, dtype: &DType) -> VortexResult<ArrayRef> {
16        let array_ref = array
17            .as_any()
18            .downcast_ref::<E::Array>()
19            .vortex_expect("Failed to downcast array");
20        CastFn::cast(self, array_ref, dtype)
21    }
22}
23
24/// Attempt to cast an array to a desired DType.
25///
26/// Some array support the ability to narrow or upcast.
27pub fn try_cast(array: &dyn Array, dtype: &DType) -> VortexResult<ArrayRef> {
28    if array.dtype() == dtype {
29        return Ok(array.to_array());
30    }
31
32    let casted = try_cast_impl(array, dtype)?;
33
34    debug_assert_eq!(
35        casted.len(),
36        array.len(),
37        "Cast length mismatch {}",
38        array.encoding()
39    );
40    debug_assert_eq!(
41        casted.dtype(),
42        dtype,
43        "Cast dtype mismatch {}",
44        array.encoding()
45    );
46
47    Ok(casted)
48}
49
50fn try_cast_impl(array: &dyn Array, dtype: &DType) -> VortexResult<ArrayRef> {
51    // TODO(ngates): check for null_count if dtype is non-nullable
52    if let Some(f) = array.vtable().cast_fn() {
53        return f.cast(array, dtype);
54    }
55
56    // Otherwise, we fall back to the canonical implementations.
57    log::debug!(
58        "Falling back to canonical cast for encoding {} and dtype {} to {}",
59        array.encoding(),
60        array.dtype(),
61        dtype
62    );
63    let canonicalized = array.to_canonical()?.into_array();
64    if let Some(f) = canonicalized.vtable().cast_fn() {
65        return f.cast(&canonicalized, dtype);
66    }
67
68    vortex_bail!(
69        "No compute kernel to cast array from {} to {}",
70        array.dtype(),
71        dtype
72    )
73}