vortex_array/compute/conformance/
binary_numeric.rs1use std::fmt::Debug;
25
26use itertools::Itertools;
27use num_traits::Bounded;
28use num_traits::Float;
29use num_traits::Num;
30use num_traits::Signed;
31use vortex_error::VortexExpect;
32use vortex_error::vortex_err;
33use vortex_error::vortex_panic;
34
35use crate::ArrayRef;
36use crate::Canonical;
37use crate::ExecutionCtx;
38use crate::IntoArray;
39use crate::RecursiveCanonical;
40use crate::arrays::ConstantArray;
41use crate::builtins::ArrayBuiltins;
42use crate::dtype::DType;
43use crate::dtype::NativePType;
44use crate::dtype::PType;
45use crate::scalar::NumericOperator;
46use crate::scalar::PrimitiveScalar;
47use crate::scalar::Scalar;
48
49fn to_vec_of_scalar(array: &ArrayRef, ctx: &mut ExecutionCtx) -> Vec<Scalar> {
50 (0..array.len())
52 .map(|index| {
53 array
54 .execute_scalar(index, ctx)
55 .vortex_expect("scalar_at should succeed in conformance test")
56 })
57 .collect_vec()
58}
59
60fn test_binary_numeric_conformance<T: NativePType + Num + Copy>(
85 array: &ArrayRef,
86 ctx: &mut ExecutionCtx,
87) where
88 Scalar: From<T>,
89{
90 test_standard_binary_numeric::<T>(array, ctx);
92
93 test_binary_numeric_edge_cases(array, ctx);
95}
96
97fn test_standard_binary_numeric<T: NativePType + Num + Copy>(
98 array: &ArrayRef,
99 ctx: &mut ExecutionCtx,
100) where
101 Scalar: From<T>,
102{
103 let canonicalized_array = array
104 .clone()
105 .execute::<Canonical>(ctx)
106 .vortex_expect("Must be able to canonicalise")
107 .into_array();
108 let original_values = to_vec_of_scalar(&canonicalized_array, ctx);
109
110 let one = T::from(1)
111 .ok_or_else(|| vortex_err!("could not convert 1 into array native type"))
112 .vortex_expect("operation should succeed in conformance test");
113 let scalar_one = Scalar::from(one)
114 .cast(array.dtype())
115 .vortex_expect("operation should succeed in conformance test");
116
117 let operators: [NumericOperator; 4] = [
118 NumericOperator::Add,
119 NumericOperator::Sub,
120 NumericOperator::Mul,
121 NumericOperator::Div,
122 ];
123
124 for operator in operators {
125 let op = operator;
126 let rhs_const = ConstantArray::new(scalar_one.clone(), array.len()).into_array();
127
128 let result = array
130 .binary(rhs_const.clone(), op.into())
131 .vortex_expect("apply shouldn't fail")
132 .execute::<RecursiveCanonical>(ctx)
133 .map(|c| c.0.into_array());
134
135 let Ok(result) = result else {
138 continue;
139 };
140
141 let actual_values = to_vec_of_scalar(&result, ctx);
142
143 let expected_results: Vec<Option<Scalar>> = original_values
145 .iter()
146 .map(|x| {
147 x.as_primitive()
148 .checked_binary_numeric(&scalar_one.as_primitive(), op)
149 .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
150 })
151 .collect();
152
153 for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
155 if let Some(expected_value) = expected {
156 assert_eq!(
157 actual,
158 expected_value,
159 "Binary numeric operation failed for encoding {} at index {}: \
160 ({array:?})[{idx}] {operator:?} {scalar_one} \
161 expected {expected_value:?}, got {actual:?}",
162 array.encoding_id(),
163 idx,
164 );
165 }
166 }
167
168 let result = rhs_const.binary(array.clone(), op.into()).and_then(|a| {
170 a.execute::<RecursiveCanonical>(ctx)
171 .map(|c| c.0.into_array())
172 });
173
174 let Ok(result) = result else {
176 continue;
177 };
178
179 let actual_values = to_vec_of_scalar(&result, ctx);
180
181 let expected_results: Vec<Option<Scalar>> = original_values
183 .iter()
184 .map(|x| {
185 scalar_one
186 .as_primitive()
187 .checked_binary_numeric(&x.as_primitive(), op)
188 .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
189 })
190 .collect();
191
192 for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
194 if let Some(expected_value) = expected {
195 assert_eq!(
196 actual,
197 expected_value,
198 "Binary numeric operation failed for encoding {} at index {}: \
199 {scalar_one} {operator:?} ({array:?})[{idx}] \
200 expected {expected_value:?}, got {actual:?}",
201 array.encoding_id(),
202 idx,
203 );
204 }
205 }
206 }
207}
208
209pub fn test_binary_numeric_array(array: &ArrayRef, ctx: &mut ExecutionCtx) {
226 match array.dtype() {
227 DType::Primitive(ptype, _) => match ptype {
228 PType::I8 => test_binary_numeric_conformance::<i8>(array, ctx),
229 PType::I16 => test_binary_numeric_conformance::<i16>(array, ctx),
230 PType::I32 => test_binary_numeric_conformance::<i32>(array, ctx),
231 PType::I64 => test_binary_numeric_conformance::<i64>(array, ctx),
232 PType::U8 => test_binary_numeric_conformance::<u8>(array, ctx),
233 PType::U16 => test_binary_numeric_conformance::<u16>(array, ctx),
234 PType::U32 => test_binary_numeric_conformance::<u32>(array, ctx),
235 PType::U64 => test_binary_numeric_conformance::<u64>(array, ctx),
236 PType::F16 => {
237 eprintln!("Skipping f16 binary numeric tests (not supported)");
239 }
240 PType::F32 => test_binary_numeric_conformance::<f32>(array, ctx),
241 PType::F64 => test_binary_numeric_conformance::<f64>(array, ctx),
242 },
243 dtype => vortex_panic!(
244 "Binary numeric tests are only supported for primitive numeric types, got {dtype}",
245 ),
246 }
247}
248
249fn test_binary_numeric_edge_cases(array: &ArrayRef, ctx: &mut ExecutionCtx) {
257 match array.dtype() {
258 DType::Primitive(ptype, _) => match ptype {
259 PType::I8 => test_binary_numeric_edge_cases_signed::<i8>(array, ctx),
260 PType::I16 => test_binary_numeric_edge_cases_signed::<i16>(array, ctx),
261 PType::I32 => test_binary_numeric_edge_cases_signed::<i32>(array, ctx),
262 PType::I64 => test_binary_numeric_edge_cases_signed::<i64>(array, ctx),
263 PType::U8 => test_binary_numeric_edge_cases_unsigned::<u8>(array, ctx),
264 PType::U16 => test_binary_numeric_edge_cases_unsigned::<u16>(array, ctx),
265 PType::U32 => test_binary_numeric_edge_cases_unsigned::<u32>(array, ctx),
266 PType::U64 => test_binary_numeric_edge_cases_unsigned::<u64>(array, ctx),
267 PType::F16 => {
268 eprintln!("Skipping f16 edge case tests (not supported)");
269 }
270 PType::F32 => test_binary_numeric_edge_cases_float::<f32>(array, ctx),
271 PType::F64 => test_binary_numeric_edge_cases_float::<f64>(array, ctx),
272 },
273 dtype => vortex_panic!(
274 "Binary numeric edge case tests are only supported for primitive numeric types, got {dtype}"
275 ),
276 }
277}
278
279fn test_binary_numeric_edge_cases_signed<T>(array: &ArrayRef, ctx: &mut ExecutionCtx)
280where
281 T: NativePType + Num + Copy + Debug + Bounded + Signed,
282 Scalar: From<T>,
283{
284 test_binary_numeric_with_scalar(array, T::zero(), ctx);
286
287 test_binary_numeric_with_scalar(array, -T::one(), ctx);
289
290 test_binary_numeric_with_scalar(array, T::max_value(), ctx);
292
293 test_binary_numeric_with_scalar(array, T::min_value(), ctx);
295}
296
297fn test_binary_numeric_edge_cases_unsigned<T>(array: &ArrayRef, ctx: &mut ExecutionCtx)
298where
299 T: NativePType + Num + Copy + Debug + Bounded,
300 Scalar: From<T>,
301{
302 test_binary_numeric_with_scalar(array, T::zero(), ctx);
304
305 test_binary_numeric_with_scalar(array, T::max_value(), ctx);
307}
308
309fn test_binary_numeric_edge_cases_float<T>(array: &ArrayRef, ctx: &mut ExecutionCtx)
310where
311 T: NativePType + Num + Copy + Debug + Float,
312 Scalar: From<T>,
313{
314 test_binary_numeric_with_scalar(array, T::zero(), ctx);
316
317 test_binary_numeric_with_scalar(array, -T::one(), ctx);
319
320 test_binary_numeric_with_scalar(array, T::max_value(), ctx);
322
323 test_binary_numeric_with_scalar(array, T::min_value(), ctx);
325
326 test_binary_numeric_with_scalar(array, T::epsilon(), ctx);
328
329 test_binary_numeric_with_scalar(array, T::min_positive_value(), ctx);
331
332 test_binary_numeric_with_scalar(array, T::nan(), ctx);
334 test_binary_numeric_with_scalar(array, T::infinity(), ctx);
335 test_binary_numeric_with_scalar(array, T::neg_infinity(), ctx);
336}
337
338fn test_binary_numeric_with_scalar<T>(array: &ArrayRef, scalar_value: T, ctx: &mut ExecutionCtx)
339where
340 T: NativePType + Num + Copy + Debug,
341 Scalar: From<T>,
342{
343 let canonicalized_array = array
344 .clone()
345 .execute::<Canonical>(ctx)
346 .vortex_expect("Must be able to canonicalise")
347 .into_array();
348 let original_values = to_vec_of_scalar(&canonicalized_array, ctx);
349
350 let scalar = Scalar::from(scalar_value)
351 .cast(array.dtype())
352 .vortex_expect("operation should succeed in conformance test");
353
354 let operators = if scalar_value == T::zero() {
356 vec![
358 NumericOperator::Add,
359 NumericOperator::Sub,
360 NumericOperator::Mul,
361 ]
362 } else {
363 vec![
364 NumericOperator::Add,
365 NumericOperator::Sub,
366 NumericOperator::Mul,
367 NumericOperator::Div,
368 ]
369 };
370
371 for operator in operators {
372 let op = operator;
373 let rhs_const = ConstantArray::new(scalar.clone(), array.len()).into_array();
374
375 let result = array
377 .binary(rhs_const, op.into())
378 .vortex_expect("apply failed")
379 .execute::<RecursiveCanonical>(ctx)
380 .map(|x| x.0.into_array());
381
382 if result.is_err() {
385 continue;
386 }
387
388 let result = result.vortex_expect("operation should succeed in conformance test");
389 let actual_values = to_vec_of_scalar(&result, ctx);
390
391 let expected_results: Vec<Option<Scalar>> = original_values
393 .iter()
394 .map(|x| {
395 x.as_primitive()
396 .checked_binary_numeric(&scalar.as_primitive(), op)
397 .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
398 })
399 .collect();
400
401 for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
403 if let Some(expected_value) = expected {
404 assert_eq!(
405 actual,
406 expected_value,
407 "Binary numeric operation failed for encoding {} at index {} with scalar {:?}: \
408 ({array:?})[{idx}] {operator:?} {scalar} \
409 expected {expected_value:?}, got {actual:?}",
410 array.encoding_id(),
411 idx,
412 scalar_value,
413 );
414 }
415 }
416 }
417}