vortex_array/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::VortexError;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_err;
12
13use crate::Array;
14use crate::ArrayRef;
15use crate::compute::ComputeFn;
16use crate::compute::ComputeFnVTable;
17use crate::compute::InvocationArgs;
18use crate::compute::Kernel;
19use crate::compute::Output;
20use crate::vtable::VTable;
21
22static CAST_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
23    let compute = ComputeFn::new("cast".into(), ArcRef::new_ref(&Cast));
24    for kernel in inventory::iter::<CastKernelRef> {
25        compute.register_kernel(kernel.0.clone());
26    }
27    compute
28});
29
30pub(crate) fn warm_up_vtable() -> usize {
31    CAST_FN.kernels().len()
32}
33
34/// Attempt to cast an array to a desired DType.
35///
36/// Some array support the ability to narrow or upcast.
37pub fn cast(array: &dyn Array, dtype: &DType) -> VortexResult<ArrayRef> {
38    CAST_FN
39        .invoke(&InvocationArgs {
40            inputs: &[array.into(), dtype.into()],
41            options: &(),
42        })?
43        .unwrap_array()
44}
45
46struct Cast;
47
48impl ComputeFnVTable for Cast {
49    fn invoke(
50        &self,
51        args: &InvocationArgs,
52        kernels: &[ArcRef<dyn Kernel>],
53    ) -> VortexResult<Output> {
54        let CastArgs { array, dtype } = CastArgs::try_from(args)?;
55
56        if array.dtype() == dtype {
57            return Ok(array.to_array().into());
58        }
59
60        // TODO(ngates): check for null_count if dtype is non-nullable
61
62        for kernel in kernels {
63            if let Some(output) = kernel.invoke(args)? {
64                return Ok(output);
65            }
66        }
67        if let Some(output) = array.invoke(&CAST_FN, args)? {
68            return Ok(output);
69        }
70
71        // Otherwise, we fall back to the canonical implementations.
72        log::debug!(
73            "Falling back to canonical cast for encoding {} and dtype {} to {}",
74            array.encoding_id(),
75            array.dtype(),
76            dtype
77        );
78
79        if array.is_canonical() {
80            vortex_bail!(
81                "No compute kernel to cast array {} with dtype {} to {}",
82                array.encoding_id(),
83                array.dtype(),
84                dtype
85            );
86        }
87
88        Ok(cast(array.to_canonical().as_ref(), dtype)?.into())
89    }
90
91    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
92        let CastArgs { dtype, .. } = CastArgs::try_from(args)?;
93        Ok(dtype.clone())
94    }
95
96    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
97        let CastArgs { array, .. } = CastArgs::try_from(args)?;
98        Ok(array.len())
99    }
100
101    fn is_elementwise(&self) -> bool {
102        true
103    }
104}
105
106struct CastArgs<'a> {
107    array: &'a dyn Array,
108    dtype: &'a DType,
109}
110
111impl<'a> TryFrom<&InvocationArgs<'a>> for CastArgs<'a> {
112    type Error = VortexError;
113
114    fn try_from(args: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
115        if args.inputs.len() != 2 {
116            vortex_bail!(
117                "Cast function requires 2 arguments, but got {}",
118                args.inputs.len()
119            );
120        }
121        let array = args.inputs[0]
122            .array()
123            .ok_or_else(|| vortex_err!("Missing array argument"))?;
124        let dtype = args.inputs[1]
125            .dtype()
126            .ok_or_else(|| vortex_err!("Missing dtype argument"))?;
127
128        Ok(CastArgs { array, dtype })
129    }
130}
131
132pub struct CastKernelRef(ArcRef<dyn Kernel>);
133inventory::collect!(CastKernelRef);
134
135pub trait CastKernel: VTable {
136    fn cast(&self, array: &Self::Array, dtype: &DType) -> VortexResult<Option<ArrayRef>>;
137}
138
139#[derive(Debug)]
140pub struct CastKernelAdapter<V: VTable>(pub V);
141
142impl<V: VTable + CastKernel> CastKernelAdapter<V> {
143    pub const fn lift(&'static self) -> CastKernelRef {
144        CastKernelRef(ArcRef::new_ref(self))
145    }
146}
147
148impl<V: VTable + CastKernel> Kernel for CastKernelAdapter<V> {
149    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
150        let CastArgs { array, dtype } = CastArgs::try_from(args)?;
151        let Some(array) = array.as_opt::<V>() else {
152            return Ok(None);
153        };
154
155        Ok(V::cast(&self.0, array, dtype)?.map(|o| o.into()))
156    }
157}