vortex_array/compute/conformance/
binary_numeric.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! # Binary Numeric Conformance Tests
5//!
6//! This module provides conformance testing for binary numeric operations on Vortex arrays.
7//! It ensures that all numeric array encodings produce identical results when performing
8//! arithmetic operations (add, subtract, multiply, divide).
9//!
10//! ## Test Strategy
11//!
12//! For each array encoding, we test:
13//! 1. All binary numeric operators against a constant scalar value
14//! 2. Both left-hand and right-hand side operations (e.g., array + 1 and 1 + array)
15//! 3. That results match the canonical primitive array implementation
16//!
17//! ## Supported Operations
18//!
19//! - Addition (`+`)
20//! - Subtraction (`-`)
21//! - Reverse Subtraction (scalar - array)
22//! - Multiplication (`*`)
23//! - Division (`/`)
24//! - Reverse Division (scalar / array)
25
26use itertools::Itertools;
27use num_traits::Num;
28use vortex_dtype::{DType, NativePType, PType};
29use vortex_error::{VortexUnwrap, vortex_err, vortex_panic};
30use vortex_scalar::{NumericOperator, PrimitiveScalar, Scalar};
31
32use crate::arrays::ConstantArray;
33use crate::compute::numeric::numeric;
34use crate::{Array, ArrayRef, IntoArray, ToCanonical};
35
36fn to_vec_of_scalar(array: &dyn Array) -> Vec<Scalar> {
37    // Not fast, but obviously correct
38    (0..array.len())
39        .map(|index| array.scalar_at(index))
40        .collect_vec()
41}
42
43/// Tests binary numeric operations for conformance across array encodings.
44///
45/// # Type Parameters
46///
47/// * `T` - The native numeric type (e.g., i32, f64) that the array contains
48///
49/// # Arguments
50///
51/// * `array` - The array to test, which should contain numeric values of type `T`
52///
53/// # Test Details
54///
55/// This function:
56/// 1. Canonicalizes the input array to primitive form to get expected values
57/// 2. Tests all binary numeric operators against a constant value of 1
58/// 3. Verifies results match the expected primitive array computation
59/// 4. Tests both array-operator-scalar and scalar-operator-array forms
60/// 5. Gracefully skips operations that would cause overflow/underflow
61///
62/// # Panics
63///
64/// Panics if:
65/// - The array cannot be converted to primitive form
66/// - Results don't match expected values (for operations that don't overflow)
67fn test_binary_numeric_conformance<T: NativePType + Num + Copy>(array: ArrayRef)
68where
69    Scalar: From<T>,
70{
71    // First test with the standard scalar value of 1
72    test_standard_binary_numeric::<T>(array.clone());
73
74    // Then test edge cases
75    test_binary_numeric_edge_cases(array);
76}
77
78fn test_standard_binary_numeric<T: NativePType + Num + Copy>(array: ArrayRef)
79where
80    Scalar: From<T>,
81{
82    let canonicalized_array = array.to_primitive();
83    let original_values = to_vec_of_scalar(&canonicalized_array.into_array());
84
85    let one = T::from(1)
86        .ok_or_else(|| vortex_err!("could not convert 1 into array native type"))
87        .vortex_unwrap();
88    let scalar_one = Scalar::from(one).cast(array.dtype()).vortex_unwrap();
89
90    let operators: [NumericOperator; 6] = [
91        NumericOperator::Add,
92        NumericOperator::Sub,
93        NumericOperator::RSub,
94        NumericOperator::Mul,
95        NumericOperator::Div,
96        NumericOperator::RDiv,
97    ];
98
99    for operator in operators {
100        // Test array operator scalar (e.g., array + 1)
101        let result = numeric(
102            &array,
103            &ConstantArray::new(scalar_one.clone(), array.len()).into_array(),
104            operator,
105        );
106
107        // Skip this operator if the entire operation fails
108        // This can happen for some edge cases in specific encodings
109        let Ok(result) = result else {
110            continue;
111        };
112
113        let actual_values = to_vec_of_scalar(&result);
114
115        // Check each element for overflow/underflow
116        let expected_results: Vec<Option<Scalar>> = original_values
117            .iter()
118            .map(|x| {
119                x.as_primitive()
120                    .checked_binary_numeric(&scalar_one.as_primitive(), operator)
121                    .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
122            })
123            .collect();
124
125        // For elements that didn't overflow, check they match
126        for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
127            if let Some(expected_value) = expected {
128                assert_eq!(
129                    actual,
130                    expected_value,
131                    "Binary numeric operation failed for encoding {} at index {}: \
132                     ({array:?})[{idx}] {operator:?} {scalar_one} \
133                     expected {expected_value:?}, got {actual:?}",
134                    array.encoding_id(),
135                    idx,
136                );
137            }
138        }
139
140        // Test scalar operator array (e.g., 1 + array)
141        let result = numeric(
142            &ConstantArray::new(scalar_one.clone(), array.len()).into_array(),
143            &array,
144            operator,
145        );
146
147        // Skip this operator if the entire operation fails
148        let Ok(result) = result else {
149            continue;
150        };
151
152        let actual_values = to_vec_of_scalar(&result);
153
154        // Check each element for overflow/underflow
155        let expected_results: Vec<Option<Scalar>> = original_values
156            .iter()
157            .map(|x| {
158                scalar_one
159                    .as_primitive()
160                    .checked_binary_numeric(&x.as_primitive(), operator)
161                    .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
162            })
163            .collect();
164
165        // For elements that didn't overflow, check they match
166        for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
167            if let Some(expected_value) = expected {
168                assert_eq!(
169                    actual,
170                    expected_value,
171                    "Binary numeric operation failed for encoding {} at index {}: \
172                     {scalar_one} {operator:?} ({array:?})[{idx}] \
173                     expected {expected_value:?}, got {actual:?}",
174                    array.encoding_id(),
175                    idx,
176                );
177            }
178        }
179    }
180}
181
182/// Entry point for binary numeric conformance testing for any array type.
183///
184/// This function automatically detects the array's numeric type and runs
185/// the appropriate tests. It's designed to be called from rstest parameterized
186/// tests without requiring explicit type parameters.
187///
188/// # Example
189///
190/// ```ignore
191/// #[rstest]
192/// #[case::i32_array(create_i32_array())]
193/// #[case::f64_array(create_f64_array())]
194/// fn test_my_encoding_binary_numeric(#[case] array: MyArray) {
195///     test_binary_numeric_array(array.into_array());
196/// }
197/// ```
198pub fn test_binary_numeric_array(array: ArrayRef) {
199    match array.dtype() {
200        DType::Primitive(ptype, _) => match ptype {
201            PType::I8 => test_binary_numeric_conformance::<i8>(array),
202            PType::I16 => test_binary_numeric_conformance::<i16>(array),
203            PType::I32 => test_binary_numeric_conformance::<i32>(array),
204            PType::I64 => test_binary_numeric_conformance::<i64>(array),
205            PType::U8 => test_binary_numeric_conformance::<u8>(array),
206            PType::U16 => test_binary_numeric_conformance::<u16>(array),
207            PType::U32 => test_binary_numeric_conformance::<u32>(array),
208            PType::U64 => test_binary_numeric_conformance::<u64>(array),
209            PType::F16 => {
210                // F16 not supported in num-traits, skip
211                eprintln!("Skipping f16 binary numeric tests (not supported)");
212            }
213            PType::F32 => test_binary_numeric_conformance::<f32>(array),
214            PType::F64 => test_binary_numeric_conformance::<f64>(array),
215        },
216        dtype => vortex_panic!(
217            "Binary numeric tests are only supported for primitive numeric types, got {dtype}",
218        ),
219    }
220}
221
222/// Tests binary numeric operations with edge case scalar values.
223///
224/// This function tests operations with scalar values:
225/// - Zero (identity for addition/subtraction, absorbing for multiplication)
226/// - Negative one (tests signed arithmetic)
227/// - Maximum value (tests overflow behavior)
228/// - Minimum value (tests underflow behavior)
229fn test_binary_numeric_edge_cases(array: ArrayRef) {
230    match array.dtype() {
231        DType::Primitive(ptype, _) => match ptype {
232            PType::I8 => test_binary_numeric_edge_cases_signed::<i8>(array),
233            PType::I16 => test_binary_numeric_edge_cases_signed::<i16>(array),
234            PType::I32 => test_binary_numeric_edge_cases_signed::<i32>(array),
235            PType::I64 => test_binary_numeric_edge_cases_signed::<i64>(array),
236            PType::U8 => test_binary_numeric_edge_cases_unsigned::<u8>(array),
237            PType::U16 => test_binary_numeric_edge_cases_unsigned::<u16>(array),
238            PType::U32 => test_binary_numeric_edge_cases_unsigned::<u32>(array),
239            PType::U64 => test_binary_numeric_edge_cases_unsigned::<u64>(array),
240            PType::F16 => {
241                eprintln!("Skipping f16 edge case tests (not supported)");
242            }
243            PType::F32 => test_binary_numeric_edge_cases_float::<f32>(array),
244            PType::F64 => test_binary_numeric_edge_cases_float::<f64>(array),
245        },
246        dtype => vortex_panic!(
247            "Binary numeric edge case tests are only supported for primitive numeric types, got {dtype}"
248        ),
249    }
250}
251
252fn test_binary_numeric_edge_cases_signed<T>(array: ArrayRef)
253where
254    T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Bounded + num_traits::Signed,
255    Scalar: From<T>,
256{
257    // Test with zero
258    test_binary_numeric_with_scalar(array.clone(), T::zero());
259
260    // Test with -1
261    test_binary_numeric_with_scalar(array.clone(), -T::one());
262
263    // Test with max value
264    test_binary_numeric_with_scalar(array.clone(), T::max_value());
265
266    // Test with min value
267    test_binary_numeric_with_scalar(array, T::min_value());
268}
269
270fn test_binary_numeric_edge_cases_unsigned<T>(array: ArrayRef)
271where
272    T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Bounded,
273    Scalar: From<T>,
274{
275    // Test with zero
276    test_binary_numeric_with_scalar(array.clone(), T::zero());
277
278    // Test with max value
279    test_binary_numeric_with_scalar(array, T::max_value());
280}
281
282fn test_binary_numeric_edge_cases_float<T>(array: ArrayRef)
283where
284    T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Float,
285    Scalar: From<T>,
286{
287    // Test with zero
288    test_binary_numeric_with_scalar(array.clone(), T::zero());
289
290    // Test with -1
291    test_binary_numeric_with_scalar(array.clone(), -T::one());
292
293    // Test with max value
294    test_binary_numeric_with_scalar(array.clone(), T::max_value());
295
296    // Test with min value
297    test_binary_numeric_with_scalar(array.clone(), T::min_value());
298
299    // Test with small positive value
300    test_binary_numeric_with_scalar(array.clone(), T::epsilon());
301
302    // Test with min positive value (subnormal)
303    test_binary_numeric_with_scalar(array.clone(), T::min_positive_value());
304
305    // Test with special float values (NaN, Infinity)
306    test_binary_numeric_with_scalar(array.clone(), T::nan());
307    test_binary_numeric_with_scalar(array.clone(), T::infinity());
308    test_binary_numeric_with_scalar(array, T::neg_infinity());
309}
310
311fn test_binary_numeric_with_scalar<T>(array: ArrayRef, scalar_value: T)
312where
313    T: NativePType + Num + Copy + std::fmt::Debug,
314    Scalar: From<T>,
315{
316    let canonicalized_array = array.to_primitive();
317    let original_values = to_vec_of_scalar(&canonicalized_array.into_array());
318
319    let scalar = Scalar::from(scalar_value)
320        .cast(array.dtype())
321        .vortex_unwrap();
322
323    // Only test operators that make sense for the given scalar
324    let operators = if scalar_value == T::zero() {
325        // Skip division by zero
326        vec![
327            NumericOperator::Add,
328            NumericOperator::Sub,
329            NumericOperator::RSub,
330            NumericOperator::Mul,
331        ]
332    } else {
333        vec![
334            NumericOperator::Add,
335            NumericOperator::Sub,
336            NumericOperator::RSub,
337            NumericOperator::Mul,
338            NumericOperator::Div,
339            NumericOperator::RDiv,
340        ]
341    };
342
343    for operator in operators {
344        // Test array operator scalar
345        let result = numeric(
346            &array,
347            &ConstantArray::new(scalar.clone(), array.len()).into_array(),
348            operator,
349        );
350
351        // Skip if the entire operation fails
352        if result.is_err() {
353            continue;
354        }
355
356        let result = result.vortex_unwrap();
357        let actual_values = to_vec_of_scalar(&result);
358
359        // Check each element for overflow/underflow
360        let expected_results: Vec<Option<Scalar>> = original_values
361            .iter()
362            .map(|x| {
363                x.as_primitive()
364                    .checked_binary_numeric(&scalar.as_primitive(), operator)
365                    .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
366            })
367            .collect();
368
369        // For elements that didn't overflow, check they match
370        for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
371            if let Some(expected_value) = expected {
372                assert_eq!(
373                    actual,
374                    expected_value,
375                    "Binary numeric operation failed for encoding {} at index {} with scalar {:?}: \
376                     ({array:?})[{idx}] {operator:?} {scalar} \
377                     expected {expected_value:?}, got {actual:?}",
378                    array.encoding_id(),
379                    idx,
380                    scalar_value,
381                );
382            }
383        }
384    }
385}