vortex_array/arrays/extension/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::DType;
5
6use crate::arrays::{ExtensionArray, ExtensionVTable};
7use crate::compute::{CastKernel, CastKernelAdapter, cast};
8use crate::{ArrayRef, IntoArray, register_kernel};
9
10impl CastKernel for ExtensionVTable {
11    fn cast(
12        &self,
13        array: &ExtensionArray,
14        dtype: &DType,
15    ) -> vortex_error::VortexResult<Option<ArrayRef>> {
16        if !array.dtype().eq_ignore_nullability(dtype) {
17            return Ok(None);
18        }
19
20        let DType::Extension(ext_dtype) = dtype else {
21            unreachable!("Already verified we have an extension dtype");
22        };
23
24        let new_storage = match cast(array.storage(), ext_dtype.storage_dtype()) {
25            Ok(arr) => arr,
26            Err(e) => {
27                log::warn!("Failed to cast storage array: {e}");
28                return Ok(None);
29            }
30        };
31
32        Ok(Some(
33            ExtensionArray::new(ext_dtype.clone(), new_storage).into_array(),
34        ))
35    }
36}
37
38register_kernel!(CastKernelAdapter(ExtensionVTable).lift());
39
40#[cfg(test)]
41mod tests {
42    use std::sync::Arc;
43
44    use rstest::rstest;
45    use vortex_buffer::buffer;
46    use vortex_dtype::datetime::{TIMESTAMP_ID, TemporalMetadata, TimeUnit};
47    use vortex_dtype::{ExtDType, Nullability, PType};
48
49    use super::*;
50    use crate::IntoArray;
51    use crate::arrays::PrimitiveArray;
52    use crate::compute::conformance::cast::test_cast_conformance;
53
54    #[test]
55    fn cast_same_ext_dtype() {
56        let ext_dtype = Arc::new(ExtDType::new(
57            TIMESTAMP_ID.clone(),
58            Arc::new(PType::I64.into()),
59            Some(TemporalMetadata::Timestamp(TimeUnit::Milliseconds, None).into()),
60        ));
61        let storage = PrimitiveArray::from_iter(Vec::<i64>::new()).into_array();
62
63        let arr = ExtensionArray::new(ext_dtype.clone(), storage);
64
65        let output = cast(arr.as_ref(), &DType::Extension(ext_dtype.clone())).unwrap();
66        assert_eq!(arr.len(), output.len());
67        assert_eq!(arr.dtype(), output.dtype());
68        assert_eq!(output.dtype(), &DType::Extension(ext_dtype));
69    }
70
71    #[test]
72    fn cast_same_ext_dtype_differet_nullability() {
73        let ext_dtype = Arc::new(ExtDType::new(
74            TIMESTAMP_ID.clone(),
75            Arc::new(PType::I64.into()),
76            Some(TemporalMetadata::Timestamp(TimeUnit::Milliseconds, None).into()),
77        ));
78        let storage = PrimitiveArray::from_iter(Vec::<i64>::new()).into_array();
79
80        let arr = ExtensionArray::new(ext_dtype.clone(), storage);
81        assert!(!arr.dtype.is_nullable());
82
83        let new_dtype = DType::Extension(ext_dtype).with_nullability(Nullability::Nullable);
84
85        let output = cast(arr.as_ref(), &new_dtype).unwrap();
86        assert_eq!(arr.len(), output.len());
87        assert!(arr.dtype().eq_ignore_nullability(output.dtype()));
88        assert_eq!(output.dtype(), &new_dtype);
89    }
90
91    #[test]
92    fn cast_different_ext_dtype() {
93        let original_dtype = Arc::new(ExtDType::new(
94            TIMESTAMP_ID.clone(),
95            Arc::new(PType::I64.into()),
96            Some(TemporalMetadata::Timestamp(TimeUnit::Milliseconds, None).into()),
97        ));
98        let target_dtype = Arc::new(ExtDType::new(
99            TIMESTAMP_ID.clone(),
100            Arc::new(PType::I64.into()),
101            // Note NS here instead of MS
102            Some(TemporalMetadata::Timestamp(TimeUnit::Nanoseconds, None).into()),
103        ));
104
105        let storage = PrimitiveArray::from_iter(Vec::<i64>::new()).into_array();
106        let arr = ExtensionArray::new(original_dtype, storage);
107
108        assert!(cast(arr.as_ref(), &DType::Extension(target_dtype)).is_err());
109    }
110
111    #[rstest]
112    #[case(create_timestamp_array(TimeUnit::Milliseconds, false))]
113    #[case(create_timestamp_array(TimeUnit::Microseconds, true))]
114    #[case(create_timestamp_array(TimeUnit::Nanoseconds, false))]
115    #[case(create_timestamp_array(TimeUnit::Seconds, true))]
116    fn test_cast_extension_conformance(#[case] array: ExtensionArray) {
117        test_cast_conformance(array.as_ref());
118    }
119
120    fn create_timestamp_array(time_unit: TimeUnit, nullable: bool) -> ExtensionArray {
121        let ext_dtype = Arc::new(ExtDType::new(
122            TIMESTAMP_ID.clone(),
123            Arc::new(if nullable {
124                DType::Primitive(PType::I64, Nullability::Nullable)
125            } else {
126                DType::Primitive(PType::I64, Nullability::NonNullable)
127            }),
128            Some(TemporalMetadata::Timestamp(time_unit, Some("UTC".to_string())).into()),
129        ));
130
131        let storage = if nullable {
132            PrimitiveArray::from_option_iter([
133                Some(1_000_000i64), // 1 second in microseconds
134                None,
135                Some(2_000_000),
136                Some(3_000_000),
137                None,
138            ])
139            .into_array()
140        } else {
141            buffer![1_000_000i64, 2_000_000, 3_000_000, 4_000_000, 5_000_000].into_array()
142        };
143
144        ExtensionArray::new(ext_dtype, storage)
145    }
146}