vortex_array/arrays/extension/compute/
cast.rs1use vortex_dtype::DType;
5
6use crate::arrays::{ExtensionArray, ExtensionVTable};
7use crate::compute::{CastKernel, CastKernelAdapter, cast};
8use crate::{ArrayRef, IntoArray, register_kernel};
9
10impl CastKernel for ExtensionVTable {
11 fn cast(
12 &self,
13 array: &ExtensionArray,
14 dtype: &DType,
15 ) -> vortex_error::VortexResult<Option<ArrayRef>> {
16 if !array.dtype().eq_ignore_nullability(dtype) {
17 return Ok(None);
18 }
19
20 let DType::Extension(ext_dtype) = dtype else {
21 unreachable!("Already verified we have an extension dtype");
22 };
23
24 let new_storage = match cast(array.storage(), ext_dtype.storage_dtype()) {
25 Ok(arr) => arr,
26 Err(e) => {
27 log::warn!("Failed to cast storage array: {e}");
28 return Ok(None);
29 }
30 };
31
32 Ok(Some(
33 ExtensionArray::new(ext_dtype.clone(), new_storage).into_array(),
34 ))
35 }
36}
37
38register_kernel!(CastKernelAdapter(ExtensionVTable).lift());
39
40#[cfg(test)]
41mod tests {
42 use std::sync::Arc;
43
44 use rstest::rstest;
45 use vortex_buffer::buffer;
46 use vortex_dtype::datetime::{TIMESTAMP_ID, TemporalMetadata, TimeUnit};
47 use vortex_dtype::{ExtDType, Nullability, PType};
48
49 use super::*;
50 use crate::IntoArray;
51 use crate::arrays::PrimitiveArray;
52 use crate::compute::conformance::cast::test_cast_conformance;
53
54 #[test]
55 fn cast_same_ext_dtype() {
56 let ext_dtype = Arc::new(ExtDType::new(
57 TIMESTAMP_ID.clone(),
58 Arc::new(PType::I64.into()),
59 Some(TemporalMetadata::Timestamp(TimeUnit::Milliseconds, None).into()),
60 ));
61 let storage = PrimitiveArray::from_iter(Vec::<i64>::new()).into_array();
62
63 let arr = ExtensionArray::new(ext_dtype.clone(), storage);
64
65 let output = cast(arr.as_ref(), &DType::Extension(ext_dtype.clone())).unwrap();
66 assert_eq!(arr.len(), output.len());
67 assert_eq!(arr.dtype(), output.dtype());
68 assert_eq!(output.dtype(), &DType::Extension(ext_dtype));
69 }
70
71 #[test]
72 fn cast_same_ext_dtype_differet_nullability() {
73 let ext_dtype = Arc::new(ExtDType::new(
74 TIMESTAMP_ID.clone(),
75 Arc::new(PType::I64.into()),
76 Some(TemporalMetadata::Timestamp(TimeUnit::Milliseconds, None).into()),
77 ));
78 let storage = PrimitiveArray::from_iter(Vec::<i64>::new()).into_array();
79
80 let arr = ExtensionArray::new(ext_dtype.clone(), storage);
81 assert!(!arr.dtype.is_nullable());
82
83 let new_dtype = DType::Extension(ext_dtype).with_nullability(Nullability::Nullable);
84
85 let output = cast(arr.as_ref(), &new_dtype).unwrap();
86 assert_eq!(arr.len(), output.len());
87 assert!(arr.dtype().eq_ignore_nullability(output.dtype()));
88 assert_eq!(output.dtype(), &new_dtype);
89 }
90
91 #[test]
92 fn cast_different_ext_dtype() {
93 let original_dtype = Arc::new(ExtDType::new(
94 TIMESTAMP_ID.clone(),
95 Arc::new(PType::I64.into()),
96 Some(TemporalMetadata::Timestamp(TimeUnit::Milliseconds, None).into()),
97 ));
98 let target_dtype = Arc::new(ExtDType::new(
99 TIMESTAMP_ID.clone(),
100 Arc::new(PType::I64.into()),
101 Some(TemporalMetadata::Timestamp(TimeUnit::Nanoseconds, None).into()),
103 ));
104
105 let storage = PrimitiveArray::from_iter(Vec::<i64>::new()).into_array();
106 let arr = ExtensionArray::new(original_dtype, storage);
107
108 assert!(cast(arr.as_ref(), &DType::Extension(target_dtype)).is_err());
109 }
110
111 #[rstest]
112 #[case(create_timestamp_array(TimeUnit::Milliseconds, false))]
113 #[case(create_timestamp_array(TimeUnit::Microseconds, true))]
114 #[case(create_timestamp_array(TimeUnit::Nanoseconds, false))]
115 #[case(create_timestamp_array(TimeUnit::Seconds, true))]
116 fn test_cast_extension_conformance(#[case] array: ExtensionArray) {
117 test_cast_conformance(array.as_ref());
118 }
119
120 fn create_timestamp_array(time_unit: TimeUnit, nullable: bool) -> ExtensionArray {
121 let ext_dtype = Arc::new(ExtDType::new(
122 TIMESTAMP_ID.clone(),
123 Arc::new(if nullable {
124 DType::Primitive(PType::I64, Nullability::Nullable)
125 } else {
126 DType::Primitive(PType::I64, Nullability::NonNullable)
127 }),
128 Some(TemporalMetadata::Timestamp(time_unit, Some("UTC".to_string())).into()),
129 ));
130
131 let storage = if nullable {
132 PrimitiveArray::from_option_iter([
133 Some(1_000_000i64), None,
135 Some(2_000_000),
136 Some(3_000_000),
137 None,
138 ])
139 .into_array()
140 } else {
141 buffer![1_000_000i64, 2_000_000, 3_000_000, 4_000_000, 5_000_000].into_array()
142 };
143
144 ExtensionArray::new(ext_dtype, storage)
145 }
146}