vortex_array/compute/conformance/
binary_numeric.rs1use itertools::Itertools;
27use num_traits::Num;
28use vortex_dtype::{DType, NativePType, PType};
29use vortex_error::{VortexUnwrap, vortex_err, vortex_panic};
30use vortex_scalar::{NumericOperator, PrimitiveScalar, Scalar};
31
32use crate::arrays::ConstantArray;
33use crate::compute::numeric::numeric;
34use crate::{Array, ArrayRef, IntoArray, ToCanonical};
35
36fn to_vec_of_scalar(array: &dyn Array) -> Vec<Scalar> {
37 (0..array.len())
39 .map(|index| array.scalar_at(index))
40 .collect_vec()
41}
42
43fn test_binary_numeric_conformance<T: NativePType + Num + Copy>(array: ArrayRef)
68where
69 Scalar: From<T>,
70{
71 test_standard_binary_numeric::<T>(array.clone());
73
74 test_binary_numeric_edge_cases(array);
76}
77
78fn test_standard_binary_numeric<T: NativePType + Num + Copy>(array: ArrayRef)
79where
80 Scalar: From<T>,
81{
82 let canonicalized_array = array.to_primitive();
83 let original_values = to_vec_of_scalar(&canonicalized_array.into_array());
84
85 let one = T::from(1)
86 .ok_or_else(|| vortex_err!("could not convert 1 into array native type"))
87 .vortex_unwrap();
88 let scalar_one = Scalar::from(one).cast(array.dtype()).vortex_unwrap();
89
90 let operators: [NumericOperator; 6] = [
91 NumericOperator::Add,
92 NumericOperator::Sub,
93 NumericOperator::RSub,
94 NumericOperator::Mul,
95 NumericOperator::Div,
96 NumericOperator::RDiv,
97 ];
98
99 for operator in operators {
100 let result = numeric(
102 &array,
103 &ConstantArray::new(scalar_one.clone(), array.len()).into_array(),
104 operator,
105 );
106
107 let Ok(result) = result else {
110 continue;
111 };
112
113 let actual_values = to_vec_of_scalar(&result);
114
115 let expected_results: Vec<Option<Scalar>> = original_values
117 .iter()
118 .map(|x| {
119 x.as_primitive()
120 .checked_binary_numeric(&scalar_one.as_primitive(), operator)
121 .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
122 })
123 .collect();
124
125 for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
127 if let Some(expected_value) = expected {
128 assert_eq!(
129 actual,
130 expected_value,
131 "Binary numeric operation failed for encoding {} at index {}: \
132 ({array:?})[{idx}] {operator:?} {scalar_one} \
133 expected {expected_value:?}, got {actual:?}",
134 array.encoding_id(),
135 idx,
136 );
137 }
138 }
139
140 let result = numeric(
142 &ConstantArray::new(scalar_one.clone(), array.len()).into_array(),
143 &array,
144 operator,
145 );
146
147 let Ok(result) = result else {
149 continue;
150 };
151
152 let actual_values = to_vec_of_scalar(&result);
153
154 let expected_results: Vec<Option<Scalar>> = original_values
156 .iter()
157 .map(|x| {
158 scalar_one
159 .as_primitive()
160 .checked_binary_numeric(&x.as_primitive(), operator)
161 .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
162 })
163 .collect();
164
165 for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
167 if let Some(expected_value) = expected {
168 assert_eq!(
169 actual,
170 expected_value,
171 "Binary numeric operation failed for encoding {} at index {}: \
172 {scalar_one} {operator:?} ({array:?})[{idx}] \
173 expected {expected_value:?}, got {actual:?}",
174 array.encoding_id(),
175 idx,
176 );
177 }
178 }
179 }
180}
181
182pub fn test_binary_numeric_array(array: ArrayRef) {
199 match array.dtype() {
200 DType::Primitive(ptype, _) => match ptype {
201 PType::I8 => test_binary_numeric_conformance::<i8>(array),
202 PType::I16 => test_binary_numeric_conformance::<i16>(array),
203 PType::I32 => test_binary_numeric_conformance::<i32>(array),
204 PType::I64 => test_binary_numeric_conformance::<i64>(array),
205 PType::U8 => test_binary_numeric_conformance::<u8>(array),
206 PType::U16 => test_binary_numeric_conformance::<u16>(array),
207 PType::U32 => test_binary_numeric_conformance::<u32>(array),
208 PType::U64 => test_binary_numeric_conformance::<u64>(array),
209 PType::F16 => {
210 eprintln!("Skipping f16 binary numeric tests (not supported)");
212 }
213 PType::F32 => test_binary_numeric_conformance::<f32>(array),
214 PType::F64 => test_binary_numeric_conformance::<f64>(array),
215 },
216 dtype => vortex_panic!(
217 "Binary numeric tests are only supported for primitive numeric types, got {dtype}",
218 ),
219 }
220}
221
222fn test_binary_numeric_edge_cases(array: ArrayRef) {
230 match array.dtype() {
231 DType::Primitive(ptype, _) => match ptype {
232 PType::I8 => test_binary_numeric_edge_cases_signed::<i8>(array),
233 PType::I16 => test_binary_numeric_edge_cases_signed::<i16>(array),
234 PType::I32 => test_binary_numeric_edge_cases_signed::<i32>(array),
235 PType::I64 => test_binary_numeric_edge_cases_signed::<i64>(array),
236 PType::U8 => test_binary_numeric_edge_cases_unsigned::<u8>(array),
237 PType::U16 => test_binary_numeric_edge_cases_unsigned::<u16>(array),
238 PType::U32 => test_binary_numeric_edge_cases_unsigned::<u32>(array),
239 PType::U64 => test_binary_numeric_edge_cases_unsigned::<u64>(array),
240 PType::F16 => {
241 eprintln!("Skipping f16 edge case tests (not supported)");
242 }
243 PType::F32 => test_binary_numeric_edge_cases_float::<f32>(array),
244 PType::F64 => test_binary_numeric_edge_cases_float::<f64>(array),
245 },
246 dtype => vortex_panic!(
247 "Binary numeric edge case tests are only supported for primitive numeric types, got {dtype}"
248 ),
249 }
250}
251
252fn test_binary_numeric_edge_cases_signed<T>(array: ArrayRef)
253where
254 T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Bounded + num_traits::Signed,
255 Scalar: From<T>,
256{
257 test_binary_numeric_with_scalar(array.clone(), T::zero());
259
260 test_binary_numeric_with_scalar(array.clone(), -T::one());
262
263 test_binary_numeric_with_scalar(array.clone(), T::max_value());
265
266 test_binary_numeric_with_scalar(array, T::min_value());
268}
269
270fn test_binary_numeric_edge_cases_unsigned<T>(array: ArrayRef)
271where
272 T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Bounded,
273 Scalar: From<T>,
274{
275 test_binary_numeric_with_scalar(array.clone(), T::zero());
277
278 test_binary_numeric_with_scalar(array, T::max_value());
280}
281
282fn test_binary_numeric_edge_cases_float<T>(array: ArrayRef)
283where
284 T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Float,
285 Scalar: From<T>,
286{
287 test_binary_numeric_with_scalar(array.clone(), T::zero());
289
290 test_binary_numeric_with_scalar(array.clone(), -T::one());
292
293 test_binary_numeric_with_scalar(array.clone(), T::max_value());
295
296 test_binary_numeric_with_scalar(array.clone(), T::min_value());
298
299 test_binary_numeric_with_scalar(array.clone(), T::epsilon());
301
302 test_binary_numeric_with_scalar(array.clone(), T::min_positive_value());
304
305 test_binary_numeric_with_scalar(array.clone(), T::nan());
307 test_binary_numeric_with_scalar(array.clone(), T::infinity());
308 test_binary_numeric_with_scalar(array, T::neg_infinity());
309}
310
311fn test_binary_numeric_with_scalar<T>(array: ArrayRef, scalar_value: T)
312where
313 T: NativePType + Num + Copy + std::fmt::Debug,
314 Scalar: From<T>,
315{
316 let canonicalized_array = array.to_primitive();
317 let original_values = to_vec_of_scalar(&canonicalized_array.into_array());
318
319 let scalar = Scalar::from(scalar_value)
320 .cast(array.dtype())
321 .vortex_unwrap();
322
323 let operators = if scalar_value == T::zero() {
325 vec![
327 NumericOperator::Add,
328 NumericOperator::Sub,
329 NumericOperator::RSub,
330 NumericOperator::Mul,
331 ]
332 } else {
333 vec![
334 NumericOperator::Add,
335 NumericOperator::Sub,
336 NumericOperator::RSub,
337 NumericOperator::Mul,
338 NumericOperator::Div,
339 NumericOperator::RDiv,
340 ]
341 };
342
343 for operator in operators {
344 let result = numeric(
346 &array,
347 &ConstantArray::new(scalar.clone(), array.len()).into_array(),
348 operator,
349 );
350
351 if result.is_err() {
353 continue;
354 }
355
356 let result = result.vortex_unwrap();
357 let actual_values = to_vec_of_scalar(&result);
358
359 let expected_results: Vec<Option<Scalar>> = original_values
361 .iter()
362 .map(|x| {
363 x.as_primitive()
364 .checked_binary_numeric(&scalar.as_primitive(), operator)
365 .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
366 })
367 .collect();
368
369 for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
371 if let Some(expected_value) = expected {
372 assert_eq!(
373 actual,
374 expected_value,
375 "Binary numeric operation failed for encoding {} at index {} with scalar {:?}: \
376 ({array:?})[{idx}] {operator:?} {scalar} \
377 expected {expected_value:?}, got {actual:?}",
378 array.encoding_id(),
379 idx,
380 scalar_value,
381 );
382 }
383 }
384 }
385}