Skip to main content

vortex_array/arrays/
assertions.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5
6use itertools::Itertools;
7use vortex_error::VortexExpect;
8
9use crate::ArrayRef;
10use crate::ExecutionCtx;
11use crate::IntoArray;
12use crate::LEGACY_SESSION;
13use crate::RecursiveCanonical;
14use crate::VortexSessionExecute;
15
16fn format_indices<I: IntoIterator<Item = usize>>(indices: I) -> impl Display {
17    indices.into_iter().format(",")
18}
19
20/// Executes an array to recursive canonical form with the given execution context.
21fn execute_to_canonical(array: ArrayRef, ctx: &mut ExecutionCtx) -> ArrayRef {
22    array
23        .execute::<RecursiveCanonical>(ctx)
24        .vortex_expect("failed to execute array to recursive canonical form")
25        .0
26        .into_array()
27}
28
29/// Finds indices where two arrays differ based on `scalar_at` comparison.
30#[expect(clippy::unwrap_used)]
31fn find_mismatched_indices(left: &ArrayRef, right: &ArrayRef) -> Vec<usize> {
32    assert_eq!(left.len(), right.len());
33    (0..left.len())
34        .filter(|i| left.scalar_at(*i).unwrap() != right.scalar_at(*i).unwrap())
35        .collect()
36}
37
38/// Asserts that the scalar at position `$n` in array `$arr` equals `$expected`.
39///
40/// This is a convenience macro for testing that avoids verbose scalar comparison code.
41///
42/// # Example
43/// ```ignore
44/// let arr = PrimitiveArray::from_iter([1, 2, 3]);
45/// assert_nth_scalar!(arr, 0, 1);
46/// assert_nth_scalar!(arr, 1, 2);
47/// ```
48#[macro_export]
49macro_rules! assert_nth_scalar {
50    ($arr:expr, $n:expr, $expected:expr) => {{
51        use $crate::IntoArray as _;
52        let arr_ref: $crate::ArrayRef = $crate::IntoArray::into_array($arr.clone());
53        assert_eq!(
54            arr_ref.scalar_at($n).unwrap(),
55            $expected.try_into().unwrap()
56        );
57    }};
58}
59
60/// Asserts that the scalar at position `$n` in array `$arr` is null.
61///
62/// # Example
63///
64/// ```ignore
65/// let arr = PrimitiveArray::from_option_iter([Some(1), None, Some(3)]);
66/// assert_nth_scalar_null!(arr, 1);
67/// ```
68#[macro_export]
69macro_rules! assert_nth_scalar_is_null {
70    ($arr:expr, $n:expr) => {{
71        let arr_ref: $crate::ArrayRef = $crate::IntoArray::into_array($arr.clone());
72        assert!(
73            arr_ref.scalar_at($n).unwrap().is_null(),
74            "expected scalar at index {} to be null, but was {:?}",
75            $n,
76            arr_ref.scalar_at($n).unwrap()
77        );
78    }};
79}
80
81#[macro_export]
82macro_rules! assert_arrays_eq {
83    ($left:expr, $right:expr) => {{
84
85
86        let left: $crate::ArrayRef = $crate::IntoArray::into_array($left.clone());
87        let right: $crate::ArrayRef = $crate::IntoArray::into_array($right.clone());
88        if left.dtype() != right.dtype() {
89            panic!(
90                "assertion left == right failed: arrays differ in type: {} != {}.\n  left: {}\n right: {}",
91                left.dtype(),
92                right.dtype(),
93                left.display_values(),
94                right.display_values()
95            )
96        }
97
98        assert_eq!(
99            left.len(),
100            right.len(),
101            "assertion left == right failed: arrays differ in length: {} != {}.\n  left: {}\n right: {}",
102            left.len(),
103            right.len(),
104            left.display_values(),
105            right.display_values()
106        );
107
108        #[allow(deprecated)]
109        let left = left.clone();
110        #[allow(deprecated)]
111        let right = right.clone();
112        $crate::arrays::assert_arrays_eq_impl(&left, &right);
113    }};
114}
115
116/// Implementation of `assert_arrays_eq!` — called by the macro after converting inputs to
117/// `ArrayRef`.
118#[track_caller]
119#[allow(clippy::panic)]
120pub fn assert_arrays_eq_impl(left: &ArrayRef, right: &ArrayRef) {
121    let executed = execute_to_canonical(left.clone(), &mut LEGACY_SESSION.create_execution_ctx());
122
123    let left_right = find_mismatched_indices(left, right);
124    let executed_right = find_mismatched_indices(&executed, right);
125
126    if !left_right.is_empty() || !executed_right.is_empty() {
127        let mut msg = String::new();
128        if !left_right.is_empty() {
129            msg.push_str(&format!(
130                "\n  left != right at indices: {}",
131                format_indices(left_right)
132            ));
133        }
134        if !executed_right.is_empty() {
135            msg.push_str(&format!(
136                "\n  executed != right at indices: {}",
137                format_indices(executed_right)
138            ));
139        }
140        panic!(
141            "assertion failed: arrays do not match:{}\n     left: {}\n    right: {}\n executed: {}",
142            msg,
143            left.display_values(),
144            right.display_values(),
145            executed.display_values()
146        )
147    }
148}