Skip to main content

vortex_array/compute/conformance/
binary_numeric.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! # Binary Numeric Conformance Tests
5//!
6//! This module provides conformance testing for binary numeric operations on Vortex arrays.
7//! It ensures that all numeric array encodings produce identical results when performing
8//! arithmetic operations (add, subtract, multiply, divide).
9//!
10//! ## Test Strategy
11//!
12//! For each array encoding, we test:
13//! 1. All binary numeric operators against a constant scalar value
14//! 2. Both left-hand and right-hand side operations (e.g., array + 1 and 1 + array)
15//! 3. That results match the canonical primitive array implementation
16//!
17//! ## Supported Operations
18//!
19//! - Addition (`+`)
20//! - Subtraction (`-`)
21//! - Multiplication (`*`)
22//! - Division (`/`)
23
24use itertools::Itertools;
25use num_traits::Num;
26use vortex_error::VortexExpect;
27use vortex_error::vortex_err;
28use vortex_error::vortex_panic;
29
30use crate::ArrayRef;
31use crate::IntoArray;
32use crate::LEGACY_SESSION;
33use crate::RecursiveCanonical;
34#[expect(deprecated)]
35use crate::ToCanonical as _;
36use crate::VortexSessionExecute;
37use crate::arrays::ConstantArray;
38use crate::builtins::ArrayBuiltins;
39use crate::dtype::DType;
40use crate::dtype::NativePType;
41use crate::dtype::PType;
42use crate::scalar::NumericOperator;
43use crate::scalar::PrimitiveScalar;
44use crate::scalar::Scalar;
45
46fn to_vec_of_scalar(array: &ArrayRef) -> Vec<Scalar> {
47    // Not fast, but obviously correct
48    (0..array.len())
49        .map(|index| {
50            array
51                .execute_scalar(index, &mut LEGACY_SESSION.create_execution_ctx())
52                .vortex_expect("scalar_at should succeed in conformance test")
53        })
54        .collect_vec()
55}
56
57/// Tests binary numeric operations for conformance across array encodings.
58///
59/// # Type Parameters
60///
61/// * `T` - The native numeric type (e.g., i32, f64) that the array contains
62///
63/// # Arguments
64///
65/// * `array` - The array to test, which should contain numeric values of type `T`
66///
67/// # Test Details
68///
69/// This function:
70/// 1. Canonicalizes the input array to primitive form to get expected values
71/// 2. Tests all binary numeric operators against a constant value of 1
72/// 3. Verifies results match the expected primitive array computation
73/// 4. Tests both array-operator-scalar and scalar-operator-array forms
74/// 5. Gracefully skips operations that would cause overflow/underflow
75///
76/// # Panics
77///
78/// Panics if:
79/// - The array cannot be converted to primitive form
80/// - Results don't match expected values (for operations that don't overflow)
81fn test_binary_numeric_conformance<T: NativePType + Num + Copy>(array: ArrayRef)
82where
83    Scalar: From<T>,
84{
85    // First test with the standard scalar value of 1
86    test_standard_binary_numeric::<T>(array.clone());
87
88    // Then test edge cases
89    test_binary_numeric_edge_cases(array);
90}
91
92fn test_standard_binary_numeric<T: NativePType + Num + Copy>(array: ArrayRef)
93where
94    Scalar: From<T>,
95{
96    #[expect(deprecated)]
97    let canonicalized_array = array.to_primitive();
98    let original_values = to_vec_of_scalar(&canonicalized_array.into_array());
99
100    let one = T::from(1)
101        .ok_or_else(|| vortex_err!("could not convert 1 into array native type"))
102        .vortex_expect("operation should succeed in conformance test");
103    let scalar_one = Scalar::from(one)
104        .cast(array.dtype())
105        .vortex_expect("operation should succeed in conformance test");
106
107    let operators: [NumericOperator; 4] = [
108        NumericOperator::Add,
109        NumericOperator::Sub,
110        NumericOperator::Mul,
111        NumericOperator::Div,
112    ];
113
114    for operator in operators {
115        let op = operator;
116        let rhs_const = ConstantArray::new(scalar_one.clone(), array.len()).into_array();
117
118        // Test array operator scalar (e.g., array + 1)
119        let result = array
120            .binary(rhs_const.clone(), op.into())
121            .vortex_expect("apply shouldn't fail")
122            .execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())
123            .map(|c| c.0.into_array());
124
125        // Skip this operator if the entire operation fails
126        // This can happen for some edge cases in specific encodings
127        let Ok(result) = result else {
128            continue;
129        };
130
131        let actual_values = to_vec_of_scalar(&result);
132
133        // Check each element for overflow/underflow
134        let expected_results: Vec<Option<Scalar>> = original_values
135            .iter()
136            .map(|x| {
137                x.as_primitive()
138                    .checked_binary_numeric(&scalar_one.as_primitive(), op)
139                    .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
140            })
141            .collect();
142
143        // For elements that didn't overflow, check they match
144        for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
145            if let Some(expected_value) = expected {
146                assert_eq!(
147                    actual,
148                    expected_value,
149                    "Binary numeric operation failed for encoding {} at index {}: \
150                     ({array:?})[{idx}] {operator:?} {scalar_one} \
151                     expected {expected_value:?}, got {actual:?}",
152                    array.encoding_id(),
153                    idx,
154                );
155            }
156        }
157
158        // Test scalar operator array (e.g., 1 + array)
159        let result = rhs_const.binary(array.clone(), op.into()).and_then(|a| {
160            a.execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())
161                .map(|c| c.0.into_array())
162        });
163
164        // Skip this operator if the entire operation fails
165        let Ok(result) = result else {
166            continue;
167        };
168
169        let actual_values = to_vec_of_scalar(&result);
170
171        // Check each element for overflow/underflow
172        let expected_results: Vec<Option<Scalar>> = original_values
173            .iter()
174            .map(|x| {
175                scalar_one
176                    .as_primitive()
177                    .checked_binary_numeric(&x.as_primitive(), op)
178                    .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
179            })
180            .collect();
181
182        // For elements that didn't overflow, check they match
183        for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
184            if let Some(expected_value) = expected {
185                assert_eq!(
186                    actual,
187                    expected_value,
188                    "Binary numeric operation failed for encoding {} at index {}: \
189                     {scalar_one} {operator:?} ({array:?})[{idx}] \
190                     expected {expected_value:?}, got {actual:?}",
191                    array.encoding_id(),
192                    idx,
193                );
194            }
195        }
196    }
197}
198
199/// Entry point for binary numeric conformance testing for any array type.
200///
201/// This function automatically detects the array's numeric type and runs
202/// the appropriate tests. It's designed to be called from rstest parameterized
203/// tests without requiring explicit type parameters.
204///
205/// # Example
206///
207/// ```ignore
208/// #[rstest]
209/// #[case::i32_array(create_i32_array())]
210/// #[case::f64_array(create_f64_array())]
211/// fn test_my_encoding_binary_numeric(#[case] array: MyArray) {
212///     test_binary_numeric_array(array.into_array());
213/// }
214/// ```
215pub fn test_binary_numeric_array(array: ArrayRef) {
216    match array.dtype() {
217        DType::Primitive(ptype, _) => match ptype {
218            PType::I8 => test_binary_numeric_conformance::<i8>(array),
219            PType::I16 => test_binary_numeric_conformance::<i16>(array),
220            PType::I32 => test_binary_numeric_conformance::<i32>(array),
221            PType::I64 => test_binary_numeric_conformance::<i64>(array),
222            PType::U8 => test_binary_numeric_conformance::<u8>(array),
223            PType::U16 => test_binary_numeric_conformance::<u16>(array),
224            PType::U32 => test_binary_numeric_conformance::<u32>(array),
225            PType::U64 => test_binary_numeric_conformance::<u64>(array),
226            PType::F16 => {
227                // F16 not supported in num-traits, skip
228                eprintln!("Skipping f16 binary numeric tests (not supported)");
229            }
230            PType::F32 => test_binary_numeric_conformance::<f32>(array),
231            PType::F64 => test_binary_numeric_conformance::<f64>(array),
232        },
233        dtype => vortex_panic!(
234            "Binary numeric tests are only supported for primitive numeric types, got {dtype}",
235        ),
236    }
237}
238
239/// Tests binary numeric operations with edge case scalar values.
240///
241/// This function tests operations with scalar values:
242/// - Zero (identity for addition/subtraction, absorbing for multiplication)
243/// - Negative one (tests signed arithmetic)
244/// - Maximum value (tests overflow behavior)
245/// - Minimum value (tests underflow behavior)
246fn test_binary_numeric_edge_cases(array: ArrayRef) {
247    match array.dtype() {
248        DType::Primitive(ptype, _) => match ptype {
249            PType::I8 => test_binary_numeric_edge_cases_signed::<i8>(array),
250            PType::I16 => test_binary_numeric_edge_cases_signed::<i16>(array),
251            PType::I32 => test_binary_numeric_edge_cases_signed::<i32>(array),
252            PType::I64 => test_binary_numeric_edge_cases_signed::<i64>(array),
253            PType::U8 => test_binary_numeric_edge_cases_unsigned::<u8>(array),
254            PType::U16 => test_binary_numeric_edge_cases_unsigned::<u16>(array),
255            PType::U32 => test_binary_numeric_edge_cases_unsigned::<u32>(array),
256            PType::U64 => test_binary_numeric_edge_cases_unsigned::<u64>(array),
257            PType::F16 => {
258                eprintln!("Skipping f16 edge case tests (not supported)");
259            }
260            PType::F32 => test_binary_numeric_edge_cases_float::<f32>(array),
261            PType::F64 => test_binary_numeric_edge_cases_float::<f64>(array),
262        },
263        dtype => vortex_panic!(
264            "Binary numeric edge case tests are only supported for primitive numeric types, got {dtype}"
265        ),
266    }
267}
268
269fn test_binary_numeric_edge_cases_signed<T>(array: ArrayRef)
270where
271    T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Bounded + num_traits::Signed,
272    Scalar: From<T>,
273{
274    // Test with zero
275    test_binary_numeric_with_scalar(array.clone(), T::zero());
276
277    // Test with -1
278    test_binary_numeric_with_scalar(array.clone(), -T::one());
279
280    // Test with max value
281    test_binary_numeric_with_scalar(array.clone(), T::max_value());
282
283    // Test with min value
284    test_binary_numeric_with_scalar(array, T::min_value());
285}
286
287fn test_binary_numeric_edge_cases_unsigned<T>(array: ArrayRef)
288where
289    T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Bounded,
290    Scalar: From<T>,
291{
292    // Test with zero
293    test_binary_numeric_with_scalar(array.clone(), T::zero());
294
295    // Test with max value
296    test_binary_numeric_with_scalar(array, T::max_value());
297}
298
299fn test_binary_numeric_edge_cases_float<T>(array: ArrayRef)
300where
301    T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Float,
302    Scalar: From<T>,
303{
304    // Test with zero
305    test_binary_numeric_with_scalar(array.clone(), T::zero());
306
307    // Test with -1
308    test_binary_numeric_with_scalar(array.clone(), -T::one());
309
310    // Test with max value
311    test_binary_numeric_with_scalar(array.clone(), T::max_value());
312
313    // Test with min value
314    test_binary_numeric_with_scalar(array.clone(), T::min_value());
315
316    // Test with small positive value
317    test_binary_numeric_with_scalar(array.clone(), T::epsilon());
318
319    // Test with min positive value (subnormal)
320    test_binary_numeric_with_scalar(array.clone(), T::min_positive_value());
321
322    // Test with special float values (NaN, Infinity)
323    test_binary_numeric_with_scalar(array.clone(), T::nan());
324    test_binary_numeric_with_scalar(array.clone(), T::infinity());
325    test_binary_numeric_with_scalar(array, T::neg_infinity());
326}
327
328fn test_binary_numeric_with_scalar<T>(array: ArrayRef, scalar_value: T)
329where
330    T: NativePType + Num + Copy + std::fmt::Debug,
331    Scalar: From<T>,
332{
333    #[expect(deprecated)]
334    let canonicalized_array = array.to_primitive();
335    let original_values = to_vec_of_scalar(&canonicalized_array.into_array());
336
337    let scalar = Scalar::from(scalar_value)
338        .cast(array.dtype())
339        .vortex_expect("operation should succeed in conformance test");
340
341    // Only test operators that make sense for the given scalar
342    let operators = if scalar_value == T::zero() {
343        // Skip division by zero
344        vec![
345            NumericOperator::Add,
346            NumericOperator::Sub,
347            NumericOperator::Mul,
348        ]
349    } else {
350        vec![
351            NumericOperator::Add,
352            NumericOperator::Sub,
353            NumericOperator::Mul,
354            NumericOperator::Div,
355        ]
356    };
357
358    for operator in operators {
359        let op = operator;
360        let rhs_const = ConstantArray::new(scalar.clone(), array.len()).into_array();
361
362        // Test array operator scalar
363        let result = array
364            .binary(rhs_const, op.into())
365            .vortex_expect("apply failed")
366            .execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())
367            .map(|x| x.0.into_array());
368
369        // Skip if the entire operation fails
370        // TODO(joe): this is odd.
371        if result.is_err() {
372            continue;
373        }
374
375        let result = result.vortex_expect("operation should succeed in conformance test");
376        let actual_values = to_vec_of_scalar(&result);
377
378        // Check each element for overflow/underflow
379        let expected_results: Vec<Option<Scalar>> = original_values
380            .iter()
381            .map(|x| {
382                x.as_primitive()
383                    .checked_binary_numeric(&scalar.as_primitive(), op)
384                    .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
385            })
386            .collect();
387
388        // For elements that didn't overflow, check they match
389        for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
390            if let Some(expected_value) = expected {
391                assert_eq!(
392                    actual,
393                    expected_value,
394                    "Binary numeric operation failed for encoding {} at index {} with scalar {:?}: \
395                     ({array:?})[{idx}] {operator:?} {scalar} \
396                     expected {expected_value:?}, got {actual:?}",
397                    array.encoding_id(),
398                    idx,
399                    scalar_value,
400                );
401            }
402        }
403    }
404}