Skip to main content

vortex_array/arrays/extension/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::ArrayRef;
7use crate::IntoArray;
8use crate::array::ArrayView;
9use crate::arrays::Extension;
10use crate::arrays::ExtensionArray;
11use crate::arrays::extension::ExtensionArrayExt;
12use crate::builtins::ArrayBuiltins;
13use crate::dtype::DType;
14use crate::scalar_fn::fns::cast::CastReduce;
15
16impl CastReduce for Extension {
17    fn cast(array: ArrayView<'_, Extension>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
18        if !array.dtype().eq_ignore_nullability(dtype) {
19            // Target is not the same extension type.
20            // Delegate to the storage array's cast.
21            return Ok(Some(array.storage_array().cast(dtype.clone())?));
22        }
23
24        let DType::Extension(ext_dtype) = dtype else {
25            unreachable!("Already verified we have an extension dtype");
26        };
27
28        let new_storage = match array
29            .storage_array()
30            .cast(ext_dtype.storage_dtype().clone())
31        {
32            Ok(arr) => arr,
33            Err(e) => {
34                tracing::warn!("Failed to cast storage array: {e}");
35                return Ok(None);
36            }
37        };
38
39        Ok(Some(
40            ExtensionArray::new(ext_dtype.clone(), new_storage).into_array(),
41        ))
42    }
43}
44
45#[cfg(test)]
46mod tests {
47    use std::sync::LazyLock;
48
49    use rstest::rstest;
50    use vortex_buffer::Buffer;
51    use vortex_buffer::buffer;
52    use vortex_session::VortexSession;
53
54    use super::*;
55    use crate::IntoArray;
56    use crate::arrays::PrimitiveArray;
57    use crate::assert_arrays_eq;
58    use crate::builtins::ArrayBuiltins;
59    use crate::compute::conformance::cast::test_cast_conformance;
60    use crate::dtype::DType;
61    use crate::dtype::Nullability;
62    use crate::dtype::PType;
63    use crate::executor::VortexSessionExecute;
64    use crate::extension::datetime::TimeUnit;
65    use crate::extension::datetime::Timestamp;
66    use crate::session::ArraySession;
67
68    static SESSION: LazyLock<VortexSession> =
69        LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
70
71    #[test]
72    fn cast_same_ext_dtype() {
73        let ext_dtype = Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
74        let storage = Buffer::<i64>::empty().into_array();
75
76        let arr = ExtensionArray::new(ext_dtype.clone(), storage);
77
78        let output = arr
79            .clone()
80            .into_array()
81            .cast(DType::Extension(ext_dtype.clone()))
82            .unwrap();
83        assert_eq!(arr.len(), output.len());
84        assert_eq!(arr.dtype(), output.dtype());
85        assert_eq!(output.dtype(), &DType::Extension(ext_dtype));
86    }
87
88    #[test]
89    fn cast_same_ext_dtype_differet_nullability() {
90        let ext_dtype = Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
91        let storage = Buffer::<i64>::empty().into_array();
92
93        let arr = ExtensionArray::new(ext_dtype.clone(), storage);
94        assert!(!arr.dtype().is_nullable());
95
96        let new_dtype = DType::Extension(ext_dtype).with_nullability(Nullability::Nullable);
97
98        let output = arr.clone().into_array().cast(new_dtype.clone()).unwrap();
99        assert_eq!(arr.len(), output.len());
100        assert!(arr.dtype().eq_ignore_nullability(output.dtype()));
101        assert_eq!(output.dtype(), &new_dtype);
102    }
103
104    #[test]
105    fn cast_different_ext_dtype() {
106        let original_dtype =
107            Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
108        // Note NS here instead of MS
109        let target_dtype = Timestamp::new(TimeUnit::Nanoseconds, Nullability::NonNullable).erased();
110
111        let storage = buffer![1i64].into_array();
112        let arr = ExtensionArray::new(original_dtype, storage);
113
114        let result = arr
115            .into_array()
116            .cast(DType::Extension(target_dtype))
117            .and_then(|a| {
118                a.execute::<ExtensionArray>(&mut SESSION.create_execution_ctx())
119                    .map(|c| c.into_array())
120            });
121        assert!(result.is_err());
122    }
123
124    #[test]
125    fn cast_timestamp_to_i64() -> VortexResult<()> {
126        let ext_dtype = Timestamp::new_with_tz(
127            TimeUnit::Nanoseconds,
128            Some("UTC".into()),
129            Nullability::NonNullable,
130        )
131        .erased();
132        let storage = buffer![1i64, 2, 3].into_array();
133        let arr = ExtensionArray::new(ext_dtype, storage).into_array();
134
135        let result = arr.cast(DType::Primitive(PType::I64, Nullability::NonNullable))?;
136        assert_eq!(
137            result.dtype(),
138            &DType::Primitive(PType::I64, Nullability::NonNullable)
139        );
140        assert_arrays_eq!(result, buffer![1i64, 2, 3].into_array());
141        Ok(())
142    }
143
144    #[rstest]
145    #[case(create_timestamp_array(TimeUnit::Milliseconds, false))]
146    #[case(create_timestamp_array(TimeUnit::Microseconds, true))]
147    #[case(create_timestamp_array(TimeUnit::Nanoseconds, false))]
148    #[case(create_timestamp_array(TimeUnit::Seconds, true))]
149    fn test_cast_extension_conformance(#[case] array: ExtensionArray) {
150        test_cast_conformance(&array.into_array());
151    }
152
153    fn create_timestamp_array(time_unit: TimeUnit, nullable: bool) -> ExtensionArray {
154        let ext_dtype =
155            Timestamp::new_with_tz(time_unit, Some("UTC".into()), nullable.into()).erased();
156
157        let storage = if nullable {
158            PrimitiveArray::from_option_iter([
159                Some(1_000_000i64), // 1 second in microseconds
160                None,
161                Some(2_000_000),
162                Some(3_000_000),
163                None,
164            ])
165            .into_array()
166        } else {
167            buffer![1_000_000i64, 2_000_000, 3_000_000, 4_000_000, 5_000_000].into_array()
168        };
169
170        ExtensionArray::new(ext_dtype, storage)
171    }
172}