1use vortex_error::VortexExpect;
5use vortex_error::VortexResult;
6use vortex_error::vortex_panic;
7
8use crate::ArrayRef;
9use crate::DynArray;
10use crate::IntoArray;
11use crate::LEGACY_SESSION;
12use crate::RecursiveCanonical;
13use crate::VortexSessionExecute;
14use crate::aggregate_fn::fns::min_max::MinMaxResult;
15use crate::aggregate_fn::fns::min_max::min_max;
16use crate::builtins::ArrayBuiltins;
17use crate::dtype::DType;
18use crate::dtype::Nullability;
19use crate::dtype::PType;
20use crate::scalar::Scalar;
21
22fn cast_and_execute(array: &ArrayRef, dtype: DType) -> VortexResult<ArrayRef> {
24 Ok(array
25 .cast(dtype)?
26 .execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())?
27 .0
28 .into_array())
29}
30
31pub fn test_cast_conformance(array: &ArrayRef) {
41 let dtype = array.dtype();
42
43 test_cast_identity(array);
45
46 test_cast_to_non_nullable(array);
47 test_cast_to_nullable(array);
48
49 match dtype {
51 DType::Null => test_cast_from_null(array),
52 DType::Primitive(ptype, ..) => match ptype {
53 PType::U8
54 | PType::U16
55 | PType::U32
56 | PType::U64
57 | PType::I8
58 | PType::I16
59 | PType::I32
60 | PType::I64 => test_cast_to_integral_types(array),
61 PType::F16 | PType::F32 | PType::F64 => test_cast_from_floating_point_types(array),
62 },
63 _ => {}
64 }
65}
66
67fn test_cast_identity(array: &ArrayRef) {
68 let result = cast_and_execute(&array.to_array(), array.dtype().clone())
70 .vortex_expect("cast should succeed in conformance test");
71 assert_eq!(result.len(), array.len());
72 assert_eq!(result.dtype(), array.dtype());
73
74 for i in 0..array.len().min(10) {
76 assert_eq!(
77 array
78 .scalar_at(i)
79 .vortex_expect("scalar_at should succeed in conformance test"),
80 result
81 .scalar_at(i)
82 .vortex_expect("scalar_at should succeed in conformance test")
83 );
84 }
85}
86
87fn test_cast_from_null(array: &ArrayRef) {
88 let result = cast_and_execute(&array.to_array(), DType::Null)
90 .vortex_expect("cast should succeed in conformance test");
91 assert_eq!(result.len(), array.len());
92 assert_eq!(result.dtype(), &DType::Null);
93
94 let nullable_types = vec![
96 DType::Bool(Nullability::Nullable),
97 DType::Primitive(PType::I32, Nullability::Nullable),
98 DType::Primitive(PType::F64, Nullability::Nullable),
99 DType::Utf8(Nullability::Nullable),
100 DType::Binary(Nullability::Nullable),
101 ];
102
103 for dtype in nullable_types {
104 let result = cast_and_execute(&array.to_array(), dtype.clone())
105 .vortex_expect("cast should succeed in conformance test");
106 assert_eq!(result.len(), array.len());
107 assert_eq!(result.dtype(), &dtype);
108
109 for i in 0..array.len().min(10) {
111 assert!(
112 result
113 .scalar_at(i)
114 .vortex_expect("scalar_at should succeed in conformance test")
115 .is_null()
116 );
117 }
118 }
119
120 let non_nullable_types = vec![
122 DType::Bool(Nullability::NonNullable),
123 DType::Primitive(PType::I32, Nullability::NonNullable),
124 ];
125
126 for dtype in non_nullable_types {
127 assert!(cast_and_execute(&array.to_array(), dtype.clone()).is_err());
128 }
129}
130
131fn test_cast_to_non_nullable(array: &ArrayRef) {
132 if array
133 .invalid_count()
134 .vortex_expect("invalid_count should succeed in conformance test")
135 == 0
136 {
137 let non_nullable = cast_and_execute(&array.to_array(), array.dtype().as_nonnullable())
138 .vortex_expect("arrays without nulls can cast to non-nullable");
139 assert_eq!(non_nullable.dtype(), &array.dtype().as_nonnullable());
140 assert_eq!(non_nullable.len(), array.len());
141
142 for i in 0..array.len().min(10) {
143 assert_eq!(
144 array
145 .scalar_at(i)
146 .vortex_expect("scalar_at should succeed in conformance test"),
147 non_nullable
148 .scalar_at(i)
149 .vortex_expect("scalar_at should succeed in conformance test")
150 );
151 }
152
153 let back_to_nullable = cast_and_execute(&non_nullable, array.dtype().clone())
154 .vortex_expect("non-nullable arrays can cast to nullable");
155 assert_eq!(back_to_nullable.dtype(), array.dtype());
156 assert_eq!(back_to_nullable.len(), array.len());
157
158 for i in 0..array.len().min(10) {
159 assert_eq!(
160 array
161 .scalar_at(i)
162 .vortex_expect("scalar_at should succeed in conformance test"),
163 back_to_nullable
164 .scalar_at(i)
165 .vortex_expect("scalar_at should succeed in conformance test")
166 );
167 }
168 } else {
169 if &DType::Null == array.dtype() {
170 return;
173 }
174 cast_and_execute(&array.to_array(), array.dtype().as_nonnullable())
175 .err()
176 .unwrap_or_else(|| {
177 vortex_panic!(
178 "arrays with nulls should error when casting to non-nullable {}",
179 array,
180 )
181 });
182 }
183}
184
185fn test_cast_to_nullable(array: &ArrayRef) {
186 let nullable = cast_and_execute(&array.to_array(), array.dtype().as_nullable())
187 .vortex_expect("arrays without nulls can cast to nullable");
188 assert_eq!(nullable.dtype(), &array.dtype().as_nullable());
189 assert_eq!(nullable.len(), array.len());
190
191 for i in 0..array.len().min(10) {
192 assert_eq!(
193 array
194 .scalar_at(i)
195 .vortex_expect("scalar_at should succeed in conformance test"),
196 nullable
197 .scalar_at(i)
198 .vortex_expect("scalar_at should succeed in conformance test")
199 );
200 }
201
202 let back = cast_and_execute(&nullable, array.dtype().clone())
203 .vortex_expect("casting to nullable and back should be a no-op");
204 assert_eq!(back.dtype(), array.dtype());
205 assert_eq!(back.len(), array.len());
206
207 for i in 0..array.len().min(10) {
208 assert_eq!(
209 array
210 .scalar_at(i)
211 .vortex_expect("scalar_at should succeed in conformance test"),
212 back.scalar_at(i)
213 .vortex_expect("scalar_at should succeed in conformance test")
214 );
215 }
216}
217
218fn test_cast_from_floating_point_types(array: &ArrayRef) {
219 let ptype = array.as_primitive_typed().ptype();
220 test_cast_to_primitive(array, PType::I8, false);
221 test_cast_to_primitive(array, PType::U8, false);
222 test_cast_to_primitive(array, PType::I16, false);
223 test_cast_to_primitive(array, PType::U16, false);
224 test_cast_to_primitive(array, PType::I32, false);
225 test_cast_to_primitive(array, PType::U32, false);
226 test_cast_to_primitive(array, PType::I64, false);
227 test_cast_to_primitive(array, PType::U64, false);
228 test_cast_to_primitive(array, PType::F16, matches!(ptype, PType::F16));
229 test_cast_to_primitive(array, PType::F32, matches!(ptype, PType::F16 | PType::F32));
230 test_cast_to_primitive(array, PType::F64, true);
231}
232
233fn test_cast_to_integral_types(array: &ArrayRef) {
234 test_cast_to_primitive(array, PType::I8, true);
235 test_cast_to_primitive(array, PType::U8, true);
236 test_cast_to_primitive(array, PType::I16, true);
237 test_cast_to_primitive(array, PType::U16, true);
238 test_cast_to_primitive(array, PType::I32, true);
239 test_cast_to_primitive(array, PType::U32, true);
240 test_cast_to_primitive(array, PType::I64, true);
241 test_cast_to_primitive(array, PType::U64, true);
242}
243
244fn fits(value: &Scalar, ptype: PType) -> bool {
246 let dtype = DType::Primitive(ptype, value.dtype().nullability());
247 value.cast(&dtype).is_ok()
248}
249
250fn test_cast_to_primitive(array: &ArrayRef, target_ptype: PType, test_round_trip: bool) {
251 let mut ctx = LEGACY_SESSION.create_execution_ctx();
252 let maybe_min_max =
253 min_max(array, &mut ctx).vortex_expect("cast should succeed in conformance test");
254
255 if let Some(MinMaxResult { min, max }) = maybe_min_max
256 && (!fits(&min, target_ptype) || !fits(&max, target_ptype))
257 {
258 cast_and_execute(
259 &array.to_array(),
260 DType::Primitive(target_ptype, array.dtype().nullability()),
261 )
262 .err()
263 .unwrap_or_else(|| {
264 vortex_panic!(
265 "Cast must fail because some values are out of bounds. {} {:?} {:?} {} {}",
266 target_ptype,
267 min,
268 max,
269 array,
270 array.display_values(),
271 )
272 });
273 return;
274 }
275
276 let casted = cast_and_execute(
278 &array.to_array(),
279 DType::Primitive(target_ptype, array.dtype().nullability()),
280 )
281 .unwrap_or_else(|e| {
282 vortex_panic!(
283 "Cast must succeed because all values are within bounds. {} {}: {e}",
284 target_ptype,
285 array.display_values(),
286 )
287 });
288 assert_eq!(
289 array
290 .validity_mask()
291 .vortex_expect("validity_mask should succeed in conformance test"),
292 casted
293 .validity_mask()
294 .vortex_expect("validity_mask should succeed in conformance test")
295 );
296 for i in 0..array.len().min(10) {
297 let original = array
298 .scalar_at(i)
299 .vortex_expect("scalar_at should succeed in conformance test");
300 let casted = casted
301 .scalar_at(i)
302 .vortex_expect("scalar_at should succeed in conformance test");
303 assert_eq!(
304 original
305 .cast(casted.dtype())
306 .vortex_expect("cast should succeed in conformance test"),
307 casted,
308 "{i} {original} {casted}"
309 );
310 if test_round_trip {
311 assert_eq!(
312 original,
313 casted
314 .cast(original.dtype())
315 .vortex_expect("cast should succeed in conformance test"),
316 "{i} {original} {casted}"
317 );
318 }
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use vortex_buffer::buffer;
325
326 use super::*;
327 use crate::IntoArray;
328 use crate::arrays::BoolArray;
329 use crate::arrays::ListArray;
330 use crate::arrays::NullArray;
331 use crate::arrays::PrimitiveArray;
332 use crate::arrays::StructArray;
333 use crate::arrays::VarBinArray;
334 use crate::dtype::DType;
335 use crate::dtype::FieldNames;
336 use crate::dtype::Nullability;
337
338 #[test]
339 fn test_cast_conformance_u32() {
340 let array = buffer![0u32, 100, 200, 65535, 1000000].into_array();
341 test_cast_conformance(&array);
342 }
343
344 #[test]
345 fn test_cast_conformance_i32() {
346 let array = buffer![-100i32, -1, 0, 1, 100].into_array();
347 test_cast_conformance(&array);
348 }
349
350 #[test]
351 fn test_cast_conformance_f32() {
352 let array = buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array();
353 test_cast_conformance(&array);
354 }
355
356 #[test]
357 fn test_cast_conformance_nullable() {
358 let array = PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]);
359 test_cast_conformance(&array.into_array());
360 }
361
362 #[test]
363 fn test_cast_conformance_bool() {
364 let array = BoolArray::from_iter(vec![true, false, true, false]);
365 test_cast_conformance(&array.into_array());
366 }
367
368 #[test]
369 fn test_cast_conformance_null() {
370 let array = NullArray::new(5);
371 test_cast_conformance(&array.into_array());
372 }
373
374 #[test]
375 fn test_cast_conformance_utf8() {
376 let array = VarBinArray::from_iter(
377 vec![Some("hello"), None, Some("world")],
378 DType::Utf8(Nullability::Nullable),
379 );
380 test_cast_conformance(&array.into_array());
381 }
382
383 #[test]
384 fn test_cast_conformance_binary() {
385 let array = VarBinArray::from_iter(
386 vec![Some(b"data".as_slice()), None, Some(b"bytes".as_slice())],
387 DType::Binary(Nullability::Nullable),
388 );
389 test_cast_conformance(&array.into_array());
390 }
391
392 #[test]
393 fn test_cast_conformance_struct() {
394 let names = FieldNames::from(["a", "b"]);
395
396 let a = buffer![1i32, 2, 3].into_array();
397 let b = VarBinArray::from_iter(
398 vec![Some("x"), None, Some("z")],
399 DType::Utf8(Nullability::Nullable),
400 )
401 .into_array();
402
403 let array =
404 StructArray::try_new(names, vec![a, b], 3, crate::validity::Validity::NonNullable)
405 .unwrap();
406 test_cast_conformance(&array.into_array());
407 }
408
409 #[test]
410 fn test_cast_conformance_list() {
411 let data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
412 let offsets = buffer![0i64, 2, 2, 5, 6].into_array();
413
414 let array =
415 ListArray::try_new(data, offsets, crate::validity::Validity::NonNullable).unwrap();
416 test_cast_conformance(&array.into_array());
417 }
418}