unit_test_utils/
lib.rs

1//! Collection of utilities for unit tests
2//!
3//! This crate offers tools for unit tests, especially for projects involving
4//! numerical methods
5//!
6extern crate num;
7use num::{Float, Zero};
8
9fn float_max<T>(a: T, b: T) -> T
10where
11    T: Float,
12{
13    if a >= b {
14        a
15    } else {
16        b
17    }
18}
19
20fn float_min<T>(a: T, b: T) -> T
21where
22    T: Float,
23{
24    if a >= b {
25        b
26    } else {
27        a
28    }
29}
30
31/// Whether two floats are nearly equal (up to specified tolerance)
32///
33/// ## Arguments
34/// - `a` first float
35/// - `b` second float
36/// - `rel_tol` relative tolerance (must be positive)
37/// - `abs_tol` absolute tolerance (must be positive)
38///
39/// ## Results
40///
41/// Returns true if and only if `a` is nearly equal to `b`
42///
43/// In particular, this function will return true if and only if BOTH of the following
44/// conditions are satisfied
45/// - `a==b`, e.g., if the two floats are identical or both equal to infinity
46/// - `|a-b| <= max(abs_tol, rel_tol*max(|a|, |b|))`
47///
48/// The function will return false if either of `a` or `b` is NaN.
49///
50/// It works with `f64` and `f32`
51///
52/// ## Panics
53///
54/// The function will panic if the specified relative or absolute tolerance is
55/// not positive.
56///
57pub fn nearly_equal<T>(a: T, b: T, rel_tol: T, abs_tol: T) -> bool
58where
59    T: Float + Zero,
60{
61    assert!(rel_tol > T::zero(), "relative tolerance nonpositive");
62    assert!(abs_tol > T::zero(), "absolute tolerance nonpositive");
63
64    let abs_a = a.abs();
65    let abs_b = b.abs();
66    let abs_diff = (a - b).abs();
67
68    if a.is_nan() || b.is_nan() {
69        false
70    } else if a == b || abs_diff <= T::min_positive_value() {
71        true
72    } else {
73        let max_abs_a_b = float_max(abs_a, abs_b);
74        abs_diff <= float_min(abs_tol, rel_tol * max_abs_a_b)
75    }
76}
77
78/// Asserts that two numbers are nearly equal
79///
80/// ## Arguments
81/// - `a` first float
82/// - `b` second float
83/// - `rel_tol` relative tolerance (must be positive)
84/// - `abs_tol` absolute tolerance (must be positive)
85/// - `msg` an error message that will be thrown if the two numbers are not nearly equal
86///
87/// ## Panics
88///
89/// The function panics if the two floating-point numbers are not almost equal to one
90/// another up to the specified tolerances
91pub fn assert_nearly_equal<T>(a: T, b: T, rel_tol: T, abs_tol: T, msg: &'static str)
92where
93    T: Float + Zero,
94{
95    assert!(nearly_equal(a, b, rel_tol, abs_tol), "{}", msg);
96}
97
98/// Checks whether two arrays are element-wise nearly equal
99///
100/// ## Arguments
101///
102/// - `a` first array
103/// - `b` second array
104/// - `rel_tol` relative tolerance
105/// - `abs_tol` absolute tolerance
106///
107/// ## Returns
108///
109/// The function returns true if and only if the application of `nearly_equal`
110/// on all elements of the two arrays returns true, i.e., if the two arrays
111/// are element-wise almost equal
112///
113/// ## Panics
114///
115/// The function will panic in the following cases:
116/// - if the specified relative or absolute tolerance is not positive and
117/// - if the two arrays have different lengths
118///
119pub fn nearly_equal_array<T>(a: &[T], b: &[T], rel_tol: T, abs_tol: T) -> bool
120where
121    T: Float + Zero,
122{
123    assert!(a.len() == b.len());
124    for (&a, &b) in a.iter().zip(b.iter()) {
125        if !nearly_equal(a, b, rel_tol, abs_tol) {
126            return false;
127        }
128    }
129    true
130}
131
132/// Asserts that two given arrays are almost equal
133pub fn assert_nearly_equal_array<T>(a: &[T], b: &[T], rel_tol: T, abs_tol: T, msg: &'static str)
134where
135    T: Float + Zero,
136{
137    assert!(a.len() == b.len());
138    a.iter()
139        .zip(b.iter())
140        .enumerate()
141        .for_each(|(idx, (&ai, &bi))| {
142            if !nearly_equal(ai, bi, rel_tol, abs_tol) {
143                panic!("({}) arrays not equal at entry {}", msg, idx)
144            }
145        });
146}
147
148/// Checks whether a given array contains any `NaN` elements
149///
150/// ## Arguments
151///
152/// - `a` an array of floating-point numbers
153///
154/// ## Returns
155///
156/// Returns `true` if and only if there is at least one element which is `NaN`
157///
158/// ## Panics
159///
160/// No panics
161pub fn is_any_nan<T>(a: &[T]) -> bool
162where
163    T: Float,
164{
165    for &a in a.iter() {
166        if a.is_nan() {
167            return true;
168        }
169    }
170    false
171}
172
173/// Asserts that no element of an array is `NaN`
174///
175/// ## Arguments
176///
177/// - `a` an array of floating-point numbers
178/// - `msg` error name
179///
180/// ## Panics
181///
182/// This function will panic if any element of the given array is `NaN`
183///
184pub fn assert_none_is_nan<T>(a: &[T], msg: &str)
185where
186    T: Float,
187{
188    for (idx, &a) in a.iter().enumerate() {
189        if a.is_nan() {
190            panic!("({}) nan at poisition {}", msg, idx);
191        }
192    }
193}
194
195/// Asserts that all elements in an array are greater than or equal a given value
196///
197/// ## Arguments
198///
199/// - `a` given array of floating-point numbers
200/// - `lim` the lower bound on the array; all elements must be greater than or equal
201///    to `lim`, otherwise the function panics
202/// - `msg` error message
203///
204/// ## Panics
205///
206/// The function panic if there is at least on element in `a` which is smaller than `lim`
207///
208pub fn assert_all_ge<T>(a: &[T], lim: T, msg: &str)
209where
210    T: Float + std::fmt::Display,
211{
212    for (idx, &a) in a.iter().enumerate() {
213        if a < lim {
214            panic!("({}) array[{}] = {} is lower than {}", msg, idx, a, lim);
215        }
216    }
217}
218
219/// Asserts that all elements in an array are less than or equal a given value
220///
221/// ## Arguments
222///
223/// - `a` given array of floating-point numbers
224/// - `lim` the upper bound on the array; all elements must be less than or equal
225///    to `lim`, otherwise the function panics
226/// - `msg` error message
227///
228/// ## Panics
229///
230/// The function panic if there is at least on element in `a` which is greater than `lim`
231///
232pub fn assert_all_le<T>(a: &[T], lim: T, msg: &str)
233where
234    T: Float + std::fmt::Display,
235{
236    for (idx, &a) in a.iter().enumerate() {
237        if a > lim {
238            panic!("({}) array[{}] = {} is greater than {}", msg, idx, a, lim);
239        }
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[test]
248    fn infinities() {
249        let a = std::f64::INFINITY;
250        let b = std::f64::INFINITY;
251        assert!(nearly_equal(a, b, 0.1, 0.1));
252    }
253
254    #[test]
255    fn nans() {
256        let a = std::f64::NAN;
257        let b = std::f64::NAN;
258        let c = 1.0;
259        assert!(!nearly_equal(a, b, 0.1, 0.1));
260        assert!(!nearly_equal(a, c, 0.1, 0.1));
261    }
262
263    #[test]
264    #[should_panic]
265    fn no_nonpositive_rel_tol() {
266        nearly_equal(5.0, 6.0, 0.0, 1e-7);
267    }
268
269    #[test]
270    #[should_panic]
271    fn no_nonpositive_abs_tol() {
272        nearly_equal(5.0, 6.0, 0.01, 0.0);
273    }
274
275    #[test]
276    fn not_nearly_equal() {
277        let a = 1e-8;
278        let b = 1e-5;
279        assert!(!nearly_equal(a, b, 1e-6, 1e-6))
280    }
281
282    #[test]
283    fn not_nearly_equal_rel_tol() {
284        let a = 1e-14;
285        let b = 1e-5;
286        assert!(!nearly_equal(a, b, 1e-6, 0.1))
287    }
288
289    #[test]
290    fn really_nearly_equal() {
291        let a = 1.;
292        let b = 1. + std::f64::MIN_POSITIVE;
293        assert!(nearly_equal(
294            a,
295            b,
296            std::f64::MIN_POSITIVE,
297            std::f64::MIN_POSITIVE
298        ))
299    }
300
301    #[test]
302    fn absolutely_equal() {
303        let a = 5.;
304        let b = 5.;
305        assert!(nearly_equal(
306            a,
307            b,
308            std::f64::MIN_POSITIVE,
309            std::f64::MIN_POSITIVE
310        ))
311    }
312
313    #[test]
314    fn with_f32() {
315        let a = 1000.0_f32;
316        let b = 1001.0_f32;
317        assert!(nearly_equal(a, b, 0.01, 1.0))
318    }
319
320    #[test]
321    #[should_panic]
322    fn assert_numbers_equal() {
323        assert_nearly_equal(1.0, 2.0, 0.01, 0.001, "wtf");
324    }
325
326    #[test]
327    fn arrays_equal() {
328        let x = [1.0, 2.0, 3.0];
329        let y = [1.0, 2.0 + 1e-7, 3.0 + 9.9999999e-6];
330        assert!(nearly_equal_array(&x, &y, 1e-4, 1e-5));
331    }
332
333    #[test]
334    fn arrays_not_equal() {
335        let x = [1.0, 2.0, 3.0];
336        let y = [1.0, 2.0 + 1e-7, 3.0 + 1e-4];
337        assert!(!nearly_equal_array(&x, &y, 1e-4, 1e-5));
338    }
339
340    #[test]
341    fn arrays_identical() {
342        let x = [1.0, 2.0, 3.0];
343        assert!(nearly_equal_array(&x, &x, 1e-4, 1e-5));
344    }
345
346    #[test]
347    #[should_panic]
348    fn assert_arrays_not_equal() {
349        let x = [1.0, 2.0, 3.0];
350        let y = [1.0, 2.0 + 1e-7, 3.0 + 1e-4];
351        assert_nearly_equal_array(&x, &y, 1e-4, 1e-5, "arrays not equal");
352    }
353
354    #[test]
355    #[should_panic]
356    fn assert_arrays_different_lens() {
357        let x = [1.0, 2.0, 3.0];
358        let y = [1.0, 2.0 + 1e-7];
359        assert_nearly_equal_array(&x, &y, 1e-4, 1e-5, "arrays not equal");
360    }
361
362    #[test]
363    fn any_is_nan() {
364        let x: [f64; 2] = [0.0, 1.0];
365        assert!(!is_any_nan(&x));
366
367        let y: [f64; 3] = [0.0, std::f64::NAN, 1.0];
368        assert!(is_any_nan(&y));
369    }
370
371    #[test]
372    #[should_panic]
373    fn none_is_none_panic() {
374        let y: [f64; 3] = [0.0, std::f64::NAN, 1.0];
375        assert_none_is_nan(&y, "y");
376    }
377
378    #[test]
379    fn assert_all_positive() {
380        let y = [0.0, 1e-10, 1e-16];
381        assert_all_ge(&y, 0., "y");
382    }
383
384    #[test]
385    #[should_panic]
386    fn assert_all_positive_panic() {
387        let y = [0.0, 1e-10, -1e-12, 10.0];
388        assert_all_ge(&y, 0., "y");
389    }
390
391    #[test]
392    fn assert_all_le_one_f32() {
393        let y = [0.0_f32, 1.0, 0.5, -100.0];
394        assert_all_le(&y, 1.0, "y");
395    }
396
397    #[test]
398    #[should_panic]
399    fn assert_all_le_one_panic() {
400        let y = [0.0, 1.0, 1.0 + 4e-16, -100.0];
401        assert_all_le(&y, 1.0, "y");
402    }
403}