Skip to main content

vortex_array/compute/conformance/
binary_numeric.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! # Binary Numeric Conformance Tests
5//!
6//! This module provides conformance testing for binary numeric operations on Vortex arrays.
7//! It ensures that all numeric array encodings produce identical results when performing
8//! arithmetic operations (add, subtract, multiply, divide).
9//!
10//! ## Test Strategy
11//!
12//! For each array encoding, we test:
13//! 1. All binary numeric operators against a constant scalar value
14//! 2. Both left-hand and right-hand side operations (e.g., array + 1 and 1 + array)
15//! 3. That results match the canonical primitive array implementation
16//!
17//! ## Supported Operations
18//!
19//! - Addition (`+`)
20//! - Subtraction (`-`)
21//! - Reverse Subtraction (scalar - array)
22//! - Multiplication (`*`)
23//! - Division (`/`)
24//! - Reverse Division (scalar / array)
25
26use itertools::Itertools;
27use num_traits::Num;
28use vortex_error::VortexExpect;
29use vortex_error::vortex_err;
30use vortex_error::vortex_panic;
31
32use crate::Array;
33use crate::ArrayRef;
34use crate::IntoArray;
35use crate::LEGACY_SESSION;
36use crate::RecursiveCanonical;
37use crate::ToCanonical;
38use crate::VortexSessionExecute;
39use crate::arrays::ConstantArray;
40use crate::builtins::ArrayBuiltins;
41use crate::dtype::DType;
42use crate::dtype::NativePType;
43use crate::dtype::PType;
44use crate::scalar::NumericOperator;
45use crate::scalar::PrimitiveScalar;
46use crate::scalar::Scalar;
47
48fn to_vec_of_scalar(array: &dyn Array) -> Vec<Scalar> {
49    // Not fast, but obviously correct
50    (0..array.len())
51        .map(|index| {
52            array
53                .scalar_at(index)
54                .vortex_expect("scalar_at should succeed in conformance test")
55        })
56        .collect_vec()
57}
58
59/// Tests binary numeric operations for conformance across array encodings.
60///
61/// # Type Parameters
62///
63/// * `T` - The native numeric type (e.g., i32, f64) that the array contains
64///
65/// # Arguments
66///
67/// * `array` - The array to test, which should contain numeric values of type `T`
68///
69/// # Test Details
70///
71/// This function:
72/// 1. Canonicalizes the input array to primitive form to get expected values
73/// 2. Tests all binary numeric operators against a constant value of 1
74/// 3. Verifies results match the expected primitive array computation
75/// 4. Tests both array-operator-scalar and scalar-operator-array forms
76/// 5. Gracefully skips operations that would cause overflow/underflow
77///
78/// # Panics
79///
80/// Panics if:
81/// - The array cannot be converted to primitive form
82/// - Results don't match expected values (for operations that don't overflow)
83fn test_binary_numeric_conformance<T: NativePType + Num + Copy>(array: ArrayRef)
84where
85    Scalar: From<T>,
86{
87    // First test with the standard scalar value of 1
88    test_standard_binary_numeric::<T>(array.clone());
89
90    // Then test edge cases
91    test_binary_numeric_edge_cases(array);
92}
93
94fn test_standard_binary_numeric<T: NativePType + Num + Copy>(array: ArrayRef)
95where
96    Scalar: From<T>,
97{
98    let canonicalized_array = array.to_primitive();
99    let original_values = to_vec_of_scalar(&canonicalized_array.into_array());
100
101    let one = T::from(1)
102        .ok_or_else(|| vortex_err!("could not convert 1 into array native type"))
103        .vortex_expect("operation should succeed in conformance test");
104    let scalar_one = Scalar::from(one)
105        .cast(array.dtype())
106        .vortex_expect("operation should succeed in conformance test");
107
108    let operators: [NumericOperator; 4] = [
109        NumericOperator::Add,
110        NumericOperator::Sub,
111        NumericOperator::Mul,
112        NumericOperator::Div,
113    ];
114
115    for operator in operators {
116        let op = operator.into();
117        let rhs_const = ConstantArray::new(scalar_one.clone(), array.len()).into_array();
118
119        // Test array operator scalar (e.g., array + 1)
120        let result = array
121            .binary(rhs_const.clone(), op)
122            .vortex_expect("apply shouldn't fail")
123            .execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())
124            .map(|c| c.0.into_array());
125
126        // Skip this operator if the entire operation fails
127        // This can happen for some edge cases in specific encodings
128        let Ok(result) = result else {
129            continue;
130        };
131
132        println!("result {}", result.display_tree());
133        println!("result {}", result.display_values());
134
135        let actual_values = to_vec_of_scalar(&result);
136
137        // Check each element for overflow/underflow
138        let expected_results: Vec<Option<Scalar>> = original_values
139            .iter()
140            .map(|x| {
141                x.as_primitive()
142                    .checked_binary_numeric(&scalar_one.as_primitive(), operator)
143                    .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
144            })
145            .collect();
146
147        // For elements that didn't overflow, check they match
148        for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
149            if let Some(expected_value) = expected {
150                assert_eq!(
151                    actual,
152                    expected_value,
153                    "Binary numeric operation failed for encoding {} at index {}: \
154                     ({array:?})[{idx}] {operator:?} {scalar_one} \
155                     expected {expected_value:?}, got {actual:?}",
156                    array.encoding_id(),
157                    idx,
158                );
159            }
160        }
161
162        // Test scalar operator array (e.g., 1 + array)
163        let result = rhs_const.binary(array.clone(), op).and_then(|a| {
164            a.execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())
165                .map(|c| c.0.into_array())
166        });
167
168        // Skip this operator if the entire operation fails
169        let Ok(result) = result else {
170            continue;
171        };
172
173        let actual_values = to_vec_of_scalar(&result);
174
175        // Check each element for overflow/underflow
176        let expected_results: Vec<Option<Scalar>> = original_values
177            .iter()
178            .map(|x| {
179                scalar_one
180                    .as_primitive()
181                    .checked_binary_numeric(&x.as_primitive(), operator)
182                    .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
183            })
184            .collect();
185
186        // For elements that didn't overflow, check they match
187        for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
188            if let Some(expected_value) = expected {
189                assert_eq!(
190                    actual,
191                    expected_value,
192                    "Binary numeric operation failed for encoding {} at index {}: \
193                     {scalar_one} {operator:?} ({array:?})[{idx}] \
194                     expected {expected_value:?}, got {actual:?}",
195                    array.encoding_id(),
196                    idx,
197                );
198            }
199        }
200    }
201}
202
203/// Entry point for binary numeric conformance testing for any array type.
204///
205/// This function automatically detects the array's numeric type and runs
206/// the appropriate tests. It's designed to be called from rstest parameterized
207/// tests without requiring explicit type parameters.
208///
209/// # Example
210///
211/// ```ignore
212/// #[rstest]
213/// #[case::i32_array(create_i32_array())]
214/// #[case::f64_array(create_f64_array())]
215/// fn test_my_encoding_binary_numeric(#[case] array: MyArray) {
216///     test_binary_numeric_array(array.into_array());
217/// }
218/// ```
219pub fn test_binary_numeric_array(array: ArrayRef) {
220    match array.dtype() {
221        DType::Primitive(ptype, _) => match ptype {
222            PType::I8 => test_binary_numeric_conformance::<i8>(array),
223            PType::I16 => test_binary_numeric_conformance::<i16>(array),
224            PType::I32 => test_binary_numeric_conformance::<i32>(array),
225            PType::I64 => test_binary_numeric_conformance::<i64>(array),
226            PType::U8 => test_binary_numeric_conformance::<u8>(array),
227            PType::U16 => test_binary_numeric_conformance::<u16>(array),
228            PType::U32 => test_binary_numeric_conformance::<u32>(array),
229            PType::U64 => test_binary_numeric_conformance::<u64>(array),
230            PType::F16 => {
231                // F16 not supported in num-traits, skip
232                eprintln!("Skipping f16 binary numeric tests (not supported)");
233            }
234            PType::F32 => test_binary_numeric_conformance::<f32>(array),
235            PType::F64 => test_binary_numeric_conformance::<f64>(array),
236        },
237        dtype => vortex_panic!(
238            "Binary numeric tests are only supported for primitive numeric types, got {dtype}",
239        ),
240    }
241}
242
243/// Tests binary numeric operations with edge case scalar values.
244///
245/// This function tests operations with scalar values:
246/// - Zero (identity for addition/subtraction, absorbing for multiplication)
247/// - Negative one (tests signed arithmetic)
248/// - Maximum value (tests overflow behavior)
249/// - Minimum value (tests underflow behavior)
250fn test_binary_numeric_edge_cases(array: ArrayRef) {
251    match array.dtype() {
252        DType::Primitive(ptype, _) => match ptype {
253            PType::I8 => test_binary_numeric_edge_cases_signed::<i8>(array),
254            PType::I16 => test_binary_numeric_edge_cases_signed::<i16>(array),
255            PType::I32 => test_binary_numeric_edge_cases_signed::<i32>(array),
256            PType::I64 => test_binary_numeric_edge_cases_signed::<i64>(array),
257            PType::U8 => test_binary_numeric_edge_cases_unsigned::<u8>(array),
258            PType::U16 => test_binary_numeric_edge_cases_unsigned::<u16>(array),
259            PType::U32 => test_binary_numeric_edge_cases_unsigned::<u32>(array),
260            PType::U64 => test_binary_numeric_edge_cases_unsigned::<u64>(array),
261            PType::F16 => {
262                eprintln!("Skipping f16 edge case tests (not supported)");
263            }
264            PType::F32 => test_binary_numeric_edge_cases_float::<f32>(array),
265            PType::F64 => test_binary_numeric_edge_cases_float::<f64>(array),
266        },
267        dtype => vortex_panic!(
268            "Binary numeric edge case tests are only supported for primitive numeric types, got {dtype}"
269        ),
270    }
271}
272
273fn test_binary_numeric_edge_cases_signed<T>(array: ArrayRef)
274where
275    T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Bounded + num_traits::Signed,
276    Scalar: From<T>,
277{
278    // Test with zero
279    test_binary_numeric_with_scalar(array.clone(), T::zero());
280
281    // Test with -1
282    test_binary_numeric_with_scalar(array.clone(), -T::one());
283
284    // Test with max value
285    test_binary_numeric_with_scalar(array.clone(), T::max_value());
286
287    // Test with min value
288    test_binary_numeric_with_scalar(array, T::min_value());
289}
290
291fn test_binary_numeric_edge_cases_unsigned<T>(array: ArrayRef)
292where
293    T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Bounded,
294    Scalar: From<T>,
295{
296    // Test with zero
297    test_binary_numeric_with_scalar(array.clone(), T::zero());
298
299    // Test with max value
300    test_binary_numeric_with_scalar(array, T::max_value());
301}
302
303fn test_binary_numeric_edge_cases_float<T>(array: ArrayRef)
304where
305    T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Float,
306    Scalar: From<T>,
307{
308    // Test with zero
309    test_binary_numeric_with_scalar(array.clone(), T::zero());
310
311    // Test with -1
312    test_binary_numeric_with_scalar(array.clone(), -T::one());
313
314    // Test with max value
315    test_binary_numeric_with_scalar(array.clone(), T::max_value());
316
317    // Test with min value
318    test_binary_numeric_with_scalar(array.clone(), T::min_value());
319
320    // Test with small positive value
321    test_binary_numeric_with_scalar(array.clone(), T::epsilon());
322
323    // Test with min positive value (subnormal)
324    test_binary_numeric_with_scalar(array.clone(), T::min_positive_value());
325
326    // Test with special float values (NaN, Infinity)
327    test_binary_numeric_with_scalar(array.clone(), T::nan());
328    test_binary_numeric_with_scalar(array.clone(), T::infinity());
329    test_binary_numeric_with_scalar(array, T::neg_infinity());
330}
331
332fn test_binary_numeric_with_scalar<T>(array: ArrayRef, scalar_value: T)
333where
334    T: NativePType + Num + Copy + std::fmt::Debug,
335    Scalar: From<T>,
336{
337    let canonicalized_array = array.to_primitive();
338    let original_values = to_vec_of_scalar(&canonicalized_array.into_array());
339
340    let scalar = Scalar::from(scalar_value)
341        .cast(array.dtype())
342        .vortex_expect("operation should succeed in conformance test");
343
344    // Only test operators that make sense for the given scalar
345    let operators = if scalar_value == T::zero() {
346        // Skip division by zero
347        vec![
348            NumericOperator::Add,
349            NumericOperator::Sub,
350            NumericOperator::Mul,
351        ]
352    } else {
353        vec![
354            NumericOperator::Add,
355            NumericOperator::Sub,
356            NumericOperator::Mul,
357            NumericOperator::Div,
358        ]
359    };
360
361    for operator in operators {
362        let op = operator.into();
363        let rhs_const = ConstantArray::new(scalar.clone(), array.len()).into_array();
364
365        // Test array operator scalar
366        let result = array
367            .binary(rhs_const, op)
368            .vortex_expect("apply failed")
369            .execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())
370            .map(|x| x.0.into_array());
371
372        // Skip if the entire operation fails
373        // TODO(joe): this is odd.
374        if result.is_err() {
375            continue;
376        }
377
378        let result = result.vortex_expect("operation should succeed in conformance test");
379        let actual_values = to_vec_of_scalar(&result);
380
381        // Check each element for overflow/underflow
382        let expected_results: Vec<Option<Scalar>> = original_values
383            .iter()
384            .map(|x| {
385                x.as_primitive()
386                    .checked_binary_numeric(&scalar.as_primitive(), operator)
387                    .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
388            })
389            .collect();
390
391        // For elements that didn't overflow, check they match
392        for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
393            if let Some(expected_value) = expected {
394                assert_eq!(
395                    actual,
396                    expected_value,
397                    "Binary numeric operation failed for encoding {} at index {} with scalar {:?}: \
398                     ({array:?})[{idx}] {operator:?} {scalar} \
399                     expected {expected_value:?}, got {actual:?}",
400                    array.encoding_id(),
401                    idx,
402                    scalar_value,
403                );
404            }
405        }
406    }
407}