vortex_array/arrays/extension/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::DType;
5
6use crate::arrays::{ExtensionArray, ExtensionVTable};
7use crate::compute::{CastKernel, CastKernelAdapter, cast};
8use crate::{ArrayRef, IntoArray, register_kernel};
9
10impl CastKernel for ExtensionVTable {
11    fn cast(
12        &self,
13        array: &ExtensionArray,
14        dtype: &DType,
15    ) -> vortex_error::VortexResult<Option<ArrayRef>> {
16        if !array.dtype().eq_ignore_nullability(dtype) {
17            return Ok(None);
18        }
19
20        let DType::Extension(ext_dtype) = dtype else {
21            unreachable!("Already verified we have an extension dtype");
22        };
23
24        let new_storage = match cast(array.storage(), ext_dtype.storage_dtype()) {
25            Ok(arr) => arr,
26            Err(e) => {
27                log::warn!("Failed to cast storage array: {e}");
28                return Ok(None);
29            }
30        };
31
32        Ok(Some(
33            ExtensionArray::new(ext_dtype.clone(), new_storage).into_array(),
34        ))
35    }
36}
37
38register_kernel!(CastKernelAdapter(ExtensionVTable).lift());
39
40#[cfg(test)]
41mod tests {
42    use std::sync::Arc;
43
44    use vortex_dtype::datetime::{TIMESTAMP_ID, TemporalMetadata, TimeUnit};
45    use vortex_dtype::{ExtDType, Nullability, PType};
46
47    use super::*;
48    use crate::arrays::PrimitiveArray;
49
50    #[test]
51    fn cast_same_ext_dtype() {
52        let ext_dtype = Arc::new(ExtDType::new(
53            TIMESTAMP_ID.clone(),
54            Arc::new(PType::I64.into()),
55            Some(TemporalMetadata::Timestamp(TimeUnit::Ms, None).into()),
56        ));
57        let storage = PrimitiveArray::from_iter(Vec::<i64>::new()).into_array();
58
59        let arr = ExtensionArray::new(ext_dtype.clone(), storage);
60
61        let output = cast(arr.as_ref(), &DType::Extension(ext_dtype.clone())).unwrap();
62        assert_eq!(arr.len(), output.len());
63        assert_eq!(arr.dtype(), output.dtype());
64        assert_eq!(output.dtype(), &DType::Extension(ext_dtype));
65    }
66
67    #[test]
68    fn cast_same_ext_dtype_differet_nullability() {
69        let ext_dtype = Arc::new(ExtDType::new(
70            TIMESTAMP_ID.clone(),
71            Arc::new(PType::I64.into()),
72            Some(TemporalMetadata::Timestamp(TimeUnit::Ms, None).into()),
73        ));
74        let storage = PrimitiveArray::from_iter(Vec::<i64>::new()).into_array();
75
76        let arr = ExtensionArray::new(ext_dtype.clone(), storage);
77        assert!(!arr.dtype.is_nullable());
78
79        let new_dtype = DType::Extension(ext_dtype).with_nullability(Nullability::Nullable);
80
81        let output = cast(arr.as_ref(), &new_dtype).unwrap();
82        assert_eq!(arr.len(), output.len());
83        assert!(arr.dtype().eq_ignore_nullability(output.dtype()));
84        assert_eq!(output.dtype(), &new_dtype);
85    }
86
87    #[test]
88    fn cast_different_ext_dtype() {
89        let original_dtype = Arc::new(ExtDType::new(
90            TIMESTAMP_ID.clone(),
91            Arc::new(PType::I64.into()),
92            Some(TemporalMetadata::Timestamp(TimeUnit::Ms, None).into()),
93        ));
94        let target_dtype = Arc::new(ExtDType::new(
95            TIMESTAMP_ID.clone(),
96            Arc::new(PType::I64.into()),
97            // Note NS here instead of MS
98            Some(TemporalMetadata::Timestamp(TimeUnit::Ns, None).into()),
99        ));
100
101        let storage = PrimitiveArray::from_iter(Vec::<i64>::new()).into_array();
102        let arr = ExtensionArray::new(original_dtype, storage);
103
104        assert!(cast(arr.as_ref(), &DType::Extension(target_dtype)).is_err());
105    }
106}