softfloat_wrapper/
bf16.rs

1use crate::{Float, RoundingMode, F128, F16, F32, F64};
2use softfloat_sys::{float16_t, float32_t};
3use std::borrow::Borrow;
4
5/// bfloat16
6#[derive(Copy, Clone, Debug)]
7pub struct BF16(float16_t);
8
9impl BF16 {
10    /// Converts primitive `f32` to `BF16`
11    pub fn from_f32(v: f32) -> Self {
12        F32::from_bits(v.to_bits()).to_bf16(RoundingMode::TiesToEven)
13    }
14
15    /// Converts primitive `f64` to `BF16`
16    pub fn from_f64(v: f64) -> Self {
17        F64::from_bits(v.to_bits()).to_bf16(RoundingMode::TiesToEven)
18    }
19}
20
21fn to_f32(x: float16_t) -> float32_t {
22    float32_t {
23        v: (x.v as u32) << 16,
24    }
25}
26
27fn from_f32(x: float32_t) -> float16_t {
28    float16_t {
29        v: (x.v >> 16) as u16,
30    }
31}
32
33impl Float for BF16 {
34    type Payload = u16;
35
36    const EXPONENT_BIT: Self::Payload = 0xff;
37    const FRACTION_BIT: Self::Payload = 0x7f;
38    const SIGN_POS: usize = 15;
39    const EXPONENT_POS: usize = 7;
40
41    #[inline]
42    fn set_payload(&mut self, x: Self::Payload) {
43        self.0.v = x;
44    }
45
46    #[inline]
47    fn from_bits(v: Self::Payload) -> Self {
48        Self(float16_t { v })
49    }
50
51    #[inline]
52    fn to_bits(&self) -> Self::Payload {
53        self.0.v
54    }
55
56    #[inline]
57    fn bits(&self) -> Self::Payload {
58        self.to_bits()
59    }
60
61    fn add<T: Borrow<Self>>(&self, x: T, rnd: RoundingMode) -> Self {
62        rnd.set();
63        let ret = unsafe { softfloat_sys::f32_add(to_f32(self.0), to_f32(x.borrow().0)) };
64        Self(from_f32(ret))
65    }
66
67    fn sub<T: Borrow<Self>>(&self, x: T, rnd: RoundingMode) -> Self {
68        rnd.set();
69        let ret = unsafe { softfloat_sys::f32_sub(to_f32(self.0), to_f32(x.borrow().0)) };
70        Self(from_f32(ret))
71    }
72
73    fn mul<T: Borrow<Self>>(&self, x: T, rnd: RoundingMode) -> Self {
74        rnd.set();
75        let ret = unsafe { softfloat_sys::f32_mul(to_f32(self.0), to_f32(x.borrow().0)) };
76        Self(from_f32(ret))
77    }
78
79    fn fused_mul_add<T: Borrow<Self>>(&self, x: T, y: T, rnd: RoundingMode) -> Self {
80        rnd.set();
81        let ret = unsafe {
82            softfloat_sys::f32_mulAdd(to_f32(self.0), to_f32(x.borrow().0), to_f32(y.borrow().0))
83        };
84        Self(from_f32(ret))
85    }
86
87    fn div<T: Borrow<Self>>(&self, x: T, rnd: RoundingMode) -> Self {
88        rnd.set();
89        let ret = unsafe { softfloat_sys::f32_div(to_f32(self.0), to_f32(x.borrow().0)) };
90        Self(from_f32(ret))
91    }
92
93    fn rem<T: Borrow<Self>>(&self, x: T, rnd: RoundingMode) -> Self {
94        rnd.set();
95        let ret = unsafe { softfloat_sys::f32_rem(to_f32(self.0), to_f32(x.borrow().0)) };
96        Self(from_f32(ret))
97    }
98
99    fn sqrt(&self, rnd: RoundingMode) -> Self {
100        rnd.set();
101        let ret = unsafe { softfloat_sys::f32_sqrt(to_f32(self.0)) };
102        Self(from_f32(ret))
103    }
104
105    fn eq<T: Borrow<Self>>(&self, x: T) -> bool {
106        unsafe { softfloat_sys::f32_eq(to_f32(self.0), to_f32(x.borrow().0)) }
107    }
108
109    fn lt<T: Borrow<Self>>(&self, x: T) -> bool {
110        unsafe { softfloat_sys::f32_lt(to_f32(self.0), to_f32(x.borrow().0)) }
111    }
112
113    fn le<T: Borrow<Self>>(&self, x: T) -> bool {
114        unsafe { softfloat_sys::f32_le(to_f32(self.0), to_f32(x.borrow().0)) }
115    }
116
117    fn lt_quiet<T: Borrow<Self>>(&self, x: T) -> bool {
118        unsafe { softfloat_sys::f32_lt_quiet(to_f32(self.0), to_f32(x.borrow().0)) }
119    }
120
121    fn le_quiet<T: Borrow<Self>>(&self, x: T) -> bool {
122        unsafe { softfloat_sys::f32_le_quiet(to_f32(self.0), to_f32(x.borrow().0)) }
123    }
124
125    fn eq_signaling<T: Borrow<Self>>(&self, x: T) -> bool {
126        unsafe { softfloat_sys::f32_eq_signaling(to_f32(self.0), to_f32(x.borrow().0)) }
127    }
128
129    fn is_signaling_nan(&self) -> bool {
130        unsafe { softfloat_sys::f32_isSignalingNaN(to_f32(self.0)) }
131    }
132
133    fn from_u32(x: u32, rnd: RoundingMode) -> Self {
134        rnd.set();
135        let ret = unsafe { softfloat_sys::ui32_to_f32(x) };
136        Self(from_f32(ret))
137    }
138
139    fn from_u64(x: u64, rnd: RoundingMode) -> Self {
140        rnd.set();
141        let ret = unsafe { softfloat_sys::ui64_to_f32(x) };
142        Self(from_f32(ret))
143    }
144
145    fn from_i32(x: i32, rnd: RoundingMode) -> Self {
146        rnd.set();
147        let ret = unsafe { softfloat_sys::i32_to_f32(x) };
148        Self(from_f32(ret))
149    }
150
151    fn from_i64(x: i64, rnd: RoundingMode) -> Self {
152        rnd.set();
153        let ret = unsafe { softfloat_sys::i64_to_f32(x) };
154        Self(from_f32(ret))
155    }
156
157    fn to_u32(&self, rnd: RoundingMode, exact: bool) -> u32 {
158        let ret = unsafe { softfloat_sys::f32_to_ui32(to_f32(self.0), rnd.to_softfloat(), exact) };
159        ret as u32
160    }
161
162    fn to_u64(&self, rnd: RoundingMode, exact: bool) -> u64 {
163        let ret = unsafe { softfloat_sys::f32_to_ui64(to_f32(self.0), rnd.to_softfloat(), exact) };
164        ret
165    }
166
167    fn to_i32(&self, rnd: RoundingMode, exact: bool) -> i32 {
168        let ret = unsafe { softfloat_sys::f32_to_i32(to_f32(self.0), rnd.to_softfloat(), exact) };
169        ret as i32
170    }
171
172    fn to_i64(&self, rnd: RoundingMode, exact: bool) -> i64 {
173        let ret = unsafe { softfloat_sys::f32_to_i64(to_f32(self.0), rnd.to_softfloat(), exact) };
174        ret
175    }
176
177    fn to_f16(&self, rnd: RoundingMode) -> F16 {
178        rnd.set();
179        let ret = unsafe { softfloat_sys::f32_to_f16(to_f32(self.0)) };
180        F16::from_bits(ret.v)
181    }
182
183    fn to_bf16(&self, _rnd: RoundingMode) -> BF16 {
184        BF16::from_bits(self.to_bits())
185    }
186
187    fn to_f32(&self, _rnd: RoundingMode) -> F32 {
188        F32::from_bits(to_f32(self.0).v)
189    }
190
191    fn to_f64(&self, rnd: RoundingMode) -> F64 {
192        rnd.set();
193        let ret = unsafe { softfloat_sys::f32_to_f64(to_f32(self.0)) };
194        F64::from_bits(ret.v)
195    }
196
197    fn to_f128(&self, rnd: RoundingMode) -> F128 {
198        rnd.set();
199        let ret = unsafe { softfloat_sys::f32_to_f128(to_f32(self.0)) };
200        let mut v = 0u128;
201        v |= ret.v[0] as u128;
202        v |= (ret.v[1] as u128) << 64;
203        F128::from_bits(v)
204    }
205
206    fn round_to_integral(&self, rnd: RoundingMode) -> Self {
207        let ret =
208            unsafe { softfloat_sys::f32_roundToInt(to_f32(self.0), rnd.to_softfloat(), false) };
209        Self(from_f32(ret))
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216    use crate::ExceptionFlags;
217    use std::cmp::Ordering;
218
219    #[test]
220    fn bf16_add() {
221        let a = 0x1234;
222        let b = 0x7654;
223        let a0 = BF16::from_bits(a);
224        let b0 = BF16::from_bits(b);
225        let d0 = a0.add(b0, RoundingMode::TiesToEven);
226        let a1 = simple_soft_float::F32::from_bits((a as u32) << 16);
227        let b1 = simple_soft_float::F32::from_bits((b as u32) << 16);
228        let d1 = a1.add(&b1, Some(simple_soft_float::RoundingMode::TiesToEven), None);
229        assert_eq!(d0.to_bits(), (*d1.bits() >> 16) as u16);
230    }
231
232    #[test]
233    fn bf16_sub() {
234        let a = 0x1234;
235        let b = 0x7654;
236        let a0 = BF16::from_bits(a);
237        let b0 = BF16::from_bits(b);
238        let d0 = a0.sub(b0, RoundingMode::TiesToEven);
239        let a1 = simple_soft_float::F32::from_bits((a as u32) << 16);
240        let b1 = simple_soft_float::F32::from_bits((b as u32) << 16);
241        let d1 = a1.sub(&b1, Some(simple_soft_float::RoundingMode::TiesToEven), None);
242        assert_eq!(d0.to_bits(), (*d1.bits() >> 16) as u16);
243    }
244
245    #[test]
246    fn bf16_mul() {
247        let a = 0x1234;
248        let b = 0x7654;
249        let a0 = BF16::from_bits(a);
250        let b0 = BF16::from_bits(b);
251        let d0 = a0.mul(b0, RoundingMode::TiesToEven);
252        let a1 = simple_soft_float::F32::from_bits((a as u32) << 16);
253        let b1 = simple_soft_float::F32::from_bits((b as u32) << 16);
254        let d1 = a1.mul(&b1, Some(simple_soft_float::RoundingMode::TiesToEven), None);
255        assert_eq!(d0.to_bits(), (*d1.bits() >> 16) as u16);
256    }
257
258    #[test]
259    fn bf16_fused_mul_add() {
260        let a = 0x1234;
261        let b = 0x1234;
262        let c = 0x1234;
263        let a0 = BF16::from_bits(a);
264        let b0 = BF16::from_bits(b);
265        let c0 = BF16::from_bits(c);
266        let d0 = a0.fused_mul_add(b0, c0, RoundingMode::TiesToEven);
267        let a1 = simple_soft_float::F32::from_bits((a as u32) << 16);
268        let b1 = simple_soft_float::F32::from_bits((b as u32) << 16);
269        let c1 = simple_soft_float::F32::from_bits((c as u32) << 16);
270        let d1 = a1.fused_mul_add(
271            &b1,
272            &c1,
273            Some(simple_soft_float::RoundingMode::TiesToEven),
274            None,
275        );
276        assert_eq!(d0.to_bits(), (*d1.bits() >> 16) as u16);
277    }
278
279    #[test]
280    fn bf16_div() {
281        let a = 0x7654;
282        let b = 0x1234;
283        let a0 = BF16::from_bits(a);
284        let b0 = BF16::from_bits(b);
285        let d0 = a0.div(b0, RoundingMode::TiesToEven);
286        let a1 = simple_soft_float::F32::from_bits((a as u32) << 16);
287        let b1 = simple_soft_float::F32::from_bits((b as u32) << 16);
288        let d1 = a1.div(&b1, Some(simple_soft_float::RoundingMode::TiesToEven), None);
289        assert_eq!(d0.to_bits(), (*d1.bits() >> 16) as u16);
290    }
291
292    #[test]
293    fn bf16_rem() {
294        let a = 0x7654;
295        let b = 0x1234;
296        let a0 = BF16::from_bits(a);
297        let b0 = BF16::from_bits(b);
298        let d0 = a0.rem(b0, RoundingMode::TiesToEven);
299        let a1 = simple_soft_float::F32::from_bits((a as u32) << 16);
300        let b1 = simple_soft_float::F32::from_bits((b as u32) << 16);
301        let d1 = a1.ieee754_remainder(&b1, Some(simple_soft_float::RoundingMode::TiesToEven), None);
302        assert_eq!(d0.to_bits(), (*d1.bits() >> 16) as u16);
303    }
304
305    #[test]
306    fn bf16_sqrt() {
307        let a = 0x7654;
308        let a0 = BF16::from_bits(a);
309        let d0 = a0.sqrt(RoundingMode::TiesToEven);
310        let a1 = simple_soft_float::F32::from_bits((a as u32) << 16);
311        let d1 = a1.sqrt(Some(simple_soft_float::RoundingMode::TiesToEven), None);
312        assert_eq!(d0.to_bits(), (*d1.bits() >> 16) as u16);
313    }
314
315    #[test]
316    fn bf16_compare() {
317        let a = BF16::from_bits(0x7654);
318        let b = BF16::from_bits(0x1234);
319        let d = a.compare(b);
320        assert_eq!(d, Some(Ordering::Greater));
321
322        let a = BF16::from_bits(0x1234);
323        let b = BF16::from_bits(0x7654);
324        let d = a.compare(b);
325        assert_eq!(d, Some(Ordering::Less));
326
327        let a = BF16::from_bits(0x1234);
328        let b = BF16::from_bits(0x1234);
329        let d = a.compare(b);
330        assert_eq!(d, Some(Ordering::Equal));
331    }
332
333    #[test]
334    fn bf16_signaling() {
335        let a = BF16::from_bits(0x7f81);
336        let b = BF16::from_bits(0x7fc1);
337        assert_eq!(a.is_signaling_nan(), true);
338        assert_eq!(b.is_signaling_nan(), false);
339
340        let mut flag = ExceptionFlags::default();
341        flag.set();
342        assert_eq!(a.eq(a), false);
343        flag.get();
344        assert_eq!(flag.is_invalid(), true);
345
346        let mut flag = ExceptionFlags::default();
347        flag.set();
348        assert_eq!(b.eq(b), false);
349        flag.get();
350        assert_eq!(flag.is_invalid(), false);
351
352        let mut flag = ExceptionFlags::default();
353        flag.set();
354        assert_eq!(a.eq_signaling(a), false);
355        flag.get();
356        assert_eq!(flag.is_invalid(), true);
357
358        let mut flag = ExceptionFlags::default();
359        flag.set();
360        assert_eq!(b.eq_signaling(b), false);
361        flag.get();
362        assert_eq!(flag.is_invalid(), true);
363    }
364
365    #[test]
366    fn from_f32() {
367        let a = BF16::from_f32(0.1);
368        assert_eq!(a.to_bits(), 0x3dcc);
369    }
370
371    #[test]
372    fn from_f64() {
373        let a = BF16::from_f64(0.1);
374        assert_eq!(a.to_bits(), 0x3dcc);
375    }
376}