Skip to main content

vortex_array/arrays/extension/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use crate::ArrayRef;
5use crate::IntoArray;
6use crate::array::ArrayView;
7use crate::arrays::Extension;
8use crate::arrays::ExtensionArray;
9use crate::arrays::extension::ExtensionArrayExt;
10use crate::builtins::ArrayBuiltins;
11use crate::dtype::DType;
12use crate::scalar_fn::fns::cast::CastReduce;
13
14impl CastReduce for Extension {
15    fn cast(
16        array: ArrayView<'_, Extension>,
17        dtype: &DType,
18    ) -> vortex_error::VortexResult<Option<ArrayRef>> {
19        if !array.dtype().eq_ignore_nullability(dtype) {
20            return Ok(None);
21        }
22
23        let DType::Extension(ext_dtype) = dtype else {
24            unreachable!("Already verified we have an extension dtype");
25        };
26
27        let new_storage = match array
28            .storage_array()
29            .cast(ext_dtype.storage_dtype().clone())
30        {
31            Ok(arr) => arr,
32            Err(e) => {
33                tracing::warn!("Failed to cast storage array: {e}");
34                return Ok(None);
35            }
36        };
37
38        Ok(Some(
39            ExtensionArray::new(ext_dtype.clone(), new_storage).into_array(),
40        ))
41    }
42}
43
44#[cfg(test)]
45mod tests {
46
47    use rstest::rstest;
48    use vortex_buffer::Buffer;
49    use vortex_buffer::buffer;
50
51    use super::*;
52    use crate::IntoArray;
53    use crate::arrays::PrimitiveArray;
54    use crate::builtins::ArrayBuiltins;
55    use crate::compute::conformance::cast::test_cast_conformance;
56    use crate::dtype::Nullability;
57    use crate::extension::datetime::TimeUnit;
58    use crate::extension::datetime::Timestamp;
59
60    #[test]
61    fn cast_same_ext_dtype() {
62        let ext_dtype = Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
63        let storage = Buffer::<i64>::empty().into_array();
64
65        let arr = ExtensionArray::new(ext_dtype.clone(), storage);
66
67        let output = arr
68            .clone()
69            .into_array()
70            .cast(DType::Extension(ext_dtype.clone()))
71            .unwrap();
72        assert_eq!(arr.len(), output.len());
73        assert_eq!(arr.dtype(), output.dtype());
74        assert_eq!(output.dtype(), &DType::Extension(ext_dtype));
75    }
76
77    #[test]
78    fn cast_same_ext_dtype_differet_nullability() {
79        let ext_dtype = Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
80        let storage = Buffer::<i64>::empty().into_array();
81
82        let arr = ExtensionArray::new(ext_dtype.clone(), storage);
83        assert!(!arr.dtype().is_nullable());
84
85        let new_dtype = DType::Extension(ext_dtype).with_nullability(Nullability::Nullable);
86
87        let output = arr.clone().into_array().cast(new_dtype.clone()).unwrap();
88        assert_eq!(arr.len(), output.len());
89        assert!(arr.dtype().eq_ignore_nullability(output.dtype()));
90        assert_eq!(output.dtype(), &new_dtype);
91    }
92
93    #[test]
94    fn cast_different_ext_dtype() {
95        let original_dtype =
96            Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
97        // Note NS here instead of MS
98        let target_dtype = Timestamp::new(TimeUnit::Nanoseconds, Nullability::NonNullable).erased();
99
100        let storage = buffer![1i64].into_array();
101        let arr = ExtensionArray::new(original_dtype, storage);
102
103        assert!(
104            arr.into_array()
105                .cast(DType::Extension(target_dtype))
106                .and_then(|a| a.to_canonical().map(|c| c.into_array()))
107                .is_err()
108        );
109    }
110
111    #[rstest]
112    #[case(create_timestamp_array(TimeUnit::Milliseconds, false))]
113    #[case(create_timestamp_array(TimeUnit::Microseconds, true))]
114    #[case(create_timestamp_array(TimeUnit::Nanoseconds, false))]
115    #[case(create_timestamp_array(TimeUnit::Seconds, true))]
116    fn test_cast_extension_conformance(#[case] array: ExtensionArray) {
117        test_cast_conformance(&array.into_array());
118    }
119
120    fn create_timestamp_array(time_unit: TimeUnit, nullable: bool) -> ExtensionArray {
121        let ext_dtype =
122            Timestamp::new_with_tz(time_unit, Some("UTC".into()), nullable.into()).erased();
123
124        let storage = if nullable {
125            PrimitiveArray::from_option_iter([
126                Some(1_000_000i64), // 1 second in microseconds
127                None,
128                Some(2_000_000),
129                Some(3_000_000),
130                None,
131            ])
132            .into_array()
133        } else {
134            buffer![1_000_000i64, 2_000_000, 3_000_000, 4_000_000, 5_000_000].into_array()
135        };
136
137        ExtensionArray::new(ext_dtype, storage)
138    }
139}