vortex_array/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
9
10use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output};
11use crate::vtable::VTable;
12use crate::{Array, ArrayRef};
13
14/// Attempt to cast an array to a desired DType.
15///
16/// Some array support the ability to narrow or upcast.
17pub fn cast(array: &dyn Array, dtype: &DType) -> VortexResult<ArrayRef> {
18    CAST_FN
19        .invoke(&InvocationArgs {
20            inputs: &[array.into(), dtype.into()],
21            options: &(),
22        })?
23        .unwrap_array()
24}
25
26pub static CAST_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
27    let compute = ComputeFn::new("cast".into(), ArcRef::new_ref(&Cast));
28    for kernel in inventory::iter::<CastKernelRef> {
29        compute.register_kernel(kernel.0.clone());
30    }
31    compute
32});
33
34struct Cast;
35
36impl ComputeFnVTable for Cast {
37    fn invoke(
38        &self,
39        args: &InvocationArgs,
40        kernels: &[ArcRef<dyn Kernel>],
41    ) -> VortexResult<Output> {
42        let CastArgs { array, dtype } = CastArgs::try_from(args)?;
43
44        if array.dtype() == dtype {
45            return Ok(array.to_array().into());
46        }
47
48        // TODO(ngates): check for null_count if dtype is non-nullable
49
50        for kernel in kernels {
51            if let Some(output) = kernel.invoke(args)? {
52                return Ok(output);
53            }
54        }
55        if let Some(output) = array.invoke(&CAST_FN, args)? {
56            return Ok(output);
57        }
58
59        // Otherwise, we fall back to the canonical implementations.
60        log::debug!(
61            "Falling back to canonical cast for encoding {} and dtype {} to {}",
62            array.encoding_id(),
63            array.dtype(),
64            dtype
65        );
66
67        if array.is_canonical() {
68            vortex_bail!(
69                "No compute kernel to cast array {} with dtype {} to {}",
70                array.encoding_id(),
71                array.dtype(),
72                dtype
73            );
74        }
75
76        Ok(cast(array.to_canonical()?.as_ref(), dtype)?.into())
77    }
78
79    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
80        let CastArgs { dtype, .. } = CastArgs::try_from(args)?;
81        Ok(dtype.clone())
82    }
83
84    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
85        let CastArgs { array, .. } = CastArgs::try_from(args)?;
86        Ok(array.len())
87    }
88
89    fn is_elementwise(&self) -> bool {
90        true
91    }
92}
93
94struct CastArgs<'a> {
95    array: &'a dyn Array,
96    dtype: &'a DType,
97}
98
99impl<'a> TryFrom<&InvocationArgs<'a>> for CastArgs<'a> {
100    type Error = VortexError;
101
102    fn try_from(args: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
103        if args.inputs.len() != 2 {
104            vortex_bail!(
105                "Cast function requires 2 arguments, but got {}",
106                args.inputs.len()
107            );
108        }
109        let array = args.inputs[0]
110            .array()
111            .ok_or_else(|| vortex_err!("Missing array argument"))?;
112        let dtype = args.inputs[1]
113            .dtype()
114            .ok_or_else(|| vortex_err!("Missing dtype argument"))?;
115
116        Ok(CastArgs { array, dtype })
117    }
118}
119
120pub struct CastKernelRef(ArcRef<dyn Kernel>);
121inventory::collect!(CastKernelRef);
122
123pub trait CastKernel: VTable {
124    fn cast(&self, array: &Self::Array, dtype: &DType) -> VortexResult<Option<ArrayRef>>;
125}
126
127#[derive(Debug)]
128pub struct CastKernelAdapter<V: VTable>(pub V);
129
130impl<V: VTable + CastKernel> CastKernelAdapter<V> {
131    pub const fn lift(&'static self) -> CastKernelRef {
132        CastKernelRef(ArcRef::new_ref(self))
133    }
134}
135
136impl<V: VTable + CastKernel> Kernel for CastKernelAdapter<V> {
137    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
138        let CastArgs { array, dtype } = CastArgs::try_from(args)?;
139        let Some(array) = array.as_opt::<V>() else {
140            return Ok(None);
141        };
142
143        Ok(V::cast(&self.0, array, dtype)?.map(|o| o.into()))
144    }
145}