ruvector_temporal_tensor/
f16.rs1#[inline]
12pub fn f32_to_f16_bits(x: f32) -> u16 {
13 let b = x.to_bits();
14 let sign = ((b >> 16) & 0x8000) as u16;
15 let exp = ((b >> 23) & 0xFF) as i32;
16 let mant = b & 0x7F_FFFF;
17
18 if exp == 255 {
20 if mant == 0 {
21 return sign | 0x7C00;
22 }
23 let nan_m = (mant >> 13) as u16;
24 return sign | 0x7C00 | nan_m | 1;
25 }
26
27 let exp16 = exp - 127 + 15;
28
29 if exp16 >= 31 {
31 return sign | 0x7C00;
32 }
33
34 if exp16 <= 0 {
36 if exp16 < -10 {
37 return sign;
38 }
39 let shift = (14 - exp16) as u32;
40 let mut mant32 = mant | 0x80_0000;
41 let round_bit = 1u32.wrapping_shl(shift.wrapping_sub(1));
42 mant32 = mant32.wrapping_add(round_bit);
43 let sub = (mant32 >> shift) as u16;
44 return sign | sub;
45 }
46
47 let mant16 = (mant >> 13) as u16;
49 let round = (mant >> 12) & 1;
50 let mut res = sign | ((exp16 as u16) << 10) | mant16;
51 if round != 0 {
52 res = res.wrapping_add(1);
53 }
54 res
55}
56
57#[inline]
62pub fn f16_bits_to_f32(h: u16) -> f32 {
63 let sign = ((h & 0x8000) as u32) << 16;
64 let exp = ((h >> 10) & 0x1F) as i32;
65 let mant = (h & 0x03FF) as u32;
66
67 if exp == 0 {
69 if mant == 0 {
70 return f32::from_bits(sign);
71 }
72 let mut e = 1i32;
73 let mut m = mant;
74 while (m & 0x0400) == 0 {
75 m <<= 1;
76 e += 1;
77 }
78 m &= 0x03FF;
79 let exp32 = 127 - 15 - e + 1;
80 let mant32 = m << 13;
81 return f32::from_bits(sign | ((exp32 as u32) << 23) | mant32);
82 }
83
84 if exp == 31 {
86 return f32::from_bits(sign | 0x7F80_0000 | (mant << 13));
87 }
88
89 let exp32 = exp - 15 + 127;
91 let mant32 = mant << 13;
92 f32::from_bits(sign | ((exp32 as u32) << 23) | mant32)
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98
99 #[test]
100 fn test_roundtrip_normal() {
101 for &v in &[0.0f32, 1.0, -1.0, 0.5, 65504.0, -65504.0, 0.0001] {
102 let h = f32_to_f16_bits(v);
103 let back = f16_bits_to_f32(h);
104 if v == 0.0 {
105 assert_eq!(back, 0.0);
106 } else {
107 let rel_err = ((back - v) / v).abs();
108 assert!(rel_err < 0.01, "v={v}, back={back}, rel_err={rel_err}");
109 }
110 }
111 }
112
113 #[test]
114 fn test_infinity() {
115 let h = f32_to_f16_bits(f32::INFINITY);
116 assert_eq!(h, 0x7C00);
117 assert!(f16_bits_to_f32(h).is_infinite());
118 }
119
120 #[test]
121 fn test_neg_infinity() {
122 let h = f32_to_f16_bits(f32::NEG_INFINITY);
123 assert_eq!(h, 0xFC00);
124 let back = f16_bits_to_f32(h);
125 assert!(back.is_infinite() && back < 0.0);
126 }
127
128 #[test]
129 fn test_nan() {
130 let h = f32_to_f16_bits(f32::NAN);
131 assert!(f16_bits_to_f32(h).is_nan());
132 }
133
134 #[test]
135 fn test_zero_signs() {
136 assert_eq!(f32_to_f16_bits(0.0f32), 0x0000);
137 assert_eq!(f32_to_f16_bits(-0.0f32), 0x8000);
138 }
139
140 #[test]
141 fn test_scale_range_accuracy() {
142 for exp in -4..=4i32 {
143 let v = 10.0f32.powi(exp);
144 let h = f32_to_f16_bits(v);
145 let back = f16_bits_to_f32(h);
146 let rel_err = ((back - v) / v).abs();
147 assert!(rel_err < 0.002, "v={v}, back={back}, rel_err={rel_err}");
148 }
149 }
150}