vortex_array/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
9
10use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output};
11use crate::vtable::VTable;
12use crate::{Array, ArrayRef};
13
14static CAST_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
15    let compute = ComputeFn::new("cast".into(), ArcRef::new_ref(&Cast));
16    for kernel in inventory::iter::<CastKernelRef> {
17        compute.register_kernel(kernel.0.clone());
18    }
19    compute
20});
21
22pub(crate) fn warm_up_vtable() -> usize {
23    CAST_FN.kernels().len()
24}
25
26/// Attempt to cast an array to a desired DType.
27///
28/// Some array support the ability to narrow or upcast.
29pub fn cast(array: &dyn Array, dtype: &DType) -> VortexResult<ArrayRef> {
30    CAST_FN
31        .invoke(&InvocationArgs {
32            inputs: &[array.into(), dtype.into()],
33            options: &(),
34        })?
35        .unwrap_array()
36}
37
38struct Cast;
39
40impl ComputeFnVTable for Cast {
41    fn invoke(
42        &self,
43        args: &InvocationArgs,
44        kernels: &[ArcRef<dyn Kernel>],
45    ) -> VortexResult<Output> {
46        let CastArgs { array, dtype } = CastArgs::try_from(args)?;
47
48        if array.dtype() == dtype {
49            return Ok(array.to_array().into());
50        }
51
52        // TODO(ngates): check for null_count if dtype is non-nullable
53
54        for kernel in kernels {
55            if let Some(output) = kernel.invoke(args)? {
56                return Ok(output);
57            }
58        }
59        if let Some(output) = array.invoke(&CAST_FN, args)? {
60            return Ok(output);
61        }
62
63        // Otherwise, we fall back to the canonical implementations.
64        log::debug!(
65            "Falling back to canonical cast for encoding {} and dtype {} to {}",
66            array.encoding_id(),
67            array.dtype(),
68            dtype
69        );
70
71        if array.is_canonical() {
72            vortex_bail!(
73                "No compute kernel to cast array {} with dtype {} to {}",
74                array.encoding_id(),
75                array.dtype(),
76                dtype
77            );
78        }
79
80        Ok(cast(array.to_canonical().as_ref(), dtype)?.into())
81    }
82
83    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
84        let CastArgs { dtype, .. } = CastArgs::try_from(args)?;
85        Ok(dtype.clone())
86    }
87
88    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
89        let CastArgs { array, .. } = CastArgs::try_from(args)?;
90        Ok(array.len())
91    }
92
93    fn is_elementwise(&self) -> bool {
94        true
95    }
96}
97
98struct CastArgs<'a> {
99    array: &'a dyn Array,
100    dtype: &'a DType,
101}
102
103impl<'a> TryFrom<&InvocationArgs<'a>> for CastArgs<'a> {
104    type Error = VortexError;
105
106    fn try_from(args: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
107        if args.inputs.len() != 2 {
108            vortex_bail!(
109                "Cast function requires 2 arguments, but got {}",
110                args.inputs.len()
111            );
112        }
113        let array = args.inputs[0]
114            .array()
115            .ok_or_else(|| vortex_err!("Missing array argument"))?;
116        let dtype = args.inputs[1]
117            .dtype()
118            .ok_or_else(|| vortex_err!("Missing dtype argument"))?;
119
120        Ok(CastArgs { array, dtype })
121    }
122}
123
124pub struct CastKernelRef(ArcRef<dyn Kernel>);
125inventory::collect!(CastKernelRef);
126
127pub trait CastKernel: VTable {
128    fn cast(&self, array: &Self::Array, dtype: &DType) -> VortexResult<Option<ArrayRef>>;
129}
130
131#[derive(Debug)]
132pub struct CastKernelAdapter<V: VTable>(pub V);
133
134impl<V: VTable + CastKernel> CastKernelAdapter<V> {
135    pub const fn lift(&'static self) -> CastKernelRef {
136        CastKernelRef(ArcRef::new_ref(self))
137    }
138}
139
140impl<V: VTable + CastKernel> Kernel for CastKernelAdapter<V> {
141    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
142        let CastArgs { array, dtype } = CastArgs::try_from(args)?;
143        let Some(array) = array.as_opt::<V>() else {
144            return Ok(None);
145        };
146
147        Ok(V::cast(&self.0, array, dtype)?.map(|o| o.into()))
148    }
149}