Skip to main content

vortex_array/arrays/extension/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use crate::ArrayRef;
5use crate::IntoArray;
6use crate::arrays::ExtensionArray;
7use crate::arrays::ExtensionVTable;
8use crate::builtins::ArrayBuiltins;
9use crate::dtype::DType;
10use crate::scalar_fn::fns::cast::CastReduce;
11
12impl CastReduce for ExtensionVTable {
13    fn cast(array: &ExtensionArray, dtype: &DType) -> vortex_error::VortexResult<Option<ArrayRef>> {
14        if !array.dtype().eq_ignore_nullability(dtype) {
15            return Ok(None);
16        }
17
18        let DType::Extension(ext_dtype) = dtype else {
19            unreachable!("Already verified we have an extension dtype");
20        };
21
22        let new_storage = match array
23            .storage()
24            .cast(ext_dtype.storage_dtype().clone())
25            .and_then(|a| a.to_canonical().map(|c| c.into_array()))
26        {
27            Ok(arr) => arr,
28            Err(e) => {
29                tracing::warn!("Failed to cast storage array: {e}");
30                return Ok(None);
31            }
32        };
33
34        Ok(Some(
35            ExtensionArray::new(ext_dtype.clone(), new_storage).into_array(),
36        ))
37    }
38}
39
40#[cfg(test)]
41mod tests {
42
43    use rstest::rstest;
44    use vortex_buffer::Buffer;
45    use vortex_buffer::buffer;
46
47    use super::*;
48    use crate::IntoArray;
49    use crate::arrays::PrimitiveArray;
50    use crate::builtins::ArrayBuiltins;
51    use crate::compute::conformance::cast::test_cast_conformance;
52    use crate::dtype::Nullability;
53    use crate::extension::datetime::TimeUnit;
54    use crate::extension::datetime::Timestamp;
55
56    #[test]
57    fn cast_same_ext_dtype() {
58        let ext_dtype = Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
59        let storage = Buffer::<i64>::empty().into_array();
60
61        let arr = ExtensionArray::new(ext_dtype.clone(), storage);
62
63        let output = arr
64            .to_array()
65            .cast(DType::Extension(ext_dtype.clone()))
66            .unwrap();
67        assert_eq!(arr.len(), output.len());
68        assert_eq!(arr.dtype(), output.dtype());
69        assert_eq!(output.dtype(), &DType::Extension(ext_dtype));
70    }
71
72    #[test]
73    fn cast_same_ext_dtype_differet_nullability() {
74        let ext_dtype = Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
75        let storage = Buffer::<i64>::empty().into_array();
76
77        let arr = ExtensionArray::new(ext_dtype.clone(), storage);
78        assert!(!arr.dtype.is_nullable());
79
80        let new_dtype = DType::Extension(ext_dtype).with_nullability(Nullability::Nullable);
81
82        let output = arr.to_array().cast(new_dtype.clone()).unwrap();
83        assert_eq!(arr.len(), output.len());
84        assert!(arr.dtype().eq_ignore_nullability(output.dtype()));
85        assert_eq!(output.dtype(), &new_dtype);
86    }
87
88    #[test]
89    fn cast_different_ext_dtype() {
90        let original_dtype =
91            Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
92        // Note NS here instead of MS
93        let target_dtype = Timestamp::new(TimeUnit::Nanoseconds, Nullability::NonNullable).erased();
94
95        let storage = buffer![1i64].into_array();
96        let arr = ExtensionArray::new(original_dtype, storage);
97
98        assert!(
99            arr.to_array()
100                .cast(DType::Extension(target_dtype))
101                .and_then(|a| a.to_canonical().map(|c| c.into_array()))
102                .is_err()
103        );
104    }
105
106    #[rstest]
107    #[case(create_timestamp_array(TimeUnit::Milliseconds, false))]
108    #[case(create_timestamp_array(TimeUnit::Microseconds, true))]
109    #[case(create_timestamp_array(TimeUnit::Nanoseconds, false))]
110    #[case(create_timestamp_array(TimeUnit::Seconds, true))]
111    fn test_cast_extension_conformance(#[case] array: ExtensionArray) {
112        test_cast_conformance(array.as_ref());
113    }
114
115    fn create_timestamp_array(time_unit: TimeUnit, nullable: bool) -> ExtensionArray {
116        let ext_dtype =
117            Timestamp::new_with_tz(time_unit, Some("UTC".into()), nullable.into()).erased();
118
119        let storage = if nullable {
120            PrimitiveArray::from_option_iter([
121                Some(1_000_000i64), // 1 second in microseconds
122                None,
123                Some(2_000_000),
124                Some(3_000_000),
125                None,
126            ])
127            .into_array()
128        } else {
129            buffer![1_000_000i64, 2_000_000, 3_000_000, 4_000_000, 5_000_000].into_array()
130        };
131
132        ExtensionArray::new(ext_dtype, storage)
133    }
134}