vortex_array/compute/conformance/
binary_numeric.rs1use itertools::Itertools;
27use num_traits::Num;
28use vortex_dtype::DType;
29use vortex_dtype::NativePType;
30use vortex_dtype::PType;
31use vortex_error::VortexExpect;
32use vortex_error::vortex_err;
33use vortex_error::vortex_panic;
34use vortex_scalar::NumericOperator;
35use vortex_scalar::PrimitiveScalar;
36use vortex_scalar::Scalar;
37
38use crate::Array;
39use crate::ArrayRef;
40use crate::IntoArray;
41use crate::ToCanonical;
42use crate::arrays::ConstantArray;
43use crate::compute::numeric::numeric;
44
45fn to_vec_of_scalar(array: &dyn Array) -> Vec<Scalar> {
46 (0..array.len())
48 .map(|index| array.scalar_at(index))
49 .collect_vec()
50}
51
52fn test_binary_numeric_conformance<T: NativePType + Num + Copy>(array: ArrayRef)
77where
78 Scalar: From<T>,
79{
80 test_standard_binary_numeric::<T>(array.clone());
82
83 test_binary_numeric_edge_cases(array);
85}
86
87fn test_standard_binary_numeric<T: NativePType + Num + Copy>(array: ArrayRef)
88where
89 Scalar: From<T>,
90{
91 let canonicalized_array = array.to_primitive();
92 let original_values = to_vec_of_scalar(&canonicalized_array.into_array());
93
94 let one = T::from(1)
95 .ok_or_else(|| vortex_err!("could not convert 1 into array native type"))
96 .vortex_expect("operation should succeed in conformance test");
97 let scalar_one = Scalar::from(one)
98 .cast(array.dtype())
99 .vortex_expect("operation should succeed in conformance test");
100
101 let operators: [NumericOperator; 6] = [
102 NumericOperator::Add,
103 NumericOperator::Sub,
104 NumericOperator::RSub,
105 NumericOperator::Mul,
106 NumericOperator::Div,
107 NumericOperator::RDiv,
108 ];
109
110 for operator in operators {
111 let result = numeric(
113 &array,
114 &ConstantArray::new(scalar_one.clone(), array.len()).into_array(),
115 operator,
116 );
117
118 let Ok(result) = result else {
121 continue;
122 };
123
124 let actual_values = to_vec_of_scalar(&result);
125
126 let expected_results: Vec<Option<Scalar>> = original_values
128 .iter()
129 .map(|x| {
130 x.as_primitive()
131 .checked_binary_numeric(&scalar_one.as_primitive(), operator)
132 .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
133 })
134 .collect();
135
136 for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
138 if let Some(expected_value) = expected {
139 assert_eq!(
140 actual,
141 expected_value,
142 "Binary numeric operation failed for encoding {} at index {}: \
143 ({array:?})[{idx}] {operator:?} {scalar_one} \
144 expected {expected_value:?}, got {actual:?}",
145 array.encoding_id(),
146 idx,
147 );
148 }
149 }
150
151 let result = numeric(
153 &ConstantArray::new(scalar_one.clone(), array.len()).into_array(),
154 &array,
155 operator,
156 );
157
158 let Ok(result) = result else {
160 continue;
161 };
162
163 let actual_values = to_vec_of_scalar(&result);
164
165 let expected_results: Vec<Option<Scalar>> = original_values
167 .iter()
168 .map(|x| {
169 scalar_one
170 .as_primitive()
171 .checked_binary_numeric(&x.as_primitive(), operator)
172 .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
173 })
174 .collect();
175
176 for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
178 if let Some(expected_value) = expected {
179 assert_eq!(
180 actual,
181 expected_value,
182 "Binary numeric operation failed for encoding {} at index {}: \
183 {scalar_one} {operator:?} ({array:?})[{idx}] \
184 expected {expected_value:?}, got {actual:?}",
185 array.encoding_id(),
186 idx,
187 );
188 }
189 }
190 }
191}
192
193pub fn test_binary_numeric_array(array: ArrayRef) {
210 match array.dtype() {
211 DType::Primitive(ptype, _) => match ptype {
212 PType::I8 => test_binary_numeric_conformance::<i8>(array),
213 PType::I16 => test_binary_numeric_conformance::<i16>(array),
214 PType::I32 => test_binary_numeric_conformance::<i32>(array),
215 PType::I64 => test_binary_numeric_conformance::<i64>(array),
216 PType::U8 => test_binary_numeric_conformance::<u8>(array),
217 PType::U16 => test_binary_numeric_conformance::<u16>(array),
218 PType::U32 => test_binary_numeric_conformance::<u32>(array),
219 PType::U64 => test_binary_numeric_conformance::<u64>(array),
220 PType::F16 => {
221 eprintln!("Skipping f16 binary numeric tests (not supported)");
223 }
224 PType::F32 => test_binary_numeric_conformance::<f32>(array),
225 PType::F64 => test_binary_numeric_conformance::<f64>(array),
226 },
227 dtype => vortex_panic!(
228 "Binary numeric tests are only supported for primitive numeric types, got {dtype}",
229 ),
230 }
231}
232
233fn test_binary_numeric_edge_cases(array: ArrayRef) {
241 match array.dtype() {
242 DType::Primitive(ptype, _) => match ptype {
243 PType::I8 => test_binary_numeric_edge_cases_signed::<i8>(array),
244 PType::I16 => test_binary_numeric_edge_cases_signed::<i16>(array),
245 PType::I32 => test_binary_numeric_edge_cases_signed::<i32>(array),
246 PType::I64 => test_binary_numeric_edge_cases_signed::<i64>(array),
247 PType::U8 => test_binary_numeric_edge_cases_unsigned::<u8>(array),
248 PType::U16 => test_binary_numeric_edge_cases_unsigned::<u16>(array),
249 PType::U32 => test_binary_numeric_edge_cases_unsigned::<u32>(array),
250 PType::U64 => test_binary_numeric_edge_cases_unsigned::<u64>(array),
251 PType::F16 => {
252 eprintln!("Skipping f16 edge case tests (not supported)");
253 }
254 PType::F32 => test_binary_numeric_edge_cases_float::<f32>(array),
255 PType::F64 => test_binary_numeric_edge_cases_float::<f64>(array),
256 },
257 dtype => vortex_panic!(
258 "Binary numeric edge case tests are only supported for primitive numeric types, got {dtype}"
259 ),
260 }
261}
262
263fn test_binary_numeric_edge_cases_signed<T>(array: ArrayRef)
264where
265 T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Bounded + num_traits::Signed,
266 Scalar: From<T>,
267{
268 test_binary_numeric_with_scalar(array.clone(), T::zero());
270
271 test_binary_numeric_with_scalar(array.clone(), -T::one());
273
274 test_binary_numeric_with_scalar(array.clone(), T::max_value());
276
277 test_binary_numeric_with_scalar(array, T::min_value());
279}
280
281fn test_binary_numeric_edge_cases_unsigned<T>(array: ArrayRef)
282where
283 T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Bounded,
284 Scalar: From<T>,
285{
286 test_binary_numeric_with_scalar(array.clone(), T::zero());
288
289 test_binary_numeric_with_scalar(array, T::max_value());
291}
292
293fn test_binary_numeric_edge_cases_float<T>(array: ArrayRef)
294where
295 T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Float,
296 Scalar: From<T>,
297{
298 test_binary_numeric_with_scalar(array.clone(), T::zero());
300
301 test_binary_numeric_with_scalar(array.clone(), -T::one());
303
304 test_binary_numeric_with_scalar(array.clone(), T::max_value());
306
307 test_binary_numeric_with_scalar(array.clone(), T::min_value());
309
310 test_binary_numeric_with_scalar(array.clone(), T::epsilon());
312
313 test_binary_numeric_with_scalar(array.clone(), T::min_positive_value());
315
316 test_binary_numeric_with_scalar(array.clone(), T::nan());
318 test_binary_numeric_with_scalar(array.clone(), T::infinity());
319 test_binary_numeric_with_scalar(array, T::neg_infinity());
320}
321
322fn test_binary_numeric_with_scalar<T>(array: ArrayRef, scalar_value: T)
323where
324 T: NativePType + Num + Copy + std::fmt::Debug,
325 Scalar: From<T>,
326{
327 let canonicalized_array = array.to_primitive();
328 let original_values = to_vec_of_scalar(&canonicalized_array.into_array());
329
330 let scalar = Scalar::from(scalar_value)
331 .cast(array.dtype())
332 .vortex_expect("operation should succeed in conformance test");
333
334 let operators = if scalar_value == T::zero() {
336 vec![
338 NumericOperator::Add,
339 NumericOperator::Sub,
340 NumericOperator::RSub,
341 NumericOperator::Mul,
342 ]
343 } else {
344 vec![
345 NumericOperator::Add,
346 NumericOperator::Sub,
347 NumericOperator::RSub,
348 NumericOperator::Mul,
349 NumericOperator::Div,
350 NumericOperator::RDiv,
351 ]
352 };
353
354 for operator in operators {
355 let result = numeric(
357 &array,
358 &ConstantArray::new(scalar.clone(), array.len()).into_array(),
359 operator,
360 );
361
362 if result.is_err() {
364 continue;
365 }
366
367 let result = result.vortex_expect("operation should succeed in conformance test");
368 let actual_values = to_vec_of_scalar(&result);
369
370 let expected_results: Vec<Option<Scalar>> = original_values
372 .iter()
373 .map(|x| {
374 x.as_primitive()
375 .checked_binary_numeric(&scalar.as_primitive(), operator)
376 .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
377 })
378 .collect();
379
380 for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
382 if let Some(expected_value) = expected {
383 assert_eq!(
384 actual,
385 expected_value,
386 "Binary numeric operation failed for encoding {} at index {} with scalar {:?}: \
387 ({array:?})[{idx}] {operator:?} {scalar} \
388 expected {expected_value:?}, got {actual:?}",
389 array.encoding_id(),
390 idx,
391 scalar_value,
392 );
393 }
394 }
395 }
396}