wasmtime_internal_math/
lib.rs

1//! A minimal helper crate for implementing float-related operations for
2//! WebAssembly in terms of the native platform primitives.
3//!
4//! > **⚠️ Warning ⚠️**: this crate is an internal-only crate for the Wasmtime
5//! > project and is not intended for general use. APIs are not strictly
6//! > reviewed for safety and usage outside of Wasmtime may have bugs. If
7//! > you're interested in using this feel free to file an issue on the
8//! > Wasmtime repository to start a discussion about doing so, but otherwise
9//! > be aware that your usage of this crate is not supported.
10//!
11//! This crate is intended to assist with solving the portability issues such
12//! as:
13//!
14//! * Functions like `f32::trunc` are not available in `#![no_std]` targets.
15//! * The `f32::trunc` function is likely faster than the `libm` fallback.
16//! * Behavior of `f32::trunc` differs across platforms, for example it's
17//!   different on Windows and glibc on Linux. Additionally riscv64's
18//!   implementation of `libm` seems to have different NaN behavior than other
19//!   platforms.
20//! * Some wasm functions are in the Rust standard library, but not stable yet.
21//!
22//! There are a few locations throughout the codebase that these functions are
23//! needed so they're implemented only in a single location here rather than
24//! multiple.
25
26#![no_std]
27
28#[cfg(feature = "std")]
29extern crate std;
30
31/// Returns the bounds for guarding a trapping f32-to-int conversion.
32///
33/// This function will return two floats, a lower bound and an upper bound,
34/// which can be used to test whether a WebAssembly f32-to-int conversion
35/// should trap. The float being converted must be greater than the lower bound
36/// and less than the upper bound for the conversion to proceed, otherwise a
37/// trap or infinity value should be generated.
38///
39/// The `signed` argument indicates whether a conversion to a signed integer is
40/// happening. If `false` a conversion to an unsigned integer is happening. The
41/// `out_bits` argument indicates how many bits are in the integer being
42/// converted to.
43pub const fn f32_cvt_to_int_bounds(signed: bool, out_bits: u32) -> (f32, f32) {
44    match (signed, out_bits) {
45        (true, 8) => (i8::min_value() as f32 - 1., i8::max_value() as f32 + 1.),
46        (true, 16) => (i16::min_value() as f32 - 1., i16::max_value() as f32 + 1.),
47        (true, 32) => (-2147483904.0, 2147483648.0),
48        (true, 64) => (-9223373136366403584.0, 9223372036854775808.0),
49        (false, 8) => (-1., u8::max_value() as f32 + 1.),
50        (false, 16) => (-1., u16::max_value() as f32 + 1.),
51        (false, 32) => (-1., 4294967296.0),
52        (false, 64) => (-1., 18446744073709551616.0),
53        _ => unreachable!(),
54    }
55}
56
57/// Same as [`f32_cvt_to_int_bounds`] but used for f64-to-int conversions.
58pub const fn f64_cvt_to_int_bounds(signed: bool, out_bits: u32) -> (f64, f64) {
59    match (signed, out_bits) {
60        (true, 8) => (i8::min_value() as f64 - 1., i8::max_value() as f64 + 1.),
61        (true, 16) => (i16::min_value() as f64 - 1., i16::max_value() as f64 + 1.),
62        (true, 32) => (-2147483649.0, 2147483648.0),
63        (true, 64) => (-9223372036854777856.0, 9223372036854775808.0),
64        (false, 8) => (-1., u8::max_value() as f64 + 1.),
65        (false, 16) => (-1., u16::max_value() as f64 + 1.),
66        (false, 32) => (-1., 4294967296.0),
67        (false, 64) => (-1., 18446744073709551616.0),
68        _ => unreachable!(),
69    }
70}
71
72pub trait WasmFloat {
73    fn wasm_trunc(self) -> Self;
74    fn wasm_copysign(self, sign: Self) -> Self;
75    fn wasm_floor(self) -> Self;
76    fn wasm_ceil(self) -> Self;
77    fn wasm_sqrt(self) -> Self;
78    fn wasm_abs(self) -> Self;
79    fn wasm_nearest(self) -> Self;
80    fn wasm_minimum(self, other: Self) -> Self;
81    fn wasm_maximum(self, other: Self) -> Self;
82    fn wasm_mul_add(self, b: Self, c: Self) -> Self;
83}
84
85impl WasmFloat for f32 {
86    #[inline]
87    fn wasm_trunc(self) -> f32 {
88        if self.is_nan() {
89            return f32::NAN;
90        }
91        #[cfg(feature = "std")]
92        if !cfg!(windows) && !cfg!(target_arch = "riscv64") {
93            return self.trunc();
94        }
95        libm::truncf(self)
96    }
97    #[inline]
98    fn wasm_copysign(self, sign: f32) -> f32 {
99        #[cfg(feature = "std")]
100        if true {
101            return self.copysign(sign);
102        }
103        libm::copysignf(self, sign)
104    }
105    #[inline]
106    fn wasm_floor(self) -> f32 {
107        if self.is_nan() {
108            return f32::NAN;
109        }
110        #[cfg(feature = "std")]
111        if !cfg!(target_arch = "riscv64") {
112            return self.floor();
113        }
114        libm::floorf(self)
115    }
116    #[inline]
117    fn wasm_ceil(self) -> f32 {
118        if self.is_nan() {
119            return f32::NAN;
120        }
121        #[cfg(feature = "std")]
122        if !cfg!(target_arch = "riscv64") {
123            return self.ceil();
124        }
125        libm::ceilf(self)
126    }
127    #[inline]
128    fn wasm_sqrt(self) -> f32 {
129        #[cfg(feature = "std")]
130        if true {
131            return self.sqrt();
132        }
133        libm::sqrtf(self)
134    }
135    #[inline]
136    fn wasm_abs(self) -> f32 {
137        #[cfg(feature = "std")]
138        if true {
139            return self.abs();
140        }
141        libm::fabsf(self)
142    }
143    #[inline]
144    fn wasm_nearest(self) -> f32 {
145        if self.is_nan() {
146            return f32::NAN;
147        }
148        #[cfg(feature = "std")]
149        if !cfg!(windows) && !cfg!(target_arch = "riscv64") {
150            return self.round_ties_even();
151        }
152        let round = libm::roundf(self);
153        if libm::fabsf(self - round) != 0.5 {
154            return round;
155        }
156        match round % 2.0 {
157            1.0 => libm::floorf(self),
158            -1.0 => libm::ceilf(self),
159            _ => round,
160        }
161    }
162    #[inline]
163    fn wasm_maximum(self, other: f32) -> f32 {
164        // FIXME: replace this with `a.maximum(b)` when rust-lang/rust#91079 is
165        // stabilized
166        if self > other {
167            self
168        } else if other > self {
169            other
170        } else if self == other {
171            if self.is_sign_positive() && other.is_sign_negative() {
172                self
173            } else {
174                other
175            }
176        } else {
177            self + other
178        }
179    }
180    #[inline]
181    fn wasm_minimum(self, other: f32) -> f32 {
182        // FIXME: replace this with `self.minimum(other)` when
183        // rust-lang/rust#91079 is stabilized
184        if self < other {
185            self
186        } else if other < self {
187            other
188        } else if self == other {
189            if self.is_sign_negative() && other.is_sign_positive() {
190                self
191            } else {
192                other
193            }
194        } else {
195            self + other
196        }
197    }
198    #[inline]
199    fn wasm_mul_add(self, b: f32, c: f32) -> f32 {
200        // The MinGW implementation of `fma` differs from other platforms, so
201        // favor `libm` there instead.
202        #[cfg(feature = "std")]
203        if !(cfg!(windows) && cfg!(target_env = "gnu")) {
204            return self.mul_add(b, c);
205        }
206        libm::fmaf(self, b, c)
207    }
208}
209
210impl WasmFloat for f64 {
211    #[inline]
212    fn wasm_trunc(self) -> f64 {
213        if self.is_nan() {
214            return f64::NAN;
215        }
216        #[cfg(feature = "std")]
217        if !cfg!(windows) && !cfg!(target_arch = "riscv64") {
218            return self.trunc();
219        }
220        libm::trunc(self)
221    }
222    #[inline]
223    fn wasm_copysign(self, sign: f64) -> f64 {
224        #[cfg(feature = "std")]
225        if true {
226            return self.copysign(sign);
227        }
228        libm::copysign(self, sign)
229    }
230    #[inline]
231    fn wasm_floor(self) -> f64 {
232        if self.is_nan() {
233            return f64::NAN;
234        }
235        #[cfg(feature = "std")]
236        if !cfg!(target_arch = "riscv64") {
237            return self.floor();
238        }
239        libm::floor(self)
240    }
241    #[inline]
242    fn wasm_ceil(self) -> f64 {
243        if self.is_nan() {
244            return f64::NAN;
245        }
246        #[cfg(feature = "std")]
247        if !cfg!(target_arch = "riscv64") {
248            return self.ceil();
249        }
250        libm::ceil(self)
251    }
252    #[inline]
253    fn wasm_sqrt(self) -> f64 {
254        #[cfg(feature = "std")]
255        if true {
256            return self.sqrt();
257        }
258        libm::sqrt(self)
259    }
260    #[inline]
261    fn wasm_abs(self) -> f64 {
262        #[cfg(feature = "std")]
263        if true {
264            return self.abs();
265        }
266        libm::fabs(self)
267    }
268    #[inline]
269    fn wasm_nearest(self) -> f64 {
270        if self.is_nan() {
271            return f64::NAN;
272        }
273        #[cfg(feature = "std")]
274        if !cfg!(windows) && !cfg!(target_arch = "riscv64") {
275            return self.round_ties_even();
276        }
277        let round = libm::round(self);
278        if libm::fabs(self - round) != 0.5 {
279            return round;
280        }
281        match round % 2.0 {
282            1.0 => libm::floor(self),
283            -1.0 => libm::ceil(self),
284            _ => round,
285        }
286    }
287    #[inline]
288    fn wasm_maximum(self, other: f64) -> f64 {
289        // FIXME: replace this with `a.maximum(b)` when rust-lang/rust#91079 is
290        // stabilized
291        if self > other {
292            self
293        } else if other > self {
294            other
295        } else if self == other {
296            if self.is_sign_positive() && other.is_sign_negative() {
297                self
298            } else {
299                other
300            }
301        } else {
302            self + other
303        }
304    }
305    #[inline]
306    fn wasm_minimum(self, other: f64) -> f64 {
307        // FIXME: replace this with `self.minimum(other)` when
308        // rust-lang/rust#91079 is stabilized
309        if self < other {
310            self
311        } else if other < self {
312            other
313        } else if self == other {
314            if self.is_sign_negative() && other.is_sign_positive() {
315                self
316            } else {
317                other
318            }
319        } else {
320            self + other
321        }
322    }
323    #[inline]
324    fn wasm_mul_add(self, b: f64, c: f64) -> f64 {
325        // The MinGW implementation of `fma` differs from other platforms, so
326        // favor `libm` there instead.
327        #[cfg(feature = "std")]
328        if !(cfg!(windows) && cfg!(target_env = "gnu")) {
329            return self.mul_add(b, c);
330        }
331        libm::fma(self, b, c)
332    }
333}