1#![deny(missing_docs)]
8
9use ndarray::Array1;
10use num_complex::Complex64;
11
12pub fn assert_close_scalar(left: f64, right: f64, atol: f64, rtol: f64) {
19 let diff = (left - right).abs();
20 let tol = atol + rtol * right.abs();
21 assert!(
22 diff <= tol,
23 "expected |{left} - {right}| <= {tol}, got {diff}"
24 );
25}
26
27pub fn assert_close_slice(left: &[f64], right: &[f64], atol: f64, rtol: f64) {
36 assert_eq!(left.len(), right.len(), "length mismatch");
37 for (l, r) in left.iter().zip(right.iter()) {
38 assert_close_scalar(*l, *r, atol, rtol);
39 }
40}
41
42pub fn assert_close_complex_slice(left: &[Complex64], right: &[Complex64], atol: f64, rtol: f64) {
52 assert_eq!(left.len(), right.len(), "length mismatch");
53 for (l, r) in left.iter().zip(right.iter()) {
54 let diff = (*l - *r).norm();
55 let tol = atol + r.norm() * rtol;
56 assert!(diff <= tol, "expected |{l:?} - {r:?}| <= {tol}, got {diff}");
57 }
58}
59
60pub fn assert_close_array1(left: &Array1<f64>, right: &Array1<f64>, atol: f64, rtol: f64) {
70 assert_eq!(left.len(), right.len(), "length mismatch");
71 for (l, r) in left.iter().zip(right.iter()) {
72 assert_close_scalar(*l, *r, atol, rtol);
73 }
74}
75
76pub fn assert_close_complex_array1(
87 left: &Array1<Complex64>,
88 right: &Array1<Complex64>,
89 atol: f64,
90 rtol: f64,
91) {
92 assert_eq!(left.len(), right.len(), "length mismatch");
93 for (l, r) in left.iter().zip(right.iter()) {
94 let diff = (*l - *r).norm();
95 let tol = atol + r.norm() * rtol;
96 assert!(diff <= tol, "expected |{l:?} - {r:?}| <= {tol}, got {diff}");
97 }
98}
99
100#[macro_export]
107macro_rules! assert_close {
108 ($left:expr, $right:expr, atol = $atol:expr, rtol = $rtol:expr) => {{
109 $crate::assert_close_scalar($left as f64, $right as f64, $atol, $rtol);
110 }};
111 ($left:expr, $right:expr, tol = $tol:expr) => {{
112 $crate::assert_close_scalar($left as f64, $right as f64, $tol, 0.0);
113 }};
114 ($left:expr, $right:expr, slice, atol = $atol:expr, rtol = $rtol:expr) => {{
115 $crate::assert_close_slice($left, $right, $atol, $rtol);
116 }};
117 ($left:expr, $right:expr, slice, tol = $tol:expr) => {{
118 $crate::assert_close_slice($left, $right, $tol, 0.0);
119 }};
120 ($left:expr, $right:expr, complex_slice, atol = $atol:expr, rtol = $rtol:expr) => {{
121 $crate::assert_close_complex_slice($left, $right, $atol, $rtol);
122 }};
123 ($left:expr, $right:expr, complex_slice, tol = $tol:expr) => {{
124 $crate::assert_close_complex_slice($left, $right, $tol, 0.0);
125 }};
126 ($left:expr, $right:expr, array, atol = $atol:expr, rtol = $rtol:expr) => {{
127 $crate::assert_close_array1($left, $right, $atol, $rtol);
128 }};
129 ($left:expr, $right:expr, array, tol = $tol:expr) => {{
130 $crate::assert_close_array1($left, $right, $tol, 0.0);
131 }};
132 ($left:expr, $right:expr, complex_array, atol = $atol:expr, rtol = $rtol:expr) => {{
133 $crate::assert_close_complex_array1($left, $right, $atol, $rtol);
134 }};
135 ($left:expr, $right:expr, complex_array, tol = $tol:expr) => {{
136 $crate::assert_close_complex_array1($left, $right, $tol, 0.0);
137 }};
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use ndarray::Array1;
144 use num_complex::Complex64;
145
146 #[test]
147 fn macro_works() {
148 assert_close!(1.0, 1.0 + 1e-9, atol = 1e-8, rtol = 1e-8);
149 assert_close!(1.0f32, 1.0f32 + 1e-6, tol = 1e-5);
150 }
151
152 #[test]
153 fn slice_macro() {
154 let a = [1.0, 2.0, 3.0];
155 let b = [1.0 + 1e-9, 2.0 - 1e-9, 3.0];
156 assert_close!(&a, &b, slice, atol = 1e-8, rtol = 1e-8);
157 }
158
159 #[test]
160 fn complex_slice_macro() {
161 let a = [Complex64::new(1.0, 2.0), Complex64::new(0.5, -0.5)];
162 let b = [Complex64::new(1.0, 2.0 + 1e-9), Complex64::new(0.5, -0.5)];
163 assert_close!(&a, &b, complex_slice, tol = 1e-8);
164 }
165
166 #[test]
167 fn array_macro() {
168 let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
169 let b = Array1::from_vec(vec![1.0 + 1e-9, 2.0 - 1e-9, 3.0]);
170 assert_close!(&a, &b, array, atol = 1e-8, rtol = 1e-8);
171 }
172
173 #[test]
174 fn complex_array_macro() {
175 let a = Array1::from_vec(vec![Complex64::new(1.0, 2.0), Complex64::new(0.5, -0.5)]);
176 let b = Array1::from_vec(vec![
177 Complex64::new(1.0, 2.0 + 1e-9),
178 Complex64::new(0.5, -0.5),
179 ]);
180 assert_close!(&a, &b, complex_array, tol = 1e-8);
181 }
182}