vortex_array/compute/
cast.rs

1use std::sync::LazyLock;
2
3use arcref::ArcRef;
4use vortex_dtype::Nullability::Nullable;
5use vortex_dtype::{DType, PType};
6use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
7
8use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output};
9use crate::vtable::VTable;
10use crate::{Array, ArrayRef};
11
12/// Attempt to cast an array to a desired DType.
13///
14/// Some array support the ability to narrow or upcast.
15pub fn cast(array: &dyn Array, dtype: &DType) -> VortexResult<ArrayRef> {
16    CAST_FN
17        .invoke(&InvocationArgs {
18            inputs: &[array.into(), dtype.into()],
19            options: &(),
20        })?
21        .unwrap_array()
22}
23
24pub static CAST_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
25    let compute = ComputeFn::new("cast".into(), ArcRef::new_ref(&Cast));
26    for kernel in inventory::iter::<CastKernelRef> {
27        compute.register_kernel(kernel.0.clone());
28    }
29    compute
30});
31
32struct Cast;
33
34impl ComputeFnVTable for Cast {
35    fn invoke(
36        &self,
37        args: &InvocationArgs,
38        kernels: &[ArcRef<dyn Kernel>],
39    ) -> VortexResult<Output> {
40        let CastArgs { array, dtype } = CastArgs::try_from(args)?;
41
42        if array.dtype() == dtype {
43            return Ok(array.to_array().into());
44        }
45
46        // TODO(ngates): check for null_count if dtype is non-nullable
47
48        for kernel in kernels {
49            if let Some(output) = kernel.invoke(args)? {
50                return Ok(output);
51            }
52        }
53        if let Some(output) = array.invoke(&CAST_FN, args)? {
54            return Ok(output);
55        }
56
57        // Otherwise, we fall back to the canonical implementations.
58        log::debug!(
59            "Falling back to canonical cast for encoding {} and dtype {} to {}",
60            array.encoding_id(),
61            array.dtype(),
62            dtype
63        );
64        if array.is_canonical() {
65            vortex_bail!(
66                "No compute kernel to cast array {} to {}",
67                array.encoding_id(),
68                dtype
69            );
70        }
71
72        Ok(cast(array.to_canonical()?.as_ref(), dtype)?.into())
73    }
74
75    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
76        let CastArgs { dtype, .. } = CastArgs::try_from(args)?;
77        Ok(dtype.clone())
78    }
79
80    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
81        let CastArgs { array, .. } = CastArgs::try_from(args)?;
82        Ok(array.len())
83    }
84
85    fn is_elementwise(&self) -> bool {
86        true
87    }
88}
89
90struct CastArgs<'a> {
91    array: &'a dyn Array,
92    dtype: &'a DType,
93}
94
95impl<'a> TryFrom<&InvocationArgs<'a>> for CastArgs<'a> {
96    type Error = VortexError;
97
98    fn try_from(args: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
99        if args.inputs.len() != 2 {
100            vortex_bail!(
101                "Cast function requires 2 arguments, but got {}",
102                args.inputs.len()
103            );
104        }
105        let array = args.inputs[0]
106            .array()
107            .ok_or_else(|| vortex_err!("Missing array argument"))?;
108        let dtype = args.inputs[1]
109            .dtype()
110            .ok_or_else(|| vortex_err!("Missing dtype argument"))?;
111
112        Ok(CastArgs { array, dtype })
113    }
114}
115
116pub struct CastKernelRef(ArcRef<dyn Kernel>);
117inventory::collect!(CastKernelRef);
118
119pub trait CastKernel: VTable {
120    fn cast(&self, array: &Self::Array, dtype: &DType) -> VortexResult<ArrayRef>;
121}
122
123#[derive(Debug)]
124pub struct CastKernelAdapter<V: VTable>(pub V);
125
126impl<V: VTable + CastKernel> CastKernelAdapter<V> {
127    pub const fn lift(&'static self) -> CastKernelRef {
128        CastKernelRef(ArcRef::new_ref(self))
129    }
130}
131
132impl<V: VTable + CastKernel> Kernel for CastKernelAdapter<V> {
133    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
134        let CastArgs { array, dtype } = CastArgs::try_from(args)?;
135        let Some(array) = array.as_opt::<V>() else {
136            return Ok(None);
137        };
138        Ok(Some(V::cast(&self.0, array, dtype)?.into()))
139    }
140}
141
142#[derive(Debug, Clone, Copy, PartialEq, Eq)]
143pub enum CastOutcome {
144    Fallible,
145    Infallible,
146}
147
148pub fn allowed_casting(from: &DType, to: &DType) -> Option<CastOutcome> {
149    // Can cast to include nullability
150    if &from.with_nullability(Nullable) == to {
151        return Some(CastOutcome::Infallible);
152    }
153    match (from, to) {
154        (DType::Primitive(from_ptype, _), DType::Primitive(to_ptype, _)) => {
155            allowed_casting_ptype(*from_ptype, *to_ptype)
156        }
157        _ => None,
158    }
159}
160
161pub fn allowed_casting_ptype(from: PType, to: PType) -> Option<CastOutcome> {
162    use CastOutcome::*;
163    use PType::*;
164
165    match (from, to) {
166        // Identity casts
167        (a, b) if a == b => Some(Infallible),
168
169        // Integer widening (always infallible)
170        (U8, U16 | U32 | U64)
171        | (U16, U32 | U64)
172        | (U32, U64)
173        | (I8, I16 | I32 | I64)
174        | (I16, I32 | I64)
175        | (I32, I64) => Some(Infallible),
176
177        // Integer narrowing (may truncate)
178        (U16 | U32 | U64, U8)
179        | (U32 | U64, U16)
180        | (U64, U32)
181        | (I16 | I32 | I64, I8)
182        | (I32 | I64, I16)
183        | (I64, I32) => Some(Fallible),
184
185        // Between signed and unsigned (fallible if negative or too big)
186        (I8 | I16 | I32 | I64, U8 | U16 | U32 | U64)
187        | (U8 | U16 | U32 | U64, I8 | I16 | I32 | I64) => Some(Fallible),
188
189        // TODO(joe): shall we allow float/int casting?
190        // Integer -> Float
191        // (U8 | U16 | U32 | U64 | I8 | I16 | I32 | I64, F16 | F32 | F64) => Some(Fallible),
192
193        // Float -> Integer (truncates, overflows possible)
194        // (F16 | F32 | F64, U8 | U16 | U32 | U64 | I8 | I16 | I32 | I64) => Some(Fallible),
195
196        // Float widening (safe)
197        (F16, F32 | F64) | (F32, F64) => Some(Infallible),
198
199        // Float narrowing (lossy)
200        (F64, F32 | F16) | (F32, F16) => Some(Fallible),
201
202        _ => None,
203    }
204}