vortex_array/compute/conformance/
binary_numeric.rs1use itertools::Itertools;
27use num_traits::Num;
28use vortex_dtype::DType;
29use vortex_dtype::NativePType;
30use vortex_dtype::PType;
31use vortex_error::VortexUnwrap;
32use vortex_error::vortex_err;
33use vortex_error::vortex_panic;
34use vortex_scalar::NumericOperator;
35use vortex_scalar::PrimitiveScalar;
36use vortex_scalar::Scalar;
37
38use crate::Array;
39use crate::ArrayRef;
40use crate::IntoArray;
41use crate::ToCanonical;
42use crate::arrays::ConstantArray;
43use crate::compute::numeric::numeric;
44
45fn to_vec_of_scalar(array: &dyn Array) -> Vec<Scalar> {
46 (0..array.len())
48 .map(|index| array.scalar_at(index))
49 .collect_vec()
50}
51
52fn test_binary_numeric_conformance<T: NativePType + Num + Copy>(array: ArrayRef)
77where
78 Scalar: From<T>,
79{
80 test_standard_binary_numeric::<T>(array.clone());
82
83 test_binary_numeric_edge_cases(array);
85}
86
87fn test_standard_binary_numeric<T: NativePType + Num + Copy>(array: ArrayRef)
88where
89 Scalar: From<T>,
90{
91 let canonicalized_array = array.to_primitive();
92 let original_values = to_vec_of_scalar(&canonicalized_array.into_array());
93
94 let one = T::from(1)
95 .ok_or_else(|| vortex_err!("could not convert 1 into array native type"))
96 .vortex_unwrap();
97 let scalar_one = Scalar::from(one).cast(array.dtype()).vortex_unwrap();
98
99 let operators: [NumericOperator; 6] = [
100 NumericOperator::Add,
101 NumericOperator::Sub,
102 NumericOperator::RSub,
103 NumericOperator::Mul,
104 NumericOperator::Div,
105 NumericOperator::RDiv,
106 ];
107
108 for operator in operators {
109 let result = numeric(
111 &array,
112 &ConstantArray::new(scalar_one.clone(), array.len()).into_array(),
113 operator,
114 );
115
116 let Ok(result) = result else {
119 continue;
120 };
121
122 let actual_values = to_vec_of_scalar(&result);
123
124 let expected_results: Vec<Option<Scalar>> = original_values
126 .iter()
127 .map(|x| {
128 x.as_primitive()
129 .checked_binary_numeric(&scalar_one.as_primitive(), operator)
130 .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
131 })
132 .collect();
133
134 for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
136 if let Some(expected_value) = expected {
137 assert_eq!(
138 actual,
139 expected_value,
140 "Binary numeric operation failed for encoding {} at index {}: \
141 ({array:?})[{idx}] {operator:?} {scalar_one} \
142 expected {expected_value:?}, got {actual:?}",
143 array.encoding_id(),
144 idx,
145 );
146 }
147 }
148
149 let result = numeric(
151 &ConstantArray::new(scalar_one.clone(), array.len()).into_array(),
152 &array,
153 operator,
154 );
155
156 let Ok(result) = result else {
158 continue;
159 };
160
161 let actual_values = to_vec_of_scalar(&result);
162
163 let expected_results: Vec<Option<Scalar>> = original_values
165 .iter()
166 .map(|x| {
167 scalar_one
168 .as_primitive()
169 .checked_binary_numeric(&x.as_primitive(), operator)
170 .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
171 })
172 .collect();
173
174 for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
176 if let Some(expected_value) = expected {
177 assert_eq!(
178 actual,
179 expected_value,
180 "Binary numeric operation failed for encoding {} at index {}: \
181 {scalar_one} {operator:?} ({array:?})[{idx}] \
182 expected {expected_value:?}, got {actual:?}",
183 array.encoding_id(),
184 idx,
185 );
186 }
187 }
188 }
189}
190
191pub fn test_binary_numeric_array(array: ArrayRef) {
208 match array.dtype() {
209 DType::Primitive(ptype, _) => match ptype {
210 PType::I8 => test_binary_numeric_conformance::<i8>(array),
211 PType::I16 => test_binary_numeric_conformance::<i16>(array),
212 PType::I32 => test_binary_numeric_conformance::<i32>(array),
213 PType::I64 => test_binary_numeric_conformance::<i64>(array),
214 PType::U8 => test_binary_numeric_conformance::<u8>(array),
215 PType::U16 => test_binary_numeric_conformance::<u16>(array),
216 PType::U32 => test_binary_numeric_conformance::<u32>(array),
217 PType::U64 => test_binary_numeric_conformance::<u64>(array),
218 PType::F16 => {
219 eprintln!("Skipping f16 binary numeric tests (not supported)");
221 }
222 PType::F32 => test_binary_numeric_conformance::<f32>(array),
223 PType::F64 => test_binary_numeric_conformance::<f64>(array),
224 },
225 dtype => vortex_panic!(
226 "Binary numeric tests are only supported for primitive numeric types, got {dtype}",
227 ),
228 }
229}
230
231fn test_binary_numeric_edge_cases(array: ArrayRef) {
239 match array.dtype() {
240 DType::Primitive(ptype, _) => match ptype {
241 PType::I8 => test_binary_numeric_edge_cases_signed::<i8>(array),
242 PType::I16 => test_binary_numeric_edge_cases_signed::<i16>(array),
243 PType::I32 => test_binary_numeric_edge_cases_signed::<i32>(array),
244 PType::I64 => test_binary_numeric_edge_cases_signed::<i64>(array),
245 PType::U8 => test_binary_numeric_edge_cases_unsigned::<u8>(array),
246 PType::U16 => test_binary_numeric_edge_cases_unsigned::<u16>(array),
247 PType::U32 => test_binary_numeric_edge_cases_unsigned::<u32>(array),
248 PType::U64 => test_binary_numeric_edge_cases_unsigned::<u64>(array),
249 PType::F16 => {
250 eprintln!("Skipping f16 edge case tests (not supported)");
251 }
252 PType::F32 => test_binary_numeric_edge_cases_float::<f32>(array),
253 PType::F64 => test_binary_numeric_edge_cases_float::<f64>(array),
254 },
255 dtype => vortex_panic!(
256 "Binary numeric edge case tests are only supported for primitive numeric types, got {dtype}"
257 ),
258 }
259}
260
261fn test_binary_numeric_edge_cases_signed<T>(array: ArrayRef)
262where
263 T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Bounded + num_traits::Signed,
264 Scalar: From<T>,
265{
266 test_binary_numeric_with_scalar(array.clone(), T::zero());
268
269 test_binary_numeric_with_scalar(array.clone(), -T::one());
271
272 test_binary_numeric_with_scalar(array.clone(), T::max_value());
274
275 test_binary_numeric_with_scalar(array, T::min_value());
277}
278
279fn test_binary_numeric_edge_cases_unsigned<T>(array: ArrayRef)
280where
281 T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Bounded,
282 Scalar: From<T>,
283{
284 test_binary_numeric_with_scalar(array.clone(), T::zero());
286
287 test_binary_numeric_with_scalar(array, T::max_value());
289}
290
291fn test_binary_numeric_edge_cases_float<T>(array: ArrayRef)
292where
293 T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Float,
294 Scalar: From<T>,
295{
296 test_binary_numeric_with_scalar(array.clone(), T::zero());
298
299 test_binary_numeric_with_scalar(array.clone(), -T::one());
301
302 test_binary_numeric_with_scalar(array.clone(), T::max_value());
304
305 test_binary_numeric_with_scalar(array.clone(), T::min_value());
307
308 test_binary_numeric_with_scalar(array.clone(), T::epsilon());
310
311 test_binary_numeric_with_scalar(array.clone(), T::min_positive_value());
313
314 test_binary_numeric_with_scalar(array.clone(), T::nan());
316 test_binary_numeric_with_scalar(array.clone(), T::infinity());
317 test_binary_numeric_with_scalar(array, T::neg_infinity());
318}
319
320fn test_binary_numeric_with_scalar<T>(array: ArrayRef, scalar_value: T)
321where
322 T: NativePType + Num + Copy + std::fmt::Debug,
323 Scalar: From<T>,
324{
325 let canonicalized_array = array.to_primitive();
326 let original_values = to_vec_of_scalar(&canonicalized_array.into_array());
327
328 let scalar = Scalar::from(scalar_value)
329 .cast(array.dtype())
330 .vortex_unwrap();
331
332 let operators = if scalar_value == T::zero() {
334 vec![
336 NumericOperator::Add,
337 NumericOperator::Sub,
338 NumericOperator::RSub,
339 NumericOperator::Mul,
340 ]
341 } else {
342 vec![
343 NumericOperator::Add,
344 NumericOperator::Sub,
345 NumericOperator::RSub,
346 NumericOperator::Mul,
347 NumericOperator::Div,
348 NumericOperator::RDiv,
349 ]
350 };
351
352 for operator in operators {
353 let result = numeric(
355 &array,
356 &ConstantArray::new(scalar.clone(), array.len()).into_array(),
357 operator,
358 );
359
360 if result.is_err() {
362 continue;
363 }
364
365 let result = result.vortex_unwrap();
366 let actual_values = to_vec_of_scalar(&result);
367
368 let expected_results: Vec<Option<Scalar>> = original_values
370 .iter()
371 .map(|x| {
372 x.as_primitive()
373 .checked_binary_numeric(&scalar.as_primitive(), operator)
374 .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
375 })
376 .collect();
377
378 for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
380 if let Some(expected_value) = expected {
381 assert_eq!(
382 actual,
383 expected_value,
384 "Binary numeric operation failed for encoding {} at index {} with scalar {:?}: \
385 ({array:?})[{idx}] {operator:?} {scalar} \
386 expected {expected_value:?}, got {actual:?}",
387 array.encoding_id(),
388 idx,
389 scalar_value,
390 );
391 }
392 }
393 }
394}