vortex_array/compute/conformance/
binary_numeric.rs1use itertools::Itertools;
27use num_traits::Num;
28use vortex_error::VortexExpect;
29use vortex_error::vortex_err;
30use vortex_error::vortex_panic;
31
32use crate::Array;
33use crate::ArrayRef;
34use crate::IntoArray;
35use crate::LEGACY_SESSION;
36use crate::RecursiveCanonical;
37use crate::ToCanonical;
38use crate::VortexSessionExecute;
39use crate::arrays::ConstantArray;
40use crate::builtins::ArrayBuiltins;
41use crate::dtype::DType;
42use crate::dtype::NativePType;
43use crate::dtype::PType;
44use crate::scalar::NumericOperator;
45use crate::scalar::PrimitiveScalar;
46use crate::scalar::Scalar;
47
48fn to_vec_of_scalar(array: &dyn Array) -> Vec<Scalar> {
49 (0..array.len())
51 .map(|index| {
52 array
53 .scalar_at(index)
54 .vortex_expect("scalar_at should succeed in conformance test")
55 })
56 .collect_vec()
57}
58
59fn test_binary_numeric_conformance<T: NativePType + Num + Copy>(array: ArrayRef)
84where
85 Scalar: From<T>,
86{
87 test_standard_binary_numeric::<T>(array.clone());
89
90 test_binary_numeric_edge_cases(array);
92}
93
94fn test_standard_binary_numeric<T: NativePType + Num + Copy>(array: ArrayRef)
95where
96 Scalar: From<T>,
97{
98 let canonicalized_array = array.to_primitive();
99 let original_values = to_vec_of_scalar(&canonicalized_array.into_array());
100
101 let one = T::from(1)
102 .ok_or_else(|| vortex_err!("could not convert 1 into array native type"))
103 .vortex_expect("operation should succeed in conformance test");
104 let scalar_one = Scalar::from(one)
105 .cast(array.dtype())
106 .vortex_expect("operation should succeed in conformance test");
107
108 let operators: [NumericOperator; 4] = [
109 NumericOperator::Add,
110 NumericOperator::Sub,
111 NumericOperator::Mul,
112 NumericOperator::Div,
113 ];
114
115 for operator in operators {
116 let op = operator.into();
117 let rhs_const = ConstantArray::new(scalar_one.clone(), array.len()).into_array();
118
119 let result = array
121 .binary(rhs_const.clone(), op)
122 .vortex_expect("apply shouldn't fail")
123 .execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())
124 .map(|c| c.0.into_array());
125
126 let Ok(result) = result else {
129 continue;
130 };
131
132 println!("result {}", result.display_tree());
133 println!("result {}", result.display_values());
134
135 let actual_values = to_vec_of_scalar(&result);
136
137 let expected_results: Vec<Option<Scalar>> = original_values
139 .iter()
140 .map(|x| {
141 x.as_primitive()
142 .checked_binary_numeric(&scalar_one.as_primitive(), operator)
143 .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
144 })
145 .collect();
146
147 for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
149 if let Some(expected_value) = expected {
150 assert_eq!(
151 actual,
152 expected_value,
153 "Binary numeric operation failed for encoding {} at index {}: \
154 ({array:?})[{idx}] {operator:?} {scalar_one} \
155 expected {expected_value:?}, got {actual:?}",
156 array.encoding_id(),
157 idx,
158 );
159 }
160 }
161
162 let result = rhs_const.binary(array.clone(), op).and_then(|a| {
164 a.execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())
165 .map(|c| c.0.into_array())
166 });
167
168 let Ok(result) = result else {
170 continue;
171 };
172
173 let actual_values = to_vec_of_scalar(&result);
174
175 let expected_results: Vec<Option<Scalar>> = original_values
177 .iter()
178 .map(|x| {
179 scalar_one
180 .as_primitive()
181 .checked_binary_numeric(&x.as_primitive(), operator)
182 .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
183 })
184 .collect();
185
186 for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
188 if let Some(expected_value) = expected {
189 assert_eq!(
190 actual,
191 expected_value,
192 "Binary numeric operation failed for encoding {} at index {}: \
193 {scalar_one} {operator:?} ({array:?})[{idx}] \
194 expected {expected_value:?}, got {actual:?}",
195 array.encoding_id(),
196 idx,
197 );
198 }
199 }
200 }
201}
202
203pub fn test_binary_numeric_array(array: ArrayRef) {
220 match array.dtype() {
221 DType::Primitive(ptype, _) => match ptype {
222 PType::I8 => test_binary_numeric_conformance::<i8>(array),
223 PType::I16 => test_binary_numeric_conformance::<i16>(array),
224 PType::I32 => test_binary_numeric_conformance::<i32>(array),
225 PType::I64 => test_binary_numeric_conformance::<i64>(array),
226 PType::U8 => test_binary_numeric_conformance::<u8>(array),
227 PType::U16 => test_binary_numeric_conformance::<u16>(array),
228 PType::U32 => test_binary_numeric_conformance::<u32>(array),
229 PType::U64 => test_binary_numeric_conformance::<u64>(array),
230 PType::F16 => {
231 eprintln!("Skipping f16 binary numeric tests (not supported)");
233 }
234 PType::F32 => test_binary_numeric_conformance::<f32>(array),
235 PType::F64 => test_binary_numeric_conformance::<f64>(array),
236 },
237 dtype => vortex_panic!(
238 "Binary numeric tests are only supported for primitive numeric types, got {dtype}",
239 ),
240 }
241}
242
243fn test_binary_numeric_edge_cases(array: ArrayRef) {
251 match array.dtype() {
252 DType::Primitive(ptype, _) => match ptype {
253 PType::I8 => test_binary_numeric_edge_cases_signed::<i8>(array),
254 PType::I16 => test_binary_numeric_edge_cases_signed::<i16>(array),
255 PType::I32 => test_binary_numeric_edge_cases_signed::<i32>(array),
256 PType::I64 => test_binary_numeric_edge_cases_signed::<i64>(array),
257 PType::U8 => test_binary_numeric_edge_cases_unsigned::<u8>(array),
258 PType::U16 => test_binary_numeric_edge_cases_unsigned::<u16>(array),
259 PType::U32 => test_binary_numeric_edge_cases_unsigned::<u32>(array),
260 PType::U64 => test_binary_numeric_edge_cases_unsigned::<u64>(array),
261 PType::F16 => {
262 eprintln!("Skipping f16 edge case tests (not supported)");
263 }
264 PType::F32 => test_binary_numeric_edge_cases_float::<f32>(array),
265 PType::F64 => test_binary_numeric_edge_cases_float::<f64>(array),
266 },
267 dtype => vortex_panic!(
268 "Binary numeric edge case tests are only supported for primitive numeric types, got {dtype}"
269 ),
270 }
271}
272
273fn test_binary_numeric_edge_cases_signed<T>(array: ArrayRef)
274where
275 T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Bounded + num_traits::Signed,
276 Scalar: From<T>,
277{
278 test_binary_numeric_with_scalar(array.clone(), T::zero());
280
281 test_binary_numeric_with_scalar(array.clone(), -T::one());
283
284 test_binary_numeric_with_scalar(array.clone(), T::max_value());
286
287 test_binary_numeric_with_scalar(array, T::min_value());
289}
290
291fn test_binary_numeric_edge_cases_unsigned<T>(array: ArrayRef)
292where
293 T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Bounded,
294 Scalar: From<T>,
295{
296 test_binary_numeric_with_scalar(array.clone(), T::zero());
298
299 test_binary_numeric_with_scalar(array, T::max_value());
301}
302
303fn test_binary_numeric_edge_cases_float<T>(array: ArrayRef)
304where
305 T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Float,
306 Scalar: From<T>,
307{
308 test_binary_numeric_with_scalar(array.clone(), T::zero());
310
311 test_binary_numeric_with_scalar(array.clone(), -T::one());
313
314 test_binary_numeric_with_scalar(array.clone(), T::max_value());
316
317 test_binary_numeric_with_scalar(array.clone(), T::min_value());
319
320 test_binary_numeric_with_scalar(array.clone(), T::epsilon());
322
323 test_binary_numeric_with_scalar(array.clone(), T::min_positive_value());
325
326 test_binary_numeric_with_scalar(array.clone(), T::nan());
328 test_binary_numeric_with_scalar(array.clone(), T::infinity());
329 test_binary_numeric_with_scalar(array, T::neg_infinity());
330}
331
332fn test_binary_numeric_with_scalar<T>(array: ArrayRef, scalar_value: T)
333where
334 T: NativePType + Num + Copy + std::fmt::Debug,
335 Scalar: From<T>,
336{
337 let canonicalized_array = array.to_primitive();
338 let original_values = to_vec_of_scalar(&canonicalized_array.into_array());
339
340 let scalar = Scalar::from(scalar_value)
341 .cast(array.dtype())
342 .vortex_expect("operation should succeed in conformance test");
343
344 let operators = if scalar_value == T::zero() {
346 vec![
348 NumericOperator::Add,
349 NumericOperator::Sub,
350 NumericOperator::Mul,
351 ]
352 } else {
353 vec![
354 NumericOperator::Add,
355 NumericOperator::Sub,
356 NumericOperator::Mul,
357 NumericOperator::Div,
358 ]
359 };
360
361 for operator in operators {
362 let op = operator.into();
363 let rhs_const = ConstantArray::new(scalar.clone(), array.len()).into_array();
364
365 let result = array
367 .binary(rhs_const, op)
368 .vortex_expect("apply failed")
369 .execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())
370 .map(|x| x.0.into_array());
371
372 if result.is_err() {
375 continue;
376 }
377
378 let result = result.vortex_expect("operation should succeed in conformance test");
379 let actual_values = to_vec_of_scalar(&result);
380
381 let expected_results: Vec<Option<Scalar>> = original_values
383 .iter()
384 .map(|x| {
385 x.as_primitive()
386 .checked_binary_numeric(&scalar.as_primitive(), operator)
387 .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
388 })
389 .collect();
390
391 for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
393 if let Some(expected_value) = expected {
394 assert_eq!(
395 actual,
396 expected_value,
397 "Binary numeric operation failed for encoding {} at index {} with scalar {:?}: \
398 ({array:?})[{idx}] {operator:?} {scalar} \
399 expected {expected_value:?}, got {actual:?}",
400 array.encoding_id(),
401 idx,
402 scalar_value,
403 );
404 }
405 }
406 }
407}