vortex_array/arrays/extension/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::DType;
5
6use crate::arrays::{ExtensionArray, ExtensionVTable};
7use crate::compute::{self, CastKernel, CastKernelAdapter};
8use crate::{ArrayRef, IntoArray, register_kernel};
9
10impl CastKernel for ExtensionVTable {
11    fn cast(
12        &self,
13        array: &ExtensionArray,
14        dtype: &DType,
15    ) -> vortex_error::VortexResult<Option<ArrayRef>> {
16        if !array.dtype().eq_ignore_nullability(dtype) {
17            return Ok(None);
18        }
19
20        let DType::Extension(ext_dtype) = dtype else {
21            unreachable!("Already verified we have an extension dtype");
22        };
23
24        let new_storage = match compute::cast(array.storage(), ext_dtype.storage_dtype()) {
25            Ok(arr) => arr,
26            Err(e) => {
27                log::warn!("Failed to cast storage array: {e}");
28                return Ok(None);
29            }
30        };
31
32        Ok(Some(
33            ExtensionArray::new(ext_dtype.clone(), new_storage).into_array(),
34        ))
35    }
36}
37
38register_kernel!(CastKernelAdapter(ExtensionVTable).lift());
39
40#[cfg(test)]
41mod tests {
42    use std::sync::Arc;
43
44    use rstest::rstest;
45    use vortex_buffer::{Buffer, buffer};
46    use vortex_dtype::datetime::{TIMESTAMP_ID, TemporalMetadata, TimeUnit};
47    use vortex_dtype::{ExtDType, Nullability, PType};
48
49    use super::*;
50    use crate::IntoArray;
51    use crate::arrays::PrimitiveArray;
52    use crate::compute::cast;
53    use crate::compute::conformance::cast::test_cast_conformance;
54
55    #[test]
56    fn cast_same_ext_dtype() {
57        let ext_dtype = Arc::new(ExtDType::new(
58            TIMESTAMP_ID.clone(),
59            Arc::new(PType::I64.into()),
60            Some(TemporalMetadata::Timestamp(TimeUnit::Milliseconds, None).into()),
61        ));
62        let storage = Buffer::<i64>::empty().into_array();
63
64        let arr = ExtensionArray::new(ext_dtype.clone(), storage);
65
66        let output = cast(arr.as_ref(), &DType::Extension(ext_dtype.clone())).unwrap();
67        assert_eq!(arr.len(), output.len());
68        assert_eq!(arr.dtype(), output.dtype());
69        assert_eq!(output.dtype(), &DType::Extension(ext_dtype));
70    }
71
72    #[test]
73    fn cast_same_ext_dtype_differet_nullability() {
74        let ext_dtype = Arc::new(ExtDType::new(
75            TIMESTAMP_ID.clone(),
76            Arc::new(PType::I64.into()),
77            Some(TemporalMetadata::Timestamp(TimeUnit::Milliseconds, None).into()),
78        ));
79        let storage = Buffer::<i64>::empty().into_array();
80
81        let arr = ExtensionArray::new(ext_dtype.clone(), storage);
82        assert!(!arr.dtype.is_nullable());
83
84        let new_dtype = DType::Extension(ext_dtype).with_nullability(Nullability::Nullable);
85
86        let output = cast(arr.as_ref(), &new_dtype).unwrap();
87        assert_eq!(arr.len(), output.len());
88        assert!(arr.dtype().eq_ignore_nullability(output.dtype()));
89        assert_eq!(output.dtype(), &new_dtype);
90    }
91
92    #[test]
93    fn cast_different_ext_dtype() {
94        let original_dtype = Arc::new(ExtDType::new(
95            TIMESTAMP_ID.clone(),
96            Arc::new(PType::I64.into()),
97            Some(TemporalMetadata::Timestamp(TimeUnit::Milliseconds, None).into()),
98        ));
99        let target_dtype = Arc::new(ExtDType::new(
100            TIMESTAMP_ID.clone(),
101            Arc::new(PType::I64.into()),
102            // Note NS here instead of MS
103            Some(TemporalMetadata::Timestamp(TimeUnit::Nanoseconds, None).into()),
104        ));
105
106        let storage = buffer![1i64].into_array();
107        let arr = ExtensionArray::new(original_dtype, storage);
108
109        assert!(cast(arr.as_ref(), &DType::Extension(target_dtype)).is_err());
110    }
111
112    #[rstest]
113    #[case(create_timestamp_array(TimeUnit::Milliseconds, false))]
114    #[case(create_timestamp_array(TimeUnit::Microseconds, true))]
115    #[case(create_timestamp_array(TimeUnit::Nanoseconds, false))]
116    #[case(create_timestamp_array(TimeUnit::Seconds, true))]
117    fn test_cast_extension_conformance(#[case] array: ExtensionArray) {
118        test_cast_conformance(array.as_ref());
119    }
120
121    fn create_timestamp_array(time_unit: TimeUnit, nullable: bool) -> ExtensionArray {
122        let ext_dtype = Arc::new(ExtDType::new(
123            TIMESTAMP_ID.clone(),
124            Arc::new(if nullable {
125                DType::Primitive(PType::I64, Nullability::Nullable)
126            } else {
127                DType::Primitive(PType::I64, Nullability::NonNullable)
128            }),
129            Some(TemporalMetadata::Timestamp(time_unit, Some("UTC".to_string())).into()),
130        ));
131
132        let storage = if nullable {
133            PrimitiveArray::from_option_iter([
134                Some(1_000_000i64), // 1 second in microseconds
135                None,
136                Some(2_000_000),
137                Some(3_000_000),
138                None,
139            ])
140            .into_array()
141        } else {
142            buffer![1_000_000i64, 2_000_000, 3_000_000, 4_000_000, 5_000_000].into_array()
143        };
144
145        ExtensionArray::new(ext_dtype, storage)
146    }
147}