vortex_array/arrays/extension/compute/
cast.rs1use vortex_error::VortexResult;
5
6use crate::ArrayRef;
7use crate::IntoArray;
8use crate::array::ArrayView;
9use crate::arrays::Extension;
10use crate::arrays::ExtensionArray;
11use crate::arrays::extension::ExtensionArrayExt;
12use crate::builtins::ArrayBuiltins;
13use crate::dtype::DType;
14use crate::scalar_fn::fns::cast::CastReduce;
15
16impl CastReduce for Extension {
17 fn cast(array: ArrayView<'_, Extension>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
18 if !array.dtype().eq_ignore_nullability(dtype) {
19 return Ok(Some(array.storage_array().cast(dtype.clone())?));
22 }
23
24 let DType::Extension(ext_dtype) = dtype else {
25 unreachable!("Already verified we have an extension dtype");
26 };
27
28 let new_storage = match array
29 .storage_array()
30 .cast(ext_dtype.storage_dtype().clone())
31 {
32 Ok(arr) => arr,
33 Err(e) => {
34 tracing::warn!("Failed to cast storage array: {e}");
35 return Ok(None);
36 }
37 };
38
39 Ok(Some(
40 ExtensionArray::new(ext_dtype.clone(), new_storage).into_array(),
41 ))
42 }
43}
44
45#[cfg(test)]
46mod tests {
47 use std::sync::LazyLock;
48
49 use rstest::rstest;
50 use vortex_buffer::Buffer;
51 use vortex_buffer::buffer;
52 use vortex_session::VortexSession;
53
54 use super::*;
55 use crate::IntoArray;
56 use crate::arrays::PrimitiveArray;
57 use crate::assert_arrays_eq;
58 use crate::builtins::ArrayBuiltins;
59 use crate::compute::conformance::cast::test_cast_conformance;
60 use crate::dtype::DType;
61 use crate::dtype::Nullability;
62 use crate::dtype::PType;
63 use crate::executor::VortexSessionExecute;
64 use crate::extension::datetime::TimeUnit;
65 use crate::extension::datetime::Timestamp;
66
67 static SESSION: LazyLock<VortexSession> = LazyLock::new(crate::array_session);
68
69 #[test]
70 fn cast_same_ext_dtype() {
71 let ext_dtype = Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
72 let storage = Buffer::<i64>::empty().into_array();
73
74 let arr = ExtensionArray::new(ext_dtype.clone(), storage);
75
76 let output = arr
77 .clone()
78 .into_array()
79 .cast(DType::Extension(ext_dtype.clone()))
80 .unwrap();
81 assert_eq!(arr.len(), output.len());
82 assert_eq!(arr.dtype(), output.dtype());
83 assert_eq!(output.dtype(), &DType::Extension(ext_dtype));
84 }
85
86 #[test]
87 fn cast_same_ext_dtype_differet_nullability() {
88 let ext_dtype = Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
89 let storage = Buffer::<i64>::empty().into_array();
90
91 let arr = ExtensionArray::new(ext_dtype.clone(), storage);
92 assert!(!arr.dtype().is_nullable());
93
94 let new_dtype = DType::Extension(ext_dtype).with_nullability(Nullability::Nullable);
95
96 let output = arr.clone().into_array().cast(new_dtype.clone()).unwrap();
97 assert_eq!(arr.len(), output.len());
98 assert!(arr.dtype().eq_ignore_nullability(output.dtype()));
99 assert_eq!(output.dtype(), &new_dtype);
100 }
101
102 #[test]
103 fn cast_different_ext_dtype() {
104 let original_dtype =
105 Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
106 let target_dtype = Timestamp::new(TimeUnit::Nanoseconds, Nullability::NonNullable).erased();
108
109 let storage = buffer![1i64].into_array();
110 let arr = ExtensionArray::new(original_dtype, storage);
111
112 let result = arr
113 .into_array()
114 .cast(DType::Extension(target_dtype))
115 .and_then(|a| {
116 a.execute::<ExtensionArray>(&mut SESSION.create_execution_ctx())
117 .map(|c| c.into_array())
118 });
119 assert!(result.is_err());
120 }
121
122 #[test]
123 fn cast_timestamp_to_i64() -> VortexResult<()> {
124 let mut ctx = SESSION.create_execution_ctx();
125 let ext_dtype = Timestamp::new_with_tz(
126 TimeUnit::Nanoseconds,
127 Some("UTC".into()),
128 Nullability::NonNullable,
129 )
130 .erased();
131 let storage = buffer![1i64, 2, 3].into_array();
132 let arr = ExtensionArray::new(ext_dtype, storage).into_array();
133
134 let result = arr.cast(DType::Primitive(PType::I64, Nullability::NonNullable))?;
135 assert_eq!(
136 result.dtype(),
137 &DType::Primitive(PType::I64, Nullability::NonNullable)
138 );
139 assert_arrays_eq!(result, buffer![1i64, 2, 3].into_array(), &mut ctx);
140 Ok(())
141 }
142
143 #[rstest]
144 #[case(create_timestamp_array(TimeUnit::Milliseconds, false))]
145 #[case(create_timestamp_array(TimeUnit::Microseconds, true))]
146 #[case(create_timestamp_array(TimeUnit::Nanoseconds, false))]
147 #[case(create_timestamp_array(TimeUnit::Seconds, true))]
148 fn test_cast_extension_conformance(#[case] array: ExtensionArray) {
149 test_cast_conformance(&array.into_array());
150 }
151
152 fn create_timestamp_array(time_unit: TimeUnit, nullable: bool) -> ExtensionArray {
153 let ext_dtype =
154 Timestamp::new_with_tz(time_unit, Some("UTC".into()), nullable.into()).erased();
155
156 let storage = if nullable {
157 PrimitiveArray::from_option_iter([
158 Some(1_000_000i64), None,
160 Some(2_000_000),
161 Some(3_000_000),
162 None,
163 ])
164 .into_array()
165 } else {
166 buffer![1_000_000i64, 2_000_000, 3_000_000, 4_000_000, 5_000_000].into_array()
167 };
168
169 ExtensionArray::new(ext_dtype, storage)
170 }
171}