1use vortex_error::VortexExpect;
5use vortex_error::VortexResult;
6use vortex_error::vortex_panic;
7
8use crate::ArrayRef;
9use crate::IntoArray;
10use crate::LEGACY_SESSION;
11use crate::RecursiveCanonical;
12use crate::VortexSessionExecute;
13use crate::aggregate_fn::fns::min_max::MinMaxResult;
14use crate::aggregate_fn::fns::min_max::min_max;
15use crate::builtins::ArrayBuiltins;
16use crate::dtype::DType;
17use crate::dtype::Nullability;
18use crate::dtype::PType;
19use crate::scalar::Scalar;
20
21fn cast_and_execute(array: &ArrayRef, dtype: DType) -> VortexResult<ArrayRef> {
23 Ok(array
24 .cast(dtype)?
25 .execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())?
26 .0
27 .into_array())
28}
29
30pub fn test_cast_conformance(array: &ArrayRef) {
40 let dtype = array.dtype();
41
42 test_cast_identity(array);
44
45 test_cast_to_non_nullable(array);
46 test_cast_to_nullable(array);
47
48 match dtype {
50 DType::Null => test_cast_from_null(array),
51 DType::Primitive(ptype, ..) => match ptype {
52 PType::U8
53 | PType::U16
54 | PType::U32
55 | PType::U64
56 | PType::I8
57 | PType::I16
58 | PType::I32
59 | PType::I64 => test_cast_to_integral_types(array),
60 PType::F16 | PType::F32 | PType::F64 => test_cast_from_floating_point_types(array),
61 },
62 _ => {}
63 }
64}
65
66fn test_cast_identity(array: &ArrayRef) {
67 let result = cast_and_execute(&array.clone(), array.dtype().clone())
69 .vortex_expect("cast should succeed in conformance test");
70 assert_eq!(result.len(), array.len());
71 assert_eq!(result.dtype(), array.dtype());
72
73 for i in 0..array.len().min(10) {
75 assert_eq!(
76 array
77 .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
78 .vortex_expect("scalar_at should succeed in conformance test"),
79 result
80 .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
81 .vortex_expect("scalar_at should succeed in conformance test")
82 );
83 }
84}
85
86fn test_cast_from_null(array: &ArrayRef) {
87 let result = cast_and_execute(&array.clone(), DType::Null)
89 .vortex_expect("cast should succeed in conformance test");
90 assert_eq!(result.len(), array.len());
91 assert_eq!(result.dtype(), &DType::Null);
92
93 let nullable_types = vec![
95 DType::Bool(Nullability::Nullable),
96 DType::Primitive(PType::I32, Nullability::Nullable),
97 DType::Primitive(PType::F64, Nullability::Nullable),
98 DType::Utf8(Nullability::Nullable),
99 DType::Binary(Nullability::Nullable),
100 ];
101
102 for dtype in nullable_types {
103 let result = cast_and_execute(&array.clone(), dtype.clone())
104 .vortex_expect("cast should succeed in conformance test");
105 assert_eq!(result.len(), array.len());
106 assert_eq!(result.dtype(), &dtype);
107
108 for i in 0..array.len().min(10) {
110 assert!(
111 result
112 .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
113 .vortex_expect("scalar_at should succeed in conformance test")
114 .is_null()
115 );
116 }
117 }
118
119 let non_nullable_types = vec![
121 DType::Bool(Nullability::NonNullable),
122 DType::Primitive(PType::I32, Nullability::NonNullable),
123 ];
124
125 for dtype in non_nullable_types {
126 assert!(cast_and_execute(&array.clone(), dtype.clone()).is_err());
127 }
128}
129
130fn test_cast_to_non_nullable(array: &ArrayRef) {
131 if array
132 .invalid_count(&mut LEGACY_SESSION.create_execution_ctx())
133 .vortex_expect("invalid_count should succeed in conformance test")
134 == 0
135 {
136 let non_nullable = cast_and_execute(&array.clone(), array.dtype().as_nonnullable())
137 .vortex_expect("arrays without nulls can cast to non-nullable");
138 assert_eq!(non_nullable.dtype(), &array.dtype().as_nonnullable());
139 assert_eq!(non_nullable.len(), array.len());
140
141 for i in 0..array.len().min(10) {
142 assert_eq!(
143 array
144 .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
145 .vortex_expect("scalar_at should succeed in conformance test"),
146 non_nullable
147 .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
148 .vortex_expect("scalar_at should succeed in conformance test")
149 );
150 }
151
152 let back_to_nullable = cast_and_execute(&non_nullable, array.dtype().clone())
153 .vortex_expect("non-nullable arrays can cast to nullable");
154 assert_eq!(back_to_nullable.dtype(), array.dtype());
155 assert_eq!(back_to_nullable.len(), array.len());
156
157 for i in 0..array.len().min(10) {
158 assert_eq!(
159 array
160 .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
161 .vortex_expect("scalar_at should succeed in conformance test"),
162 back_to_nullable
163 .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
164 .vortex_expect("scalar_at should succeed in conformance test")
165 );
166 }
167 } else {
168 if &DType::Null == array.dtype() {
169 return;
172 }
173 cast_and_execute(&array.clone(), array.dtype().as_nonnullable())
174 .err()
175 .unwrap_or_else(|| {
176 vortex_panic!(
177 "arrays with nulls should error when casting to non-nullable {}",
178 array,
179 )
180 });
181 }
182}
183
184fn test_cast_to_nullable(array: &ArrayRef) {
185 let nullable = cast_and_execute(&array.clone(), array.dtype().as_nullable())
186 .vortex_expect("arrays without nulls can cast to nullable");
187 assert_eq!(nullable.dtype(), &array.dtype().as_nullable());
188 assert_eq!(nullable.len(), array.len());
189
190 for i in 0..array.len().min(10) {
191 assert_eq!(
192 array
193 .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
194 .vortex_expect("scalar_at should succeed in conformance test"),
195 nullable
196 .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
197 .vortex_expect("scalar_at should succeed in conformance test")
198 );
199 }
200
201 let back = cast_and_execute(&nullable, array.dtype().clone())
202 .vortex_expect("casting to nullable and back should be a no-op");
203 assert_eq!(back.dtype(), array.dtype());
204 assert_eq!(back.len(), array.len());
205
206 for i in 0..array.len().min(10) {
207 assert_eq!(
208 array
209 .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
210 .vortex_expect("scalar_at should succeed in conformance test"),
211 back.execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
212 .vortex_expect("scalar_at should succeed in conformance test")
213 );
214 }
215}
216
217fn test_cast_from_floating_point_types(array: &ArrayRef) {
218 let ptype = array.as_primitive_typed().ptype();
219 test_cast_to_primitive(array, PType::I8, false);
220 test_cast_to_primitive(array, PType::U8, false);
221 test_cast_to_primitive(array, PType::I16, false);
222 test_cast_to_primitive(array, PType::U16, false);
223 test_cast_to_primitive(array, PType::I32, false);
224 test_cast_to_primitive(array, PType::U32, false);
225 test_cast_to_primitive(array, PType::I64, false);
226 test_cast_to_primitive(array, PType::U64, false);
227 test_cast_to_primitive(array, PType::F16, matches!(ptype, PType::F16));
228 test_cast_to_primitive(array, PType::F32, matches!(ptype, PType::F16 | PType::F32));
229 test_cast_to_primitive(array, PType::F64, true);
230}
231
232fn test_cast_to_integral_types(array: &ArrayRef) {
233 test_cast_to_primitive(array, PType::I8, true);
234 test_cast_to_primitive(array, PType::U8, true);
235 test_cast_to_primitive(array, PType::I16, true);
236 test_cast_to_primitive(array, PType::U16, true);
237 test_cast_to_primitive(array, PType::I32, true);
238 test_cast_to_primitive(array, PType::U32, true);
239 test_cast_to_primitive(array, PType::I64, true);
240 test_cast_to_primitive(array, PType::U64, true);
241}
242
243fn fits(value: &Scalar, ptype: PType) -> bool {
245 let dtype = DType::Primitive(ptype, value.dtype().nullability());
246 value.cast(&dtype).is_ok()
247}
248
249fn test_cast_to_primitive(array: &ArrayRef, target_ptype: PType, test_round_trip: bool) {
250 let mut ctx = LEGACY_SESSION.create_execution_ctx();
251 let maybe_min_max =
252 min_max(array, &mut ctx).vortex_expect("cast should succeed in conformance test");
253
254 if let Some(MinMaxResult { min, max }) = maybe_min_max
255 && (!fits(&min, target_ptype) || !fits(&max, target_ptype))
256 {
257 cast_and_execute(
258 &array.clone(),
259 DType::Primitive(target_ptype, array.dtype().nullability()),
260 )
261 .err()
262 .unwrap_or_else(|| {
263 vortex_panic!(
264 "Cast must fail because some values are out of bounds. {} {:?} {:?} {} {}",
265 target_ptype,
266 min,
267 max,
268 array,
269 array.display_values(),
270 )
271 });
272 return;
273 }
274
275 let casted = cast_and_execute(
277 &array.clone(),
278 DType::Primitive(target_ptype, array.dtype().nullability()),
279 )
280 .unwrap_or_else(|e| {
281 vortex_panic!(
282 "Cast must succeed because all values are within bounds. {} {}: {e}",
283 target_ptype,
284 array.display_values(),
285 )
286 });
287 assert_eq!(
288 array
289 .validity()
290 .vortex_expect("validity_mask should succeed in conformance test")
291 .to_mask(array.len(), &mut LEGACY_SESSION.create_execution_ctx())
292 .vortex_expect("Failed to compute validity mask"),
293 casted
294 .validity()
295 .vortex_expect("validity_mask should succeed in conformance test")
296 .to_mask(casted.len(), &mut LEGACY_SESSION.create_execution_ctx())
297 .vortex_expect("Failed to compute validity mask")
298 );
299 for i in 0..array.len().min(10) {
300 let original = array
301 .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
302 .vortex_expect("scalar_at should succeed in conformance test");
303 let casted = casted
304 .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
305 .vortex_expect("scalar_at should succeed in conformance test");
306 assert_eq!(
307 original
308 .cast(casted.dtype())
309 .vortex_expect("cast should succeed in conformance test"),
310 casted,
311 "{i} {original} {casted}"
312 );
313 if test_round_trip {
314 assert_eq!(
315 original,
316 casted
317 .cast(original.dtype())
318 .vortex_expect("cast should succeed in conformance test"),
319 "{i} {original} {casted}"
320 );
321 }
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use vortex_buffer::buffer;
328
329 use super::*;
330 use crate::IntoArray;
331 use crate::arrays::BoolArray;
332 use crate::arrays::ListArray;
333 use crate::arrays::NullArray;
334 use crate::arrays::PrimitiveArray;
335 use crate::arrays::StructArray;
336 use crate::arrays::VarBinArray;
337 use crate::dtype::DType;
338 use crate::dtype::FieldNames;
339 use crate::dtype::Nullability;
340
341 #[test]
342 fn test_cast_conformance_u32() {
343 let array = buffer![0u32, 100, 200, 65535, 1000000].into_array();
344 test_cast_conformance(&array);
345 }
346
347 #[test]
348 fn test_cast_conformance_i32() {
349 let array = buffer![-100i32, -1, 0, 1, 100].into_array();
350 test_cast_conformance(&array);
351 }
352
353 #[test]
354 fn test_cast_conformance_f32() {
355 let array = buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array();
356 test_cast_conformance(&array);
357 }
358
359 #[test]
360 fn test_cast_conformance_nullable() {
361 let array = PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]);
362 test_cast_conformance(&array.into_array());
363 }
364
365 #[test]
366 fn test_cast_conformance_bool() {
367 let array = BoolArray::from_iter(vec![true, false, true, false]);
368 test_cast_conformance(&array.into_array());
369 }
370
371 #[test]
372 fn test_cast_conformance_null() {
373 let array = NullArray::new(5);
374 test_cast_conformance(&array.into_array());
375 }
376
377 #[test]
378 fn test_cast_conformance_utf8() {
379 let array = VarBinArray::from_iter(
380 vec![Some("hello"), None, Some("world")],
381 DType::Utf8(Nullability::Nullable),
382 );
383 test_cast_conformance(&array.into_array());
384 }
385
386 #[test]
387 fn test_cast_conformance_binary() {
388 let array = VarBinArray::from_iter(
389 vec![Some(b"data".as_slice()), None, Some(b"bytes".as_slice())],
390 DType::Binary(Nullability::Nullable),
391 );
392 test_cast_conformance(&array.into_array());
393 }
394
395 #[test]
396 fn test_cast_conformance_struct() {
397 let names = FieldNames::from(["a", "b"]);
398
399 let a = buffer![1i32, 2, 3].into_array();
400 let b = VarBinArray::from_iter(
401 vec![Some("x"), None, Some("z")],
402 DType::Utf8(Nullability::Nullable),
403 )
404 .into_array();
405
406 let array =
407 StructArray::try_new(names, vec![a, b], 3, crate::validity::Validity::NonNullable)
408 .unwrap();
409 test_cast_conformance(&array.into_array());
410 }
411
412 #[test]
413 fn test_cast_conformance_list() {
414 let data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
415 let offsets = buffer![0i64, 2, 2, 5, 6].into_array();
416
417 let array =
418 ListArray::try_new(data, offsets, crate::validity::Validity::NonNullable).unwrap();
419 test_cast_conformance(&array.into_array());
420 }
421}