vortex_array/compute/conformance/
binary_numeric.rs1use itertools::Itertools;
25use num_traits::Num;
26use vortex_error::VortexExpect;
27use vortex_error::vortex_err;
28use vortex_error::vortex_panic;
29
30use crate::ArrayRef;
31use crate::IntoArray;
32use crate::LEGACY_SESSION;
33use crate::RecursiveCanonical;
34#[expect(deprecated)]
35use crate::ToCanonical as _;
36use crate::VortexSessionExecute;
37use crate::arrays::ConstantArray;
38use crate::builtins::ArrayBuiltins;
39use crate::dtype::DType;
40use crate::dtype::NativePType;
41use crate::dtype::PType;
42use crate::scalar::NumericOperator;
43use crate::scalar::PrimitiveScalar;
44use crate::scalar::Scalar;
45
46fn to_vec_of_scalar(array: &ArrayRef) -> Vec<Scalar> {
47 (0..array.len())
49 .map(|index| {
50 array
51 .execute_scalar(index, &mut LEGACY_SESSION.create_execution_ctx())
52 .vortex_expect("scalar_at should succeed in conformance test")
53 })
54 .collect_vec()
55}
56
57fn test_binary_numeric_conformance<T: NativePType + Num + Copy>(array: ArrayRef)
82where
83 Scalar: From<T>,
84{
85 test_standard_binary_numeric::<T>(array.clone());
87
88 test_binary_numeric_edge_cases(array);
90}
91
92fn test_standard_binary_numeric<T: NativePType + Num + Copy>(array: ArrayRef)
93where
94 Scalar: From<T>,
95{
96 #[expect(deprecated)]
97 let canonicalized_array = array.to_primitive();
98 let original_values = to_vec_of_scalar(&canonicalized_array.into_array());
99
100 let one = T::from(1)
101 .ok_or_else(|| vortex_err!("could not convert 1 into array native type"))
102 .vortex_expect("operation should succeed in conformance test");
103 let scalar_one = Scalar::from(one)
104 .cast(array.dtype())
105 .vortex_expect("operation should succeed in conformance test");
106
107 let operators: [NumericOperator; 4] = [
108 NumericOperator::Add,
109 NumericOperator::Sub,
110 NumericOperator::Mul,
111 NumericOperator::Div,
112 ];
113
114 for operator in operators {
115 let op = operator;
116 let rhs_const = ConstantArray::new(scalar_one.clone(), array.len()).into_array();
117
118 let result = array
120 .binary(rhs_const.clone(), op.into())
121 .vortex_expect("apply shouldn't fail")
122 .execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())
123 .map(|c| c.0.into_array());
124
125 let Ok(result) = result else {
128 continue;
129 };
130
131 let actual_values = to_vec_of_scalar(&result);
132
133 let expected_results: Vec<Option<Scalar>> = original_values
135 .iter()
136 .map(|x| {
137 x.as_primitive()
138 .checked_binary_numeric(&scalar_one.as_primitive(), op)
139 .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
140 })
141 .collect();
142
143 for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
145 if let Some(expected_value) = expected {
146 assert_eq!(
147 actual,
148 expected_value,
149 "Binary numeric operation failed for encoding {} at index {}: \
150 ({array:?})[{idx}] {operator:?} {scalar_one} \
151 expected {expected_value:?}, got {actual:?}",
152 array.encoding_id(),
153 idx,
154 );
155 }
156 }
157
158 let result = rhs_const.binary(array.clone(), op.into()).and_then(|a| {
160 a.execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())
161 .map(|c| c.0.into_array())
162 });
163
164 let Ok(result) = result else {
166 continue;
167 };
168
169 let actual_values = to_vec_of_scalar(&result);
170
171 let expected_results: Vec<Option<Scalar>> = original_values
173 .iter()
174 .map(|x| {
175 scalar_one
176 .as_primitive()
177 .checked_binary_numeric(&x.as_primitive(), op)
178 .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
179 })
180 .collect();
181
182 for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
184 if let Some(expected_value) = expected {
185 assert_eq!(
186 actual,
187 expected_value,
188 "Binary numeric operation failed for encoding {} at index {}: \
189 {scalar_one} {operator:?} ({array:?})[{idx}] \
190 expected {expected_value:?}, got {actual:?}",
191 array.encoding_id(),
192 idx,
193 );
194 }
195 }
196 }
197}
198
199pub fn test_binary_numeric_array(array: ArrayRef) {
216 match array.dtype() {
217 DType::Primitive(ptype, _) => match ptype {
218 PType::I8 => test_binary_numeric_conformance::<i8>(array),
219 PType::I16 => test_binary_numeric_conformance::<i16>(array),
220 PType::I32 => test_binary_numeric_conformance::<i32>(array),
221 PType::I64 => test_binary_numeric_conformance::<i64>(array),
222 PType::U8 => test_binary_numeric_conformance::<u8>(array),
223 PType::U16 => test_binary_numeric_conformance::<u16>(array),
224 PType::U32 => test_binary_numeric_conformance::<u32>(array),
225 PType::U64 => test_binary_numeric_conformance::<u64>(array),
226 PType::F16 => {
227 eprintln!("Skipping f16 binary numeric tests (not supported)");
229 }
230 PType::F32 => test_binary_numeric_conformance::<f32>(array),
231 PType::F64 => test_binary_numeric_conformance::<f64>(array),
232 },
233 dtype => vortex_panic!(
234 "Binary numeric tests are only supported for primitive numeric types, got {dtype}",
235 ),
236 }
237}
238
239fn test_binary_numeric_edge_cases(array: ArrayRef) {
247 match array.dtype() {
248 DType::Primitive(ptype, _) => match ptype {
249 PType::I8 => test_binary_numeric_edge_cases_signed::<i8>(array),
250 PType::I16 => test_binary_numeric_edge_cases_signed::<i16>(array),
251 PType::I32 => test_binary_numeric_edge_cases_signed::<i32>(array),
252 PType::I64 => test_binary_numeric_edge_cases_signed::<i64>(array),
253 PType::U8 => test_binary_numeric_edge_cases_unsigned::<u8>(array),
254 PType::U16 => test_binary_numeric_edge_cases_unsigned::<u16>(array),
255 PType::U32 => test_binary_numeric_edge_cases_unsigned::<u32>(array),
256 PType::U64 => test_binary_numeric_edge_cases_unsigned::<u64>(array),
257 PType::F16 => {
258 eprintln!("Skipping f16 edge case tests (not supported)");
259 }
260 PType::F32 => test_binary_numeric_edge_cases_float::<f32>(array),
261 PType::F64 => test_binary_numeric_edge_cases_float::<f64>(array),
262 },
263 dtype => vortex_panic!(
264 "Binary numeric edge case tests are only supported for primitive numeric types, got {dtype}"
265 ),
266 }
267}
268
269fn test_binary_numeric_edge_cases_signed<T>(array: ArrayRef)
270where
271 T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Bounded + num_traits::Signed,
272 Scalar: From<T>,
273{
274 test_binary_numeric_with_scalar(array.clone(), T::zero());
276
277 test_binary_numeric_with_scalar(array.clone(), -T::one());
279
280 test_binary_numeric_with_scalar(array.clone(), T::max_value());
282
283 test_binary_numeric_with_scalar(array, T::min_value());
285}
286
287fn test_binary_numeric_edge_cases_unsigned<T>(array: ArrayRef)
288where
289 T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Bounded,
290 Scalar: From<T>,
291{
292 test_binary_numeric_with_scalar(array.clone(), T::zero());
294
295 test_binary_numeric_with_scalar(array, T::max_value());
297}
298
299fn test_binary_numeric_edge_cases_float<T>(array: ArrayRef)
300where
301 T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Float,
302 Scalar: From<T>,
303{
304 test_binary_numeric_with_scalar(array.clone(), T::zero());
306
307 test_binary_numeric_with_scalar(array.clone(), -T::one());
309
310 test_binary_numeric_with_scalar(array.clone(), T::max_value());
312
313 test_binary_numeric_with_scalar(array.clone(), T::min_value());
315
316 test_binary_numeric_with_scalar(array.clone(), T::epsilon());
318
319 test_binary_numeric_with_scalar(array.clone(), T::min_positive_value());
321
322 test_binary_numeric_with_scalar(array.clone(), T::nan());
324 test_binary_numeric_with_scalar(array.clone(), T::infinity());
325 test_binary_numeric_with_scalar(array, T::neg_infinity());
326}
327
328fn test_binary_numeric_with_scalar<T>(array: ArrayRef, scalar_value: T)
329where
330 T: NativePType + Num + Copy + std::fmt::Debug,
331 Scalar: From<T>,
332{
333 #[expect(deprecated)]
334 let canonicalized_array = array.to_primitive();
335 let original_values = to_vec_of_scalar(&canonicalized_array.into_array());
336
337 let scalar = Scalar::from(scalar_value)
338 .cast(array.dtype())
339 .vortex_expect("operation should succeed in conformance test");
340
341 let operators = if scalar_value == T::zero() {
343 vec![
345 NumericOperator::Add,
346 NumericOperator::Sub,
347 NumericOperator::Mul,
348 ]
349 } else {
350 vec![
351 NumericOperator::Add,
352 NumericOperator::Sub,
353 NumericOperator::Mul,
354 NumericOperator::Div,
355 ]
356 };
357
358 for operator in operators {
359 let op = operator;
360 let rhs_const = ConstantArray::new(scalar.clone(), array.len()).into_array();
361
362 let result = array
364 .binary(rhs_const, op.into())
365 .vortex_expect("apply failed")
366 .execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())
367 .map(|x| x.0.into_array());
368
369 if result.is_err() {
372 continue;
373 }
374
375 let result = result.vortex_expect("operation should succeed in conformance test");
376 let actual_values = to_vec_of_scalar(&result);
377
378 let expected_results: Vec<Option<Scalar>> = original_values
380 .iter()
381 .map(|x| {
382 x.as_primitive()
383 .checked_binary_numeric(&scalar.as_primitive(), op)
384 .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
385 })
386 .collect();
387
388 for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
390 if let Some(expected_value) = expected {
391 assert_eq!(
392 actual,
393 expected_value,
394 "Binary numeric operation failed for encoding {} at index {} with scalar {:?}: \
395 ({array:?})[{idx}] {operator:?} {scalar} \
396 expected {expected_value:?}, got {actual:?}",
397 array.encoding_id(),
398 idx,
399 scalar_value,
400 );
401 }
402 }
403 }
404}