vortex_array/compute/conformance/
binary_numeric.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! # Binary Numeric Conformance Tests
5//!
6//! This module provides conformance testing for binary numeric operations on Vortex arrays.
7//! It ensures that all numeric array encodings produce identical results when performing
8//! arithmetic operations (add, subtract, multiply, divide).
9//!
10//! ## Test Strategy
11//!
12//! For each array encoding, we test:
13//! 1. All binary numeric operators against a constant scalar value
14//! 2. Both left-hand and right-hand side operations (e.g., array + 1 and 1 + array)
15//! 3. That results match the canonical primitive array implementation
16//!
17//! ## Supported Operations
18//!
19//! - Addition (`+`)
20//! - Subtraction (`-`)
21//! - Reverse Subtraction (scalar - array)
22//! - Multiplication (`*`)
23//! - Division (`/`)
24//! - Reverse Division (scalar / array)
25
26use itertools::Itertools;
27use num_traits::Num;
28use vortex_dtype::DType;
29use vortex_dtype::NativePType;
30use vortex_dtype::PType;
31use vortex_error::VortexUnwrap;
32use vortex_error::vortex_err;
33use vortex_error::vortex_panic;
34use vortex_scalar::NumericOperator;
35use vortex_scalar::PrimitiveScalar;
36use vortex_scalar::Scalar;
37
38use crate::Array;
39use crate::ArrayRef;
40use crate::IntoArray;
41use crate::ToCanonical;
42use crate::arrays::ConstantArray;
43use crate::compute::numeric::numeric;
44
45fn to_vec_of_scalar(array: &dyn Array) -> Vec<Scalar> {
46    // Not fast, but obviously correct
47    (0..array.len())
48        .map(|index| array.scalar_at(index))
49        .collect_vec()
50}
51
52/// Tests binary numeric operations for conformance across array encodings.
53///
54/// # Type Parameters
55///
56/// * `T` - The native numeric type (e.g., i32, f64) that the array contains
57///
58/// # Arguments
59///
60/// * `array` - The array to test, which should contain numeric values of type `T`
61///
62/// # Test Details
63///
64/// This function:
65/// 1. Canonicalizes the input array to primitive form to get expected values
66/// 2. Tests all binary numeric operators against a constant value of 1
67/// 3. Verifies results match the expected primitive array computation
68/// 4. Tests both array-operator-scalar and scalar-operator-array forms
69/// 5. Gracefully skips operations that would cause overflow/underflow
70///
71/// # Panics
72///
73/// Panics if:
74/// - The array cannot be converted to primitive form
75/// - Results don't match expected values (for operations that don't overflow)
76fn test_binary_numeric_conformance<T: NativePType + Num + Copy>(array: ArrayRef)
77where
78    Scalar: From<T>,
79{
80    // First test with the standard scalar value of 1
81    test_standard_binary_numeric::<T>(array.clone());
82
83    // Then test edge cases
84    test_binary_numeric_edge_cases(array);
85}
86
87fn test_standard_binary_numeric<T: NativePType + Num + Copy>(array: ArrayRef)
88where
89    Scalar: From<T>,
90{
91    let canonicalized_array = array.to_primitive();
92    let original_values = to_vec_of_scalar(&canonicalized_array.into_array());
93
94    let one = T::from(1)
95        .ok_or_else(|| vortex_err!("could not convert 1 into array native type"))
96        .vortex_unwrap();
97    let scalar_one = Scalar::from(one).cast(array.dtype()).vortex_unwrap();
98
99    let operators: [NumericOperator; 6] = [
100        NumericOperator::Add,
101        NumericOperator::Sub,
102        NumericOperator::RSub,
103        NumericOperator::Mul,
104        NumericOperator::Div,
105        NumericOperator::RDiv,
106    ];
107
108    for operator in operators {
109        // Test array operator scalar (e.g., array + 1)
110        let result = numeric(
111            &array,
112            &ConstantArray::new(scalar_one.clone(), array.len()).into_array(),
113            operator,
114        );
115
116        // Skip this operator if the entire operation fails
117        // This can happen for some edge cases in specific encodings
118        let Ok(result) = result else {
119            continue;
120        };
121
122        let actual_values = to_vec_of_scalar(&result);
123
124        // Check each element for overflow/underflow
125        let expected_results: Vec<Option<Scalar>> = original_values
126            .iter()
127            .map(|x| {
128                x.as_primitive()
129                    .checked_binary_numeric(&scalar_one.as_primitive(), operator)
130                    .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
131            })
132            .collect();
133
134        // For elements that didn't overflow, check they match
135        for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
136            if let Some(expected_value) = expected {
137                assert_eq!(
138                    actual,
139                    expected_value,
140                    "Binary numeric operation failed for encoding {} at index {}: \
141                     ({array:?})[{idx}] {operator:?} {scalar_one} \
142                     expected {expected_value:?}, got {actual:?}",
143                    array.encoding_id(),
144                    idx,
145                );
146            }
147        }
148
149        // Test scalar operator array (e.g., 1 + array)
150        let result = numeric(
151            &ConstantArray::new(scalar_one.clone(), array.len()).into_array(),
152            &array,
153            operator,
154        );
155
156        // Skip this operator if the entire operation fails
157        let Ok(result) = result else {
158            continue;
159        };
160
161        let actual_values = to_vec_of_scalar(&result);
162
163        // Check each element for overflow/underflow
164        let expected_results: Vec<Option<Scalar>> = original_values
165            .iter()
166            .map(|x| {
167                scalar_one
168                    .as_primitive()
169                    .checked_binary_numeric(&x.as_primitive(), operator)
170                    .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
171            })
172            .collect();
173
174        // For elements that didn't overflow, check they match
175        for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
176            if let Some(expected_value) = expected {
177                assert_eq!(
178                    actual,
179                    expected_value,
180                    "Binary numeric operation failed for encoding {} at index {}: \
181                     {scalar_one} {operator:?} ({array:?})[{idx}] \
182                     expected {expected_value:?}, got {actual:?}",
183                    array.encoding_id(),
184                    idx,
185                );
186            }
187        }
188    }
189}
190
191/// Entry point for binary numeric conformance testing for any array type.
192///
193/// This function automatically detects the array's numeric type and runs
194/// the appropriate tests. It's designed to be called from rstest parameterized
195/// tests without requiring explicit type parameters.
196///
197/// # Example
198///
199/// ```ignore
200/// #[rstest]
201/// #[case::i32_array(create_i32_array())]
202/// #[case::f64_array(create_f64_array())]
203/// fn test_my_encoding_binary_numeric(#[case] array: MyArray) {
204///     test_binary_numeric_array(array.into_array());
205/// }
206/// ```
207pub fn test_binary_numeric_array(array: ArrayRef) {
208    match array.dtype() {
209        DType::Primitive(ptype, _) => match ptype {
210            PType::I8 => test_binary_numeric_conformance::<i8>(array),
211            PType::I16 => test_binary_numeric_conformance::<i16>(array),
212            PType::I32 => test_binary_numeric_conformance::<i32>(array),
213            PType::I64 => test_binary_numeric_conformance::<i64>(array),
214            PType::U8 => test_binary_numeric_conformance::<u8>(array),
215            PType::U16 => test_binary_numeric_conformance::<u16>(array),
216            PType::U32 => test_binary_numeric_conformance::<u32>(array),
217            PType::U64 => test_binary_numeric_conformance::<u64>(array),
218            PType::F16 => {
219                // F16 not supported in num-traits, skip
220                eprintln!("Skipping f16 binary numeric tests (not supported)");
221            }
222            PType::F32 => test_binary_numeric_conformance::<f32>(array),
223            PType::F64 => test_binary_numeric_conformance::<f64>(array),
224        },
225        dtype => vortex_panic!(
226            "Binary numeric tests are only supported for primitive numeric types, got {dtype}",
227        ),
228    }
229}
230
231/// Tests binary numeric operations with edge case scalar values.
232///
233/// This function tests operations with scalar values:
234/// - Zero (identity for addition/subtraction, absorbing for multiplication)
235/// - Negative one (tests signed arithmetic)
236/// - Maximum value (tests overflow behavior)
237/// - Minimum value (tests underflow behavior)
238fn test_binary_numeric_edge_cases(array: ArrayRef) {
239    match array.dtype() {
240        DType::Primitive(ptype, _) => match ptype {
241            PType::I8 => test_binary_numeric_edge_cases_signed::<i8>(array),
242            PType::I16 => test_binary_numeric_edge_cases_signed::<i16>(array),
243            PType::I32 => test_binary_numeric_edge_cases_signed::<i32>(array),
244            PType::I64 => test_binary_numeric_edge_cases_signed::<i64>(array),
245            PType::U8 => test_binary_numeric_edge_cases_unsigned::<u8>(array),
246            PType::U16 => test_binary_numeric_edge_cases_unsigned::<u16>(array),
247            PType::U32 => test_binary_numeric_edge_cases_unsigned::<u32>(array),
248            PType::U64 => test_binary_numeric_edge_cases_unsigned::<u64>(array),
249            PType::F16 => {
250                eprintln!("Skipping f16 edge case tests (not supported)");
251            }
252            PType::F32 => test_binary_numeric_edge_cases_float::<f32>(array),
253            PType::F64 => test_binary_numeric_edge_cases_float::<f64>(array),
254        },
255        dtype => vortex_panic!(
256            "Binary numeric edge case tests are only supported for primitive numeric types, got {dtype}"
257        ),
258    }
259}
260
261fn test_binary_numeric_edge_cases_signed<T>(array: ArrayRef)
262where
263    T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Bounded + num_traits::Signed,
264    Scalar: From<T>,
265{
266    // Test with zero
267    test_binary_numeric_with_scalar(array.clone(), T::zero());
268
269    // Test with -1
270    test_binary_numeric_with_scalar(array.clone(), -T::one());
271
272    // Test with max value
273    test_binary_numeric_with_scalar(array.clone(), T::max_value());
274
275    // Test with min value
276    test_binary_numeric_with_scalar(array, T::min_value());
277}
278
279fn test_binary_numeric_edge_cases_unsigned<T>(array: ArrayRef)
280where
281    T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Bounded,
282    Scalar: From<T>,
283{
284    // Test with zero
285    test_binary_numeric_with_scalar(array.clone(), T::zero());
286
287    // Test with max value
288    test_binary_numeric_with_scalar(array, T::max_value());
289}
290
291fn test_binary_numeric_edge_cases_float<T>(array: ArrayRef)
292where
293    T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Float,
294    Scalar: From<T>,
295{
296    // Test with zero
297    test_binary_numeric_with_scalar(array.clone(), T::zero());
298
299    // Test with -1
300    test_binary_numeric_with_scalar(array.clone(), -T::one());
301
302    // Test with max value
303    test_binary_numeric_with_scalar(array.clone(), T::max_value());
304
305    // Test with min value
306    test_binary_numeric_with_scalar(array.clone(), T::min_value());
307
308    // Test with small positive value
309    test_binary_numeric_with_scalar(array.clone(), T::epsilon());
310
311    // Test with min positive value (subnormal)
312    test_binary_numeric_with_scalar(array.clone(), T::min_positive_value());
313
314    // Test with special float values (NaN, Infinity)
315    test_binary_numeric_with_scalar(array.clone(), T::nan());
316    test_binary_numeric_with_scalar(array.clone(), T::infinity());
317    test_binary_numeric_with_scalar(array, T::neg_infinity());
318}
319
320fn test_binary_numeric_with_scalar<T>(array: ArrayRef, scalar_value: T)
321where
322    T: NativePType + Num + Copy + std::fmt::Debug,
323    Scalar: From<T>,
324{
325    let canonicalized_array = array.to_primitive();
326    let original_values = to_vec_of_scalar(&canonicalized_array.into_array());
327
328    let scalar = Scalar::from(scalar_value)
329        .cast(array.dtype())
330        .vortex_unwrap();
331
332    // Only test operators that make sense for the given scalar
333    let operators = if scalar_value == T::zero() {
334        // Skip division by zero
335        vec![
336            NumericOperator::Add,
337            NumericOperator::Sub,
338            NumericOperator::RSub,
339            NumericOperator::Mul,
340        ]
341    } else {
342        vec![
343            NumericOperator::Add,
344            NumericOperator::Sub,
345            NumericOperator::RSub,
346            NumericOperator::Mul,
347            NumericOperator::Div,
348            NumericOperator::RDiv,
349        ]
350    };
351
352    for operator in operators {
353        // Test array operator scalar
354        let result = numeric(
355            &array,
356            &ConstantArray::new(scalar.clone(), array.len()).into_array(),
357            operator,
358        );
359
360        // Skip if the entire operation fails
361        if result.is_err() {
362            continue;
363        }
364
365        let result = result.vortex_unwrap();
366        let actual_values = to_vec_of_scalar(&result);
367
368        // Check each element for overflow/underflow
369        let expected_results: Vec<Option<Scalar>> = original_values
370            .iter()
371            .map(|x| {
372                x.as_primitive()
373                    .checked_binary_numeric(&scalar.as_primitive(), operator)
374                    .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
375            })
376            .collect();
377
378        // For elements that didn't overflow, check they match
379        for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
380            if let Some(expected_value) = expected {
381                assert_eq!(
382                    actual,
383                    expected_value,
384                    "Binary numeric operation failed for encoding {} at index {} with scalar {:?}: \
385                     ({array:?})[{idx}] {operator:?} {scalar} \
386                     expected {expected_value:?}, got {actual:?}",
387                    array.encoding_id(),
388                    idx,
389                    scalar_value,
390                );
391            }
392        }
393    }
394}