Skip to main content

vortex_array/arrays/extension/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use crate::ArrayRef;
5use crate::IntoArray;
6use crate::arrays::ExtensionArray;
7use crate::arrays::ExtensionVTable;
8use crate::builtins::ArrayBuiltins;
9use crate::dtype::DType;
10use crate::scalar_fn::fns::cast::CastReduce;
11
12impl CastReduce for ExtensionVTable {
13    fn cast(array: &ExtensionArray, dtype: &DType) -> vortex_error::VortexResult<Option<ArrayRef>> {
14        if !array.dtype().eq_ignore_nullability(dtype) {
15            return Ok(None);
16        }
17
18        let DType::Extension(ext_dtype) = dtype else {
19            unreachable!("Already verified we have an extension dtype");
20        };
21
22        let new_storage = match array.storage().cast(ext_dtype.storage_dtype().clone()) {
23            Ok(arr) => arr,
24            Err(e) => {
25                tracing::warn!("Failed to cast storage array: {e}");
26                return Ok(None);
27            }
28        };
29
30        Ok(Some(
31            ExtensionArray::new(ext_dtype.clone(), new_storage).into_array(),
32        ))
33    }
34}
35
36#[cfg(test)]
37mod tests {
38
39    use rstest::rstest;
40    use vortex_buffer::Buffer;
41    use vortex_buffer::buffer;
42
43    use super::*;
44    use crate::IntoArray;
45    use crate::arrays::PrimitiveArray;
46    use crate::builtins::ArrayBuiltins;
47    use crate::compute::conformance::cast::test_cast_conformance;
48    use crate::dtype::Nullability;
49    use crate::extension::datetime::TimeUnit;
50    use crate::extension::datetime::Timestamp;
51
52    #[test]
53    fn cast_same_ext_dtype() {
54        let ext_dtype = Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
55        let storage = Buffer::<i64>::empty().into_array();
56
57        let arr = ExtensionArray::new(ext_dtype.clone(), storage);
58
59        let output = arr
60            .clone()
61            .into_array()
62            .cast(DType::Extension(ext_dtype.clone()))
63            .unwrap();
64        assert_eq!(arr.len(), output.len());
65        assert_eq!(arr.dtype(), output.dtype());
66        assert_eq!(output.dtype(), &DType::Extension(ext_dtype));
67    }
68
69    #[test]
70    fn cast_same_ext_dtype_differet_nullability() {
71        let ext_dtype = Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
72        let storage = Buffer::<i64>::empty().into_array();
73
74        let arr = ExtensionArray::new(ext_dtype.clone(), storage);
75        assert!(!arr.dtype.is_nullable());
76
77        let new_dtype = DType::Extension(ext_dtype).with_nullability(Nullability::Nullable);
78
79        let output = arr.clone().into_array().cast(new_dtype.clone()).unwrap();
80        assert_eq!(arr.len(), output.len());
81        assert!(arr.dtype().eq_ignore_nullability(output.dtype()));
82        assert_eq!(output.dtype(), &new_dtype);
83    }
84
85    #[test]
86    fn cast_different_ext_dtype() {
87        let original_dtype =
88            Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
89        // Note NS here instead of MS
90        let target_dtype = Timestamp::new(TimeUnit::Nanoseconds, Nullability::NonNullable).erased();
91
92        let storage = buffer![1i64].into_array();
93        let arr = ExtensionArray::new(original_dtype, storage);
94
95        assert!(
96            arr.into_array()
97                .cast(DType::Extension(target_dtype))
98                .and_then(|a| a.to_canonical().map(|c| c.into_array()))
99                .is_err()
100        );
101    }
102
103    #[rstest]
104    #[case(create_timestamp_array(TimeUnit::Milliseconds, false))]
105    #[case(create_timestamp_array(TimeUnit::Microseconds, true))]
106    #[case(create_timestamp_array(TimeUnit::Nanoseconds, false))]
107    #[case(create_timestamp_array(TimeUnit::Seconds, true))]
108    fn test_cast_extension_conformance(#[case] array: ExtensionArray) {
109        test_cast_conformance(&array.into_array());
110    }
111
112    fn create_timestamp_array(time_unit: TimeUnit, nullable: bool) -> ExtensionArray {
113        let ext_dtype =
114            Timestamp::new_with_tz(time_unit, Some("UTC".into()), nullable.into()).erased();
115
116        let storage = if nullable {
117            PrimitiveArray::from_option_iter([
118                Some(1_000_000i64), // 1 second in microseconds
119                None,
120                Some(2_000_000),
121                Some(3_000_000),
122                None,
123            ])
124            .into_array()
125        } else {
126            buffer![1_000_000i64, 2_000_000, 3_000_000, 4_000_000, 5_000_000].into_array()
127        };
128
129        ExtensionArray::new(ext_dtype, storage)
130    }
131}