Skip to main content

vortex_array/arrays/
assertions.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5
6use itertools::Itertools;
7use vortex_error::VortexExpect;
8
9use crate::ArrayRef;
10use crate::ExecutionCtx;
11use crate::IntoArray;
12use crate::RecursiveCanonical;
13use crate::aggregate_fn::fns::all_non_distinct::all_non_distinct;
14
15fn format_indices<I: IntoIterator<Item = usize>>(indices: I) -> impl Display {
16    indices.into_iter().format(",")
17}
18
19/// Executes an array to recursive canonical form with the given execution context.
20fn execute_to_canonical(array: ArrayRef, ctx: &mut ExecutionCtx) -> ArrayRef {
21    array
22        .execute::<RecursiveCanonical>(ctx)
23        .vortex_expect("failed to execute array to recursive canonical form")
24        .0
25        .into_array()
26}
27
28/// Finds indices where two arrays differ based on `scalar_at` comparison.
29#[expect(clippy::unwrap_used)]
30fn find_mismatched_indices(
31    left: &ArrayRef,
32    right: &ArrayRef,
33    ctx: &mut ExecutionCtx,
34) -> Vec<usize> {
35    assert_eq!(left.len(), right.len());
36    (0..left.len())
37        .filter(|i| left.execute_scalar(*i, ctx).unwrap() != right.execute_scalar(*i, ctx).unwrap())
38        .collect()
39}
40
41/// Asserts that the scalar at position `$n` in array `$arr` equals `$expected`.
42///
43/// This is a convenience macro for testing that avoids verbose scalar comparison code.
44///
45/// # Example
46/// ```ignore
47/// let arr = PrimitiveArray::from_iter([1, 2, 3]);
48/// assert_nth_scalar!(arr, 0, 1, &mut ctx);
49/// assert_nth_scalar!(arr, 1, 2, &mut ctx);
50/// ```
51#[macro_export]
52macro_rules! assert_nth_scalar {
53    ($arr:expr, $n:expr, $expected:expr, $ctx:expr) => {{
54        use $crate::IntoArray as _;
55        let arr_ref: $crate::ArrayRef = $crate::IntoArray::into_array($arr.clone());
56        let expected = $expected.try_into().unwrap();
57        assert_eq!(arr_ref.execute_scalar($n, $ctx).unwrap(), expected);
58    }};
59}
60
61/// Asserts that the scalar at position `$n` in array `$arr` is null.
62///
63/// # Example
64///
65/// ```ignore
66/// let arr = PrimitiveArray::from_option_iter([Some(1), None, Some(3)]);
67/// assert_nth_scalar_null!(arr, 1, &mut ctx);
68/// ```
69#[macro_export]
70macro_rules! assert_nth_scalar_is_null {
71    ($arr:expr, $n:expr, $ctx:expr) => {{
72        let arr_ref: $crate::ArrayRef = $crate::IntoArray::into_array($arr.clone());
73        let scalar = arr_ref.execute_scalar($n, $ctx).unwrap();
74        assert!(
75            scalar.is_null(),
76            "expected scalar at index {} to be null, but was {:?}",
77            $n,
78            scalar
79        );
80    }};
81}
82
83#[macro_export]
84macro_rules! assert_arrays_eq {
85    ($left:expr, $right:expr, $ctx:expr) => {{
86        let left: $crate::ArrayRef = $crate::IntoArray::into_array($left.clone());
87        let right: $crate::ArrayRef = $crate::IntoArray::into_array($right.clone());
88        if left.dtype() != right.dtype() {
89            panic!(
90                "assertion left == right failed: arrays differ in type: {} != {}.\n  left: {}\n right: {}",
91                left.dtype(),
92                right.dtype(),
93                left.display_values(),
94                right.display_values()
95            )
96        }
97
98        assert_eq!(
99            left.len(),
100            right.len(),
101            "assertion left == right failed: arrays differ in length: {} != {}.\n  left: {}\n right: {}",
102            left.len(),
103            right.len(),
104            left.display_values(),
105            right.display_values()
106        );
107
108        let left = left.clone();
109        let right = right.clone();
110        $crate::arrays::assert_arrays_eq_impl(&left, &right, $ctx);
111    }};
112}
113
114/// Implementation of `assert_arrays_eq!` — called by the macro after converting inputs to
115/// `ArrayRef`.
116#[track_caller]
117#[expect(clippy::panic)]
118pub fn assert_arrays_eq_impl(left: &ArrayRef, right: &ArrayRef, ctx: &mut ExecutionCtx) {
119    let executed = execute_to_canonical(left.clone(), ctx);
120
121    let left_right_the_same =
122        all_non_distinct(left, right, ctx).vortex_expect("failed to compare left and right");
123    let executed_right_the_same = all_non_distinct(&executed, right, ctx)
124        .vortex_expect("failed to compare executed left and right");
125
126    if !left_right_the_same || !executed_right_the_same {
127        let left_right = find_mismatched_indices(left, right, ctx);
128
129        let mut msg = String::new();
130        if !left_right.is_empty() {
131            msg.push_str(&format!(
132                "\n  left != right at indices: {}",
133                format_indices(left_right)
134            ));
135        }
136
137        let executed_right = find_mismatched_indices(&executed, right, ctx);
138        if !executed_right.is_empty() {
139            msg.push_str(&format!(
140                "\n  executed != right at indices: {}",
141                format_indices(executed_right)
142            ));
143        }
144        panic!(
145            "assertion failed: arrays do not match:{}\n     left: {}\n    right: {}\n executed: {}",
146            msg,
147            left.display_values(),
148            right.display_values(),
149            executed.display_values()
150        )
151    }
152}