vortex_array/arrays/extension/compute/
cast.rs1use vortex_dtype::DType;
5
6use crate::ArrayRef;
7use crate::IntoArray;
8use crate::arrays::ExtensionArray;
9use crate::arrays::ExtensionVTable;
10use crate::builtins::ArrayBuiltins;
11use crate::compute::CastReduce;
12
13impl CastReduce for ExtensionVTable {
14 fn cast(array: &ExtensionArray, dtype: &DType) -> vortex_error::VortexResult<Option<ArrayRef>> {
15 if !array.dtype().eq_ignore_nullability(dtype) {
16 return Ok(None);
17 }
18
19 let DType::Extension(ext_dtype) = dtype else {
20 unreachable!("Already verified we have an extension dtype");
21 };
22
23 let new_storage = match array
24 .storage()
25 .cast(ext_dtype.storage_dtype().clone())
26 .and_then(|a| a.to_canonical().map(|c| c.into_array()))
27 {
28 Ok(arr) => arr,
29 Err(e) => {
30 tracing::warn!("Failed to cast storage array: {e}");
31 return Ok(None);
32 }
33 };
34
35 Ok(Some(
36 ExtensionArray::new(ext_dtype.clone(), new_storage).into_array(),
37 ))
38 }
39}
40
41#[cfg(test)]
42mod tests {
43
44 use rstest::rstest;
45 use vortex_buffer::Buffer;
46 use vortex_buffer::buffer;
47 use vortex_dtype::Nullability;
48 use vortex_dtype::datetime::TimeUnit;
49 use vortex_dtype::datetime::Timestamp;
50
51 use super::*;
52 use crate::IntoArray;
53 use crate::arrays::PrimitiveArray;
54 use crate::builtins::ArrayBuiltins;
55 use crate::compute::conformance::cast::test_cast_conformance;
56
57 #[test]
58 fn cast_same_ext_dtype() {
59 let ext_dtype = Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
60 let storage = Buffer::<i64>::empty().into_array();
61
62 let arr = ExtensionArray::new(ext_dtype.clone(), storage);
63
64 let output = arr
65 .to_array()
66 .cast(DType::Extension(ext_dtype.clone()))
67 .unwrap();
68 assert_eq!(arr.len(), output.len());
69 assert_eq!(arr.dtype(), output.dtype());
70 assert_eq!(output.dtype(), &DType::Extension(ext_dtype));
71 }
72
73 #[test]
74 fn cast_same_ext_dtype_differet_nullability() {
75 let ext_dtype = Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
76 let storage = Buffer::<i64>::empty().into_array();
77
78 let arr = ExtensionArray::new(ext_dtype.clone(), storage);
79 assert!(!arr.dtype.is_nullable());
80
81 let new_dtype = DType::Extension(ext_dtype).with_nullability(Nullability::Nullable);
82
83 let output = arr.to_array().cast(new_dtype.clone()).unwrap();
84 assert_eq!(arr.len(), output.len());
85 assert!(arr.dtype().eq_ignore_nullability(output.dtype()));
86 assert_eq!(output.dtype(), &new_dtype);
87 }
88
89 #[test]
90 fn cast_different_ext_dtype() {
91 let original_dtype =
92 Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
93 let target_dtype = Timestamp::new(TimeUnit::Nanoseconds, Nullability::NonNullable).erased();
95
96 let storage = buffer![1i64].into_array();
97 let arr = ExtensionArray::new(original_dtype, storage);
98
99 assert!(
100 arr.to_array()
101 .cast(DType::Extension(target_dtype))
102 .and_then(|a| a.to_canonical().map(|c| c.into_array()))
103 .is_err()
104 );
105 }
106
107 #[rstest]
108 #[case(create_timestamp_array(TimeUnit::Milliseconds, false))]
109 #[case(create_timestamp_array(TimeUnit::Microseconds, true))]
110 #[case(create_timestamp_array(TimeUnit::Nanoseconds, false))]
111 #[case(create_timestamp_array(TimeUnit::Seconds, true))]
112 fn test_cast_extension_conformance(#[case] array: ExtensionArray) {
113 test_cast_conformance(array.as_ref());
114 }
115
116 fn create_timestamp_array(time_unit: TimeUnit, nullable: bool) -> ExtensionArray {
117 let ext_dtype =
118 Timestamp::new_with_tz(time_unit, Some("UTC".into()), nullable.into()).erased();
119
120 let storage = if nullable {
121 PrimitiveArray::from_option_iter([
122 Some(1_000_000i64), None,
124 Some(2_000_000),
125 Some(3_000_000),
126 None,
127 ])
128 .into_array()
129 } else {
130 buffer![1_000_000i64, 2_000_000, 3_000_000, 4_000_000, 5_000_000].into_array()
131 };
132
133 ExtensionArray::new(ext_dtype, storage)
134 }
135}