vortex_array/arrays/extension/compute/
cast.rs1use crate::ArrayRef;
5use crate::IntoArray;
6use crate::array::ArrayView;
7use crate::arrays::Extension;
8use crate::arrays::ExtensionArray;
9use crate::arrays::extension::ExtensionArrayExt;
10use crate::builtins::ArrayBuiltins;
11use crate::dtype::DType;
12use crate::scalar_fn::fns::cast::CastReduce;
13
14impl CastReduce for Extension {
15 fn cast(
16 array: ArrayView<'_, Extension>,
17 dtype: &DType,
18 ) -> vortex_error::VortexResult<Option<ArrayRef>> {
19 if !array.dtype().eq_ignore_nullability(dtype) {
20 return Ok(Some(array.storage_array().cast(dtype.clone())?));
23 }
24
25 let DType::Extension(ext_dtype) = dtype else {
26 unreachable!("Already verified we have an extension dtype");
27 };
28
29 let new_storage = match array
30 .storage_array()
31 .cast(ext_dtype.storage_dtype().clone())
32 {
33 Ok(arr) => arr,
34 Err(e) => {
35 tracing::warn!("Failed to cast storage array: {e}");
36 return Ok(None);
37 }
38 };
39
40 Ok(Some(
41 ExtensionArray::new(ext_dtype.clone(), new_storage).into_array(),
42 ))
43 }
44}
45
46#[cfg(test)]
47mod tests {
48
49 use rstest::rstest;
50 use vortex_buffer::Buffer;
51 use vortex_buffer::buffer;
52
53 use super::*;
54 use crate::IntoArray;
55 use crate::arrays::PrimitiveArray;
56 use crate::assert_arrays_eq;
57 use crate::builtins::ArrayBuiltins;
58 use crate::compute::conformance::cast::test_cast_conformance;
59 use crate::dtype::DType;
60 use crate::dtype::Nullability;
61 use crate::dtype::PType;
62 use crate::extension::datetime::TimeUnit;
63 use crate::extension::datetime::Timestamp;
64
65 #[test]
66 fn cast_same_ext_dtype() {
67 let ext_dtype = Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
68 let storage = Buffer::<i64>::empty().into_array();
69
70 let arr = ExtensionArray::new(ext_dtype.clone(), storage);
71
72 let output = arr
73 .clone()
74 .into_array()
75 .cast(DType::Extension(ext_dtype.clone()))
76 .unwrap();
77 assert_eq!(arr.len(), output.len());
78 assert_eq!(arr.dtype(), output.dtype());
79 assert_eq!(output.dtype(), &DType::Extension(ext_dtype));
80 }
81
82 #[test]
83 fn cast_same_ext_dtype_differet_nullability() {
84 let ext_dtype = Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
85 let storage = Buffer::<i64>::empty().into_array();
86
87 let arr = ExtensionArray::new(ext_dtype.clone(), storage);
88 assert!(!arr.dtype().is_nullable());
89
90 let new_dtype = DType::Extension(ext_dtype).with_nullability(Nullability::Nullable);
91
92 let output = arr.clone().into_array().cast(new_dtype.clone()).unwrap();
93 assert_eq!(arr.len(), output.len());
94 assert!(arr.dtype().eq_ignore_nullability(output.dtype()));
95 assert_eq!(output.dtype(), &new_dtype);
96 }
97
98 #[test]
99 fn cast_different_ext_dtype() {
100 let original_dtype =
101 Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
102 let target_dtype = Timestamp::new(TimeUnit::Nanoseconds, Nullability::NonNullable).erased();
104
105 let storage = buffer![1i64].into_array();
106 let arr = ExtensionArray::new(original_dtype, storage);
107
108 assert!(
109 arr.into_array()
110 .cast(DType::Extension(target_dtype))
111 .and_then(|a| a.to_canonical().map(|c| c.into_array()))
112 .is_err()
113 );
114 }
115
116 #[test]
117 fn cast_timestamp_to_i64() -> vortex_error::VortexResult<()> {
118 let ext_dtype = Timestamp::new_with_tz(
119 TimeUnit::Nanoseconds,
120 Some("UTC".into()),
121 Nullability::NonNullable,
122 )
123 .erased();
124 let storage = buffer![1i64, 2, 3].into_array();
125 let arr = ExtensionArray::new(ext_dtype, storage).into_array();
126
127 let result = arr.cast(DType::Primitive(PType::I64, Nullability::NonNullable))?;
128 assert_eq!(
129 result.dtype(),
130 &DType::Primitive(PType::I64, Nullability::NonNullable)
131 );
132 assert_arrays_eq!(result, buffer![1i64, 2, 3].into_array());
133 Ok(())
134 }
135
136 #[rstest]
137 #[case(create_timestamp_array(TimeUnit::Milliseconds, false))]
138 #[case(create_timestamp_array(TimeUnit::Microseconds, true))]
139 #[case(create_timestamp_array(TimeUnit::Nanoseconds, false))]
140 #[case(create_timestamp_array(TimeUnit::Seconds, true))]
141 fn test_cast_extension_conformance(#[case] array: ExtensionArray) {
142 test_cast_conformance(&array.into_array());
143 }
144
145 fn create_timestamp_array(time_unit: TimeUnit, nullable: bool) -> ExtensionArray {
146 let ext_dtype =
147 Timestamp::new_with_tz(time_unit, Some("UTC".into()), nullable.into()).erased();
148
149 let storage = if nullable {
150 PrimitiveArray::from_option_iter([
151 Some(1_000_000i64), None,
153 Some(2_000_000),
154 Some(3_000_000),
155 None,
156 ])
157 .into_array()
158 } else {
159 buffer![1_000_000i64, 2_000_000, 3_000_000, 4_000_000, 5_000_000].into_array()
160 };
161
162 ExtensionArray::new(ext_dtype, storage)
163 }
164}