spl_math/
approximations.rs1#![allow(clippy::arithmetic_side_effects)]
2use {
5 num_traits::{CheckedShl, CheckedShr, PrimInt},
6 std::cmp::Ordering,
7};
8
9pub fn sqrt<T: PrimInt + CheckedShl + CheckedShr>(radicand: T) -> Option<T> {
19 match radicand.cmp(&T::zero()) {
20 Ordering::Less => return None, Ordering::Equal => return Some(T::zero()), _ => {}
23 }
24
25 let max_shift: u32 = T::zero().leading_zeros() - 1;
27 let shift: u32 = (max_shift - radicand.leading_zeros()) & !1;
28 let mut bit = T::one().checked_shl(shift)?;
29
30 let mut n = radicand;
31 let mut result = T::zero();
32 while bit != T::zero() {
33 let result_with_bit = result.checked_add(&bit)?;
34 if n >= result_with_bit {
35 n = n.checked_sub(&result_with_bit)?;
36 result = result.checked_shr(1)?.checked_add(&bit)?;
37 } else {
38 result = result.checked_shr(1)?;
39 }
40 bit = bit.checked_shr(2)?;
41 }
42 Some(result)
43}
44
45#[inline(never)]
55pub fn f32_normal_cdf(argument: f32) -> f32 {
56 const PI: f32 = std::f32::consts::PI;
57
58 let mod_argument = if argument < 0.0 {
59 -1.0 * argument
60 } else {
61 argument
62 };
63 let tabulation_numerator: f32 =
64 (1.0 / (1.0 * (2.0 * PI).sqrt())) * (-1.0 * (mod_argument * mod_argument) / 2.0).exp();
65 let tabulation_denominator: f32 =
66 0.226 + 0.64 * mod_argument + 0.33 * (mod_argument * mod_argument + 3.0).sqrt();
67 let y: f32 = 1.0 - tabulation_numerator / tabulation_denominator;
68 if argument < 0.0 {
69 1.0 - y
70 } else {
71 y
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use {super::*, proptest::prelude::*};
78
79 fn check_square_root(radicand: u128) {
80 let root = sqrt(radicand).unwrap();
81 let lower_bound = root.saturating_sub(1).checked_pow(2).unwrap();
82 let upper_bound = root.checked_add(1).unwrap().checked_pow(2).unwrap();
83 assert!(radicand <= upper_bound);
84 assert!(radicand >= lower_bound);
85 }
86
87 #[test]
88 fn test_square_root_min_max() {
89 let test_roots = [0, u64::MAX];
90 for i in test_roots.iter() {
91 check_square_root(*i as u128);
92 }
93 }
94
95 proptest! {
96 #[test]
97 fn test_square_root(a in 0..u64::MAX) {
98 check_square_root(a as u128);
99 }
100 }
101
102 fn check_normal_cdf_f32(argument: f32) {
103 let result = f32_normal_cdf(argument);
104 let check_result = 0.5 * (1.0 + libm::erff(argument / std::f32::consts::SQRT_2));
105 let abs_difference: f32 = (result - check_result).abs();
106 assert!(abs_difference <= 0.000_2);
107 }
108
109 #[test]
110 fn test_normal_cdf_f32_min_max() {
111 let test_arguments: [f32; 2] = [f32::MIN, f32::MAX];
112 for i in test_arguments.iter() {
113 check_normal_cdf_f32(*i)
114 }
115 }
116
117 proptest! {
118 #[test]
119 fn test_normal_cdf(a in -1000..1000) {
120
121 check_normal_cdf_f32((a as f32)*0.005);
122 }
123 }
124}