vortex_array/arrays/extension/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::DType;
5
6use crate::ArrayRef;
7use crate::IntoArray;
8use crate::arrays::ExtensionArray;
9use crate::arrays::ExtensionVTable;
10use crate::compute::CastKernel;
11use crate::compute::CastKernelAdapter;
12use crate::compute::{self};
13use crate::register_kernel;
14
15impl CastKernel for ExtensionVTable {
16    fn cast(
17        &self,
18        array: &ExtensionArray,
19        dtype: &DType,
20    ) -> vortex_error::VortexResult<Option<ArrayRef>> {
21        if !array.dtype().eq_ignore_nullability(dtype) {
22            return Ok(None);
23        }
24
25        let DType::Extension(ext_dtype) = dtype else {
26            unreachable!("Already verified we have an extension dtype");
27        };
28
29        let new_storage = match compute::cast(array.storage(), ext_dtype.storage_dtype()) {
30            Ok(arr) => arr,
31            Err(e) => {
32                log::warn!("Failed to cast storage array: {e}");
33                return Ok(None);
34            }
35        };
36
37        Ok(Some(
38            ExtensionArray::new(ext_dtype.clone(), new_storage).into_array(),
39        ))
40    }
41}
42
43register_kernel!(CastKernelAdapter(ExtensionVTable).lift());
44
45#[cfg(test)]
46mod tests {
47    use std::sync::Arc;
48
49    use rstest::rstest;
50    use vortex_buffer::Buffer;
51    use vortex_buffer::buffer;
52    use vortex_dtype::ExtDType;
53    use vortex_dtype::Nullability;
54    use vortex_dtype::PType;
55    use vortex_dtype::datetime::TIMESTAMP_ID;
56    use vortex_dtype::datetime::TemporalMetadata;
57    use vortex_dtype::datetime::TimeUnit;
58
59    use super::*;
60    use crate::IntoArray;
61    use crate::arrays::PrimitiveArray;
62    use crate::compute::cast;
63    use crate::compute::conformance::cast::test_cast_conformance;
64
65    #[test]
66    fn cast_same_ext_dtype() {
67        let ext_dtype = Arc::new(ExtDType::new(
68            TIMESTAMP_ID.clone(),
69            Arc::new(PType::I64.into()),
70            Some(TemporalMetadata::Timestamp(TimeUnit::Milliseconds, None).into()),
71        ));
72        let storage = Buffer::<i64>::empty().into_array();
73
74        let arr = ExtensionArray::new(ext_dtype.clone(), storage);
75
76        let output = cast(arr.as_ref(), &DType::Extension(ext_dtype.clone())).unwrap();
77        assert_eq!(arr.len(), output.len());
78        assert_eq!(arr.dtype(), output.dtype());
79        assert_eq!(output.dtype(), &DType::Extension(ext_dtype));
80    }
81
82    #[test]
83    fn cast_same_ext_dtype_differet_nullability() {
84        let ext_dtype = Arc::new(ExtDType::new(
85            TIMESTAMP_ID.clone(),
86            Arc::new(PType::I64.into()),
87            Some(TemporalMetadata::Timestamp(TimeUnit::Milliseconds, None).into()),
88        ));
89        let storage = Buffer::<i64>::empty().into_array();
90
91        let arr = ExtensionArray::new(ext_dtype.clone(), storage);
92        assert!(!arr.dtype.is_nullable());
93
94        let new_dtype = DType::Extension(ext_dtype).with_nullability(Nullability::Nullable);
95
96        let output = cast(arr.as_ref(), &new_dtype).unwrap();
97        assert_eq!(arr.len(), output.len());
98        assert!(arr.dtype().eq_ignore_nullability(output.dtype()));
99        assert_eq!(output.dtype(), &new_dtype);
100    }
101
102    #[test]
103    fn cast_different_ext_dtype() {
104        let original_dtype = Arc::new(ExtDType::new(
105            TIMESTAMP_ID.clone(),
106            Arc::new(PType::I64.into()),
107            Some(TemporalMetadata::Timestamp(TimeUnit::Milliseconds, None).into()),
108        ));
109        let target_dtype = Arc::new(ExtDType::new(
110            TIMESTAMP_ID.clone(),
111            Arc::new(PType::I64.into()),
112            // Note NS here instead of MS
113            Some(TemporalMetadata::Timestamp(TimeUnit::Nanoseconds, None).into()),
114        ));
115
116        let storage = buffer![1i64].into_array();
117        let arr = ExtensionArray::new(original_dtype, storage);
118
119        assert!(cast(arr.as_ref(), &DType::Extension(target_dtype)).is_err());
120    }
121
122    #[rstest]
123    #[case(create_timestamp_array(TimeUnit::Milliseconds, false))]
124    #[case(create_timestamp_array(TimeUnit::Microseconds, true))]
125    #[case(create_timestamp_array(TimeUnit::Nanoseconds, false))]
126    #[case(create_timestamp_array(TimeUnit::Seconds, true))]
127    fn test_cast_extension_conformance(#[case] array: ExtensionArray) {
128        test_cast_conformance(array.as_ref());
129    }
130
131    fn create_timestamp_array(time_unit: TimeUnit, nullable: bool) -> ExtensionArray {
132        let ext_dtype = Arc::new(ExtDType::new(
133            TIMESTAMP_ID.clone(),
134            Arc::new(if nullable {
135                DType::Primitive(PType::I64, Nullability::Nullable)
136            } else {
137                DType::Primitive(PType::I64, Nullability::NonNullable)
138            }),
139            Some(TemporalMetadata::Timestamp(time_unit, Some("UTC".to_string())).into()),
140        ));
141
142        let storage = if nullable {
143            PrimitiveArray::from_option_iter([
144                Some(1_000_000i64), // 1 second in microseconds
145                None,
146                Some(2_000_000),
147                Some(3_000_000),
148                None,
149            ])
150            .into_array()
151        } else {
152            buffer![1_000_000i64, 2_000_000, 3_000_000, 4_000_000, 5_000_000].into_array()
153        };
154
155        ExtensionArray::new(ext_dtype, storage)
156    }
157}