vortex_array/compute/
cast.rs

1use std::sync::LazyLock;
2
3use vortex_dtype::DType;
4use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
5
6use crate::arcref::ArcRef;
7use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output};
8use crate::encoding::Encoding;
9use crate::{Array, ArrayRef};
10
11/// Attempt to cast an array to a desired DType.
12///
13/// Some array support the ability to narrow or upcast.
14pub fn cast(array: &dyn Array, dtype: &DType) -> VortexResult<ArrayRef> {
15    CAST_FN
16        .invoke(&InvocationArgs {
17            inputs: &[array.into(), dtype.into()],
18            options: &(),
19        })?
20        .unwrap_array()
21}
22
23pub static CAST_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
24    let compute = ComputeFn::new("cast".into(), ArcRef::new_ref(&Cast));
25    for kernel in inventory::iter::<CastKernelRef> {
26        compute.register_kernel(kernel.0.clone());
27    }
28    compute
29});
30
31struct Cast;
32
33impl ComputeFnVTable for Cast {
34    fn invoke(
35        &self,
36        args: &InvocationArgs,
37        kernels: &[ArcRef<dyn Kernel>],
38    ) -> VortexResult<Output> {
39        let CastArgs { array, dtype } = CastArgs::try_from(args)?;
40
41        if array.dtype() == dtype {
42            return Ok(array.to_array().into());
43        }
44
45        // TODO(ngates): check for null_count if dtype is non-nullable
46
47        for kernel in kernels {
48            if let Some(output) = kernel.invoke(args)? {
49                return Ok(output);
50            }
51        }
52        if let Some(output) = array.invoke(&CAST_FN, args)? {
53            return Ok(output);
54        }
55
56        // Otherwise, we fall back to the canonical implementations.
57        log::debug!(
58            "Falling back to canonical cast for encoding {} and dtype {} to {}",
59            array.encoding(),
60            array.dtype(),
61            dtype
62        );
63        if array.is_canonical() {
64            vortex_bail!(
65                "No compute kernel to cast array {} to {}",
66                array.encoding(),
67                dtype
68            );
69        }
70
71        Ok(cast(array.to_canonical()?.as_ref(), dtype)?.into())
72    }
73
74    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
75        let CastArgs { dtype, .. } = CastArgs::try_from(args)?;
76        Ok(dtype.clone())
77    }
78
79    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
80        let CastArgs { array, .. } = CastArgs::try_from(args)?;
81        Ok(array.len())
82    }
83
84    fn is_elementwise(&self) -> bool {
85        true
86    }
87}
88
89struct CastArgs<'a> {
90    array: &'a dyn Array,
91    dtype: &'a DType,
92}
93
94impl<'a> TryFrom<&InvocationArgs<'a>> for CastArgs<'a> {
95    type Error = VortexError;
96
97    fn try_from(args: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
98        if args.inputs.len() != 2 {
99            vortex_bail!(
100                "Cast function requires 2 arguments, but got {}",
101                args.inputs.len()
102            );
103        }
104        let array = args.inputs[0]
105            .array()
106            .ok_or_else(|| vortex_err!("Missing array argument"))?;
107        let dtype = args.inputs[1]
108            .dtype()
109            .ok_or_else(|| vortex_err!("Missing dtype argument"))?;
110
111        Ok(CastArgs { array, dtype })
112    }
113}
114
115pub struct CastKernelRef(ArcRef<dyn Kernel>);
116inventory::collect!(CastKernelRef);
117
118pub trait CastKernel: Encoding {
119    fn cast(&self, array: &Self::Array, dtype: &DType) -> VortexResult<ArrayRef>;
120}
121
122#[derive(Debug)]
123pub struct CastKernelAdapter<E: Encoding>(pub E);
124
125impl<E: Encoding + CastKernel> CastKernelAdapter<E> {
126    pub const fn lift(&'static self) -> CastKernelRef {
127        CastKernelRef(ArcRef::new_ref(self))
128    }
129}
130
131impl<E: Encoding + CastKernel> Kernel for CastKernelAdapter<E> {
132    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
133        let CastArgs { array, dtype } = CastArgs::try_from(args)?;
134        let Some(array) = array.as_any().downcast_ref::<E::Array>() else {
135            return Ok(None);
136        };
137        Ok(Some(E::cast(&self.0, array, dtype)?.into()))
138    }
139}