vortex_array/compute/conformance/
cast.rs1use vortex_dtype::DType;
5use vortex_dtype::Nullability;
6use vortex_dtype::PType;
7use vortex_error::VortexExpect;
8use vortex_error::vortex_panic;
9use vortex_scalar::Scalar;
10
11use crate::Array;
12use crate::compute::MinMaxResult;
13use crate::compute::cast;
14use crate::compute::min_max;
15
16pub fn test_cast_conformance(array: &dyn Array) {
26 let dtype = array.dtype();
27
28 test_cast_identity(array);
30
31 test_cast_to_non_nullable(array);
32 test_cast_to_nullable(array);
33
34 match dtype {
36 DType::Null => test_cast_from_null(array),
37 DType::Primitive(ptype, ..) => match ptype {
38 PType::U8
39 | PType::U16
40 | PType::U32
41 | PType::U64
42 | PType::I8
43 | PType::I16
44 | PType::I32
45 | PType::I64 => test_cast_to_integral_types(array),
46 PType::F16 | PType::F32 | PType::F64 => test_cast_from_floating_point_types(array),
47 },
48 _ => {}
49 }
50}
51
52fn test_cast_identity(array: &dyn Array) {
53 let result =
55 cast(array, array.dtype()).vortex_expect("cast should succeed in conformance test");
56 assert_eq!(result.len(), array.len());
57 assert_eq!(result.dtype(), array.dtype());
58
59 for i in 0..array.len().min(10) {
61 assert_eq!(array.scalar_at(i), result.scalar_at(i));
62 }
63}
64
65fn test_cast_from_null(array: &dyn Array) {
66 let result = cast(array, &DType::Null).vortex_expect("cast should succeed in conformance test");
68 assert_eq!(result.len(), array.len());
69 assert_eq!(result.dtype(), &DType::Null);
70
71 let nullable_types = vec![
73 DType::Bool(Nullability::Nullable),
74 DType::Primitive(PType::I32, Nullability::Nullable),
75 DType::Primitive(PType::F64, Nullability::Nullable),
76 DType::Utf8(Nullability::Nullable),
77 DType::Binary(Nullability::Nullable),
78 ];
79
80 for dtype in nullable_types {
81 let result = cast(array, &dtype).vortex_expect("cast should succeed in conformance test");
82 assert_eq!(result.len(), array.len());
83 assert_eq!(result.dtype(), &dtype);
84
85 for i in 0..array.len().min(10) {
87 assert!(result.scalar_at(i).is_null());
88 }
89 }
90
91 let non_nullable_types = vec![
93 DType::Bool(Nullability::NonNullable),
94 DType::Primitive(PType::I32, Nullability::NonNullable),
95 ];
96
97 for dtype in non_nullable_types {
98 assert!(cast(array, &dtype).is_err());
99 }
100}
101
102fn test_cast_to_non_nullable(array: &dyn Array) {
103 if array.invalid_count() == 0 {
104 let non_nullable = cast(array, &array.dtype().as_nonnullable())
105 .vortex_expect("arrays without nulls can cast to non-nullable");
106 assert_eq!(non_nullable.dtype(), &array.dtype().as_nonnullable());
107 assert_eq!(non_nullable.len(), array.len());
108
109 for i in 0..array.len().min(10) {
110 assert_eq!(array.scalar_at(i), non_nullable.scalar_at(i));
111 }
112
113 let back_to_nullable = cast(&non_nullable, array.dtype())
114 .vortex_expect("non-nullable arrays can cast to nullable");
115 assert_eq!(back_to_nullable.dtype(), array.dtype());
116 assert_eq!(back_to_nullable.len(), array.len());
117
118 for i in 0..array.len().min(10) {
119 assert_eq!(array.scalar_at(i), back_to_nullable.scalar_at(i));
120 }
121 } else {
122 if &DType::Null == array.dtype() {
123 return;
126 }
127 cast(array, &array.dtype().as_nonnullable())
128 .err()
129 .unwrap_or_else(|| {
130 vortex_panic!(
131 "arrays with nulls should error when casting to non-nullable {}",
132 array,
133 )
134 });
135 }
136}
137
138fn test_cast_to_nullable(array: &dyn Array) {
139 let nullable = cast(array, &array.dtype().as_nullable())
140 .vortex_expect("arrays without nulls can cast to nullable");
141 assert_eq!(nullable.dtype(), &array.dtype().as_nullable());
142 assert_eq!(nullable.len(), array.len());
143
144 for i in 0..array.len().min(10) {
145 assert_eq!(array.scalar_at(i), nullable.scalar_at(i));
146 }
147
148 let back = cast(&nullable, array.dtype())
149 .vortex_expect("casting to nullable and back should be a no-op");
150 assert_eq!(back.dtype(), array.dtype());
151 assert_eq!(back.len(), array.len());
152
153 for i in 0..array.len().min(10) {
154 assert_eq!(array.scalar_at(i), back.scalar_at(i));
155 }
156}
157
158fn test_cast_from_floating_point_types(array: &dyn Array) {
159 let ptype = array.as_primitive_typed().ptype();
160 test_cast_to_primitive(array, PType::I8, false);
161 test_cast_to_primitive(array, PType::U8, false);
162 test_cast_to_primitive(array, PType::I16, false);
163 test_cast_to_primitive(array, PType::U16, false);
164 test_cast_to_primitive(array, PType::I32, false);
165 test_cast_to_primitive(array, PType::U32, false);
166 test_cast_to_primitive(array, PType::I64, false);
167 test_cast_to_primitive(array, PType::U64, false);
168 test_cast_to_primitive(array, PType::F16, matches!(ptype, PType::F16));
169 test_cast_to_primitive(array, PType::F32, matches!(ptype, PType::F16 | PType::F32));
170 test_cast_to_primitive(array, PType::F64, true);
171}
172
173fn test_cast_to_integral_types(array: &dyn Array) {
174 test_cast_to_primitive(array, PType::I8, true);
175 test_cast_to_primitive(array, PType::U8, true);
176 test_cast_to_primitive(array, PType::I16, true);
177 test_cast_to_primitive(array, PType::U16, true);
178 test_cast_to_primitive(array, PType::I32, true);
179 test_cast_to_primitive(array, PType::U32, true);
180 test_cast_to_primitive(array, PType::I64, true);
181 test_cast_to_primitive(array, PType::U64, true);
182}
183
184fn fits(value: &Scalar, ptype: PType) -> bool {
186 let dtype = DType::Primitive(ptype, value.dtype().nullability());
187 value.cast(&dtype).is_ok()
188}
189
190fn test_cast_to_primitive(array: &dyn Array, target_ptype: PType, test_round_trip: bool) {
191 let maybe_min_max = min_max(array).vortex_expect("cast should succeed in conformance test");
192
193 if let Some(MinMaxResult { min, max }) = maybe_min_max
194 && (!fits(&min, target_ptype) || !fits(&max, target_ptype))
195 {
196 cast(
197 array,
198 &DType::Primitive(target_ptype, array.dtype().nullability()),
199 )
200 .err()
201 .unwrap_or_else(|| {
202 vortex_panic!(
203 "Cast must fail because some values are out of bounds. {} {:?} {:?} {} {}",
204 target_ptype,
205 min,
206 max,
207 array,
208 array.display_values(),
209 )
210 });
211 return;
212 }
213
214 let casted = cast(
216 array,
217 &DType::Primitive(target_ptype, array.dtype().nullability()),
218 )
219 .unwrap_or_else(|e| {
220 vortex_panic!(
221 "Cast must succeed because all values are within bounds. {} {}: {e}",
222 target_ptype,
223 array.display_values(),
224 )
225 });
226 assert_eq!(array.validity_mask(), casted.validity_mask());
227 for i in 0..array.len().min(10) {
228 let original = array.scalar_at(i);
229 let casted = casted.scalar_at(i);
230 assert_eq!(
231 original
232 .cast(casted.dtype())
233 .vortex_expect("cast should succeed in conformance test"),
234 casted,
235 "{i} {original} {casted}"
236 );
237 if test_round_trip {
238 assert_eq!(
239 original,
240 casted
241 .cast(original.dtype())
242 .vortex_expect("cast should succeed in conformance test"),
243 "{i} {original} {casted}"
244 );
245 }
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use vortex_buffer::buffer;
252 use vortex_dtype::DType;
253 use vortex_dtype::FieldNames;
254 use vortex_dtype::Nullability;
255
256 use super::*;
257 use crate::IntoArray;
258 use crate::arrays::BoolArray;
259 use crate::arrays::ListArray;
260 use crate::arrays::NullArray;
261 use crate::arrays::PrimitiveArray;
262 use crate::arrays::StructArray;
263 use crate::arrays::VarBinArray;
264
265 #[test]
266 fn test_cast_conformance_u32() {
267 let array = buffer![0u32, 100, 200, 65535, 1000000].into_array();
268 test_cast_conformance(array.as_ref());
269 }
270
271 #[test]
272 fn test_cast_conformance_i32() {
273 let array = buffer![-100i32, -1, 0, 1, 100].into_array();
274 test_cast_conformance(array.as_ref());
275 }
276
277 #[test]
278 fn test_cast_conformance_f32() {
279 let array = buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array();
280 test_cast_conformance(array.as_ref());
281 }
282
283 #[test]
284 fn test_cast_conformance_nullable() {
285 let array = PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]);
286 test_cast_conformance(array.as_ref());
287 }
288
289 #[test]
290 fn test_cast_conformance_bool() {
291 let array = BoolArray::from_iter(vec![true, false, true, false]);
292 test_cast_conformance(array.as_ref());
293 }
294
295 #[test]
296 fn test_cast_conformance_null() {
297 let array = NullArray::new(5);
298 test_cast_conformance(array.as_ref());
299 }
300
301 #[test]
302 fn test_cast_conformance_utf8() {
303 let array = VarBinArray::from_iter(
304 vec![Some("hello"), None, Some("world")],
305 DType::Utf8(Nullability::Nullable),
306 );
307 test_cast_conformance(array.as_ref());
308 }
309
310 #[test]
311 fn test_cast_conformance_binary() {
312 let array = VarBinArray::from_iter(
313 vec![Some(b"data".as_slice()), None, Some(b"bytes".as_slice())],
314 DType::Binary(Nullability::Nullable),
315 );
316 test_cast_conformance(array.as_ref());
317 }
318
319 #[test]
320 fn test_cast_conformance_struct() {
321 let names = FieldNames::from(["a", "b"]);
322
323 let a = buffer![1i32, 2, 3].into_array();
324 let b = VarBinArray::from_iter(
325 vec![Some("x"), None, Some("z")],
326 DType::Utf8(Nullability::Nullable),
327 )
328 .into_array();
329
330 let array =
331 StructArray::try_new(names, vec![a, b], 3, crate::validity::Validity::NonNullable)
332 .unwrap();
333 test_cast_conformance(array.as_ref());
334 }
335
336 #[test]
337 fn test_cast_conformance_list() {
338 let data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
339 let offsets = buffer![0i64, 2, 2, 5, 6].into_array();
340
341 let array =
342 ListArray::try_new(data, offsets, crate::validity::Validity::NonNullable).unwrap();
343 test_cast_conformance(array.as_ref());
344 }
345}