Skip to main content

ruvector_temporal_tensor/
f16.rs

1//! Software IEEE 754 half-precision (f16) conversion.
2//!
3//! No external crate dependencies. Handles normals, denormals, infinity, and NaN.
4//! Round-to-nearest with ties-to-even for normal values.
5
6/// Convert f32 to f16 bit representation.
7///
8/// Handles all IEEE 754 special cases: infinity, NaN, denormals, and zero (both signs).
9/// Values outside f16 range saturate to infinity. Values too small for f16 denormals
10/// flush to zero.
11#[inline]
12pub fn f32_to_f16_bits(x: f32) -> u16 {
13    let b = x.to_bits();
14    let sign = ((b >> 16) & 0x8000) as u16;
15    let exp = ((b >> 23) & 0xFF) as i32;
16    let mant = b & 0x7F_FFFF;
17
18    // Infinity or NaN
19    if exp == 255 {
20        if mant == 0 {
21            return sign | 0x7C00;
22        }
23        let nan_m = (mant >> 13) as u16;
24        return sign | 0x7C00 | nan_m | 1;
25    }
26
27    let exp16 = exp - 127 + 15;
28
29    // Overflow -> Infinity
30    if exp16 >= 31 {
31        return sign | 0x7C00;
32    }
33
34    // Underflow -> denormal or zero
35    if exp16 <= 0 {
36        if exp16 < -10 {
37            return sign;
38        }
39        let shift = (14 - exp16) as u32;
40        let mut mant32 = mant | 0x80_0000;
41        let round_bit = 1u32.wrapping_shl(shift.wrapping_sub(1));
42        mant32 = mant32.wrapping_add(round_bit);
43        let sub = (mant32 >> shift) as u16;
44        return sign | sub;
45    }
46
47    // Normal case
48    let mant16 = (mant >> 13) as u16;
49    let round = (mant >> 12) & 1;
50    let mut res = sign | ((exp16 as u16) << 10) | mant16;
51    if round != 0 {
52        res = res.wrapping_add(1);
53    }
54    res
55}
56
57/// Convert f16 bit representation to f32.
58///
59/// Exactly reconstructs the f32 value represented by the f16 bit pattern.
60/// Handles denormals by normalizing the mantissa before constructing the f32 bits.
61#[inline]
62pub fn f16_bits_to_f32(h: u16) -> f32 {
63    let sign = ((h & 0x8000) as u32) << 16;
64    let exp = ((h >> 10) & 0x1F) as i32;
65    let mant = (h & 0x03FF) as u32;
66
67    // Zero or denormal
68    if exp == 0 {
69        if mant == 0 {
70            return f32::from_bits(sign);
71        }
72        let mut e = 1i32;
73        let mut m = mant;
74        while (m & 0x0400) == 0 {
75            m <<= 1;
76            e += 1;
77        }
78        m &= 0x03FF;
79        let exp32 = 127 - 15 - e + 1;
80        let mant32 = m << 13;
81        return f32::from_bits(sign | ((exp32 as u32) << 23) | mant32);
82    }
83
84    // Infinity or NaN
85    if exp == 31 {
86        return f32::from_bits(sign | 0x7F80_0000 | (mant << 13));
87    }
88
89    // Normal
90    let exp32 = exp - 15 + 127;
91    let mant32 = mant << 13;
92    f32::from_bits(sign | ((exp32 as u32) << 23) | mant32)
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98
99    #[test]
100    fn test_roundtrip_normal() {
101        for &v in &[0.0f32, 1.0, -1.0, 0.5, 65504.0, -65504.0, 0.0001] {
102            let h = f32_to_f16_bits(v);
103            let back = f16_bits_to_f32(h);
104            if v == 0.0 {
105                assert_eq!(back, 0.0);
106            } else {
107                let rel_err = ((back - v) / v).abs();
108                assert!(rel_err < 0.01, "v={v}, back={back}, rel_err={rel_err}");
109            }
110        }
111    }
112
113    #[test]
114    fn test_infinity() {
115        let h = f32_to_f16_bits(f32::INFINITY);
116        assert_eq!(h, 0x7C00);
117        assert!(f16_bits_to_f32(h).is_infinite());
118    }
119
120    #[test]
121    fn test_neg_infinity() {
122        let h = f32_to_f16_bits(f32::NEG_INFINITY);
123        assert_eq!(h, 0xFC00);
124        let back = f16_bits_to_f32(h);
125        assert!(back.is_infinite() && back < 0.0);
126    }
127
128    #[test]
129    fn test_nan() {
130        let h = f32_to_f16_bits(f32::NAN);
131        assert!(f16_bits_to_f32(h).is_nan());
132    }
133
134    #[test]
135    fn test_zero_signs() {
136        assert_eq!(f32_to_f16_bits(0.0f32), 0x0000);
137        assert_eq!(f32_to_f16_bits(-0.0f32), 0x8000);
138    }
139
140    #[test]
141    fn test_scale_range_accuracy() {
142        for exp in -4..=4i32 {
143            let v = 10.0f32.powi(exp);
144            let h = f32_to_f16_bits(v);
145            let back = f16_bits_to_f32(h);
146            let rel_err = ((back - v) / v).abs();
147            assert!(rel_err < 0.002, "v={v}, back={back}, rel_err={rel_err}");
148        }
149    }
150}