spl_math/
approximations.rs

1#![allow(clippy::arithmetic_side_effects)]
2//! Approximation calculations
3
4use {
5    num_traits::{CheckedShl, CheckedShr, PrimInt},
6    std::cmp::Ordering,
7};
8
9/// Calculate square root of the given number
10///
11/// Code lovingly adapted from the excellent work at:
12///
13/// <https://github.com/derekdreery/integer-sqrt-rs>
14///
15/// The algorithm is based on the implementation in:
16///
17/// <https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_(base_2)>
18pub fn sqrt<T: PrimInt + CheckedShl + CheckedShr>(radicand: T) -> Option<T> {
19    match radicand.cmp(&T::zero()) {
20        Ordering::Less => return None,             // fail for less than 0
21        Ordering::Equal => return Some(T::zero()), // do nothing for 0
22        _ => {}
23    }
24
25    // Compute bit, the largest power of 4 <= n
26    let max_shift: u32 = T::zero().leading_zeros() - 1;
27    let shift: u32 = (max_shift - radicand.leading_zeros()) & !1;
28    let mut bit = T::one().checked_shl(shift)?;
29
30    let mut n = radicand;
31    let mut result = T::zero();
32    while bit != T::zero() {
33        let result_with_bit = result.checked_add(&bit)?;
34        if n >= result_with_bit {
35            n = n.checked_sub(&result_with_bit)?;
36            result = result.checked_shr(1)?.checked_add(&bit)?;
37        } else {
38            result = result.checked_shr(1)?;
39        }
40        bit = bit.checked_shr(2)?;
41    }
42    Some(result)
43}
44
45/// Calculate the normal cdf of the given number
46///
47/// The approximation is accurate to 3 digits
48///
49/// Code lovingly adapted from the excellent work at:
50///
51/// <https://www.hrpub.org/download/20140305/MS7-13401470.pdf>
52///
53/// The algorithm is based on the implementation in the paper above.
54#[inline(never)]
55pub fn f32_normal_cdf(argument: f32) -> f32 {
56    const PI: f32 = std::f32::consts::PI;
57
58    let mod_argument = if argument < 0.0 {
59        -1.0 * argument
60    } else {
61        argument
62    };
63    let tabulation_numerator: f32 =
64        (1.0 / (1.0 * (2.0 * PI).sqrt())) * (-1.0 * (mod_argument * mod_argument) / 2.0).exp();
65    let tabulation_denominator: f32 =
66        0.226 + 0.64 * mod_argument + 0.33 * (mod_argument * mod_argument + 3.0).sqrt();
67    let y: f32 = 1.0 - tabulation_numerator / tabulation_denominator;
68    if argument < 0.0 {
69        1.0 - y
70    } else {
71        y
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use {super::*, proptest::prelude::*};
78
79    fn check_square_root(radicand: u128) {
80        let root = sqrt(radicand).unwrap();
81        let lower_bound = root.saturating_sub(1).checked_pow(2).unwrap();
82        let upper_bound = root.checked_add(1).unwrap().checked_pow(2).unwrap();
83        assert!(radicand <= upper_bound);
84        assert!(radicand >= lower_bound);
85    }
86
87    #[test]
88    fn test_square_root_min_max() {
89        let test_roots = [0, u64::MAX];
90        for i in test_roots.iter() {
91            check_square_root(*i as u128);
92        }
93    }
94
95    proptest! {
96        #[test]
97        fn test_square_root(a in 0..u64::MAX) {
98            check_square_root(a as u128);
99        }
100    }
101
102    fn check_normal_cdf_f32(argument: f32) {
103        let result = f32_normal_cdf(argument);
104        let check_result = 0.5 * (1.0 + libm::erff(argument / std::f32::consts::SQRT_2));
105        let abs_difference: f32 = (result - check_result).abs();
106        assert!(abs_difference <= 0.000_2);
107    }
108
109    #[test]
110    fn test_normal_cdf_f32_min_max() {
111        let test_arguments: [f32; 2] = [f32::MIN, f32::MAX];
112        for i in test_arguments.iter() {
113            check_normal_cdf_f32(*i)
114        }
115    }
116
117    proptest! {
118        #[test]
119        fn test_normal_cdf(a in -1000..1000) {
120
121            check_normal_cdf_f32((a as f32)*0.005);
122        }
123    }
124}