vortex_array/compute/conformance/
binary_numeric.rs1use itertools::Itertools;
25use num_traits::Num;
26use vortex_error::VortexExpect;
27use vortex_error::vortex_err;
28use vortex_error::vortex_panic;
29
30use crate::ArrayRef;
31use crate::IntoArray;
32use crate::LEGACY_SESSION;
33use crate::RecursiveCanonical;
34use crate::ToCanonical;
35use crate::VortexSessionExecute;
36use crate::arrays::ConstantArray;
37use crate::builtins::ArrayBuiltins;
38use crate::dtype::DType;
39use crate::dtype::NativePType;
40use crate::dtype::PType;
41use crate::scalar::NumericOperator;
42use crate::scalar::PrimitiveScalar;
43use crate::scalar::Scalar;
44
45fn to_vec_of_scalar(array: &ArrayRef) -> Vec<Scalar> {
46 (0..array.len())
48 .map(|index| {
49 array
50 .scalar_at(index)
51 .vortex_expect("scalar_at should succeed in conformance test")
52 })
53 .collect_vec()
54}
55
56fn test_binary_numeric_conformance<T: NativePType + Num + Copy>(array: ArrayRef)
81where
82 Scalar: From<T>,
83{
84 test_standard_binary_numeric::<T>(array.clone());
86
87 test_binary_numeric_edge_cases(array);
89}
90
91fn test_standard_binary_numeric<T: NativePType + Num + Copy>(array: ArrayRef)
92where
93 Scalar: From<T>,
94{
95 let canonicalized_array = array.to_primitive();
96 let original_values = to_vec_of_scalar(&canonicalized_array.into_array());
97
98 let one = T::from(1)
99 .ok_or_else(|| vortex_err!("could not convert 1 into array native type"))
100 .vortex_expect("operation should succeed in conformance test");
101 let scalar_one = Scalar::from(one)
102 .cast(array.dtype())
103 .vortex_expect("operation should succeed in conformance test");
104
105 let operators: [NumericOperator; 4] = [
106 NumericOperator::Add,
107 NumericOperator::Sub,
108 NumericOperator::Mul,
109 NumericOperator::Div,
110 ];
111
112 for operator in operators {
113 let op = operator;
114 let rhs_const = ConstantArray::new(scalar_one.clone(), array.len()).into_array();
115
116 let result = array
118 .binary(rhs_const.clone(), op.into())
119 .vortex_expect("apply shouldn't fail")
120 .execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())
121 .map(|c| c.0.into_array());
122
123 let Ok(result) = result else {
126 continue;
127 };
128
129 let actual_values = to_vec_of_scalar(&result);
130
131 let expected_results: Vec<Option<Scalar>> = original_values
133 .iter()
134 .map(|x| {
135 x.as_primitive()
136 .checked_binary_numeric(&scalar_one.as_primitive(), op)
137 .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
138 })
139 .collect();
140
141 for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
143 if let Some(expected_value) = expected {
144 assert_eq!(
145 actual,
146 expected_value,
147 "Binary numeric operation failed for encoding {} at index {}: \
148 ({array:?})[{idx}] {operator:?} {scalar_one} \
149 expected {expected_value:?}, got {actual:?}",
150 array.encoding_id(),
151 idx,
152 );
153 }
154 }
155
156 let result = rhs_const.binary(array.clone(), op.into()).and_then(|a| {
158 a.execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())
159 .map(|c| c.0.into_array())
160 });
161
162 let Ok(result) = result else {
164 continue;
165 };
166
167 let actual_values = to_vec_of_scalar(&result);
168
169 let expected_results: Vec<Option<Scalar>> = original_values
171 .iter()
172 .map(|x| {
173 scalar_one
174 .as_primitive()
175 .checked_binary_numeric(&x.as_primitive(), op)
176 .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
177 })
178 .collect();
179
180 for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
182 if let Some(expected_value) = expected {
183 assert_eq!(
184 actual,
185 expected_value,
186 "Binary numeric operation failed for encoding {} at index {}: \
187 {scalar_one} {operator:?} ({array:?})[{idx}] \
188 expected {expected_value:?}, got {actual:?}",
189 array.encoding_id(),
190 idx,
191 );
192 }
193 }
194 }
195}
196
197pub fn test_binary_numeric_array(array: ArrayRef) {
214 match array.dtype() {
215 DType::Primitive(ptype, _) => match ptype {
216 PType::I8 => test_binary_numeric_conformance::<i8>(array),
217 PType::I16 => test_binary_numeric_conformance::<i16>(array),
218 PType::I32 => test_binary_numeric_conformance::<i32>(array),
219 PType::I64 => test_binary_numeric_conformance::<i64>(array),
220 PType::U8 => test_binary_numeric_conformance::<u8>(array),
221 PType::U16 => test_binary_numeric_conformance::<u16>(array),
222 PType::U32 => test_binary_numeric_conformance::<u32>(array),
223 PType::U64 => test_binary_numeric_conformance::<u64>(array),
224 PType::F16 => {
225 eprintln!("Skipping f16 binary numeric tests (not supported)");
227 }
228 PType::F32 => test_binary_numeric_conformance::<f32>(array),
229 PType::F64 => test_binary_numeric_conformance::<f64>(array),
230 },
231 dtype => vortex_panic!(
232 "Binary numeric tests are only supported for primitive numeric types, got {dtype}",
233 ),
234 }
235}
236
237fn test_binary_numeric_edge_cases(array: ArrayRef) {
245 match array.dtype() {
246 DType::Primitive(ptype, _) => match ptype {
247 PType::I8 => test_binary_numeric_edge_cases_signed::<i8>(array),
248 PType::I16 => test_binary_numeric_edge_cases_signed::<i16>(array),
249 PType::I32 => test_binary_numeric_edge_cases_signed::<i32>(array),
250 PType::I64 => test_binary_numeric_edge_cases_signed::<i64>(array),
251 PType::U8 => test_binary_numeric_edge_cases_unsigned::<u8>(array),
252 PType::U16 => test_binary_numeric_edge_cases_unsigned::<u16>(array),
253 PType::U32 => test_binary_numeric_edge_cases_unsigned::<u32>(array),
254 PType::U64 => test_binary_numeric_edge_cases_unsigned::<u64>(array),
255 PType::F16 => {
256 eprintln!("Skipping f16 edge case tests (not supported)");
257 }
258 PType::F32 => test_binary_numeric_edge_cases_float::<f32>(array),
259 PType::F64 => test_binary_numeric_edge_cases_float::<f64>(array),
260 },
261 dtype => vortex_panic!(
262 "Binary numeric edge case tests are only supported for primitive numeric types, got {dtype}"
263 ),
264 }
265}
266
267fn test_binary_numeric_edge_cases_signed<T>(array: ArrayRef)
268where
269 T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Bounded + num_traits::Signed,
270 Scalar: From<T>,
271{
272 test_binary_numeric_with_scalar(array.clone(), T::zero());
274
275 test_binary_numeric_with_scalar(array.clone(), -T::one());
277
278 test_binary_numeric_with_scalar(array.clone(), T::max_value());
280
281 test_binary_numeric_with_scalar(array, T::min_value());
283}
284
285fn test_binary_numeric_edge_cases_unsigned<T>(array: ArrayRef)
286where
287 T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Bounded,
288 Scalar: From<T>,
289{
290 test_binary_numeric_with_scalar(array.clone(), T::zero());
292
293 test_binary_numeric_with_scalar(array, T::max_value());
295}
296
297fn test_binary_numeric_edge_cases_float<T>(array: ArrayRef)
298where
299 T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Float,
300 Scalar: From<T>,
301{
302 test_binary_numeric_with_scalar(array.clone(), T::zero());
304
305 test_binary_numeric_with_scalar(array.clone(), -T::one());
307
308 test_binary_numeric_with_scalar(array.clone(), T::max_value());
310
311 test_binary_numeric_with_scalar(array.clone(), T::min_value());
313
314 test_binary_numeric_with_scalar(array.clone(), T::epsilon());
316
317 test_binary_numeric_with_scalar(array.clone(), T::min_positive_value());
319
320 test_binary_numeric_with_scalar(array.clone(), T::nan());
322 test_binary_numeric_with_scalar(array.clone(), T::infinity());
323 test_binary_numeric_with_scalar(array, T::neg_infinity());
324}
325
326fn test_binary_numeric_with_scalar<T>(array: ArrayRef, scalar_value: T)
327where
328 T: NativePType + Num + Copy + std::fmt::Debug,
329 Scalar: From<T>,
330{
331 let canonicalized_array = array.to_primitive();
332 let original_values = to_vec_of_scalar(&canonicalized_array.into_array());
333
334 let scalar = Scalar::from(scalar_value)
335 .cast(array.dtype())
336 .vortex_expect("operation should succeed in conformance test");
337
338 let operators = if scalar_value == T::zero() {
340 vec![
342 NumericOperator::Add,
343 NumericOperator::Sub,
344 NumericOperator::Mul,
345 ]
346 } else {
347 vec![
348 NumericOperator::Add,
349 NumericOperator::Sub,
350 NumericOperator::Mul,
351 NumericOperator::Div,
352 ]
353 };
354
355 for operator in operators {
356 let op = operator;
357 let rhs_const = ConstantArray::new(scalar.clone(), array.len()).into_array();
358
359 let result = array
361 .binary(rhs_const, op.into())
362 .vortex_expect("apply failed")
363 .execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())
364 .map(|x| x.0.into_array());
365
366 if result.is_err() {
369 continue;
370 }
371
372 let result = result.vortex_expect("operation should succeed in conformance test");
373 let actual_values = to_vec_of_scalar(&result);
374
375 let expected_results: Vec<Option<Scalar>> = original_values
377 .iter()
378 .map(|x| {
379 x.as_primitive()
380 .checked_binary_numeric(&scalar.as_primitive(), op)
381 .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
382 })
383 .collect();
384
385 for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
387 if let Some(expected_value) = expected {
388 assert_eq!(
389 actual,
390 expected_value,
391 "Binary numeric operation failed for encoding {} at index {} with scalar {:?}: \
392 ({array:?})[{idx}] {operator:?} {scalar} \
393 expected {expected_value:?}, got {actual:?}",
394 array.encoding_id(),
395 idx,
396 scalar_value,
397 );
398 }
399 }
400 }
401}