Skip to main content

vortex_array/compute/conformance/
binary_numeric.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! # Binary Numeric Conformance Tests
5//!
6//! This module provides conformance testing for binary numeric operations on Vortex arrays.
7//! It ensures that all numeric array encodings produce identical results when performing
8//! arithmetic operations (add, subtract, multiply, divide).
9//!
10//! ## Test Strategy
11//!
12//! For each array encoding, we test:
13//! 1. All binary numeric operators against a constant scalar value
14//! 2. Both left-hand and right-hand side operations (e.g., array + 1 and 1 + array)
15//! 3. That results match the canonical primitive array implementation
16//!
17//! ## Supported Operations
18//!
19//! - Addition (`+`)
20//! - Subtraction (`-`)
21//! - Multiplication (`*`)
22//! - Division (`/`)
23
24use std::fmt::Debug;
25
26use itertools::Itertools;
27use num_traits::Bounded;
28use num_traits::Float;
29use num_traits::Num;
30use num_traits::Signed;
31use vortex_error::VortexExpect;
32use vortex_error::vortex_err;
33use vortex_error::vortex_panic;
34
35use crate::ArrayRef;
36use crate::Canonical;
37use crate::ExecutionCtx;
38use crate::IntoArray;
39use crate::RecursiveCanonical;
40use crate::arrays::ConstantArray;
41use crate::builtins::ArrayBuiltins;
42use crate::dtype::DType;
43use crate::dtype::NativePType;
44use crate::dtype::PType;
45use crate::scalar::NumericOperator;
46use crate::scalar::PrimitiveScalar;
47use crate::scalar::Scalar;
48
49fn to_vec_of_scalar(array: &ArrayRef, ctx: &mut ExecutionCtx) -> Vec<Scalar> {
50    // Not fast, but obviously correct
51    (0..array.len())
52        .map(|index| {
53            array
54                .execute_scalar(index, ctx)
55                .vortex_expect("scalar_at should succeed in conformance test")
56        })
57        .collect_vec()
58}
59
60/// Tests binary numeric operations for conformance across array encodings.
61///
62/// # Type Parameters
63///
64/// * `T` - The native numeric type (e.g., i32, f64) that the array contains
65///
66/// # Arguments
67///
68/// * `array` - The array to test, which should contain numeric values of type `T`
69///
70/// # Test Details
71///
72/// This function:
73/// 1. Canonicalizes the input array to primitive form to get expected values
74/// 2. Tests all binary numeric operators against a constant value of 1
75/// 3. Verifies results match the expected primitive array computation
76/// 4. Tests both array-operator-scalar and scalar-operator-array forms
77/// 5. Gracefully skips operations that would cause overflow/underflow
78///
79/// # Panics
80///
81/// Panics if:
82/// - The array cannot be converted to primitive form
83/// - Results don't match expected values (for operations that don't overflow)
84fn test_binary_numeric_conformance<T: NativePType + Num + Copy>(
85    array: &ArrayRef,
86    ctx: &mut ExecutionCtx,
87) where
88    Scalar: From<T>,
89{
90    // First test with the standard scalar value of 1
91    test_standard_binary_numeric::<T>(array, ctx);
92
93    // Then test edge cases
94    test_binary_numeric_edge_cases(array, ctx);
95}
96
97fn test_standard_binary_numeric<T: NativePType + Num + Copy>(
98    array: &ArrayRef,
99    ctx: &mut ExecutionCtx,
100) where
101    Scalar: From<T>,
102{
103    let canonicalized_array = array
104        .clone()
105        .execute::<Canonical>(ctx)
106        .vortex_expect("Must be able to canonicalise")
107        .into_array();
108    let original_values = to_vec_of_scalar(&canonicalized_array, ctx);
109
110    let one = T::from(1)
111        .ok_or_else(|| vortex_err!("could not convert 1 into array native type"))
112        .vortex_expect("operation should succeed in conformance test");
113    let scalar_one = Scalar::from(one)
114        .cast(array.dtype())
115        .vortex_expect("operation should succeed in conformance test");
116
117    let operators: [NumericOperator; 4] = [
118        NumericOperator::Add,
119        NumericOperator::Sub,
120        NumericOperator::Mul,
121        NumericOperator::Div,
122    ];
123
124    for operator in operators {
125        let op = operator;
126        let rhs_const = ConstantArray::new(scalar_one.clone(), array.len()).into_array();
127
128        // Test array operator scalar (e.g., array + 1)
129        let result = array
130            .binary(rhs_const.clone(), op.into())
131            .vortex_expect("apply shouldn't fail")
132            .execute::<RecursiveCanonical>(ctx)
133            .map(|c| c.0.into_array());
134
135        // Skip this operator if the entire operation fails
136        // This can happen for some edge cases in specific encodings
137        let Ok(result) = result else {
138            continue;
139        };
140
141        let actual_values = to_vec_of_scalar(&result, ctx);
142
143        // Check each element for overflow/underflow
144        let expected_results: Vec<Option<Scalar>> = original_values
145            .iter()
146            .map(|x| {
147                x.as_primitive()
148                    .checked_binary_numeric(&scalar_one.as_primitive(), op)
149                    .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
150            })
151            .collect();
152
153        // For elements that didn't overflow, check they match
154        for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
155            if let Some(expected_value) = expected {
156                assert_eq!(
157                    actual,
158                    expected_value,
159                    "Binary numeric operation failed for encoding {} at index {}: \
160                     ({array:?})[{idx}] {operator:?} {scalar_one} \
161                     expected {expected_value:?}, got {actual:?}",
162                    array.encoding_id(),
163                    idx,
164                );
165            }
166        }
167
168        // Test scalar operator array (e.g., 1 + array)
169        let result = rhs_const.binary(array.clone(), op.into()).and_then(|a| {
170            a.execute::<RecursiveCanonical>(ctx)
171                .map(|c| c.0.into_array())
172        });
173
174        // Skip this operator if the entire operation fails
175        let Ok(result) = result else {
176            continue;
177        };
178
179        let actual_values = to_vec_of_scalar(&result, ctx);
180
181        // Check each element for overflow/underflow
182        let expected_results: Vec<Option<Scalar>> = original_values
183            .iter()
184            .map(|x| {
185                scalar_one
186                    .as_primitive()
187                    .checked_binary_numeric(&x.as_primitive(), op)
188                    .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
189            })
190            .collect();
191
192        // For elements that didn't overflow, check they match
193        for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
194            if let Some(expected_value) = expected {
195                assert_eq!(
196                    actual,
197                    expected_value,
198                    "Binary numeric operation failed for encoding {} at index {}: \
199                     {scalar_one} {operator:?} ({array:?})[{idx}] \
200                     expected {expected_value:?}, got {actual:?}",
201                    array.encoding_id(),
202                    idx,
203                );
204            }
205        }
206    }
207}
208
209/// Entry point for binary numeric conformance testing for any array type.
210///
211/// This function automatically detects the array's numeric type and runs
212/// the appropriate tests. It's designed to be called from rstest parameterized
213/// tests without requiring explicit type parameters.
214///
215/// # Example
216///
217/// ```ignore
218/// #[rstest]
219/// #[case::i32_array(create_i32_array())]
220/// #[case::f64_array(create_f64_array())]
221/// fn test_my_encoding_binary_numeric(#[case] array: MyArray) {
222///     test_binary_numeric_array(array.into_array());
223/// }
224/// ```
225pub fn test_binary_numeric_array(array: &ArrayRef, ctx: &mut ExecutionCtx) {
226    match array.dtype() {
227        DType::Primitive(ptype, _) => match ptype {
228            PType::I8 => test_binary_numeric_conformance::<i8>(array, ctx),
229            PType::I16 => test_binary_numeric_conformance::<i16>(array, ctx),
230            PType::I32 => test_binary_numeric_conformance::<i32>(array, ctx),
231            PType::I64 => test_binary_numeric_conformance::<i64>(array, ctx),
232            PType::U8 => test_binary_numeric_conformance::<u8>(array, ctx),
233            PType::U16 => test_binary_numeric_conformance::<u16>(array, ctx),
234            PType::U32 => test_binary_numeric_conformance::<u32>(array, ctx),
235            PType::U64 => test_binary_numeric_conformance::<u64>(array, ctx),
236            PType::F16 => {
237                // F16 not supported in num-traits, skip
238                eprintln!("Skipping f16 binary numeric tests (not supported)");
239            }
240            PType::F32 => test_binary_numeric_conformance::<f32>(array, ctx),
241            PType::F64 => test_binary_numeric_conformance::<f64>(array, ctx),
242        },
243        dtype => vortex_panic!(
244            "Binary numeric tests are only supported for primitive numeric types, got {dtype}",
245        ),
246    }
247}
248
249/// Tests binary numeric operations with edge case scalar values.
250///
251/// This function tests operations with scalar values:
252/// - Zero (identity for addition/subtraction, absorbing for multiplication)
253/// - Negative one (tests signed arithmetic)
254/// - Maximum value (tests overflow behavior)
255/// - Minimum value (tests underflow behavior)
256fn test_binary_numeric_edge_cases(array: &ArrayRef, ctx: &mut ExecutionCtx) {
257    match array.dtype() {
258        DType::Primitive(ptype, _) => match ptype {
259            PType::I8 => test_binary_numeric_edge_cases_signed::<i8>(array, ctx),
260            PType::I16 => test_binary_numeric_edge_cases_signed::<i16>(array, ctx),
261            PType::I32 => test_binary_numeric_edge_cases_signed::<i32>(array, ctx),
262            PType::I64 => test_binary_numeric_edge_cases_signed::<i64>(array, ctx),
263            PType::U8 => test_binary_numeric_edge_cases_unsigned::<u8>(array, ctx),
264            PType::U16 => test_binary_numeric_edge_cases_unsigned::<u16>(array, ctx),
265            PType::U32 => test_binary_numeric_edge_cases_unsigned::<u32>(array, ctx),
266            PType::U64 => test_binary_numeric_edge_cases_unsigned::<u64>(array, ctx),
267            PType::F16 => {
268                eprintln!("Skipping f16 edge case tests (not supported)");
269            }
270            PType::F32 => test_binary_numeric_edge_cases_float::<f32>(array, ctx),
271            PType::F64 => test_binary_numeric_edge_cases_float::<f64>(array, ctx),
272        },
273        dtype => vortex_panic!(
274            "Binary numeric edge case tests are only supported for primitive numeric types, got {dtype}"
275        ),
276    }
277}
278
279fn test_binary_numeric_edge_cases_signed<T>(array: &ArrayRef, ctx: &mut ExecutionCtx)
280where
281    T: NativePType + Num + Copy + Debug + Bounded + Signed,
282    Scalar: From<T>,
283{
284    // Test with zero
285    test_binary_numeric_with_scalar(array, T::zero(), ctx);
286
287    // Test with -1
288    test_binary_numeric_with_scalar(array, -T::one(), ctx);
289
290    // Test with max value
291    test_binary_numeric_with_scalar(array, T::max_value(), ctx);
292
293    // Test with min value
294    test_binary_numeric_with_scalar(array, T::min_value(), ctx);
295}
296
297fn test_binary_numeric_edge_cases_unsigned<T>(array: &ArrayRef, ctx: &mut ExecutionCtx)
298where
299    T: NativePType + Num + Copy + Debug + Bounded,
300    Scalar: From<T>,
301{
302    // Test with zero
303    test_binary_numeric_with_scalar(array, T::zero(), ctx);
304
305    // Test with max value
306    test_binary_numeric_with_scalar(array, T::max_value(), ctx);
307}
308
309fn test_binary_numeric_edge_cases_float<T>(array: &ArrayRef, ctx: &mut ExecutionCtx)
310where
311    T: NativePType + Num + Copy + Debug + Float,
312    Scalar: From<T>,
313{
314    // Test with zero
315    test_binary_numeric_with_scalar(array, T::zero(), ctx);
316
317    // Test with -1
318    test_binary_numeric_with_scalar(array, -T::one(), ctx);
319
320    // Test with max value
321    test_binary_numeric_with_scalar(array, T::max_value(), ctx);
322
323    // Test with min value
324    test_binary_numeric_with_scalar(array, T::min_value(), ctx);
325
326    // Test with small positive value
327    test_binary_numeric_with_scalar(array, T::epsilon(), ctx);
328
329    // Test with min positive value (subnormal)
330    test_binary_numeric_with_scalar(array, T::min_positive_value(), ctx);
331
332    // Test with special float values (NaN, Infinity)
333    test_binary_numeric_with_scalar(array, T::nan(), ctx);
334    test_binary_numeric_with_scalar(array, T::infinity(), ctx);
335    test_binary_numeric_with_scalar(array, T::neg_infinity(), ctx);
336}
337
338fn test_binary_numeric_with_scalar<T>(array: &ArrayRef, scalar_value: T, ctx: &mut ExecutionCtx)
339where
340    T: NativePType + Num + Copy + Debug,
341    Scalar: From<T>,
342{
343    let canonicalized_array = array
344        .clone()
345        .execute::<Canonical>(ctx)
346        .vortex_expect("Must be able to canonicalise")
347        .into_array();
348    let original_values = to_vec_of_scalar(&canonicalized_array, ctx);
349
350    let scalar = Scalar::from(scalar_value)
351        .cast(array.dtype())
352        .vortex_expect("operation should succeed in conformance test");
353
354    // Only test operators that make sense for the given scalar
355    let operators = if scalar_value == T::zero() {
356        // Skip division by zero
357        vec![
358            NumericOperator::Add,
359            NumericOperator::Sub,
360            NumericOperator::Mul,
361        ]
362    } else {
363        vec![
364            NumericOperator::Add,
365            NumericOperator::Sub,
366            NumericOperator::Mul,
367            NumericOperator::Div,
368        ]
369    };
370
371    for operator in operators {
372        let op = operator;
373        let rhs_const = ConstantArray::new(scalar.clone(), array.len()).into_array();
374
375        // Test array operator scalar
376        let result = array
377            .binary(rhs_const, op.into())
378            .vortex_expect("apply failed")
379            .execute::<RecursiveCanonical>(ctx)
380            .map(|x| x.0.into_array());
381
382        // Skip if the entire operation fails
383        // TODO(joe): this is odd.
384        if result.is_err() {
385            continue;
386        }
387
388        let result = result.vortex_expect("operation should succeed in conformance test");
389        let actual_values = to_vec_of_scalar(&result, ctx);
390
391        // Check each element for overflow/underflow
392        let expected_results: Vec<Option<Scalar>> = original_values
393            .iter()
394            .map(|x| {
395                x.as_primitive()
396                    .checked_binary_numeric(&scalar.as_primitive(), op)
397                    .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
398            })
399            .collect();
400
401        // For elements that didn't overflow, check they match
402        for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
403            if let Some(expected_value) = expected {
404                assert_eq!(
405                    actual,
406                    expected_value,
407                    "Binary numeric operation failed for encoding {} at index {} with scalar {:?}: \
408                     ({array:?})[{idx}] {operator:?} {scalar} \
409                     expected {expected_value:?}, got {actual:?}",
410                    array.encoding_id(),
411                    idx,
412                    scalar_value,
413                );
414            }
415        }
416    }
417}