Skip to main content

vortex_array/compute/conformance/
binary_numeric.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! # Binary Numeric Conformance Tests
5//!
6//! This module provides conformance testing for binary numeric operations on Vortex arrays.
7//! It ensures that all numeric array encodings produce identical results when performing
8//! arithmetic operations (add, subtract, multiply, divide).
9//!
10//! ## Test Strategy
11//!
12//! For each array encoding, we test:
13//! 1. All binary numeric operators against a constant scalar value
14//! 2. Both left-hand and right-hand side operations (e.g., array + 1 and 1 + array)
15//! 3. That results match the canonical primitive array implementation
16//!
17//! ## Supported Operations
18//!
19//! - Addition (`+`)
20//! - Subtraction (`-`)
21//! - Multiplication (`*`)
22//! - Division (`/`)
23
24use itertools::Itertools;
25use num_traits::Num;
26use vortex_error::VortexExpect;
27use vortex_error::vortex_err;
28use vortex_error::vortex_panic;
29
30use crate::ArrayRef;
31use crate::IntoArray;
32use crate::LEGACY_SESSION;
33use crate::RecursiveCanonical;
34use crate::ToCanonical;
35use crate::VortexSessionExecute;
36use crate::arrays::ConstantArray;
37use crate::builtins::ArrayBuiltins;
38use crate::dtype::DType;
39use crate::dtype::NativePType;
40use crate::dtype::PType;
41use crate::scalar::NumericOperator;
42use crate::scalar::PrimitiveScalar;
43use crate::scalar::Scalar;
44
45fn to_vec_of_scalar(array: &ArrayRef) -> Vec<Scalar> {
46    // Not fast, but obviously correct
47    (0..array.len())
48        .map(|index| {
49            array
50                .scalar_at(index)
51                .vortex_expect("scalar_at should succeed in conformance test")
52        })
53        .collect_vec()
54}
55
56/// Tests binary numeric operations for conformance across array encodings.
57///
58/// # Type Parameters
59///
60/// * `T` - The native numeric type (e.g., i32, f64) that the array contains
61///
62/// # Arguments
63///
64/// * `array` - The array to test, which should contain numeric values of type `T`
65///
66/// # Test Details
67///
68/// This function:
69/// 1. Canonicalizes the input array to primitive form to get expected values
70/// 2. Tests all binary numeric operators against a constant value of 1
71/// 3. Verifies results match the expected primitive array computation
72/// 4. Tests both array-operator-scalar and scalar-operator-array forms
73/// 5. Gracefully skips operations that would cause overflow/underflow
74///
75/// # Panics
76///
77/// Panics if:
78/// - The array cannot be converted to primitive form
79/// - Results don't match expected values (for operations that don't overflow)
80fn test_binary_numeric_conformance<T: NativePType + Num + Copy>(array: ArrayRef)
81where
82    Scalar: From<T>,
83{
84    // First test with the standard scalar value of 1
85    test_standard_binary_numeric::<T>(array.clone());
86
87    // Then test edge cases
88    test_binary_numeric_edge_cases(array);
89}
90
91fn test_standard_binary_numeric<T: NativePType + Num + Copy>(array: ArrayRef)
92where
93    Scalar: From<T>,
94{
95    let canonicalized_array = array.to_primitive();
96    let original_values = to_vec_of_scalar(&canonicalized_array.into_array());
97
98    let one = T::from(1)
99        .ok_or_else(|| vortex_err!("could not convert 1 into array native type"))
100        .vortex_expect("operation should succeed in conformance test");
101    let scalar_one = Scalar::from(one)
102        .cast(array.dtype())
103        .vortex_expect("operation should succeed in conformance test");
104
105    let operators: [NumericOperator; 4] = [
106        NumericOperator::Add,
107        NumericOperator::Sub,
108        NumericOperator::Mul,
109        NumericOperator::Div,
110    ];
111
112    for operator in operators {
113        let op = operator;
114        let rhs_const = ConstantArray::new(scalar_one.clone(), array.len()).into_array();
115
116        // Test array operator scalar (e.g., array + 1)
117        let result = array
118            .binary(rhs_const.clone(), op.into())
119            .vortex_expect("apply shouldn't fail")
120            .execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())
121            .map(|c| c.0.into_array());
122
123        // Skip this operator if the entire operation fails
124        // This can happen for some edge cases in specific encodings
125        let Ok(result) = result else {
126            continue;
127        };
128
129        let actual_values = to_vec_of_scalar(&result);
130
131        // Check each element for overflow/underflow
132        let expected_results: Vec<Option<Scalar>> = original_values
133            .iter()
134            .map(|x| {
135                x.as_primitive()
136                    .checked_binary_numeric(&scalar_one.as_primitive(), op)
137                    .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
138            })
139            .collect();
140
141        // For elements that didn't overflow, check they match
142        for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
143            if let Some(expected_value) = expected {
144                assert_eq!(
145                    actual,
146                    expected_value,
147                    "Binary numeric operation failed for encoding {} at index {}: \
148                     ({array:?})[{idx}] {operator:?} {scalar_one} \
149                     expected {expected_value:?}, got {actual:?}",
150                    array.encoding_id(),
151                    idx,
152                );
153            }
154        }
155
156        // Test scalar operator array (e.g., 1 + array)
157        let result = rhs_const.binary(array.clone(), op.into()).and_then(|a| {
158            a.execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())
159                .map(|c| c.0.into_array())
160        });
161
162        // Skip this operator if the entire operation fails
163        let Ok(result) = result else {
164            continue;
165        };
166
167        let actual_values = to_vec_of_scalar(&result);
168
169        // Check each element for overflow/underflow
170        let expected_results: Vec<Option<Scalar>> = original_values
171            .iter()
172            .map(|x| {
173                scalar_one
174                    .as_primitive()
175                    .checked_binary_numeric(&x.as_primitive(), op)
176                    .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
177            })
178            .collect();
179
180        // For elements that didn't overflow, check they match
181        for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
182            if let Some(expected_value) = expected {
183                assert_eq!(
184                    actual,
185                    expected_value,
186                    "Binary numeric operation failed for encoding {} at index {}: \
187                     {scalar_one} {operator:?} ({array:?})[{idx}] \
188                     expected {expected_value:?}, got {actual:?}",
189                    array.encoding_id(),
190                    idx,
191                );
192            }
193        }
194    }
195}
196
197/// Entry point for binary numeric conformance testing for any array type.
198///
199/// This function automatically detects the array's numeric type and runs
200/// the appropriate tests. It's designed to be called from rstest parameterized
201/// tests without requiring explicit type parameters.
202///
203/// # Example
204///
205/// ```ignore
206/// #[rstest]
207/// #[case::i32_array(create_i32_array())]
208/// #[case::f64_array(create_f64_array())]
209/// fn test_my_encoding_binary_numeric(#[case] array: MyArray) {
210///     test_binary_numeric_array(array.into_array());
211/// }
212/// ```
213pub fn test_binary_numeric_array(array: ArrayRef) {
214    match array.dtype() {
215        DType::Primitive(ptype, _) => match ptype {
216            PType::I8 => test_binary_numeric_conformance::<i8>(array),
217            PType::I16 => test_binary_numeric_conformance::<i16>(array),
218            PType::I32 => test_binary_numeric_conformance::<i32>(array),
219            PType::I64 => test_binary_numeric_conformance::<i64>(array),
220            PType::U8 => test_binary_numeric_conformance::<u8>(array),
221            PType::U16 => test_binary_numeric_conformance::<u16>(array),
222            PType::U32 => test_binary_numeric_conformance::<u32>(array),
223            PType::U64 => test_binary_numeric_conformance::<u64>(array),
224            PType::F16 => {
225                // F16 not supported in num-traits, skip
226                eprintln!("Skipping f16 binary numeric tests (not supported)");
227            }
228            PType::F32 => test_binary_numeric_conformance::<f32>(array),
229            PType::F64 => test_binary_numeric_conformance::<f64>(array),
230        },
231        dtype => vortex_panic!(
232            "Binary numeric tests are only supported for primitive numeric types, got {dtype}",
233        ),
234    }
235}
236
237/// Tests binary numeric operations with edge case scalar values.
238///
239/// This function tests operations with scalar values:
240/// - Zero (identity for addition/subtraction, absorbing for multiplication)
241/// - Negative one (tests signed arithmetic)
242/// - Maximum value (tests overflow behavior)
243/// - Minimum value (tests underflow behavior)
244fn test_binary_numeric_edge_cases(array: ArrayRef) {
245    match array.dtype() {
246        DType::Primitive(ptype, _) => match ptype {
247            PType::I8 => test_binary_numeric_edge_cases_signed::<i8>(array),
248            PType::I16 => test_binary_numeric_edge_cases_signed::<i16>(array),
249            PType::I32 => test_binary_numeric_edge_cases_signed::<i32>(array),
250            PType::I64 => test_binary_numeric_edge_cases_signed::<i64>(array),
251            PType::U8 => test_binary_numeric_edge_cases_unsigned::<u8>(array),
252            PType::U16 => test_binary_numeric_edge_cases_unsigned::<u16>(array),
253            PType::U32 => test_binary_numeric_edge_cases_unsigned::<u32>(array),
254            PType::U64 => test_binary_numeric_edge_cases_unsigned::<u64>(array),
255            PType::F16 => {
256                eprintln!("Skipping f16 edge case tests (not supported)");
257            }
258            PType::F32 => test_binary_numeric_edge_cases_float::<f32>(array),
259            PType::F64 => test_binary_numeric_edge_cases_float::<f64>(array),
260        },
261        dtype => vortex_panic!(
262            "Binary numeric edge case tests are only supported for primitive numeric types, got {dtype}"
263        ),
264    }
265}
266
267fn test_binary_numeric_edge_cases_signed<T>(array: ArrayRef)
268where
269    T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Bounded + num_traits::Signed,
270    Scalar: From<T>,
271{
272    // Test with zero
273    test_binary_numeric_with_scalar(array.clone(), T::zero());
274
275    // Test with -1
276    test_binary_numeric_with_scalar(array.clone(), -T::one());
277
278    // Test with max value
279    test_binary_numeric_with_scalar(array.clone(), T::max_value());
280
281    // Test with min value
282    test_binary_numeric_with_scalar(array, T::min_value());
283}
284
285fn test_binary_numeric_edge_cases_unsigned<T>(array: ArrayRef)
286where
287    T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Bounded,
288    Scalar: From<T>,
289{
290    // Test with zero
291    test_binary_numeric_with_scalar(array.clone(), T::zero());
292
293    // Test with max value
294    test_binary_numeric_with_scalar(array, T::max_value());
295}
296
297fn test_binary_numeric_edge_cases_float<T>(array: ArrayRef)
298where
299    T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Float,
300    Scalar: From<T>,
301{
302    // Test with zero
303    test_binary_numeric_with_scalar(array.clone(), T::zero());
304
305    // Test with -1
306    test_binary_numeric_with_scalar(array.clone(), -T::one());
307
308    // Test with max value
309    test_binary_numeric_with_scalar(array.clone(), T::max_value());
310
311    // Test with min value
312    test_binary_numeric_with_scalar(array.clone(), T::min_value());
313
314    // Test with small positive value
315    test_binary_numeric_with_scalar(array.clone(), T::epsilon());
316
317    // Test with min positive value (subnormal)
318    test_binary_numeric_with_scalar(array.clone(), T::min_positive_value());
319
320    // Test with special float values (NaN, Infinity)
321    test_binary_numeric_with_scalar(array.clone(), T::nan());
322    test_binary_numeric_with_scalar(array.clone(), T::infinity());
323    test_binary_numeric_with_scalar(array, T::neg_infinity());
324}
325
326fn test_binary_numeric_with_scalar<T>(array: ArrayRef, scalar_value: T)
327where
328    T: NativePType + Num + Copy + std::fmt::Debug,
329    Scalar: From<T>,
330{
331    let canonicalized_array = array.to_primitive();
332    let original_values = to_vec_of_scalar(&canonicalized_array.into_array());
333
334    let scalar = Scalar::from(scalar_value)
335        .cast(array.dtype())
336        .vortex_expect("operation should succeed in conformance test");
337
338    // Only test operators that make sense for the given scalar
339    let operators = if scalar_value == T::zero() {
340        // Skip division by zero
341        vec![
342            NumericOperator::Add,
343            NumericOperator::Sub,
344            NumericOperator::Mul,
345        ]
346    } else {
347        vec![
348            NumericOperator::Add,
349            NumericOperator::Sub,
350            NumericOperator::Mul,
351            NumericOperator::Div,
352        ]
353    };
354
355    for operator in operators {
356        let op = operator;
357        let rhs_const = ConstantArray::new(scalar.clone(), array.len()).into_array();
358
359        // Test array operator scalar
360        let result = array
361            .binary(rhs_const, op.into())
362            .vortex_expect("apply failed")
363            .execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())
364            .map(|x| x.0.into_array());
365
366        // Skip if the entire operation fails
367        // TODO(joe): this is odd.
368        if result.is_err() {
369            continue;
370        }
371
372        let result = result.vortex_expect("operation should succeed in conformance test");
373        let actual_values = to_vec_of_scalar(&result);
374
375        // Check each element for overflow/underflow
376        let expected_results: Vec<Option<Scalar>> = original_values
377            .iter()
378            .map(|x| {
379                x.as_primitive()
380                    .checked_binary_numeric(&scalar.as_primitive(), op)
381                    .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
382            })
383            .collect();
384
385        // For elements that didn't overflow, check they match
386        for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
387            if let Some(expected_value) = expected {
388                assert_eq!(
389                    actual,
390                    expected_value,
391                    "Binary numeric operation failed for encoding {} at index {} with scalar {:?}: \
392                     ({array:?})[{idx}] {operator:?} {scalar} \
393                     expected {expected_value:?}, got {actual:?}",
394                    array.encoding_id(),
395                    idx,
396                    scalar_value,
397                );
398            }
399        }
400    }
401}