Skip to main content

vector_ta/utilities/
math_functions.rs

1#![allow(clippy::many_single_char_names)]
2
3use std::f64::consts::{FRAC_PI_2, FRAC_PI_4, LN_2, PI};
4
5#[inline(always)]
6pub fn atan_fast(z: f64) -> f64 {
7    const C0: f64 = 0.2447;
8    const C1: f64 = 0.0663;
9    const PIO4: f64 = std::f64::consts::FRAC_PI_4;
10    const PIO2: f64 = std::f64::consts::FRAC_PI_2;
11
12    let a = z.abs();
13    if a <= 1.0 {
14        let t = C1.mul_add(a, C0);
15        PIO4.mul_add(z, z.mul_add(a - 1.0, t))
16    } else {
17        let inv = 1.0 / z;
18        let t = C1.mul_add(inv.abs(), C0);
19        let base = PIO4.mul_add(inv, inv.mul_add(inv.abs() - 1.0, t));
20        if z.is_sign_positive() {
21            PIO2 - base
22        } else {
23            -PIO2 - base
24        }
25    }
26}
27
28#[inline(always)]
29fn flip_sign_nonnan(x: f64, val: f64) -> f64 {
30    if x.is_sign_negative() {
31        -val
32    } else {
33        val
34    }
35}
36
37#[inline(always)]
38pub fn atan_raw64(x: f64) -> f64 {
39    const N2: f64 = 0.273;
40    (FRAC_PI_4 + N2 - N2 * x.abs()) * x
41}
42
43#[inline(always)]
44pub fn atan64(x: f64) -> f64 {
45    if x.abs() > 1.0 {
46        debug_assert!(!x.is_nan());
47        flip_sign_nonnan(x, FRAC_PI_2) - atan_raw64(1.0 / x)
48    } else {
49        atan_raw64(x)
50    }
51}
52
53#[inline(always)]
54pub fn fast_sin_f64(mut x: f64) -> f64 {
55    const TWO_PI: f64 = 2.0 * PI;
56
57    x %= TWO_PI;
58    if x < -PI {
59        x += TWO_PI;
60    } else if x > PI {
61        x -= TWO_PI;
62    }
63
64    const FOUROVERPI: f64 = 1.2732395447351627;
65    const FOUROVERPISQ: f64 = 0.405_284_734_569_351_1;
66    const Q: f64 = 0.776_330_232_480_075;
67
68    let sign = if x < 0.0 { -1.0 } else { 1.0 };
69    let ax = x.abs();
70
71    let mut y = FOUROVERPI * ax - FOUROVERPISQ * ax * ax;
72    if sign < 0.0 {
73        y = -y;
74    }
75    y * (Q + (1.0 - Q) * y.abs())
76}
77
78#[inline(always)]
79pub fn fast_cos_f64(mut x: f64) -> f64 {
80    const TWO_PI: f64 = 2.0 * PI;
81
82    x %= TWO_PI;
83    if x < -PI {
84        x += TWO_PI;
85    } else if x > PI {
86        x -= TWO_PI;
87    }
88
89    x += FRAC_PI_2;
90    if x > PI {
91        x -= TWO_PI;
92    } else if x < -PI {
93        x += TWO_PI;
94    }
95
96    const FOUROVERPI: f64 = 1.2732395447351627;
97    const FOUROVERPISQ: f64 = 0.405_284_734_569_351_1;
98    const Q: f64 = 0.776_330_232_480_075;
99
100    let sign = if x < 0.0 { -1.0 } else { 1.0 };
101    let ax = x.abs();
102
103    let mut y = FOUROVERPI * ax - FOUROVERPISQ * ax * ax;
104    if sign < 0.0 {
105        y = -y;
106    }
107    y * (Q + (1.0 - Q) * y.abs())
108}
109
110#[inline(always)]
111fn to_bits_f64(x: f64) -> u64 {
112    x.to_bits()
113}
114#[inline(always)]
115fn from_bits_f64(u: u64) -> f64 {
116    f64::from_bits(u)
117}
118
119#[inline]
120pub fn log2_approx_f64(x: f64) -> f64 {
121    let mut y = to_bits_f64(x) as f64;
122    y *= 2.220446049250313e-16;
123    y - 1022.94269504
124}
125
126#[inline]
127pub fn ln_approx_f64(x: f64) -> f64 {
128    log2_approx_f64(x) * LN_2
129}
130
131#[inline]
132pub fn pow2_approx_f64(p: f64) -> f64 {
133    let clipp = if p < -1022.0 { -1022.0 } else { p };
134    const POW2_OFFSET: f64 = 1022.942695;
135    let v = ((1u64 << 52) as f64 * (clipp + POW2_OFFSET)) as u64;
136    from_bits_f64(v)
137}
138
139#[inline]
140pub fn pow_approx_f64(x: f64, p: f64) -> f64 {
141    pow2_approx_f64(p * log2_approx_f64(x))
142}
143
144#[inline]
145pub fn exp_approx_f64(p: f64) -> f64 {
146    const INV_LN2: f64 = std::f64::consts::LOG2_E;
147    pow2_approx_f64(INV_LN2 * p)
148}
149
150#[inline]
151pub fn sigmoid_approx_f64(x: f64) -> f64 {
152    1.0 / (1.0 + exp_approx_f64(-x))
153}
154
155#[inline]
156pub fn lambertw_approx_f64(x: f64) -> f64 {
157    if x == 0.0 {
158        return 0.0;
159    }
160
161    let mut w = if x < 1.0 {
162        x
163    } else {
164        let g = ln_approx_f64(x).max(0.0);
165        if g < 0.5 {
166            0.5
167        } else {
168            g
169        }
170    };
171
172    for _ in 0..2 {
173        let ew = exp_approx_f64(w);
174        let f = w * ew - x;
175        let fp = ew * (w + 1.0);
176        w -= f / fp;
177    }
178    w
179}
180
181#[inline]
182pub fn lambertwexpx_approx_f64(v: f64) -> f64 {
183    let mut y = 1.0_f64 + v.abs();
184    for _ in 0..5 {
185        let w = lambertw_approx_f64(y);
186        y = w * exp_approx_f64(w);
187    }
188    y
189}
190
191#[inline]
192pub fn ln_gamma_approx_f64(x: f64) -> f64 {
193    -0.0810614667_f64 - x - ln_approx_f64(x) + (0.5_f64 + x) * ln_approx_f64(1.0_f64 + x)
194}
195
196#[inline]
197pub fn digamma_approx_f64(x: f64) -> f64 {
198    let onepx = 1.0 + x;
199    -1.0 / x - 1.0 / (2.0 * onepx) + ln_approx_f64(onepx)
200}
201
202#[inline]
203pub fn erfc_approx_f64(x: f64) -> f64 {
204    const K: f64 = 3.3509633149424609;
205    2.0 / (1.0 + pow2_approx_f64(K * x))
206}
207
208#[inline]
209pub fn erf_approx_f64(x: f64) -> f64 {
210    1.0 - erfc_approx_f64(x)
211}
212
213#[inline]
214pub fn erf_inv_approx_f64(x: f64) -> f64 {
215    const INVK: f64 = 0.30004578719350504;
216    let ratio = (1.0 + x) / (1.0 - x);
217    INVK * log2_approx_f64(ratio)
218}
219
220#[inline]
221pub fn sinh_approx_f64(x: f64) -> f64 {
222    0.5 * (exp_approx_f64(x) - exp_approx_f64(-x))
223}
224
225#[inline]
226pub fn cosh_approx_f64(x: f64) -> f64 {
227    0.5 * (exp_approx_f64(x) + exp_approx_f64(-x))
228}
229
230#[inline]
231pub fn tanh_approx_f64(x: f64) -> f64 {
232    -1.0 + 2.0 / (1.0 + exp_approx_f64(-2.0 * x))
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use std::f64::consts::PI;
239
240    fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
241        (a - b).abs() < tol
242    }
243
244    #[test]
245    fn test_fast_sin_cos() {
246        let angles = [
247            0.0,
248            PI * 0.25,
249            PI * 0.5,
250            PI * 0.75,
251            PI,
252            -PI * 0.5,
253            -PI,
254            10.0,
255            -10.0,
256        ];
257        for &ang in &angles {
258            let fs = fast_sin_f64(ang);
259            let fc = fast_cos_f64(ang);
260            let rs = ang.sin();
261            let rc = ang.cos();
262
263            assert!(
264                approx_eq(fs, rs, 0.05),
265                "fast_sin_f64({ang}) => {fs} vs std => {rs}"
266            );
267            assert!(
268                approx_eq(fc, rc, 0.05),
269                "fast_cos_f64({ang}) => {fc} vs std => {rc}"
270            );
271        }
272    }
273
274    #[test]
275    fn test_atan_approx() {
276        let vals = [0.0, 0.5, 1.0, 2.0, -1.0, -10.0];
277        for &v in &vals {
278            let app = atan64(v);
279            let real = v.atan();
280            assert!(
281                approx_eq(app, real, 0.1),
282                "atan64({v}) => {app}, real => {real}"
283            );
284        }
285    }
286
287    #[test]
288    fn test_log2_approx() {
289        let vals = [0.125, 0.5, 1.0, 2.0, 8.0, 10.0];
290        for &v in &vals {
291            let app = log2_approx_f64(v);
292            let real = v.log2();
293            assert!(
294                approx_eq(app, real, 0.15),
295                "log2_approx_f64({v}) => {app}, real => {real}"
296            );
297        }
298    }
299
300    #[test]
301    fn test_ln_approx() {
302        let vals = [0.125, 0.5, 1.0, 2.0, 8.0, 10.0];
303        for &v in &vals {
304            let app = ln_approx_f64(v);
305            let real = v.ln();
306            assert!(
307                approx_eq(app, real, 0.2),
308                "ln_approx_f64({v}) => {app}, real => {real}"
309            );
310        }
311    }
312
313    #[test]
314    fn test_exp_approx() {
315        let vals = [-2.0, -1.0, 0.0, 1.0, 2.0, 5.0];
316        for &v in &vals {
317            let app = exp_approx_f64(v);
318            let real = v.exp();
319            let tol = 0.15 * real.abs().max(1.0);
320            assert!(
321                approx_eq(app, real, tol),
322                "exp_approx_f64({v}) => {app}, real => {real}"
323            );
324        }
325    }
326
327    #[test]
328    fn test_pow2_approx() {
329        let vals = [-10.0, -1.0, 0.0, 1.0, 10.0, 15.5];
330        for &v in &vals {
331            let app = pow2_approx_f64(v);
332            let real = (2.0_f64).powf(v);
333            let tol = 0.15 * real.abs().max(1.0);
334            assert!(
335                approx_eq(app, real, tol),
336                "pow2_approx_f64({v}) => {app}, real => {real}"
337            );
338        }
339    }
340
341    #[test]
342    fn test_pow_approx() {
343        let bases = [0.5, 1.0, 2.0, 10.0];
344        let exps = [-2.0, -1.0, 0.0, 1.0, 2.0];
345        for &b in &bases {
346            for &p in &exps {
347                let app = pow_approx_f64(b, p);
348                let real = b.powf(p);
349                let tol = 0.20 * real.abs().max(1.0);
350                assert!(
351                    approx_eq(app, real, tol),
352                    "pow_approx_f64({b}^{p}) => {app}, real => {real}"
353                );
354            }
355        }
356    }
357
358    #[test]
359    fn test_sigmoid_approx() {
360        let vals = [-4.0, -1.0, 0.0, 1.0, 4.0];
361        for &v in &vals {
362            let app = sigmoid_approx_f64(v);
363            let real = 1.0 / (1.0 + (-v).exp());
364            assert!(
365                approx_eq(app, real, 0.02),
366                "sigmoid_approx_f64({v}) => {app}, real => {real}"
367            );
368        }
369    }
370
371    #[test]
372    fn test_erf_inv_approx() {
373        let vals = [-0.9, -0.5, 0.0, 0.5, 0.9];
374        for &v in &vals {
375            let y_approx = erf_inv_approx_f64(v);
376            let check = erf_approx_f64(y_approx);
377            assert!(
378                approx_eq(check, v, 0.2),
379                "erf_inv_approx_f64({v}) => {y_approx}, but erf_approx_f64 => {check}"
380            );
381        }
382    }
383
384    #[test]
385    fn test_hyperbolic_approx() {
386        let vals = [-2.0, -1.0, 0.0, 1.0, 2.0];
387        for &v in &vals {
388            let sh = sinh_approx_f64(v);
389            let ch = cosh_approx_f64(v);
390            let th = tanh_approx_f64(v);
391            let tol_s = 0.15 * v.sinh().abs().max(1.0);
392            let tol_c = 0.15 * v.cosh().abs().max(1.0);
393            assert!(
394                approx_eq(sh, v.sinh(), tol_s),
395                "sinh_approx_f64({v}) => {sh}, real => {}",
396                v.sinh()
397            );
398            assert!(
399                approx_eq(ch, v.cosh(), tol_c),
400                "cosh_approx_f64({v}) => {ch}, real => {}",
401                v.cosh()
402            );
403            assert!(
404                approx_eq(th, v.tanh(), 0.15),
405                "tanh_approx_f64({v}) => {th}, real => {}",
406                v.tanh()
407            );
408        }
409    }
410
411    #[test]
412    fn test_lambertw_approx() {
413        let xvals = [1.0_f64, std::f64::consts::E];
414        let real = [0.5671432904097838, 1.0];
415        for (i, &x) in xvals.iter().enumerate() {
416            let app = lambertw_approx_f64(x);
417            assert!(
418                approx_eq(app, real[i], 0.2),
419                "lambertw_approx_f64({x}) => {app}, real => {}",
420                real[i]
421            );
422        }
423    }
424
425    #[test]
426    fn test_lambertwexpx_approx() {
427        let vals = [1.0, 2.0, 3.0];
428        for &v in &vals {
429            let y = lambertwexpx_approx_f64(v);
430            let wtest = lambertw_approx_f64(y);
431            let check = wtest * exp_approx_f64(wtest);
432            assert!(
433                approx_eq(check, y, 0.3 * y.max(1.0)),
434                "lambertwexpx_approx_f64({v}) => {y}, but checking => {check}"
435            );
436        }
437    }
438}