1use vortex_error::VortexExpect;
5use vortex_error::VortexResult;
6use vortex_error::vortex_panic;
7
8use crate::ArrayRef;
9use crate::DynArray;
10use crate::IntoArray;
11use crate::LEGACY_SESSION;
12use crate::RecursiveCanonical;
13use crate::VortexSessionExecute;
14use crate::builtins::ArrayBuiltins;
15use crate::compute::MinMaxResult;
16use crate::compute::min_max;
17use crate::dtype::DType;
18use crate::dtype::Nullability;
19use crate::dtype::PType;
20use crate::scalar::Scalar;
21
22fn cast_and_execute(array: &ArrayRef, dtype: DType) -> VortexResult<ArrayRef> {
24 Ok(array
25 .cast(dtype)?
26 .execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())?
27 .0
28 .into_array())
29}
30
31pub fn test_cast_conformance(array: &ArrayRef) {
41 let dtype = array.dtype();
42
43 test_cast_identity(array);
45
46 test_cast_to_non_nullable(array);
47 test_cast_to_nullable(array);
48
49 match dtype {
51 DType::Null => test_cast_from_null(array),
52 DType::Primitive(ptype, ..) => match ptype {
53 PType::U8
54 | PType::U16
55 | PType::U32
56 | PType::U64
57 | PType::I8
58 | PType::I16
59 | PType::I32
60 | PType::I64 => test_cast_to_integral_types(array),
61 PType::F16 | PType::F32 | PType::F64 => test_cast_from_floating_point_types(array),
62 },
63 _ => {}
64 }
65}
66
67fn test_cast_identity(array: &ArrayRef) {
68 let result = cast_and_execute(&array.to_array(), array.dtype().clone())
70 .vortex_expect("cast should succeed in conformance test");
71 assert_eq!(result.len(), array.len());
72 assert_eq!(result.dtype(), array.dtype());
73
74 for i in 0..array.len().min(10) {
76 assert_eq!(
77 array
78 .scalar_at(i)
79 .vortex_expect("scalar_at should succeed in conformance test"),
80 result
81 .scalar_at(i)
82 .vortex_expect("scalar_at should succeed in conformance test")
83 );
84 }
85}
86
87fn test_cast_from_null(array: &ArrayRef) {
88 let result = cast_and_execute(&array.to_array(), DType::Null)
90 .vortex_expect("cast should succeed in conformance test");
91 assert_eq!(result.len(), array.len());
92 assert_eq!(result.dtype(), &DType::Null);
93
94 let nullable_types = vec![
96 DType::Bool(Nullability::Nullable),
97 DType::Primitive(PType::I32, Nullability::Nullable),
98 DType::Primitive(PType::F64, Nullability::Nullable),
99 DType::Utf8(Nullability::Nullable),
100 DType::Binary(Nullability::Nullable),
101 ];
102
103 for dtype in nullable_types {
104 let result = cast_and_execute(&array.to_array(), dtype.clone())
105 .vortex_expect("cast should succeed in conformance test");
106 assert_eq!(result.len(), array.len());
107 assert_eq!(result.dtype(), &dtype);
108
109 for i in 0..array.len().min(10) {
111 assert!(
112 result
113 .scalar_at(i)
114 .vortex_expect("scalar_at should succeed in conformance test")
115 .is_null()
116 );
117 }
118 }
119
120 let non_nullable_types = vec![
122 DType::Bool(Nullability::NonNullable),
123 DType::Primitive(PType::I32, Nullability::NonNullable),
124 ];
125
126 for dtype in non_nullable_types {
127 assert!(cast_and_execute(&array.to_array(), dtype.clone()).is_err());
128 }
129}
130
131fn test_cast_to_non_nullable(array: &ArrayRef) {
132 if array
133 .invalid_count()
134 .vortex_expect("invalid_count should succeed in conformance test")
135 == 0
136 {
137 let non_nullable = cast_and_execute(&array.to_array(), array.dtype().as_nonnullable())
138 .vortex_expect("arrays without nulls can cast to non-nullable");
139 assert_eq!(non_nullable.dtype(), &array.dtype().as_nonnullable());
140 assert_eq!(non_nullable.len(), array.len());
141
142 for i in 0..array.len().min(10) {
143 assert_eq!(
144 array
145 .scalar_at(i)
146 .vortex_expect("scalar_at should succeed in conformance test"),
147 non_nullable
148 .scalar_at(i)
149 .vortex_expect("scalar_at should succeed in conformance test")
150 );
151 }
152
153 let back_to_nullable = cast_and_execute(&non_nullable, array.dtype().clone())
154 .vortex_expect("non-nullable arrays can cast to nullable");
155 assert_eq!(back_to_nullable.dtype(), array.dtype());
156 assert_eq!(back_to_nullable.len(), array.len());
157
158 for i in 0..array.len().min(10) {
159 assert_eq!(
160 array
161 .scalar_at(i)
162 .vortex_expect("scalar_at should succeed in conformance test"),
163 back_to_nullable
164 .scalar_at(i)
165 .vortex_expect("scalar_at should succeed in conformance test")
166 );
167 }
168 } else {
169 if &DType::Null == array.dtype() {
170 return;
173 }
174 cast_and_execute(&array.to_array(), array.dtype().as_nonnullable())
175 .err()
176 .unwrap_or_else(|| {
177 vortex_panic!(
178 "arrays with nulls should error when casting to non-nullable {}",
179 array,
180 )
181 });
182 }
183}
184
185fn test_cast_to_nullable(array: &ArrayRef) {
186 let nullable = cast_and_execute(&array.to_array(), array.dtype().as_nullable())
187 .vortex_expect("arrays without nulls can cast to nullable");
188 assert_eq!(nullable.dtype(), &array.dtype().as_nullable());
189 assert_eq!(nullable.len(), array.len());
190
191 for i in 0..array.len().min(10) {
192 assert_eq!(
193 array
194 .scalar_at(i)
195 .vortex_expect("scalar_at should succeed in conformance test"),
196 nullable
197 .scalar_at(i)
198 .vortex_expect("scalar_at should succeed in conformance test")
199 );
200 }
201
202 let back = cast_and_execute(&nullable, array.dtype().clone())
203 .vortex_expect("casting to nullable and back should be a no-op");
204 assert_eq!(back.dtype(), array.dtype());
205 assert_eq!(back.len(), array.len());
206
207 for i in 0..array.len().min(10) {
208 assert_eq!(
209 array
210 .scalar_at(i)
211 .vortex_expect("scalar_at should succeed in conformance test"),
212 back.scalar_at(i)
213 .vortex_expect("scalar_at should succeed in conformance test")
214 );
215 }
216}
217
218fn test_cast_from_floating_point_types(array: &ArrayRef) {
219 let ptype = array.as_primitive_typed().ptype();
220 test_cast_to_primitive(array, PType::I8, false);
221 test_cast_to_primitive(array, PType::U8, false);
222 test_cast_to_primitive(array, PType::I16, false);
223 test_cast_to_primitive(array, PType::U16, false);
224 test_cast_to_primitive(array, PType::I32, false);
225 test_cast_to_primitive(array, PType::U32, false);
226 test_cast_to_primitive(array, PType::I64, false);
227 test_cast_to_primitive(array, PType::U64, false);
228 test_cast_to_primitive(array, PType::F16, matches!(ptype, PType::F16));
229 test_cast_to_primitive(array, PType::F32, matches!(ptype, PType::F16 | PType::F32));
230 test_cast_to_primitive(array, PType::F64, true);
231}
232
233fn test_cast_to_integral_types(array: &ArrayRef) {
234 test_cast_to_primitive(array, PType::I8, true);
235 test_cast_to_primitive(array, PType::U8, true);
236 test_cast_to_primitive(array, PType::I16, true);
237 test_cast_to_primitive(array, PType::U16, true);
238 test_cast_to_primitive(array, PType::I32, true);
239 test_cast_to_primitive(array, PType::U32, true);
240 test_cast_to_primitive(array, PType::I64, true);
241 test_cast_to_primitive(array, PType::U64, true);
242}
243
244fn fits(value: &Scalar, ptype: PType) -> bool {
246 let dtype = DType::Primitive(ptype, value.dtype().nullability());
247 value.cast(&dtype).is_ok()
248}
249
250fn test_cast_to_primitive(array: &ArrayRef, target_ptype: PType, test_round_trip: bool) {
251 let maybe_min_max = min_max(array).vortex_expect("cast should succeed in conformance test");
252
253 if let Some(MinMaxResult { min, max }) = maybe_min_max
254 && (!fits(&min, target_ptype) || !fits(&max, target_ptype))
255 {
256 cast_and_execute(
257 &array.to_array(),
258 DType::Primitive(target_ptype, array.dtype().nullability()),
259 )
260 .err()
261 .unwrap_or_else(|| {
262 vortex_panic!(
263 "Cast must fail because some values are out of bounds. {} {:?} {:?} {} {}",
264 target_ptype,
265 min,
266 max,
267 array,
268 array.display_values(),
269 )
270 });
271 return;
272 }
273
274 let casted = cast_and_execute(
276 &array.to_array(),
277 DType::Primitive(target_ptype, array.dtype().nullability()),
278 )
279 .unwrap_or_else(|e| {
280 vortex_panic!(
281 "Cast must succeed because all values are within bounds. {} {}: {e}",
282 target_ptype,
283 array.display_values(),
284 )
285 });
286 assert_eq!(
287 array
288 .validity_mask()
289 .vortex_expect("validity_mask should succeed in conformance test"),
290 casted
291 .validity_mask()
292 .vortex_expect("validity_mask should succeed in conformance test")
293 );
294 for i in 0..array.len().min(10) {
295 let original = array
296 .scalar_at(i)
297 .vortex_expect("scalar_at should succeed in conformance test");
298 let casted = casted
299 .scalar_at(i)
300 .vortex_expect("scalar_at should succeed in conformance test");
301 assert_eq!(
302 original
303 .cast(casted.dtype())
304 .vortex_expect("cast should succeed in conformance test"),
305 casted,
306 "{i} {original} {casted}"
307 );
308 if test_round_trip {
309 assert_eq!(
310 original,
311 casted
312 .cast(original.dtype())
313 .vortex_expect("cast should succeed in conformance test"),
314 "{i} {original} {casted}"
315 );
316 }
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use vortex_buffer::buffer;
323
324 use super::*;
325 use crate::IntoArray;
326 use crate::arrays::BoolArray;
327 use crate::arrays::ListArray;
328 use crate::arrays::NullArray;
329 use crate::arrays::PrimitiveArray;
330 use crate::arrays::StructArray;
331 use crate::arrays::VarBinArray;
332 use crate::dtype::DType;
333 use crate::dtype::FieldNames;
334 use crate::dtype::Nullability;
335
336 #[test]
337 fn test_cast_conformance_u32() {
338 let array = buffer![0u32, 100, 200, 65535, 1000000].into_array();
339 test_cast_conformance(&array);
340 }
341
342 #[test]
343 fn test_cast_conformance_i32() {
344 let array = buffer![-100i32, -1, 0, 1, 100].into_array();
345 test_cast_conformance(&array);
346 }
347
348 #[test]
349 fn test_cast_conformance_f32() {
350 let array = buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array();
351 test_cast_conformance(&array);
352 }
353
354 #[test]
355 fn test_cast_conformance_nullable() {
356 let array = PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]);
357 test_cast_conformance(&array.into_array());
358 }
359
360 #[test]
361 fn test_cast_conformance_bool() {
362 let array = BoolArray::from_iter(vec![true, false, true, false]);
363 test_cast_conformance(&array.into_array());
364 }
365
366 #[test]
367 fn test_cast_conformance_null() {
368 let array = NullArray::new(5);
369 test_cast_conformance(&array.into_array());
370 }
371
372 #[test]
373 fn test_cast_conformance_utf8() {
374 let array = VarBinArray::from_iter(
375 vec![Some("hello"), None, Some("world")],
376 DType::Utf8(Nullability::Nullable),
377 );
378 test_cast_conformance(&array.into_array());
379 }
380
381 #[test]
382 fn test_cast_conformance_binary() {
383 let array = VarBinArray::from_iter(
384 vec![Some(b"data".as_slice()), None, Some(b"bytes".as_slice())],
385 DType::Binary(Nullability::Nullable),
386 );
387 test_cast_conformance(&array.into_array());
388 }
389
390 #[test]
391 fn test_cast_conformance_struct() {
392 let names = FieldNames::from(["a", "b"]);
393
394 let a = buffer![1i32, 2, 3].into_array();
395 let b = VarBinArray::from_iter(
396 vec![Some("x"), None, Some("z")],
397 DType::Utf8(Nullability::Nullable),
398 )
399 .into_array();
400
401 let array =
402 StructArray::try_new(names, vec![a, b], 3, crate::validity::Validity::NonNullable)
403 .unwrap();
404 test_cast_conformance(&array.into_array());
405 }
406
407 #[test]
408 fn test_cast_conformance_list() {
409 let data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
410 let offsets = buffer![0i64, 2, 2, 5, 6].into_array();
411
412 let array =
413 ListArray::try_new(data, offsets, crate::validity::Validity::NonNullable).unwrap();
414 test_cast_conformance(&array.into_array());
415 }
416}