Skip to main content

zenjxl_decoder_simd/
float16.rs

1// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2//
3// Use of this source code is governed by a BSD-style
4// license that can be found in the LICENSE file.
5
6//! IEEE 754 half-precision (binary16) floating-point type.
7//!
8//! This is a minimal implementation providing only the operations needed for JPEG XL decoding,
9//! avoiding external dependencies like `half` which pulls in `zerocopy`.
10
11/// IEEE 754 binary16 half-precision floating-point type.
12///
13/// Format: 1 sign bit, 5 exponent bits (bias 15), 10 mantissa bits.
14#[allow(non_camel_case_types)]
15#[derive(Copy, Clone, Default, PartialEq, Eq, Hash)]
16#[repr(transparent)]
17pub struct f16(u16);
18
19impl f16 {
20    /// Positive zero.
21    pub const ZERO: Self = Self(0);
22
23    /// Creates an f16 from its raw bit representation.
24    #[inline]
25    pub const fn from_bits(bits: u16) -> Self {
26        Self(bits)
27    }
28
29    /// Returns the raw bit representation.
30    #[inline]
31    pub const fn to_bits(self) -> u16 {
32        self.0
33    }
34
35    /// Converts to f32.
36    #[inline]
37    pub fn to_f32(self) -> f32 {
38        let bits = self.0;
39        let sign = ((bits >> 15) & 1) as u32;
40        let exp = ((bits >> 10) & 0x1F) as u32;
41        let mant = (bits & 0x3FF) as u32;
42
43        let f32_bits = if exp == 0 {
44            if mant == 0 {
45                // Zero (signed)
46                sign << 31
47            } else {
48                // Denormal f16 -> normalized f32
49                // Find the leading 1 bit in mantissa
50                let mut m = mant;
51                let mut e = 0u32;
52                while (m & 0x400) == 0 {
53                    m <<= 1;
54                    e += 1;
55                }
56                m &= 0x3FF; // Remove the implicit leading 1
57                // f16 denormal exponent is -14 (not -15), adjust by shift count
58                let new_exp = 127 - 14 - e;
59                (sign << 31) | (new_exp << 23) | (m << 13)
60            }
61        } else if exp == 31 {
62            // Infinity or NaN
63            if mant == 0 {
64                // Infinity
65                (sign << 31) | (0xFF << 23)
66            } else {
67                // NaN - preserve some payload bits, ensure quiet NaN
68                (sign << 31) | (0xFF << 23) | (mant << 13) | 0x0040_0000
69            }
70        } else {
71            // Normal number
72            // Rebias: f16 uses bias 15, f32 uses bias 127
73            // new_exp = exp - 15 + 127 = exp + 112
74            let new_exp = exp + 112;
75            (sign << 31) | (new_exp << 23) | (mant << 13)
76        };
77
78        f32::from_bits(f32_bits)
79    }
80
81    /// Creates an f16 from an f32.
82    #[inline]
83    pub fn from_f32(f: f32) -> Self {
84        let bits = f.to_bits();
85        let sign = ((bits >> 31) & 1) as u16;
86        let exp = ((bits >> 23) & 0xFF) as i32;
87        let mant = bits & 0x007F_FFFF;
88
89        let h_bits = if exp == 0 {
90            // Zero or f32 denormal -> f16 zero (too small)
91            sign << 15
92        } else if exp == 255 {
93            // Infinity or NaN
94            if mant == 0 {
95                (sign << 15) | (0x1F << 10) // Infinity
96            } else {
97                (sign << 15) | (0x1F << 10) | 0x0200 // Quiet NaN
98            }
99        } else {
100            let unbiased = exp - 127;
101
102            if unbiased < -24 {
103                // Too small, underflow to zero
104                sign << 15
105            } else if unbiased < -14 {
106                // Denormal f16
107                let shift = (-14 - unbiased) as u32;
108                let m = ((mant | 0x0080_0000) >> (shift + 14)) as u16;
109                (sign << 15) | m
110            } else if unbiased > 15 {
111                // Overflow to infinity
112                (sign << 15) | (0x1F << 10)
113            } else {
114                // Normal f16
115                let h_exp = (unbiased + 15) as u16;
116                let h_mant = (mant >> 13) as u16;
117
118                // Round to nearest, ties to even
119                let round_bit = (mant >> 12) & 1;
120                let sticky = mant & 0x0FFF;
121                let h_mant = if round_bit == 1 && (sticky != 0 || (h_mant & 1) == 1) {
122                    h_mant + 1
123                } else {
124                    h_mant
125                };
126
127                // Handle mantissa overflow from rounding
128                if h_mant > 0x3FF {
129                    if h_exp >= 30 {
130                        // Overflow to infinity
131                        (sign << 15) | (0x1F << 10)
132                    } else {
133                        (sign << 15) | ((h_exp + 1) << 10)
134                    }
135                } else {
136                    (sign << 15) | (h_exp << 10) | h_mant
137                }
138            }
139        };
140
141        Self(h_bits)
142    }
143
144    /// Creates an f16 from an f64.
145    #[inline]
146    pub fn from_f64(f: f64) -> Self {
147        // Convert via f32 - sufficient precision for f16
148        Self::from_f32(f as f32)
149    }
150
151    /// Converts to f64.
152    #[inline]
153    pub fn to_f64(self) -> f64 {
154        self.to_f32() as f64
155    }
156
157    /// Returns true if this is neither infinite nor NaN.
158    #[inline]
159    pub fn is_finite(self) -> bool {
160        // Exponent of 31 means infinity or NaN
161        ((self.0 >> 10) & 0x1F) != 31
162    }
163
164    /// Returns the bytes in little-endian order.
165    #[inline]
166    pub const fn to_le_bytes(self) -> [u8; 2] {
167        self.0.to_le_bytes()
168    }
169
170    /// Returns the bytes in big-endian order.
171    #[inline]
172    pub const fn to_be_bytes(self) -> [u8; 2] {
173        self.0.to_be_bytes()
174    }
175}
176
177impl From<f16> for f32 {
178    #[inline]
179    fn from(f: f16) -> f32 {
180        f.to_f32()
181    }
182}
183
184impl From<f16> for f64 {
185    #[inline]
186    fn from(f: f16) -> f64 {
187        f.to_f64()
188    }
189}
190
191impl core::fmt::Debug for f16 {
192    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
193        write!(f, "{}", self.to_f32())
194    }
195}
196
197impl core::fmt::Display for f16 {
198    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
199        write!(f, "{}", self.to_f32())
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206
207    #[test]
208    fn test_zero() {
209        let z = f16::ZERO;
210        assert_eq!(z.to_bits(), 0);
211        assert_eq!(z.to_f32(), 0.0);
212        assert!(z.is_finite());
213    }
214
215    #[test]
216    fn test_one() {
217        // 1.0 in f16: sign=0, exp=15 (biased), mant=0 -> 0x3C00
218        let one = f16::from_bits(0x3C00);
219        assert!((one.to_f32() - 1.0).abs() < 1e-6);
220        assert!(one.is_finite());
221    }
222
223    #[test]
224    fn test_negative_one() {
225        // -1.0 in f16: sign=1, exp=15, mant=0 -> 0xBC00
226        let neg_one = f16::from_bits(0xBC00);
227        assert!((neg_one.to_f32() - (-1.0)).abs() < 1e-6);
228    }
229
230    #[test]
231    fn test_infinity() {
232        // +Inf: sign=0, exp=31, mant=0 -> 0x7C00
233        let inf = f16::from_bits(0x7C00);
234        assert!(inf.to_f32().is_infinite());
235        assert!(!inf.is_finite());
236
237        // -Inf: 0xFC00
238        let neg_inf = f16::from_bits(0xFC00);
239        assert!(neg_inf.to_f32().is_infinite());
240        assert!(!neg_inf.is_finite());
241    }
242
243    #[test]
244    fn test_nan() {
245        // NaN: exp=31, mant!=0 -> 0x7C01 (or any mant != 0)
246        let nan = f16::from_bits(0x7C01);
247        assert!(nan.to_f32().is_nan());
248        assert!(!nan.is_finite());
249    }
250
251    #[test]
252    fn test_denormal() {
253        // Smallest positive denormal: 0x0001
254        let tiny = f16::from_bits(0x0001);
255        let val = tiny.to_f32();
256        assert!(val > 0.0);
257        assert!(val < 1e-6);
258        assert!(tiny.is_finite());
259    }
260
261    #[test]
262    fn test_roundtrip_normal() {
263        let test_values: [f32; 8] = [0.5, 1.0, 2.0, 100.0, 0.001, -0.5, -1.0, -100.0];
264        for &v in &test_values {
265            let h = f16::from_f32(v);
266            let back = h.to_f32();
267            // f16 has limited precision, allow ~0.1% error for normal values
268            let rel_err = ((v - back) / v).abs();
269            assert!(
270                rel_err < 0.002,
271                "Roundtrip failed for {}: got {}, rel_err {}",
272                v,
273                back,
274                rel_err
275            );
276        }
277    }
278
279    #[test]
280    fn test_roundtrip_special() {
281        // Zero
282        assert_eq!(f16::from_f32(0.0).to_f32(), 0.0);
283
284        // Infinity
285        assert!(f16::from_f32(f32::INFINITY).to_f32().is_infinite());
286        assert!(f16::from_f32(f32::NEG_INFINITY).to_f32().is_infinite());
287
288        // NaN
289        assert!(f16::from_f32(f32::NAN).to_f32().is_nan());
290    }
291
292    #[test]
293    fn test_overflow_to_infinity() {
294        // f16 max is ~65504, values above should overflow to infinity
295        let big = f16::from_f32(100000.0);
296        assert!(big.to_f32().is_infinite());
297    }
298
299    #[test]
300    fn test_underflow_to_zero() {
301        // Very small values should underflow to zero
302        let tiny = f16::from_f32(1e-10);
303        assert_eq!(tiny.to_f32(), 0.0);
304    }
305
306    #[test]
307    fn test_bytes() {
308        let h = f16::from_bits(0x1234);
309        assert_eq!(h.to_le_bytes(), [0x34, 0x12]);
310        assert_eq!(h.to_be_bytes(), [0x12, 0x34]);
311    }
312}