Skip to main content

vortex_array/arrays/extension/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::DType;
5
6use crate::ArrayRef;
7use crate::IntoArray;
8use crate::arrays::ExtensionArray;
9use crate::arrays::ExtensionVTable;
10use crate::builtins::ArrayBuiltins;
11use crate::compute::CastReduce;
12
13impl CastReduce for ExtensionVTable {
14    fn cast(array: &ExtensionArray, dtype: &DType) -> vortex_error::VortexResult<Option<ArrayRef>> {
15        if !array.dtype().eq_ignore_nullability(dtype) {
16            return Ok(None);
17        }
18
19        let DType::Extension(ext_dtype) = dtype else {
20            unreachable!("Already verified we have an extension dtype");
21        };
22
23        let new_storage = match array
24            .storage()
25            .cast(ext_dtype.storage_dtype().clone())
26            .and_then(|a| a.to_canonical().map(|c| c.into_array()))
27        {
28            Ok(arr) => arr,
29            Err(e) => {
30                tracing::warn!("Failed to cast storage array: {e}");
31                return Ok(None);
32            }
33        };
34
35        Ok(Some(
36            ExtensionArray::new(ext_dtype.clone(), new_storage).into_array(),
37        ))
38    }
39}
40
41#[cfg(test)]
42mod tests {
43
44    use rstest::rstest;
45    use vortex_buffer::Buffer;
46    use vortex_buffer::buffer;
47    use vortex_dtype::Nullability;
48    use vortex_dtype::datetime::TimeUnit;
49    use vortex_dtype::datetime::Timestamp;
50
51    use super::*;
52    use crate::IntoArray;
53    use crate::arrays::PrimitiveArray;
54    use crate::builtins::ArrayBuiltins;
55    use crate::compute::conformance::cast::test_cast_conformance;
56
57    #[test]
58    fn cast_same_ext_dtype() {
59        let ext_dtype = Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
60        let storage = Buffer::<i64>::empty().into_array();
61
62        let arr = ExtensionArray::new(ext_dtype.clone(), storage);
63
64        let output = arr
65            .to_array()
66            .cast(DType::Extension(ext_dtype.clone()))
67            .unwrap();
68        assert_eq!(arr.len(), output.len());
69        assert_eq!(arr.dtype(), output.dtype());
70        assert_eq!(output.dtype(), &DType::Extension(ext_dtype));
71    }
72
73    #[test]
74    fn cast_same_ext_dtype_differet_nullability() {
75        let ext_dtype = Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
76        let storage = Buffer::<i64>::empty().into_array();
77
78        let arr = ExtensionArray::new(ext_dtype.clone(), storage);
79        assert!(!arr.dtype.is_nullable());
80
81        let new_dtype = DType::Extension(ext_dtype).with_nullability(Nullability::Nullable);
82
83        let output = arr.to_array().cast(new_dtype.clone()).unwrap();
84        assert_eq!(arr.len(), output.len());
85        assert!(arr.dtype().eq_ignore_nullability(output.dtype()));
86        assert_eq!(output.dtype(), &new_dtype);
87    }
88
89    #[test]
90    fn cast_different_ext_dtype() {
91        let original_dtype =
92            Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
93        // Note NS here instead of MS
94        let target_dtype = Timestamp::new(TimeUnit::Nanoseconds, Nullability::NonNullable).erased();
95
96        let storage = buffer![1i64].into_array();
97        let arr = ExtensionArray::new(original_dtype, storage);
98
99        assert!(
100            arr.to_array()
101                .cast(DType::Extension(target_dtype))
102                .and_then(|a| a.to_canonical().map(|c| c.into_array()))
103                .is_err()
104        );
105    }
106
107    #[rstest]
108    #[case(create_timestamp_array(TimeUnit::Milliseconds, false))]
109    #[case(create_timestamp_array(TimeUnit::Microseconds, true))]
110    #[case(create_timestamp_array(TimeUnit::Nanoseconds, false))]
111    #[case(create_timestamp_array(TimeUnit::Seconds, true))]
112    fn test_cast_extension_conformance(#[case] array: ExtensionArray) {
113        test_cast_conformance(array.as_ref());
114    }
115
116    fn create_timestamp_array(time_unit: TimeUnit, nullable: bool) -> ExtensionArray {
117        let ext_dtype =
118            Timestamp::new_with_tz(time_unit, Some("UTC".into()), nullable.into()).erased();
119
120        let storage = if nullable {
121            PrimitiveArray::from_option_iter([
122                Some(1_000_000i64), // 1 second in microseconds
123                None,
124                Some(2_000_000),
125                Some(3_000_000),
126                None,
127            ])
128            .into_array()
129        } else {
130            buffer![1_000_000i64, 2_000_000, 3_000_000, 4_000_000, 5_000_000].into_array()
131        };
132
133        ExtensionArray::new(ext_dtype, storage)
134    }
135}