zenjxl_decoder_simd/
float16.rs1#[allow(non_camel_case_types)]
15#[derive(Copy, Clone, Default, PartialEq, Eq, Hash)]
16#[repr(transparent)]
17pub struct f16(u16);
18
19impl f16 {
20 pub const ZERO: Self = Self(0);
22
23 #[inline]
25 pub const fn from_bits(bits: u16) -> Self {
26 Self(bits)
27 }
28
29 #[inline]
31 pub const fn to_bits(self) -> u16 {
32 self.0
33 }
34
35 #[inline]
37 pub fn to_f32(self) -> f32 {
38 let bits = self.0;
39 let sign = ((bits >> 15) & 1) as u32;
40 let exp = ((bits >> 10) & 0x1F) as u32;
41 let mant = (bits & 0x3FF) as u32;
42
43 let f32_bits = if exp == 0 {
44 if mant == 0 {
45 sign << 31
47 } else {
48 let mut m = mant;
51 let mut e = 0u32;
52 while (m & 0x400) == 0 {
53 m <<= 1;
54 e += 1;
55 }
56 m &= 0x3FF; let new_exp = 127 - 14 - e;
59 (sign << 31) | (new_exp << 23) | (m << 13)
60 }
61 } else if exp == 31 {
62 if mant == 0 {
64 (sign << 31) | (0xFF << 23)
66 } else {
67 (sign << 31) | (0xFF << 23) | (mant << 13) | 0x0040_0000
69 }
70 } else {
71 let new_exp = exp + 112;
75 (sign << 31) | (new_exp << 23) | (mant << 13)
76 };
77
78 f32::from_bits(f32_bits)
79 }
80
81 #[inline]
83 pub fn from_f32(f: f32) -> Self {
84 let bits = f.to_bits();
85 let sign = ((bits >> 31) & 1) as u16;
86 let exp = ((bits >> 23) & 0xFF) as i32;
87 let mant = bits & 0x007F_FFFF;
88
89 let h_bits = if exp == 0 {
90 sign << 15
92 } else if exp == 255 {
93 if mant == 0 {
95 (sign << 15) | (0x1F << 10) } else {
97 (sign << 15) | (0x1F << 10) | 0x0200 }
99 } else {
100 let unbiased = exp - 127;
101
102 if unbiased < -24 {
103 sign << 15
105 } else if unbiased < -14 {
106 let shift = (-14 - unbiased) as u32;
108 let m = ((mant | 0x0080_0000) >> (shift + 14)) as u16;
109 (sign << 15) | m
110 } else if unbiased > 15 {
111 (sign << 15) | (0x1F << 10)
113 } else {
114 let h_exp = (unbiased + 15) as u16;
116 let h_mant = (mant >> 13) as u16;
117
118 let round_bit = (mant >> 12) & 1;
120 let sticky = mant & 0x0FFF;
121 let h_mant = if round_bit == 1 && (sticky != 0 || (h_mant & 1) == 1) {
122 h_mant + 1
123 } else {
124 h_mant
125 };
126
127 if h_mant > 0x3FF {
129 if h_exp >= 30 {
130 (sign << 15) | (0x1F << 10)
132 } else {
133 (sign << 15) | ((h_exp + 1) << 10)
134 }
135 } else {
136 (sign << 15) | (h_exp << 10) | h_mant
137 }
138 }
139 };
140
141 Self(h_bits)
142 }
143
144 #[inline]
146 pub fn from_f64(f: f64) -> Self {
147 Self::from_f32(f as f32)
149 }
150
151 #[inline]
153 pub fn to_f64(self) -> f64 {
154 self.to_f32() as f64
155 }
156
157 #[inline]
159 pub fn is_finite(self) -> bool {
160 ((self.0 >> 10) & 0x1F) != 31
162 }
163
164 #[inline]
166 pub const fn to_le_bytes(self) -> [u8; 2] {
167 self.0.to_le_bytes()
168 }
169
170 #[inline]
172 pub const fn to_be_bytes(self) -> [u8; 2] {
173 self.0.to_be_bytes()
174 }
175}
176
177impl From<f16> for f32 {
178 #[inline]
179 fn from(f: f16) -> f32 {
180 f.to_f32()
181 }
182}
183
184impl From<f16> for f64 {
185 #[inline]
186 fn from(f: f16) -> f64 {
187 f.to_f64()
188 }
189}
190
191impl core::fmt::Debug for f16 {
192 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
193 write!(f, "{}", self.to_f32())
194 }
195}
196
197impl core::fmt::Display for f16 {
198 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
199 write!(f, "{}", self.to_f32())
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206
207 #[test]
208 fn test_zero() {
209 let z = f16::ZERO;
210 assert_eq!(z.to_bits(), 0);
211 assert_eq!(z.to_f32(), 0.0);
212 assert!(z.is_finite());
213 }
214
215 #[test]
216 fn test_one() {
217 let one = f16::from_bits(0x3C00);
219 assert!((one.to_f32() - 1.0).abs() < 1e-6);
220 assert!(one.is_finite());
221 }
222
223 #[test]
224 fn test_negative_one() {
225 let neg_one = f16::from_bits(0xBC00);
227 assert!((neg_one.to_f32() - (-1.0)).abs() < 1e-6);
228 }
229
230 #[test]
231 fn test_infinity() {
232 let inf = f16::from_bits(0x7C00);
234 assert!(inf.to_f32().is_infinite());
235 assert!(!inf.is_finite());
236
237 let neg_inf = f16::from_bits(0xFC00);
239 assert!(neg_inf.to_f32().is_infinite());
240 assert!(!neg_inf.is_finite());
241 }
242
243 #[test]
244 fn test_nan() {
245 let nan = f16::from_bits(0x7C01);
247 assert!(nan.to_f32().is_nan());
248 assert!(!nan.is_finite());
249 }
250
251 #[test]
252 fn test_denormal() {
253 let tiny = f16::from_bits(0x0001);
255 let val = tiny.to_f32();
256 assert!(val > 0.0);
257 assert!(val < 1e-6);
258 assert!(tiny.is_finite());
259 }
260
261 #[test]
262 fn test_roundtrip_normal() {
263 let test_values: [f32; 8] = [0.5, 1.0, 2.0, 100.0, 0.001, -0.5, -1.0, -100.0];
264 for &v in &test_values {
265 let h = f16::from_f32(v);
266 let back = h.to_f32();
267 let rel_err = ((v - back) / v).abs();
269 assert!(
270 rel_err < 0.002,
271 "Roundtrip failed for {}: got {}, rel_err {}",
272 v,
273 back,
274 rel_err
275 );
276 }
277 }
278
279 #[test]
280 fn test_roundtrip_special() {
281 assert_eq!(f16::from_f32(0.0).to_f32(), 0.0);
283
284 assert!(f16::from_f32(f32::INFINITY).to_f32().is_infinite());
286 assert!(f16::from_f32(f32::NEG_INFINITY).to_f32().is_infinite());
287
288 assert!(f16::from_f32(f32::NAN).to_f32().is_nan());
290 }
291
292 #[test]
293 fn test_overflow_to_infinity() {
294 let big = f16::from_f32(100000.0);
296 assert!(big.to_f32().is_infinite());
297 }
298
299 #[test]
300 fn test_underflow_to_zero() {
301 let tiny = f16::from_f32(1e-10);
303 assert_eq!(tiny.to_f32(), 0.0);
304 }
305
306 #[test]
307 fn test_bytes() {
308 let h = f16::from_bits(0x1234);
309 assert_eq!(h.to_le_bytes(), [0x34, 0x12]);
310 assert_eq!(h.to_be_bytes(), [0x12, 0x34]);
311 }
312}