1use crate::{Float, RoundingMode, F128, F16, F32, F64};
2use softfloat_sys::{float16_t, float32_t};
3use std::borrow::Borrow;
4
5#[derive(Copy, Clone, Debug)]
7pub struct BF16(float16_t);
8
9impl BF16 {
10 pub fn from_f32(v: f32) -> Self {
12 F32::from_bits(v.to_bits()).to_bf16(RoundingMode::TiesToEven)
13 }
14
15 pub fn from_f64(v: f64) -> Self {
17 F64::from_bits(v.to_bits()).to_bf16(RoundingMode::TiesToEven)
18 }
19}
20
21fn to_f32(x: float16_t) -> float32_t {
22 float32_t {
23 v: (x.v as u32) << 16,
24 }
25}
26
27fn from_f32(x: float32_t) -> float16_t {
28 float16_t {
29 v: (x.v >> 16) as u16,
30 }
31}
32
33impl Float for BF16 {
34 type Payload = u16;
35
36 const EXPONENT_BIT: Self::Payload = 0xff;
37 const FRACTION_BIT: Self::Payload = 0x7f;
38 const SIGN_POS: usize = 15;
39 const EXPONENT_POS: usize = 7;
40
41 #[inline]
42 fn set_payload(&mut self, x: Self::Payload) {
43 self.0.v = x;
44 }
45
46 #[inline]
47 fn from_bits(v: Self::Payload) -> Self {
48 Self(float16_t { v })
49 }
50
51 #[inline]
52 fn to_bits(&self) -> Self::Payload {
53 self.0.v
54 }
55
56 #[inline]
57 fn bits(&self) -> Self::Payload {
58 self.to_bits()
59 }
60
61 fn add<T: Borrow<Self>>(&self, x: T, rnd: RoundingMode) -> Self {
62 rnd.set();
63 let ret = unsafe { softfloat_sys::f32_add(to_f32(self.0), to_f32(x.borrow().0)) };
64 Self(from_f32(ret))
65 }
66
67 fn sub<T: Borrow<Self>>(&self, x: T, rnd: RoundingMode) -> Self {
68 rnd.set();
69 let ret = unsafe { softfloat_sys::f32_sub(to_f32(self.0), to_f32(x.borrow().0)) };
70 Self(from_f32(ret))
71 }
72
73 fn mul<T: Borrow<Self>>(&self, x: T, rnd: RoundingMode) -> Self {
74 rnd.set();
75 let ret = unsafe { softfloat_sys::f32_mul(to_f32(self.0), to_f32(x.borrow().0)) };
76 Self(from_f32(ret))
77 }
78
79 fn fused_mul_add<T: Borrow<Self>>(&self, x: T, y: T, rnd: RoundingMode) -> Self {
80 rnd.set();
81 let ret = unsafe {
82 softfloat_sys::f32_mulAdd(to_f32(self.0), to_f32(x.borrow().0), to_f32(y.borrow().0))
83 };
84 Self(from_f32(ret))
85 }
86
87 fn div<T: Borrow<Self>>(&self, x: T, rnd: RoundingMode) -> Self {
88 rnd.set();
89 let ret = unsafe { softfloat_sys::f32_div(to_f32(self.0), to_f32(x.borrow().0)) };
90 Self(from_f32(ret))
91 }
92
93 fn rem<T: Borrow<Self>>(&self, x: T, rnd: RoundingMode) -> Self {
94 rnd.set();
95 let ret = unsafe { softfloat_sys::f32_rem(to_f32(self.0), to_f32(x.borrow().0)) };
96 Self(from_f32(ret))
97 }
98
99 fn sqrt(&self, rnd: RoundingMode) -> Self {
100 rnd.set();
101 let ret = unsafe { softfloat_sys::f32_sqrt(to_f32(self.0)) };
102 Self(from_f32(ret))
103 }
104
105 fn eq<T: Borrow<Self>>(&self, x: T) -> bool {
106 unsafe { softfloat_sys::f32_eq(to_f32(self.0), to_f32(x.borrow().0)) }
107 }
108
109 fn lt<T: Borrow<Self>>(&self, x: T) -> bool {
110 unsafe { softfloat_sys::f32_lt(to_f32(self.0), to_f32(x.borrow().0)) }
111 }
112
113 fn le<T: Borrow<Self>>(&self, x: T) -> bool {
114 unsafe { softfloat_sys::f32_le(to_f32(self.0), to_f32(x.borrow().0)) }
115 }
116
117 fn lt_quiet<T: Borrow<Self>>(&self, x: T) -> bool {
118 unsafe { softfloat_sys::f32_lt_quiet(to_f32(self.0), to_f32(x.borrow().0)) }
119 }
120
121 fn le_quiet<T: Borrow<Self>>(&self, x: T) -> bool {
122 unsafe { softfloat_sys::f32_le_quiet(to_f32(self.0), to_f32(x.borrow().0)) }
123 }
124
125 fn eq_signaling<T: Borrow<Self>>(&self, x: T) -> bool {
126 unsafe { softfloat_sys::f32_eq_signaling(to_f32(self.0), to_f32(x.borrow().0)) }
127 }
128
129 fn is_signaling_nan(&self) -> bool {
130 unsafe { softfloat_sys::f32_isSignalingNaN(to_f32(self.0)) }
131 }
132
133 fn from_u32(x: u32, rnd: RoundingMode) -> Self {
134 rnd.set();
135 let ret = unsafe { softfloat_sys::ui32_to_f32(x) };
136 Self(from_f32(ret))
137 }
138
139 fn from_u64(x: u64, rnd: RoundingMode) -> Self {
140 rnd.set();
141 let ret = unsafe { softfloat_sys::ui64_to_f32(x) };
142 Self(from_f32(ret))
143 }
144
145 fn from_i32(x: i32, rnd: RoundingMode) -> Self {
146 rnd.set();
147 let ret = unsafe { softfloat_sys::i32_to_f32(x) };
148 Self(from_f32(ret))
149 }
150
151 fn from_i64(x: i64, rnd: RoundingMode) -> Self {
152 rnd.set();
153 let ret = unsafe { softfloat_sys::i64_to_f32(x) };
154 Self(from_f32(ret))
155 }
156
157 fn to_u32(&self, rnd: RoundingMode, exact: bool) -> u32 {
158 let ret = unsafe { softfloat_sys::f32_to_ui32(to_f32(self.0), rnd.to_softfloat(), exact) };
159 ret as u32
160 }
161
162 fn to_u64(&self, rnd: RoundingMode, exact: bool) -> u64 {
163 let ret = unsafe { softfloat_sys::f32_to_ui64(to_f32(self.0), rnd.to_softfloat(), exact) };
164 ret
165 }
166
167 fn to_i32(&self, rnd: RoundingMode, exact: bool) -> i32 {
168 let ret = unsafe { softfloat_sys::f32_to_i32(to_f32(self.0), rnd.to_softfloat(), exact) };
169 ret as i32
170 }
171
172 fn to_i64(&self, rnd: RoundingMode, exact: bool) -> i64 {
173 let ret = unsafe { softfloat_sys::f32_to_i64(to_f32(self.0), rnd.to_softfloat(), exact) };
174 ret
175 }
176
177 fn to_f16(&self, rnd: RoundingMode) -> F16 {
178 rnd.set();
179 let ret = unsafe { softfloat_sys::f32_to_f16(to_f32(self.0)) };
180 F16::from_bits(ret.v)
181 }
182
183 fn to_bf16(&self, _rnd: RoundingMode) -> BF16 {
184 BF16::from_bits(self.to_bits())
185 }
186
187 fn to_f32(&self, _rnd: RoundingMode) -> F32 {
188 F32::from_bits(to_f32(self.0).v)
189 }
190
191 fn to_f64(&self, rnd: RoundingMode) -> F64 {
192 rnd.set();
193 let ret = unsafe { softfloat_sys::f32_to_f64(to_f32(self.0)) };
194 F64::from_bits(ret.v)
195 }
196
197 fn to_f128(&self, rnd: RoundingMode) -> F128 {
198 rnd.set();
199 let ret = unsafe { softfloat_sys::f32_to_f128(to_f32(self.0)) };
200 let mut v = 0u128;
201 v |= ret.v[0] as u128;
202 v |= (ret.v[1] as u128) << 64;
203 F128::from_bits(v)
204 }
205
206 fn round_to_integral(&self, rnd: RoundingMode) -> Self {
207 let ret =
208 unsafe { softfloat_sys::f32_roundToInt(to_f32(self.0), rnd.to_softfloat(), false) };
209 Self(from_f32(ret))
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216 use crate::ExceptionFlags;
217 use std::cmp::Ordering;
218
219 #[test]
220 fn bf16_add() {
221 let a = 0x1234;
222 let b = 0x7654;
223 let a0 = BF16::from_bits(a);
224 let b0 = BF16::from_bits(b);
225 let d0 = a0.add(b0, RoundingMode::TiesToEven);
226 let a1 = simple_soft_float::F32::from_bits((a as u32) << 16);
227 let b1 = simple_soft_float::F32::from_bits((b as u32) << 16);
228 let d1 = a1.add(&b1, Some(simple_soft_float::RoundingMode::TiesToEven), None);
229 assert_eq!(d0.to_bits(), (*d1.bits() >> 16) as u16);
230 }
231
232 #[test]
233 fn bf16_sub() {
234 let a = 0x1234;
235 let b = 0x7654;
236 let a0 = BF16::from_bits(a);
237 let b0 = BF16::from_bits(b);
238 let d0 = a0.sub(b0, RoundingMode::TiesToEven);
239 let a1 = simple_soft_float::F32::from_bits((a as u32) << 16);
240 let b1 = simple_soft_float::F32::from_bits((b as u32) << 16);
241 let d1 = a1.sub(&b1, Some(simple_soft_float::RoundingMode::TiesToEven), None);
242 assert_eq!(d0.to_bits(), (*d1.bits() >> 16) as u16);
243 }
244
245 #[test]
246 fn bf16_mul() {
247 let a = 0x1234;
248 let b = 0x7654;
249 let a0 = BF16::from_bits(a);
250 let b0 = BF16::from_bits(b);
251 let d0 = a0.mul(b0, RoundingMode::TiesToEven);
252 let a1 = simple_soft_float::F32::from_bits((a as u32) << 16);
253 let b1 = simple_soft_float::F32::from_bits((b as u32) << 16);
254 let d1 = a1.mul(&b1, Some(simple_soft_float::RoundingMode::TiesToEven), None);
255 assert_eq!(d0.to_bits(), (*d1.bits() >> 16) as u16);
256 }
257
258 #[test]
259 fn bf16_fused_mul_add() {
260 let a = 0x1234;
261 let b = 0x1234;
262 let c = 0x1234;
263 let a0 = BF16::from_bits(a);
264 let b0 = BF16::from_bits(b);
265 let c0 = BF16::from_bits(c);
266 let d0 = a0.fused_mul_add(b0, c0, RoundingMode::TiesToEven);
267 let a1 = simple_soft_float::F32::from_bits((a as u32) << 16);
268 let b1 = simple_soft_float::F32::from_bits((b as u32) << 16);
269 let c1 = simple_soft_float::F32::from_bits((c as u32) << 16);
270 let d1 = a1.fused_mul_add(
271 &b1,
272 &c1,
273 Some(simple_soft_float::RoundingMode::TiesToEven),
274 None,
275 );
276 assert_eq!(d0.to_bits(), (*d1.bits() >> 16) as u16);
277 }
278
279 #[test]
280 fn bf16_div() {
281 let a = 0x7654;
282 let b = 0x1234;
283 let a0 = BF16::from_bits(a);
284 let b0 = BF16::from_bits(b);
285 let d0 = a0.div(b0, RoundingMode::TiesToEven);
286 let a1 = simple_soft_float::F32::from_bits((a as u32) << 16);
287 let b1 = simple_soft_float::F32::from_bits((b as u32) << 16);
288 let d1 = a1.div(&b1, Some(simple_soft_float::RoundingMode::TiesToEven), None);
289 assert_eq!(d0.to_bits(), (*d1.bits() >> 16) as u16);
290 }
291
292 #[test]
293 fn bf16_rem() {
294 let a = 0x7654;
295 let b = 0x1234;
296 let a0 = BF16::from_bits(a);
297 let b0 = BF16::from_bits(b);
298 let d0 = a0.rem(b0, RoundingMode::TiesToEven);
299 let a1 = simple_soft_float::F32::from_bits((a as u32) << 16);
300 let b1 = simple_soft_float::F32::from_bits((b as u32) << 16);
301 let d1 = a1.ieee754_remainder(&b1, Some(simple_soft_float::RoundingMode::TiesToEven), None);
302 assert_eq!(d0.to_bits(), (*d1.bits() >> 16) as u16);
303 }
304
305 #[test]
306 fn bf16_sqrt() {
307 let a = 0x7654;
308 let a0 = BF16::from_bits(a);
309 let d0 = a0.sqrt(RoundingMode::TiesToEven);
310 let a1 = simple_soft_float::F32::from_bits((a as u32) << 16);
311 let d1 = a1.sqrt(Some(simple_soft_float::RoundingMode::TiesToEven), None);
312 assert_eq!(d0.to_bits(), (*d1.bits() >> 16) as u16);
313 }
314
315 #[test]
316 fn bf16_compare() {
317 let a = BF16::from_bits(0x7654);
318 let b = BF16::from_bits(0x1234);
319 let d = a.compare(b);
320 assert_eq!(d, Some(Ordering::Greater));
321
322 let a = BF16::from_bits(0x1234);
323 let b = BF16::from_bits(0x7654);
324 let d = a.compare(b);
325 assert_eq!(d, Some(Ordering::Less));
326
327 let a = BF16::from_bits(0x1234);
328 let b = BF16::from_bits(0x1234);
329 let d = a.compare(b);
330 assert_eq!(d, Some(Ordering::Equal));
331 }
332
333 #[test]
334 fn bf16_signaling() {
335 let a = BF16::from_bits(0x7f81);
336 let b = BF16::from_bits(0x7fc1);
337 assert_eq!(a.is_signaling_nan(), true);
338 assert_eq!(b.is_signaling_nan(), false);
339
340 let mut flag = ExceptionFlags::default();
341 flag.set();
342 assert_eq!(a.eq(a), false);
343 flag.get();
344 assert_eq!(flag.is_invalid(), true);
345
346 let mut flag = ExceptionFlags::default();
347 flag.set();
348 assert_eq!(b.eq(b), false);
349 flag.get();
350 assert_eq!(flag.is_invalid(), false);
351
352 let mut flag = ExceptionFlags::default();
353 flag.set();
354 assert_eq!(a.eq_signaling(a), false);
355 flag.get();
356 assert_eq!(flag.is_invalid(), true);
357
358 let mut flag = ExceptionFlags::default();
359 flag.set();
360 assert_eq!(b.eq_signaling(b), false);
361 flag.get();
362 assert_eq!(flag.is_invalid(), true);
363 }
364
365 #[test]
366 fn from_f32() {
367 let a = BF16::from_f32(0.1);
368 assert_eq!(a.to_bits(), 0x3dcc);
369 }
370
371 #[test]
372 fn from_f64() {
373 let a = BF16::from_f64(0.1);
374 assert_eq!(a.to_bits(), 0x3dcc);
375 }
376}