vortex_array/compute/conformance/
binary_numeric.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! # Binary Numeric Conformance Tests
5//!
6//! This module provides conformance testing for binary numeric operations on Vortex arrays.
7//! It ensures that all numeric array encodings produce identical results when performing
8//! arithmetic operations (add, subtract, multiply, divide).
9//!
10//! ## Test Strategy
11//!
12//! For each array encoding, we test:
13//! 1. All binary numeric operators against a constant scalar value
14//! 2. Both left-hand and right-hand side operations (e.g., array + 1 and 1 + array)
15//! 3. That results match the canonical primitive array implementation
16//!
17//! ## Supported Operations
18//!
19//! - Addition (`+`)
20//! - Subtraction (`-`)
21//! - Reverse Subtraction (scalar - array)
22//! - Multiplication (`*`)
23//! - Division (`/`)
24//! - Reverse Division (scalar / array)
25
26use itertools::Itertools;
27use num_traits::Num;
28use vortex_dtype::DType;
29use vortex_dtype::NativePType;
30use vortex_dtype::PType;
31use vortex_error::VortexExpect;
32use vortex_error::vortex_err;
33use vortex_error::vortex_panic;
34use vortex_scalar::NumericOperator;
35use vortex_scalar::PrimitiveScalar;
36use vortex_scalar::Scalar;
37
38use crate::Array;
39use crate::ArrayRef;
40use crate::IntoArray;
41use crate::ToCanonical;
42use crate::arrays::ConstantArray;
43use crate::compute::numeric::numeric;
44
45fn to_vec_of_scalar(array: &dyn Array) -> Vec<Scalar> {
46    // Not fast, but obviously correct
47    (0..array.len())
48        .map(|index| array.scalar_at(index))
49        .collect_vec()
50}
51
52/// Tests binary numeric operations for conformance across array encodings.
53///
54/// # Type Parameters
55///
56/// * `T` - The native numeric type (e.g., i32, f64) that the array contains
57///
58/// # Arguments
59///
60/// * `array` - The array to test, which should contain numeric values of type `T`
61///
62/// # Test Details
63///
64/// This function:
65/// 1. Canonicalizes the input array to primitive form to get expected values
66/// 2. Tests all binary numeric operators against a constant value of 1
67/// 3. Verifies results match the expected primitive array computation
68/// 4. Tests both array-operator-scalar and scalar-operator-array forms
69/// 5. Gracefully skips operations that would cause overflow/underflow
70///
71/// # Panics
72///
73/// Panics if:
74/// - The array cannot be converted to primitive form
75/// - Results don't match expected values (for operations that don't overflow)
76fn test_binary_numeric_conformance<T: NativePType + Num + Copy>(array: ArrayRef)
77where
78    Scalar: From<T>,
79{
80    // First test with the standard scalar value of 1
81    test_standard_binary_numeric::<T>(array.clone());
82
83    // Then test edge cases
84    test_binary_numeric_edge_cases(array);
85}
86
87fn test_standard_binary_numeric<T: NativePType + Num + Copy>(array: ArrayRef)
88where
89    Scalar: From<T>,
90{
91    let canonicalized_array = array.to_primitive();
92    let original_values = to_vec_of_scalar(&canonicalized_array.into_array());
93
94    let one = T::from(1)
95        .ok_or_else(|| vortex_err!("could not convert 1 into array native type"))
96        .vortex_expect("operation should succeed in conformance test");
97    let scalar_one = Scalar::from(one)
98        .cast(array.dtype())
99        .vortex_expect("operation should succeed in conformance test");
100
101    let operators: [NumericOperator; 6] = [
102        NumericOperator::Add,
103        NumericOperator::Sub,
104        NumericOperator::RSub,
105        NumericOperator::Mul,
106        NumericOperator::Div,
107        NumericOperator::RDiv,
108    ];
109
110    for operator in operators {
111        // Test array operator scalar (e.g., array + 1)
112        let result = numeric(
113            &array,
114            &ConstantArray::new(scalar_one.clone(), array.len()).into_array(),
115            operator,
116        );
117
118        // Skip this operator if the entire operation fails
119        // This can happen for some edge cases in specific encodings
120        let Ok(result) = result else {
121            continue;
122        };
123
124        let actual_values = to_vec_of_scalar(&result);
125
126        // Check each element for overflow/underflow
127        let expected_results: Vec<Option<Scalar>> = original_values
128            .iter()
129            .map(|x| {
130                x.as_primitive()
131                    .checked_binary_numeric(&scalar_one.as_primitive(), operator)
132                    .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
133            })
134            .collect();
135
136        // For elements that didn't overflow, check they match
137        for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
138            if let Some(expected_value) = expected {
139                assert_eq!(
140                    actual,
141                    expected_value,
142                    "Binary numeric operation failed for encoding {} at index {}: \
143                     ({array:?})[{idx}] {operator:?} {scalar_one} \
144                     expected {expected_value:?}, got {actual:?}",
145                    array.encoding_id(),
146                    idx,
147                );
148            }
149        }
150
151        // Test scalar operator array (e.g., 1 + array)
152        let result = numeric(
153            &ConstantArray::new(scalar_one.clone(), array.len()).into_array(),
154            &array,
155            operator,
156        );
157
158        // Skip this operator if the entire operation fails
159        let Ok(result) = result else {
160            continue;
161        };
162
163        let actual_values = to_vec_of_scalar(&result);
164
165        // Check each element for overflow/underflow
166        let expected_results: Vec<Option<Scalar>> = original_values
167            .iter()
168            .map(|x| {
169                scalar_one
170                    .as_primitive()
171                    .checked_binary_numeric(&x.as_primitive(), operator)
172                    .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
173            })
174            .collect();
175
176        // For elements that didn't overflow, check they match
177        for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
178            if let Some(expected_value) = expected {
179                assert_eq!(
180                    actual,
181                    expected_value,
182                    "Binary numeric operation failed for encoding {} at index {}: \
183                     {scalar_one} {operator:?} ({array:?})[{idx}] \
184                     expected {expected_value:?}, got {actual:?}",
185                    array.encoding_id(),
186                    idx,
187                );
188            }
189        }
190    }
191}
192
193/// Entry point for binary numeric conformance testing for any array type.
194///
195/// This function automatically detects the array's numeric type and runs
196/// the appropriate tests. It's designed to be called from rstest parameterized
197/// tests without requiring explicit type parameters.
198///
199/// # Example
200///
201/// ```ignore
202/// #[rstest]
203/// #[case::i32_array(create_i32_array())]
204/// #[case::f64_array(create_f64_array())]
205/// fn test_my_encoding_binary_numeric(#[case] array: MyArray) {
206///     test_binary_numeric_array(array.into_array());
207/// }
208/// ```
209pub fn test_binary_numeric_array(array: ArrayRef) {
210    match array.dtype() {
211        DType::Primitive(ptype, _) => match ptype {
212            PType::I8 => test_binary_numeric_conformance::<i8>(array),
213            PType::I16 => test_binary_numeric_conformance::<i16>(array),
214            PType::I32 => test_binary_numeric_conformance::<i32>(array),
215            PType::I64 => test_binary_numeric_conformance::<i64>(array),
216            PType::U8 => test_binary_numeric_conformance::<u8>(array),
217            PType::U16 => test_binary_numeric_conformance::<u16>(array),
218            PType::U32 => test_binary_numeric_conformance::<u32>(array),
219            PType::U64 => test_binary_numeric_conformance::<u64>(array),
220            PType::F16 => {
221                // F16 not supported in num-traits, skip
222                eprintln!("Skipping f16 binary numeric tests (not supported)");
223            }
224            PType::F32 => test_binary_numeric_conformance::<f32>(array),
225            PType::F64 => test_binary_numeric_conformance::<f64>(array),
226        },
227        dtype => vortex_panic!(
228            "Binary numeric tests are only supported for primitive numeric types, got {dtype}",
229        ),
230    }
231}
232
233/// Tests binary numeric operations with edge case scalar values.
234///
235/// This function tests operations with scalar values:
236/// - Zero (identity for addition/subtraction, absorbing for multiplication)
237/// - Negative one (tests signed arithmetic)
238/// - Maximum value (tests overflow behavior)
239/// - Minimum value (tests underflow behavior)
240fn test_binary_numeric_edge_cases(array: ArrayRef) {
241    match array.dtype() {
242        DType::Primitive(ptype, _) => match ptype {
243            PType::I8 => test_binary_numeric_edge_cases_signed::<i8>(array),
244            PType::I16 => test_binary_numeric_edge_cases_signed::<i16>(array),
245            PType::I32 => test_binary_numeric_edge_cases_signed::<i32>(array),
246            PType::I64 => test_binary_numeric_edge_cases_signed::<i64>(array),
247            PType::U8 => test_binary_numeric_edge_cases_unsigned::<u8>(array),
248            PType::U16 => test_binary_numeric_edge_cases_unsigned::<u16>(array),
249            PType::U32 => test_binary_numeric_edge_cases_unsigned::<u32>(array),
250            PType::U64 => test_binary_numeric_edge_cases_unsigned::<u64>(array),
251            PType::F16 => {
252                eprintln!("Skipping f16 edge case tests (not supported)");
253            }
254            PType::F32 => test_binary_numeric_edge_cases_float::<f32>(array),
255            PType::F64 => test_binary_numeric_edge_cases_float::<f64>(array),
256        },
257        dtype => vortex_panic!(
258            "Binary numeric edge case tests are only supported for primitive numeric types, got {dtype}"
259        ),
260    }
261}
262
263fn test_binary_numeric_edge_cases_signed<T>(array: ArrayRef)
264where
265    T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Bounded + num_traits::Signed,
266    Scalar: From<T>,
267{
268    // Test with zero
269    test_binary_numeric_with_scalar(array.clone(), T::zero());
270
271    // Test with -1
272    test_binary_numeric_with_scalar(array.clone(), -T::one());
273
274    // Test with max value
275    test_binary_numeric_with_scalar(array.clone(), T::max_value());
276
277    // Test with min value
278    test_binary_numeric_with_scalar(array, T::min_value());
279}
280
281fn test_binary_numeric_edge_cases_unsigned<T>(array: ArrayRef)
282where
283    T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Bounded,
284    Scalar: From<T>,
285{
286    // Test with zero
287    test_binary_numeric_with_scalar(array.clone(), T::zero());
288
289    // Test with max value
290    test_binary_numeric_with_scalar(array, T::max_value());
291}
292
293fn test_binary_numeric_edge_cases_float<T>(array: ArrayRef)
294where
295    T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Float,
296    Scalar: From<T>,
297{
298    // Test with zero
299    test_binary_numeric_with_scalar(array.clone(), T::zero());
300
301    // Test with -1
302    test_binary_numeric_with_scalar(array.clone(), -T::one());
303
304    // Test with max value
305    test_binary_numeric_with_scalar(array.clone(), T::max_value());
306
307    // Test with min value
308    test_binary_numeric_with_scalar(array.clone(), T::min_value());
309
310    // Test with small positive value
311    test_binary_numeric_with_scalar(array.clone(), T::epsilon());
312
313    // Test with min positive value (subnormal)
314    test_binary_numeric_with_scalar(array.clone(), T::min_positive_value());
315
316    // Test with special float values (NaN, Infinity)
317    test_binary_numeric_with_scalar(array.clone(), T::nan());
318    test_binary_numeric_with_scalar(array.clone(), T::infinity());
319    test_binary_numeric_with_scalar(array, T::neg_infinity());
320}
321
322fn test_binary_numeric_with_scalar<T>(array: ArrayRef, scalar_value: T)
323where
324    T: NativePType + Num + Copy + std::fmt::Debug,
325    Scalar: From<T>,
326{
327    let canonicalized_array = array.to_primitive();
328    let original_values = to_vec_of_scalar(&canonicalized_array.into_array());
329
330    let scalar = Scalar::from(scalar_value)
331        .cast(array.dtype())
332        .vortex_expect("operation should succeed in conformance test");
333
334    // Only test operators that make sense for the given scalar
335    let operators = if scalar_value == T::zero() {
336        // Skip division by zero
337        vec![
338            NumericOperator::Add,
339            NumericOperator::Sub,
340            NumericOperator::RSub,
341            NumericOperator::Mul,
342        ]
343    } else {
344        vec![
345            NumericOperator::Add,
346            NumericOperator::Sub,
347            NumericOperator::RSub,
348            NumericOperator::Mul,
349            NumericOperator::Div,
350            NumericOperator::RDiv,
351        ]
352    };
353
354    for operator in operators {
355        // Test array operator scalar
356        let result = numeric(
357            &array,
358            &ConstantArray::new(scalar.clone(), array.len()).into_array(),
359            operator,
360        );
361
362        // Skip if the entire operation fails
363        if result.is_err() {
364            continue;
365        }
366
367        let result = result.vortex_expect("operation should succeed in conformance test");
368        let actual_values = to_vec_of_scalar(&result);
369
370        // Check each element for overflow/underflow
371        let expected_results: Vec<Option<Scalar>> = original_values
372            .iter()
373            .map(|x| {
374                x.as_primitive()
375                    .checked_binary_numeric(&scalar.as_primitive(), operator)
376                    .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
377            })
378            .collect();
379
380        // For elements that didn't overflow, check they match
381        for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
382            if let Some(expected_value) = expected {
383                assert_eq!(
384                    actual,
385                    expected_value,
386                    "Binary numeric operation failed for encoding {} at index {} with scalar {:?}: \
387                     ({array:?})[{idx}] {operator:?} {scalar} \
388                     expected {expected_value:?}, got {actual:?}",
389                    array.encoding_id(),
390                    idx,
391                    scalar_value,
392                );
393            }
394        }
395    }
396}