x_math/
lib.rs

1#[cfg(feature = "sse")]
2use std::arch::x86_64::*;
3
4const PI: f32 = f32::from_bits(0x40490fdb); // pi
5const PI_2: f32 = f32::from_bits(0x3fc90fdb); // pi / 2
6const PI_4: f32 = f32::from_bits(0x3f490fdb); // pi / 4
7const TAU: f32 = f32::from_bits(0x40c90fdb); // 2 * pi
8const INV_PI: f32 = f32::from_bits(0x3ea2f983); // 1 / pi
9
10#[inline]
11pub fn trunc(x: f32) -> f32 {
12    #[cfg(feature = "sse")]
13    unsafe {
14        _mm_cvtss_f32(_mm_round_ss(
15            _mm_set_ss(x),
16            _mm_set_ss(x),
17            _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC,
18        ))
19    }
20    #[cfg(not(feature = "sse"))]
21    {
22        let i = x as i32;
23        i as f32
24    }
25}
26
27#[inline]
28pub fn floor(x: f32) -> f32 {
29    #[cfg(feature = "sse")]
30    unsafe {
31        _mm_cvtss_f32(_mm_floor_ss(_mm_set_ss(x), _mm_set_ss(x)))
32    }
33    #[cfg(not(feature = "sse"))]
34    {
35        let r = x.to_bits() >> 31;
36        let x = x - r as f32;
37        trunc(f32::from_bits(x.to_bits() - r))
38    }
39}
40
41#[inline]
42pub fn ceil(x: f32) -> f32 {
43    #[cfg(feature = "sse")]
44    unsafe {
45        _mm_cvtss_f32(_mm_ceil_ss(_mm_set_ss(x), _mm_set_ss(x)))
46    }
47    #[cfg(not(feature = "sse"))]
48    {
49        let r = 1 - (x.to_bits() >> 31);
50        let x = x + r as f32;
51        trunc(f32::from_bits(x.to_bits() - r))
52    }
53}
54
55#[inline]
56pub fn round(x: f32) -> f32 {
57    #[cfg(feature = "sse")]
58    unsafe {
59        _mm_cvtss_f32(_mm_round_ss(
60            _mm_set_ss(x),
61            _mm_set_ss(x),
62            _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC,
63        ))
64    }
65    #[cfg(not(feature = "sse"))]
66    {
67        let x = x + f32::from_bits(0x3effffff);
68        let u = x.to_bits() >> 31;
69        trunc(x - u as f32)
70    }
71}
72
73#[inline]
74pub fn modulo(x: f32, e: f32) -> f32 {
75    x - e * floor(x / e)
76}
77
78#[inline]
79pub fn fract(x: f32) -> f32 {
80    x - floor(x)
81}
82
83#[inline]
84pub fn abs(x: f32) -> f32 {
85    f32::from_bits(x.to_bits() & 0x7fffffff)
86}
87
88#[inline]
89pub fn sign(x: f32) -> f32 {
90    f32::from_bits((x.to_bits() & 0x80000000) | 0x3f800000)
91}
92
93#[inline]
94pub fn cos(x: f32) -> f32 {
95    // stackoverflow.com/a/77792413
96
97    let x = abs(modulo(x, TAU) - PI) - PI_2;
98    let x = x + (f32::from_bits(0xbc96e670) * x) * (x * x);
99    x + (f32::from_bits(0xbe17b083) * x) * (x * x)
100}
101
102#[inline]
103pub fn sin(x: f32) -> f32 {
104    cos(x - PI_2)
105}
106
107#[inline]
108pub fn sqrt(x: f32) -> f32 {
109    #[cfg(feature = "sse")]
110    unsafe {
111        _mm_cvtss_f32(_mm_sqrt_ss(_mm_set_ss(x)))
112    }
113    #[cfg(not(feature = "sse"))]
114    {
115        let s = f32::from_bits((x.to_bits() + 0x3f769e5c) >> 1);
116        let s = 0.5 * (s + x / s);
117        #[cfg(feature = "acc")]
118        {
119            0.5 * (s + x / s)
120        }
121        #[cfg(not(feature = "acc"))]
122        {
123            s
124        }
125    }
126}
127
128#[inline]
129pub fn cbrt(x: f32) -> f32 {
130    // www.mdpi.com/1996-1073/14/4/1058
131
132    const A: f32 = f32::from_bits(0x3fe04c03);
133    const B: f32 = f32::from_bits(0x3f0266d9);
134    const C: f32 = f32::from_bits(0xbfa01f36);
135
136    let s = sign(x);
137    let x = abs(x);
138    let i = 0x548c2b4b - (x.to_bits() / 3);
139    let y = f32::from_bits(i);
140    let c = x * y * y * y;
141    let y = y * (A + c * (B * c + C));
142    let d = x * y * y;
143    let c = d - d * d * y;
144    let c = c * f32::from_bits(0x3eaaaaab) + d;
145    s * c
146}
147
148#[inline]
149pub fn rsqrt(x: f32) -> f32 {
150    #[cfg(feature = "sse")]
151    unsafe {
152        _mm_cvtss_f32(_mm_rsqrt_ss(_mm_set_ss(x)))
153    }
154    #[cfg(not(feature = "sse"))]
155    {
156        let y = x * 0.5;
157        let mut x = f32::from_bits(0x5f3759df - (x.to_bits() >> 1));
158
159        x *= 1.5 - (y * x * x);
160        #[cfg(feature = "acc")]
161        {
162            x * (1.5 - (y * x * x))
163        }
164        #[cfg(not(feature = "acc"))]
165        {
166            x
167        }
168    }
169}
170
171#[inline]
172pub fn min(a: f32, b: f32) -> f32 {
173    if a < b { a } else { b }
174}
175
176#[inline]
177pub fn max(a: f32, b: f32) -> f32 {
178    if a > b { a } else { b }
179}
180
181#[inline]
182pub fn clamp(x: f32, a: f32, b: f32) -> f32 {
183    min(max(x, a), b)
184}
185
186#[inline]
187pub fn atan2(y: f32, x: f32) -> f32 {
188    // math.stackexchange.com/a/1105038
189
190    let nx = x.to_bits() >> 31;
191    let ny = y.to_bits() & 0x80000000;
192
193    let x = abs(x);
194    let y = abs(y);
195
196    let p = (y > x) as u32;
197    let y = min(x, y) / max(x, y);
198
199    let r = y * y;
200    let r = ((f32::from_bits(0xbd3e7316) * r + f32::from_bits(0x3e232344)) * r
201        - f32::from_bits(0x3ea7be2c))
202        * r
203        * y
204        + y;
205
206    let r = r - f32::from_bits(p * 0x3fc90fdb);
207    let u = r.to_bits() ^ ((p ^ nx) << 31);
208    let r = f32::from_bits(u) + f32::from_bits(nx * 0x40490fdb);
209
210    f32::from_bits((r.to_bits() & 0x7fffffff) | ny)
211}
212
213#[inline]
214pub fn asin(x: f32) -> f32 {
215    let s = sign(x);
216    let x = abs(x);
217    let z = 1.0 - sqrt(1.0 - x * x);
218    let a = x - 0.35;
219    s * (PI_4 * (x + z + 0.12 * z * z) + f32::from_bits(0x3d07ae14)
220        - f32::from_bits(0x3e98a3d7) * a * a)
221}
222
223#[inline]
224pub fn acos(x: f32) -> f32 {
225    asin(-x) + PI_2
226}
227
228#[inline]
229pub fn exp2(x: f32) -> f32 {
230    #[cfg(feature = "acc")]
231    {
232        // docs.rs/fast-math/latest/src/fast_math/exp.rs.html#32
233
234        let n = (x * f32::from_bits(0x4b000000)) as u32;
235        let l = n & 0xff800000;
236        let x = (n - l) as f32;
237
238        let x = (f32::from_bits(0x27aca418) * x + f32::from_bits(0x33a85ada)) * x
239            + f32::from_bits(0x3f803884);
240        f32::from_bits(l + x.to_bits())
241    }
242    #[cfg(not(feature = "acc"))]
243    {
244        // docs.rs/fastapprox/latest/src/fastapprox/faster/mod.rs.html#21
245
246        f32::from_bits(((x + f32::from_bits(0x42fde2a9)) * f32::from_bits(0x4b000000)) as u32)
247    }
248}
249
250#[inline]
251pub fn exp(x: f32) -> f32 {
252    exp2(x * f32::from_bits(0x3fb8aa3b))
253}
254
255#[inline]
256pub fn sinh(x: f32) -> f32 {
257    let a = f32::from_bits(0x3fb8aa3b) * x - 1.0;
258    let b = f32::from_bits(0xbfb8aa3b) * x - 1.0;
259    exp2(a) - exp2(b)
260}
261
262#[inline]
263pub fn cosh(x: f32) -> f32 {
264    let a = f32::from_bits(0x3fb8aa3b) * x - 1.0;
265    let b = f32::from_bits(0xbfb8aa3b) * x - 1.0;
266    exp2(a) + exp2(b)
267}
268
269#[inline]
270pub fn tanh(x: f32) -> f32 {
271    let s = sign(x);
272    let x = abs(x);
273    // couldnt figure out a branchless way
274    if x < 1.0 {
275        let z = 0.07 * x * x;
276        s * (x + (z * x + f32::from_bits(0xc08db6db)) * z)
277    } else {
278        let x = (1.05 * x - 0.1) * x + 1.09;
279        s - s / (x * x)
280    }
281}
282
283#[inline]
284pub fn tan(x: f32) -> f32 {
285    // observablehq.com/@jrus/fasttan
286
287    let x = x * INV_PI;
288    let x = 2.0 * (x - round(x));
289    let y = 1.0 - x * x;
290    x * (f32::from_bits(0xbc994764) * y
291        + f32::from_bits(0x3ea1b529)
292        + f32::from_bits(0x3fa30738) / y)
293}
294
295#[inline]
296pub fn log2(x: f32) -> f32 {
297    #[cfg(feature = "acc")]
298    {
299        // docs.rs/fast-math/latest/src/fast_math/log.rs.html#66
300
301        let a = x.to_bits();
302        let c = ((a >> 23) & 0xff) as i32;
303        let d = a & 0x7fffff;
304
305        let a = (a >> 22) & 1;
306
307        let x = f32::from_bits(d | ((a ^ 0x7f) << 23)) - 1.0;
308        let f = (a as i32 + c - 127) as f32;
309
310        f + x * (x * f32::from_bits(0xbf213248) + f32::from_bits(0x3fbbc593))
311    }
312    #[cfg(not(feature = "acc"))]
313    {
314        // docs.rs/fastapprox/latest/src/fastapprox/faster/mod.rs.html#5
315
316        let z = x.to_bits() as f32;
317        z * f32::from_bits(0x34000000) - f32::from_bits(0x42fde2a9)
318    }
319}