1#![allow(clippy::many_single_char_names)]
2
3use std::f64::consts::{FRAC_PI_2, FRAC_PI_4, LN_2, PI};
4
5#[inline(always)]
6pub fn atan_fast(z: f64) -> f64 {
7 const C0: f64 = 0.2447;
8 const C1: f64 = 0.0663;
9 const PIO4: f64 = std::f64::consts::FRAC_PI_4;
10 const PIO2: f64 = std::f64::consts::FRAC_PI_2;
11
12 let a = z.abs();
13 if a <= 1.0 {
14 let t = C1.mul_add(a, C0);
15 PIO4.mul_add(z, z.mul_add(a - 1.0, t))
16 } else {
17 let inv = 1.0 / z;
18 let t = C1.mul_add(inv.abs(), C0);
19 let base = PIO4.mul_add(inv, inv.mul_add(inv.abs() - 1.0, t));
20 if z.is_sign_positive() {
21 PIO2 - base
22 } else {
23 -PIO2 - base
24 }
25 }
26}
27
28#[inline(always)]
29fn flip_sign_nonnan(x: f64, val: f64) -> f64 {
30 if x.is_sign_negative() {
31 -val
32 } else {
33 val
34 }
35}
36
37#[inline(always)]
38pub fn atan_raw64(x: f64) -> f64 {
39 const N2: f64 = 0.273;
40 (FRAC_PI_4 + N2 - N2 * x.abs()) * x
41}
42
43#[inline(always)]
44pub fn atan64(x: f64) -> f64 {
45 if x.abs() > 1.0 {
46 debug_assert!(!x.is_nan());
47 flip_sign_nonnan(x, FRAC_PI_2) - atan_raw64(1.0 / x)
48 } else {
49 atan_raw64(x)
50 }
51}
52
53#[inline(always)]
54pub fn fast_sin_f64(mut x: f64) -> f64 {
55 const TWO_PI: f64 = 2.0 * PI;
56
57 x %= TWO_PI;
58 if x < -PI {
59 x += TWO_PI;
60 } else if x > PI {
61 x -= TWO_PI;
62 }
63
64 const FOUROVERPI: f64 = 1.2732395447351627;
65 const FOUROVERPISQ: f64 = 0.405_284_734_569_351_1;
66 const Q: f64 = 0.776_330_232_480_075;
67
68 let sign = if x < 0.0 { -1.0 } else { 1.0 };
69 let ax = x.abs();
70
71 let mut y = FOUROVERPI * ax - FOUROVERPISQ * ax * ax;
72 if sign < 0.0 {
73 y = -y;
74 }
75 y * (Q + (1.0 - Q) * y.abs())
76}
77
78#[inline(always)]
79pub fn fast_cos_f64(mut x: f64) -> f64 {
80 const TWO_PI: f64 = 2.0 * PI;
81
82 x %= TWO_PI;
83 if x < -PI {
84 x += TWO_PI;
85 } else if x > PI {
86 x -= TWO_PI;
87 }
88
89 x += FRAC_PI_2;
90 if x > PI {
91 x -= TWO_PI;
92 } else if x < -PI {
93 x += TWO_PI;
94 }
95
96 const FOUROVERPI: f64 = 1.2732395447351627;
97 const FOUROVERPISQ: f64 = 0.405_284_734_569_351_1;
98 const Q: f64 = 0.776_330_232_480_075;
99
100 let sign = if x < 0.0 { -1.0 } else { 1.0 };
101 let ax = x.abs();
102
103 let mut y = FOUROVERPI * ax - FOUROVERPISQ * ax * ax;
104 if sign < 0.0 {
105 y = -y;
106 }
107 y * (Q + (1.0 - Q) * y.abs())
108}
109
110#[inline(always)]
111fn to_bits_f64(x: f64) -> u64 {
112 x.to_bits()
113}
114#[inline(always)]
115fn from_bits_f64(u: u64) -> f64 {
116 f64::from_bits(u)
117}
118
119#[inline]
120pub fn log2_approx_f64(x: f64) -> f64 {
121 let mut y = to_bits_f64(x) as f64;
122 y *= 2.220446049250313e-16;
123 y - 1022.94269504
124}
125
126#[inline]
127pub fn ln_approx_f64(x: f64) -> f64 {
128 log2_approx_f64(x) * LN_2
129}
130
131#[inline]
132pub fn pow2_approx_f64(p: f64) -> f64 {
133 let clipp = if p < -1022.0 { -1022.0 } else { p };
134 const POW2_OFFSET: f64 = 1022.942695;
135 let v = ((1u64 << 52) as f64 * (clipp + POW2_OFFSET)) as u64;
136 from_bits_f64(v)
137}
138
139#[inline]
140pub fn pow_approx_f64(x: f64, p: f64) -> f64 {
141 pow2_approx_f64(p * log2_approx_f64(x))
142}
143
144#[inline]
145pub fn exp_approx_f64(p: f64) -> f64 {
146 const INV_LN2: f64 = std::f64::consts::LOG2_E;
147 pow2_approx_f64(INV_LN2 * p)
148}
149
150#[inline]
151pub fn sigmoid_approx_f64(x: f64) -> f64 {
152 1.0 / (1.0 + exp_approx_f64(-x))
153}
154
155#[inline]
156pub fn lambertw_approx_f64(x: f64) -> f64 {
157 if x == 0.0 {
158 return 0.0;
159 }
160
161 let mut w = if x < 1.0 {
162 x
163 } else {
164 let g = ln_approx_f64(x).max(0.0);
165 if g < 0.5 {
166 0.5
167 } else {
168 g
169 }
170 };
171
172 for _ in 0..2 {
173 let ew = exp_approx_f64(w);
174 let f = w * ew - x;
175 let fp = ew * (w + 1.0);
176 w -= f / fp;
177 }
178 w
179}
180
181#[inline]
182pub fn lambertwexpx_approx_f64(v: f64) -> f64 {
183 let mut y = 1.0_f64 + v.abs();
184 for _ in 0..5 {
185 let w = lambertw_approx_f64(y);
186 y = w * exp_approx_f64(w);
187 }
188 y
189}
190
191#[inline]
192pub fn ln_gamma_approx_f64(x: f64) -> f64 {
193 -0.0810614667_f64 - x - ln_approx_f64(x) + (0.5_f64 + x) * ln_approx_f64(1.0_f64 + x)
194}
195
196#[inline]
197pub fn digamma_approx_f64(x: f64) -> f64 {
198 let onepx = 1.0 + x;
199 -1.0 / x - 1.0 / (2.0 * onepx) + ln_approx_f64(onepx)
200}
201
202#[inline]
203pub fn erfc_approx_f64(x: f64) -> f64 {
204 const K: f64 = 3.3509633149424609;
205 2.0 / (1.0 + pow2_approx_f64(K * x))
206}
207
208#[inline]
209pub fn erf_approx_f64(x: f64) -> f64 {
210 1.0 - erfc_approx_f64(x)
211}
212
213#[inline]
214pub fn erf_inv_approx_f64(x: f64) -> f64 {
215 const INVK: f64 = 0.30004578719350504;
216 let ratio = (1.0 + x) / (1.0 - x);
217 INVK * log2_approx_f64(ratio)
218}
219
220#[inline]
221pub fn sinh_approx_f64(x: f64) -> f64 {
222 0.5 * (exp_approx_f64(x) - exp_approx_f64(-x))
223}
224
225#[inline]
226pub fn cosh_approx_f64(x: f64) -> f64 {
227 0.5 * (exp_approx_f64(x) + exp_approx_f64(-x))
228}
229
230#[inline]
231pub fn tanh_approx_f64(x: f64) -> f64 {
232 -1.0 + 2.0 / (1.0 + exp_approx_f64(-2.0 * x))
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use std::f64::consts::PI;
239
240 fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
241 (a - b).abs() < tol
242 }
243
244 #[test]
245 fn test_fast_sin_cos() {
246 let angles = [
247 0.0,
248 PI * 0.25,
249 PI * 0.5,
250 PI * 0.75,
251 PI,
252 -PI * 0.5,
253 -PI,
254 10.0,
255 -10.0,
256 ];
257 for &ang in &angles {
258 let fs = fast_sin_f64(ang);
259 let fc = fast_cos_f64(ang);
260 let rs = ang.sin();
261 let rc = ang.cos();
262
263 assert!(
264 approx_eq(fs, rs, 0.05),
265 "fast_sin_f64({ang}) => {fs} vs std => {rs}"
266 );
267 assert!(
268 approx_eq(fc, rc, 0.05),
269 "fast_cos_f64({ang}) => {fc} vs std => {rc}"
270 );
271 }
272 }
273
274 #[test]
275 fn test_atan_approx() {
276 let vals = [0.0, 0.5, 1.0, 2.0, -1.0, -10.0];
277 for &v in &vals {
278 let app = atan64(v);
279 let real = v.atan();
280 assert!(
281 approx_eq(app, real, 0.1),
282 "atan64({v}) => {app}, real => {real}"
283 );
284 }
285 }
286
287 #[test]
288 fn test_log2_approx() {
289 let vals = [0.125, 0.5, 1.0, 2.0, 8.0, 10.0];
290 for &v in &vals {
291 let app = log2_approx_f64(v);
292 let real = v.log2();
293 assert!(
294 approx_eq(app, real, 0.15),
295 "log2_approx_f64({v}) => {app}, real => {real}"
296 );
297 }
298 }
299
300 #[test]
301 fn test_ln_approx() {
302 let vals = [0.125, 0.5, 1.0, 2.0, 8.0, 10.0];
303 for &v in &vals {
304 let app = ln_approx_f64(v);
305 let real = v.ln();
306 assert!(
307 approx_eq(app, real, 0.2),
308 "ln_approx_f64({v}) => {app}, real => {real}"
309 );
310 }
311 }
312
313 #[test]
314 fn test_exp_approx() {
315 let vals = [-2.0, -1.0, 0.0, 1.0, 2.0, 5.0];
316 for &v in &vals {
317 let app = exp_approx_f64(v);
318 let real = v.exp();
319 let tol = 0.15 * real.abs().max(1.0);
320 assert!(
321 approx_eq(app, real, tol),
322 "exp_approx_f64({v}) => {app}, real => {real}"
323 );
324 }
325 }
326
327 #[test]
328 fn test_pow2_approx() {
329 let vals = [-10.0, -1.0, 0.0, 1.0, 10.0, 15.5];
330 for &v in &vals {
331 let app = pow2_approx_f64(v);
332 let real = (2.0_f64).powf(v);
333 let tol = 0.15 * real.abs().max(1.0);
334 assert!(
335 approx_eq(app, real, tol),
336 "pow2_approx_f64({v}) => {app}, real => {real}"
337 );
338 }
339 }
340
341 #[test]
342 fn test_pow_approx() {
343 let bases = [0.5, 1.0, 2.0, 10.0];
344 let exps = [-2.0, -1.0, 0.0, 1.0, 2.0];
345 for &b in &bases {
346 for &p in &exps {
347 let app = pow_approx_f64(b, p);
348 let real = b.powf(p);
349 let tol = 0.20 * real.abs().max(1.0);
350 assert!(
351 approx_eq(app, real, tol),
352 "pow_approx_f64({b}^{p}) => {app}, real => {real}"
353 );
354 }
355 }
356 }
357
358 #[test]
359 fn test_sigmoid_approx() {
360 let vals = [-4.0, -1.0, 0.0, 1.0, 4.0];
361 for &v in &vals {
362 let app = sigmoid_approx_f64(v);
363 let real = 1.0 / (1.0 + (-v).exp());
364 assert!(
365 approx_eq(app, real, 0.02),
366 "sigmoid_approx_f64({v}) => {app}, real => {real}"
367 );
368 }
369 }
370
371 #[test]
372 fn test_erf_inv_approx() {
373 let vals = [-0.9, -0.5, 0.0, 0.5, 0.9];
374 for &v in &vals {
375 let y_approx = erf_inv_approx_f64(v);
376 let check = erf_approx_f64(y_approx);
377 assert!(
378 approx_eq(check, v, 0.2),
379 "erf_inv_approx_f64({v}) => {y_approx}, but erf_approx_f64 => {check}"
380 );
381 }
382 }
383
384 #[test]
385 fn test_hyperbolic_approx() {
386 let vals = [-2.0, -1.0, 0.0, 1.0, 2.0];
387 for &v in &vals {
388 let sh = sinh_approx_f64(v);
389 let ch = cosh_approx_f64(v);
390 let th = tanh_approx_f64(v);
391 let tol_s = 0.15 * v.sinh().abs().max(1.0);
392 let tol_c = 0.15 * v.cosh().abs().max(1.0);
393 assert!(
394 approx_eq(sh, v.sinh(), tol_s),
395 "sinh_approx_f64({v}) => {sh}, real => {}",
396 v.sinh()
397 );
398 assert!(
399 approx_eq(ch, v.cosh(), tol_c),
400 "cosh_approx_f64({v}) => {ch}, real => {}",
401 v.cosh()
402 );
403 assert!(
404 approx_eq(th, v.tanh(), 0.15),
405 "tanh_approx_f64({v}) => {th}, real => {}",
406 v.tanh()
407 );
408 }
409 }
410
411 #[test]
412 fn test_lambertw_approx() {
413 let xvals = [1.0_f64, std::f64::consts::E];
414 let real = [0.5671432904097838, 1.0];
415 for (i, &x) in xvals.iter().enumerate() {
416 let app = lambertw_approx_f64(x);
417 assert!(
418 approx_eq(app, real[i], 0.2),
419 "lambertw_approx_f64({x}) => {app}, real => {}",
420 real[i]
421 );
422 }
423 }
424
425 #[test]
426 fn test_lambertwexpx_approx() {
427 let vals = [1.0, 2.0, 3.0];
428 for &v in &vals {
429 let y = lambertwexpx_approx_f64(v);
430 let wtest = lambertw_approx_f64(y);
431 let check = wtest * exp_approx_f64(wtest);
432 assert!(
433 approx_eq(check, y, 0.3 * y.max(1.0)),
434 "lambertwexpx_approx_f64({v}) => {y}, but checking => {check}"
435 );
436 }
437 }
438}