vortex_array/arrays/extension/compute/
cast.rs1use crate::ArrayRef;
5use crate::IntoArray;
6use crate::arrays::ExtensionArray;
7use crate::arrays::ExtensionVTable;
8use crate::builtins::ArrayBuiltins;
9use crate::dtype::DType;
10use crate::scalar_fn::fns::cast::CastReduce;
11
12impl CastReduce for ExtensionVTable {
13 fn cast(array: &ExtensionArray, dtype: &DType) -> vortex_error::VortexResult<Option<ArrayRef>> {
14 if !array.dtype().eq_ignore_nullability(dtype) {
15 return Ok(None);
16 }
17
18 let DType::Extension(ext_dtype) = dtype else {
19 unreachable!("Already verified we have an extension dtype");
20 };
21
22 let new_storage = match array
23 .storage()
24 .cast(ext_dtype.storage_dtype().clone())
25 .and_then(|a| a.to_canonical().map(|c| c.into_array()))
26 {
27 Ok(arr) => arr,
28 Err(e) => {
29 tracing::warn!("Failed to cast storage array: {e}");
30 return Ok(None);
31 }
32 };
33
34 Ok(Some(
35 ExtensionArray::new(ext_dtype.clone(), new_storage).into_array(),
36 ))
37 }
38}
39
40#[cfg(test)]
41mod tests {
42
43 use rstest::rstest;
44 use vortex_buffer::Buffer;
45 use vortex_buffer::buffer;
46
47 use super::*;
48 use crate::IntoArray;
49 use crate::arrays::PrimitiveArray;
50 use crate::builtins::ArrayBuiltins;
51 use crate::compute::conformance::cast::test_cast_conformance;
52 use crate::dtype::Nullability;
53 use crate::extension::datetime::TimeUnit;
54 use crate::extension::datetime::Timestamp;
55
56 #[test]
57 fn cast_same_ext_dtype() {
58 let ext_dtype = Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
59 let storage = Buffer::<i64>::empty().into_array();
60
61 let arr = ExtensionArray::new(ext_dtype.clone(), storage);
62
63 let output = arr
64 .to_array()
65 .cast(DType::Extension(ext_dtype.clone()))
66 .unwrap();
67 assert_eq!(arr.len(), output.len());
68 assert_eq!(arr.dtype(), output.dtype());
69 assert_eq!(output.dtype(), &DType::Extension(ext_dtype));
70 }
71
72 #[test]
73 fn cast_same_ext_dtype_differet_nullability() {
74 let ext_dtype = Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
75 let storage = Buffer::<i64>::empty().into_array();
76
77 let arr = ExtensionArray::new(ext_dtype.clone(), storage);
78 assert!(!arr.dtype.is_nullable());
79
80 let new_dtype = DType::Extension(ext_dtype).with_nullability(Nullability::Nullable);
81
82 let output = arr.to_array().cast(new_dtype.clone()).unwrap();
83 assert_eq!(arr.len(), output.len());
84 assert!(arr.dtype().eq_ignore_nullability(output.dtype()));
85 assert_eq!(output.dtype(), &new_dtype);
86 }
87
88 #[test]
89 fn cast_different_ext_dtype() {
90 let original_dtype =
91 Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
92 let target_dtype = Timestamp::new(TimeUnit::Nanoseconds, Nullability::NonNullable).erased();
94
95 let storage = buffer![1i64].into_array();
96 let arr = ExtensionArray::new(original_dtype, storage);
97
98 assert!(
99 arr.to_array()
100 .cast(DType::Extension(target_dtype))
101 .and_then(|a| a.to_canonical().map(|c| c.into_array()))
102 .is_err()
103 );
104 }
105
106 #[rstest]
107 #[case(create_timestamp_array(TimeUnit::Milliseconds, false))]
108 #[case(create_timestamp_array(TimeUnit::Microseconds, true))]
109 #[case(create_timestamp_array(TimeUnit::Nanoseconds, false))]
110 #[case(create_timestamp_array(TimeUnit::Seconds, true))]
111 fn test_cast_extension_conformance(#[case] array: ExtensionArray) {
112 test_cast_conformance(array.as_ref());
113 }
114
115 fn create_timestamp_array(time_unit: TimeUnit, nullable: bool) -> ExtensionArray {
116 let ext_dtype =
117 Timestamp::new_with_tz(time_unit, Some("UTC".into()), nullable.into()).erased();
118
119 let storage = if nullable {
120 PrimitiveArray::from_option_iter([
121 Some(1_000_000i64), None,
123 Some(2_000_000),
124 Some(3_000_000),
125 None,
126 ])
127 .into_array()
128 } else {
129 buffer![1_000_000i64, 2_000_000, 3_000_000, 4_000_000, 5_000_000].into_array()
130 };
131
132 ExtensionArray::new(ext_dtype, storage)
133 }
134}