vortex_array/arrays/extension/compute/
cast.rs1use vortex_dtype::DType;
5
6use crate::arrays::{ExtensionArray, ExtensionVTable};
7use crate::compute::{self, CastKernel, CastKernelAdapter};
8use crate::{ArrayRef, IntoArray, register_kernel};
9
10impl CastKernel for ExtensionVTable {
11 fn cast(
12 &self,
13 array: &ExtensionArray,
14 dtype: &DType,
15 ) -> vortex_error::VortexResult<Option<ArrayRef>> {
16 if !array.dtype().eq_ignore_nullability(dtype) {
17 return Ok(None);
18 }
19
20 let DType::Extension(ext_dtype) = dtype else {
21 unreachable!("Already verified we have an extension dtype");
22 };
23
24 let new_storage = match compute::cast(array.storage(), ext_dtype.storage_dtype()) {
25 Ok(arr) => arr,
26 Err(e) => {
27 log::warn!("Failed to cast storage array: {e}");
28 return Ok(None);
29 }
30 };
31
32 Ok(Some(
33 ExtensionArray::new(ext_dtype.clone(), new_storage).into_array(),
34 ))
35 }
36}
37
38register_kernel!(CastKernelAdapter(ExtensionVTable).lift());
39
40#[cfg(test)]
41mod tests {
42 use std::sync::Arc;
43
44 use rstest::rstest;
45 use vortex_buffer::{Buffer, buffer};
46 use vortex_dtype::datetime::{TIMESTAMP_ID, TemporalMetadata, TimeUnit};
47 use vortex_dtype::{ExtDType, Nullability, PType};
48
49 use super::*;
50 use crate::IntoArray;
51 use crate::arrays::PrimitiveArray;
52 use crate::compute::cast;
53 use crate::compute::conformance::cast::test_cast_conformance;
54
55 #[test]
56 fn cast_same_ext_dtype() {
57 let ext_dtype = Arc::new(ExtDType::new(
58 TIMESTAMP_ID.clone(),
59 Arc::new(PType::I64.into()),
60 Some(TemporalMetadata::Timestamp(TimeUnit::Milliseconds, None).into()),
61 ));
62 let storage = Buffer::<i64>::empty().into_array();
63
64 let arr = ExtensionArray::new(ext_dtype.clone(), storage);
65
66 let output = cast(arr.as_ref(), &DType::Extension(ext_dtype.clone())).unwrap();
67 assert_eq!(arr.len(), output.len());
68 assert_eq!(arr.dtype(), output.dtype());
69 assert_eq!(output.dtype(), &DType::Extension(ext_dtype));
70 }
71
72 #[test]
73 fn cast_same_ext_dtype_differet_nullability() {
74 let ext_dtype = Arc::new(ExtDType::new(
75 TIMESTAMP_ID.clone(),
76 Arc::new(PType::I64.into()),
77 Some(TemporalMetadata::Timestamp(TimeUnit::Milliseconds, None).into()),
78 ));
79 let storage = Buffer::<i64>::empty().into_array();
80
81 let arr = ExtensionArray::new(ext_dtype.clone(), storage);
82 assert!(!arr.dtype.is_nullable());
83
84 let new_dtype = DType::Extension(ext_dtype).with_nullability(Nullability::Nullable);
85
86 let output = cast(arr.as_ref(), &new_dtype).unwrap();
87 assert_eq!(arr.len(), output.len());
88 assert!(arr.dtype().eq_ignore_nullability(output.dtype()));
89 assert_eq!(output.dtype(), &new_dtype);
90 }
91
92 #[test]
93 fn cast_different_ext_dtype() {
94 let original_dtype = Arc::new(ExtDType::new(
95 TIMESTAMP_ID.clone(),
96 Arc::new(PType::I64.into()),
97 Some(TemporalMetadata::Timestamp(TimeUnit::Milliseconds, None).into()),
98 ));
99 let target_dtype = Arc::new(ExtDType::new(
100 TIMESTAMP_ID.clone(),
101 Arc::new(PType::I64.into()),
102 Some(TemporalMetadata::Timestamp(TimeUnit::Nanoseconds, None).into()),
104 ));
105
106 let storage = buffer![1i64].into_array();
107 let arr = ExtensionArray::new(original_dtype, storage);
108
109 assert!(cast(arr.as_ref(), &DType::Extension(target_dtype)).is_err());
110 }
111
112 #[rstest]
113 #[case(create_timestamp_array(TimeUnit::Milliseconds, false))]
114 #[case(create_timestamp_array(TimeUnit::Microseconds, true))]
115 #[case(create_timestamp_array(TimeUnit::Nanoseconds, false))]
116 #[case(create_timestamp_array(TimeUnit::Seconds, true))]
117 fn test_cast_extension_conformance(#[case] array: ExtensionArray) {
118 test_cast_conformance(array.as_ref());
119 }
120
121 fn create_timestamp_array(time_unit: TimeUnit, nullable: bool) -> ExtensionArray {
122 let ext_dtype = Arc::new(ExtDType::new(
123 TIMESTAMP_ID.clone(),
124 Arc::new(if nullable {
125 DType::Primitive(PType::I64, Nullability::Nullable)
126 } else {
127 DType::Primitive(PType::I64, Nullability::NonNullable)
128 }),
129 Some(TemporalMetadata::Timestamp(time_unit, Some("UTC".to_string())).into()),
130 ));
131
132 let storage = if nullable {
133 PrimitiveArray::from_option_iter([
134 Some(1_000_000i64), None,
136 Some(2_000_000),
137 Some(3_000_000),
138 None,
139 ])
140 .into_array()
141 } else {
142 buffer![1_000_000i64, 2_000_000, 3_000_000, 4_000_000, 5_000_000].into_array()
143 };
144
145 ExtensionArray::new(ext_dtype, storage)
146 }
147}