vortex_array/arrays/extension/compute/
cast.rs1use vortex_dtype::DType;
5
6use crate::ArrayRef;
7use crate::IntoArray;
8use crate::arrays::ExtensionArray;
9use crate::arrays::ExtensionVTable;
10use crate::compute::CastKernel;
11use crate::compute::CastKernelAdapter;
12use crate::compute::{self};
13use crate::register_kernel;
14
15impl CastKernel for ExtensionVTable {
16 fn cast(
17 &self,
18 array: &ExtensionArray,
19 dtype: &DType,
20 ) -> vortex_error::VortexResult<Option<ArrayRef>> {
21 if !array.dtype().eq_ignore_nullability(dtype) {
22 return Ok(None);
23 }
24
25 let DType::Extension(ext_dtype) = dtype else {
26 unreachable!("Already verified we have an extension dtype");
27 };
28
29 let new_storage = match compute::cast(array.storage(), ext_dtype.storage_dtype()) {
30 Ok(arr) => arr,
31 Err(e) => {
32 log::warn!("Failed to cast storage array: {e}");
33 return Ok(None);
34 }
35 };
36
37 Ok(Some(
38 ExtensionArray::new(ext_dtype.clone(), new_storage).into_array(),
39 ))
40 }
41}
42
43register_kernel!(CastKernelAdapter(ExtensionVTable).lift());
44
45#[cfg(test)]
46mod tests {
47 use std::sync::Arc;
48
49 use rstest::rstest;
50 use vortex_buffer::Buffer;
51 use vortex_buffer::buffer;
52 use vortex_dtype::ExtDType;
53 use vortex_dtype::Nullability;
54 use vortex_dtype::PType;
55 use vortex_dtype::datetime::TIMESTAMP_ID;
56 use vortex_dtype::datetime::TemporalMetadata;
57 use vortex_dtype::datetime::TimeUnit;
58
59 use super::*;
60 use crate::IntoArray;
61 use crate::arrays::PrimitiveArray;
62 use crate::compute::cast;
63 use crate::compute::conformance::cast::test_cast_conformance;
64
65 #[test]
66 fn cast_same_ext_dtype() {
67 let ext_dtype = Arc::new(ExtDType::new(
68 TIMESTAMP_ID.clone(),
69 Arc::new(PType::I64.into()),
70 Some(TemporalMetadata::Timestamp(TimeUnit::Milliseconds, None).into()),
71 ));
72 let storage = Buffer::<i64>::empty().into_array();
73
74 let arr = ExtensionArray::new(ext_dtype.clone(), storage);
75
76 let output = cast(arr.as_ref(), &DType::Extension(ext_dtype.clone())).unwrap();
77 assert_eq!(arr.len(), output.len());
78 assert_eq!(arr.dtype(), output.dtype());
79 assert_eq!(output.dtype(), &DType::Extension(ext_dtype));
80 }
81
82 #[test]
83 fn cast_same_ext_dtype_differet_nullability() {
84 let ext_dtype = Arc::new(ExtDType::new(
85 TIMESTAMP_ID.clone(),
86 Arc::new(PType::I64.into()),
87 Some(TemporalMetadata::Timestamp(TimeUnit::Milliseconds, None).into()),
88 ));
89 let storage = Buffer::<i64>::empty().into_array();
90
91 let arr = ExtensionArray::new(ext_dtype.clone(), storage);
92 assert!(!arr.dtype.is_nullable());
93
94 let new_dtype = DType::Extension(ext_dtype).with_nullability(Nullability::Nullable);
95
96 let output = cast(arr.as_ref(), &new_dtype).unwrap();
97 assert_eq!(arr.len(), output.len());
98 assert!(arr.dtype().eq_ignore_nullability(output.dtype()));
99 assert_eq!(output.dtype(), &new_dtype);
100 }
101
102 #[test]
103 fn cast_different_ext_dtype() {
104 let original_dtype = Arc::new(ExtDType::new(
105 TIMESTAMP_ID.clone(),
106 Arc::new(PType::I64.into()),
107 Some(TemporalMetadata::Timestamp(TimeUnit::Milliseconds, None).into()),
108 ));
109 let target_dtype = Arc::new(ExtDType::new(
110 TIMESTAMP_ID.clone(),
111 Arc::new(PType::I64.into()),
112 Some(TemporalMetadata::Timestamp(TimeUnit::Nanoseconds, None).into()),
114 ));
115
116 let storage = buffer![1i64].into_array();
117 let arr = ExtensionArray::new(original_dtype, storage);
118
119 assert!(cast(arr.as_ref(), &DType::Extension(target_dtype)).is_err());
120 }
121
122 #[rstest]
123 #[case(create_timestamp_array(TimeUnit::Milliseconds, false))]
124 #[case(create_timestamp_array(TimeUnit::Microseconds, true))]
125 #[case(create_timestamp_array(TimeUnit::Nanoseconds, false))]
126 #[case(create_timestamp_array(TimeUnit::Seconds, true))]
127 fn test_cast_extension_conformance(#[case] array: ExtensionArray) {
128 test_cast_conformance(array.as_ref());
129 }
130
131 fn create_timestamp_array(time_unit: TimeUnit, nullable: bool) -> ExtensionArray {
132 let ext_dtype = Arc::new(ExtDType::new(
133 TIMESTAMP_ID.clone(),
134 Arc::new(if nullable {
135 DType::Primitive(PType::I64, Nullability::Nullable)
136 } else {
137 DType::Primitive(PType::I64, Nullability::NonNullable)
138 }),
139 Some(TemporalMetadata::Timestamp(time_unit, Some("UTC".to_string())).into()),
140 ));
141
142 let storage = if nullable {
143 PrimitiveArray::from_option_iter([
144 Some(1_000_000i64), None,
146 Some(2_000_000),
147 Some(3_000_000),
148 None,
149 ])
150 .into_array()
151 } else {
152 buffer![1_000_000i64, 2_000_000, 3_000_000, 4_000_000, 5_000_000].into_array()
153 };
154
155 ExtensionArray::new(ext_dtype, storage)
156 }
157}