1extern crate num;
7use num::{Float, Zero};
8
9fn float_max<T>(a: T, b: T) -> T
10where
11 T: Float,
12{
13 if a >= b {
14 a
15 } else {
16 b
17 }
18}
19
20fn float_min<T>(a: T, b: T) -> T
21where
22 T: Float,
23{
24 if a >= b {
25 b
26 } else {
27 a
28 }
29}
30
31pub fn nearly_equal<T>(a: T, b: T, rel_tol: T, abs_tol: T) -> bool
58where
59 T: Float + Zero,
60{
61 assert!(rel_tol > T::zero(), "relative tolerance nonpositive");
62 assert!(abs_tol > T::zero(), "absolute tolerance nonpositive");
63
64 let abs_a = a.abs();
65 let abs_b = b.abs();
66 let abs_diff = (a - b).abs();
67
68 if a.is_nan() || b.is_nan() {
69 false
70 } else if a == b || abs_diff <= T::min_positive_value() {
71 true
72 } else {
73 let max_abs_a_b = float_max(abs_a, abs_b);
74 abs_diff <= float_min(abs_tol, rel_tol * max_abs_a_b)
75 }
76}
77
78pub fn assert_nearly_equal<T>(a: T, b: T, rel_tol: T, abs_tol: T, msg: &'static str)
92where
93 T: Float + Zero,
94{
95 assert!(nearly_equal(a, b, rel_tol, abs_tol), "{}", msg);
96}
97
98pub fn nearly_equal_array<T>(a: &[T], b: &[T], rel_tol: T, abs_tol: T) -> bool
120where
121 T: Float + Zero,
122{
123 assert!(a.len() == b.len());
124 for (&a, &b) in a.iter().zip(b.iter()) {
125 if !nearly_equal(a, b, rel_tol, abs_tol) {
126 return false;
127 }
128 }
129 true
130}
131
132pub fn assert_nearly_equal_array<T>(a: &[T], b: &[T], rel_tol: T, abs_tol: T, msg: &'static str)
134where
135 T: Float + Zero,
136{
137 assert!(a.len() == b.len());
138 a.iter()
139 .zip(b.iter())
140 .enumerate()
141 .for_each(|(idx, (&ai, &bi))| {
142 if !nearly_equal(ai, bi, rel_tol, abs_tol) {
143 panic!("({}) arrays not equal at entry {}", msg, idx)
144 }
145 });
146}
147
148pub fn is_any_nan<T>(a: &[T]) -> bool
162where
163 T: Float,
164{
165 for &a in a.iter() {
166 if a.is_nan() {
167 return true;
168 }
169 }
170 false
171}
172
173pub fn assert_none_is_nan<T>(a: &[T], msg: &str)
185where
186 T: Float,
187{
188 for (idx, &a) in a.iter().enumerate() {
189 if a.is_nan() {
190 panic!("({}) nan at poisition {}", msg, idx);
191 }
192 }
193}
194
195pub fn assert_all_ge<T>(a: &[T], lim: T, msg: &str)
209where
210 T: Float + std::fmt::Display,
211{
212 for (idx, &a) in a.iter().enumerate() {
213 if a < lim {
214 panic!("({}) array[{}] = {} is lower than {}", msg, idx, a, lim);
215 }
216 }
217}
218
219pub fn assert_all_le<T>(a: &[T], lim: T, msg: &str)
233where
234 T: Float + std::fmt::Display,
235{
236 for (idx, &a) in a.iter().enumerate() {
237 if a > lim {
238 panic!("({}) array[{}] = {} is greater than {}", msg, idx, a, lim);
239 }
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 #[test]
248 fn infinities() {
249 let a = std::f64::INFINITY;
250 let b = std::f64::INFINITY;
251 assert!(nearly_equal(a, b, 0.1, 0.1));
252 }
253
254 #[test]
255 fn nans() {
256 let a = std::f64::NAN;
257 let b = std::f64::NAN;
258 let c = 1.0;
259 assert!(!nearly_equal(a, b, 0.1, 0.1));
260 assert!(!nearly_equal(a, c, 0.1, 0.1));
261 }
262
263 #[test]
264 #[should_panic]
265 fn no_nonpositive_rel_tol() {
266 nearly_equal(5.0, 6.0, 0.0, 1e-7);
267 }
268
269 #[test]
270 #[should_panic]
271 fn no_nonpositive_abs_tol() {
272 nearly_equal(5.0, 6.0, 0.01, 0.0);
273 }
274
275 #[test]
276 fn not_nearly_equal() {
277 let a = 1e-8;
278 let b = 1e-5;
279 assert!(!nearly_equal(a, b, 1e-6, 1e-6))
280 }
281
282 #[test]
283 fn not_nearly_equal_rel_tol() {
284 let a = 1e-14;
285 let b = 1e-5;
286 assert!(!nearly_equal(a, b, 1e-6, 0.1))
287 }
288
289 #[test]
290 fn really_nearly_equal() {
291 let a = 1.;
292 let b = 1. + std::f64::MIN_POSITIVE;
293 assert!(nearly_equal(
294 a,
295 b,
296 std::f64::MIN_POSITIVE,
297 std::f64::MIN_POSITIVE
298 ))
299 }
300
301 #[test]
302 fn absolutely_equal() {
303 let a = 5.;
304 let b = 5.;
305 assert!(nearly_equal(
306 a,
307 b,
308 std::f64::MIN_POSITIVE,
309 std::f64::MIN_POSITIVE
310 ))
311 }
312
313 #[test]
314 fn with_f32() {
315 let a = 1000.0_f32;
316 let b = 1001.0_f32;
317 assert!(nearly_equal(a, b, 0.01, 1.0))
318 }
319
320 #[test]
321 #[should_panic]
322 fn assert_numbers_equal() {
323 assert_nearly_equal(1.0, 2.0, 0.01, 0.001, "wtf");
324 }
325
326 #[test]
327 fn arrays_equal() {
328 let x = [1.0, 2.0, 3.0];
329 let y = [1.0, 2.0 + 1e-7, 3.0 + 9.9999999e-6];
330 assert!(nearly_equal_array(&x, &y, 1e-4, 1e-5));
331 }
332
333 #[test]
334 fn arrays_not_equal() {
335 let x = [1.0, 2.0, 3.0];
336 let y = [1.0, 2.0 + 1e-7, 3.0 + 1e-4];
337 assert!(!nearly_equal_array(&x, &y, 1e-4, 1e-5));
338 }
339
340 #[test]
341 fn arrays_identical() {
342 let x = [1.0, 2.0, 3.0];
343 assert!(nearly_equal_array(&x, &x, 1e-4, 1e-5));
344 }
345
346 #[test]
347 #[should_panic]
348 fn assert_arrays_not_equal() {
349 let x = [1.0, 2.0, 3.0];
350 let y = [1.0, 2.0 + 1e-7, 3.0 + 1e-4];
351 assert_nearly_equal_array(&x, &y, 1e-4, 1e-5, "arrays not equal");
352 }
353
354 #[test]
355 #[should_panic]
356 fn assert_arrays_different_lens() {
357 let x = [1.0, 2.0, 3.0];
358 let y = [1.0, 2.0 + 1e-7];
359 assert_nearly_equal_array(&x, &y, 1e-4, 1e-5, "arrays not equal");
360 }
361
362 #[test]
363 fn any_is_nan() {
364 let x: [f64; 2] = [0.0, 1.0];
365 assert!(!is_any_nan(&x));
366
367 let y: [f64; 3] = [0.0, std::f64::NAN, 1.0];
368 assert!(is_any_nan(&y));
369 }
370
371 #[test]
372 #[should_panic]
373 fn none_is_none_panic() {
374 let y: [f64; 3] = [0.0, std::f64::NAN, 1.0];
375 assert_none_is_nan(&y, "y");
376 }
377
378 #[test]
379 fn assert_all_positive() {
380 let y = [0.0, 1e-10, 1e-16];
381 assert_all_ge(&y, 0., "y");
382 }
383
384 #[test]
385 #[should_panic]
386 fn assert_all_positive_panic() {
387 let y = [0.0, 1e-10, -1e-12, 10.0];
388 assert_all_ge(&y, 0., "y");
389 }
390
391 #[test]
392 fn assert_all_le_one_f32() {
393 let y = [0.0_f32, 1.0, 0.5, -100.0];
394 assert_all_le(&y, 1.0, "y");
395 }
396
397 #[test]
398 #[should_panic]
399 fn assert_all_le_one_panic() {
400 let y = [0.0, 1.0, 1.0 + 4e-16, -100.0];
401 assert_all_le(&y, 1.0, "y");
402 }
403}