Skip to main content

vortex_array/arrays/extension/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use crate::ArrayRef;
5use crate::IntoArray;
6use crate::array::ArrayView;
7use crate::arrays::Extension;
8use crate::arrays::ExtensionArray;
9use crate::arrays::extension::ExtensionArrayExt;
10use crate::builtins::ArrayBuiltins;
11use crate::dtype::DType;
12use crate::scalar_fn::fns::cast::CastReduce;
13
14impl CastReduce for Extension {
15    fn cast(
16        array: ArrayView<'_, Extension>,
17        dtype: &DType,
18    ) -> vortex_error::VortexResult<Option<ArrayRef>> {
19        if !array.dtype().eq_ignore_nullability(dtype) {
20            // Target is not the same extension type.
21            // Delegate to the storage array's cast.
22            return Ok(Some(array.storage_array().cast(dtype.clone())?));
23        }
24
25        let DType::Extension(ext_dtype) = dtype else {
26            unreachable!("Already verified we have an extension dtype");
27        };
28
29        let new_storage = match array
30            .storage_array()
31            .cast(ext_dtype.storage_dtype().clone())
32        {
33            Ok(arr) => arr,
34            Err(e) => {
35                tracing::warn!("Failed to cast storage array: {e}");
36                return Ok(None);
37            }
38        };
39
40        Ok(Some(
41            ExtensionArray::new(ext_dtype.clone(), new_storage).into_array(),
42        ))
43    }
44}
45
46#[cfg(test)]
47mod tests {
48
49    use rstest::rstest;
50    use vortex_buffer::Buffer;
51    use vortex_buffer::buffer;
52
53    use super::*;
54    use crate::IntoArray;
55    use crate::arrays::PrimitiveArray;
56    use crate::assert_arrays_eq;
57    use crate::builtins::ArrayBuiltins;
58    use crate::compute::conformance::cast::test_cast_conformance;
59    use crate::dtype::DType;
60    use crate::dtype::Nullability;
61    use crate::dtype::PType;
62    use crate::extension::datetime::TimeUnit;
63    use crate::extension::datetime::Timestamp;
64
65    #[test]
66    fn cast_same_ext_dtype() {
67        let ext_dtype = Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
68        let storage = Buffer::<i64>::empty().into_array();
69
70        let arr = ExtensionArray::new(ext_dtype.clone(), storage);
71
72        let output = arr
73            .clone()
74            .into_array()
75            .cast(DType::Extension(ext_dtype.clone()))
76            .unwrap();
77        assert_eq!(arr.len(), output.len());
78        assert_eq!(arr.dtype(), output.dtype());
79        assert_eq!(output.dtype(), &DType::Extension(ext_dtype));
80    }
81
82    #[test]
83    fn cast_same_ext_dtype_differet_nullability() {
84        let ext_dtype = Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
85        let storage = Buffer::<i64>::empty().into_array();
86
87        let arr = ExtensionArray::new(ext_dtype.clone(), storage);
88        assert!(!arr.dtype().is_nullable());
89
90        let new_dtype = DType::Extension(ext_dtype).with_nullability(Nullability::Nullable);
91
92        let output = arr.clone().into_array().cast(new_dtype.clone()).unwrap();
93        assert_eq!(arr.len(), output.len());
94        assert!(arr.dtype().eq_ignore_nullability(output.dtype()));
95        assert_eq!(output.dtype(), &new_dtype);
96    }
97
98    #[test]
99    fn cast_different_ext_dtype() {
100        let original_dtype =
101            Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
102        // Note NS here instead of MS
103        let target_dtype = Timestamp::new(TimeUnit::Nanoseconds, Nullability::NonNullable).erased();
104
105        let storage = buffer![1i64].into_array();
106        let arr = ExtensionArray::new(original_dtype, storage);
107
108        assert!(
109            arr.into_array()
110                .cast(DType::Extension(target_dtype))
111                .and_then(|a| a.to_canonical().map(|c| c.into_array()))
112                .is_err()
113        );
114    }
115
116    #[test]
117    fn cast_timestamp_to_i64() -> vortex_error::VortexResult<()> {
118        let ext_dtype = Timestamp::new_with_tz(
119            TimeUnit::Nanoseconds,
120            Some("UTC".into()),
121            Nullability::NonNullable,
122        )
123        .erased();
124        let storage = buffer![1i64, 2, 3].into_array();
125        let arr = ExtensionArray::new(ext_dtype, storage).into_array();
126
127        let result = arr.cast(DType::Primitive(PType::I64, Nullability::NonNullable))?;
128        assert_eq!(
129            result.dtype(),
130            &DType::Primitive(PType::I64, Nullability::NonNullable)
131        );
132        assert_arrays_eq!(result, buffer![1i64, 2, 3].into_array());
133        Ok(())
134    }
135
136    #[rstest]
137    #[case(create_timestamp_array(TimeUnit::Milliseconds, false))]
138    #[case(create_timestamp_array(TimeUnit::Microseconds, true))]
139    #[case(create_timestamp_array(TimeUnit::Nanoseconds, false))]
140    #[case(create_timestamp_array(TimeUnit::Seconds, true))]
141    fn test_cast_extension_conformance(#[case] array: ExtensionArray) {
142        test_cast_conformance(&array.into_array());
143    }
144
145    fn create_timestamp_array(time_unit: TimeUnit, nullable: bool) -> ExtensionArray {
146        let ext_dtype =
147            Timestamp::new_with_tz(time_unit, Some("UTC".into()), nullable.into()).erased();
148
149        let storage = if nullable {
150            PrimitiveArray::from_option_iter([
151                Some(1_000_000i64), // 1 second in microseconds
152                None,
153                Some(2_000_000),
154                Some(3_000_000),
155                None,
156            ])
157            .into_array()
158        } else {
159            buffer![1_000_000i64, 2_000_000, 3_000_000, 4_000_000, 5_000_000].into_array()
160        };
161
162        ExtensionArray::new(ext_dtype, storage)
163    }
164}