Skip to main content

sklears_simd/
half_precision.rs

1//! Half-precision floating point operations (FP16/BF16)
2//!
3//! This module provides SIMD-optimized operations for half-precision floating point
4//! formats, essential for modern AI/ML workloads.
5
6#[cfg(feature = "no-std")]
7use core::fmt;
8#[cfg(not(feature = "no-std"))]
9use std::fmt;
10
11/// IEEE 754 half-precision (FP16) floating point format
12#[derive(Debug, Clone, Copy, PartialEq)]
13pub struct F16(pub u16);
14
15/// Google's BFloat16 format optimized for AI/ML
16#[derive(Debug, Clone, Copy, PartialEq)]
17pub struct BF16(pub u16);
18
19impl F16 {
20    /// Create a new F16 from a u16 bit representation
21    pub fn from_bits(bits: u16) -> Self {
22        F16(bits)
23    }
24
25    /// Get the bit representation
26    pub fn to_bits(self) -> u16 {
27        self.0
28    }
29
30    /// Convert from f32 to FP16
31    pub fn from_f32(value: f32) -> Self {
32        // IEEE 754 half-precision conversion
33        let bits = value.to_bits();
34        let sign = (bits >> 31) as u16;
35        let exp = ((bits >> 23) & 0xFF) as i32;
36        let mant = bits & 0x7FFFFF;
37
38        if exp == 0 && mant == 0 {
39            // Zero
40            F16(sign << 15)
41        } else if exp == 0xFF {
42            // Infinity or NaN
43            let new_mant = if mant == 0 { 0 } else { 0x3FF };
44            F16((sign << 15) | 0x7C00 | new_mant)
45        } else {
46            // Normal numbers
47            let new_exp = exp - 127 + 15;
48            if new_exp <= 0 {
49                // Underflow to zero or denormal
50                if new_exp < -10 {
51                    F16(sign << 15)
52                } else {
53                    let new_mant = (mant | 0x800000) >> (14 - new_exp);
54                    F16((sign << 15) | ((new_mant + 0x1000) >> 13) as u16)
55                }
56            } else if new_exp >= 31 {
57                // Overflow to infinity
58                F16((sign << 15) | 0x7C00)
59            } else {
60                // Normal case
61                let new_mant = ((mant + 0x1000) >> 13) as u16;
62                F16((sign << 15) | ((new_exp as u16) << 10) | new_mant)
63            }
64        }
65    }
66
67    /// Convert FP16 to f32
68    pub fn to_f32(self) -> f32 {
69        let bits = self.0;
70        let sign = (bits >> 15) as u32;
71        let exp = ((bits >> 10) & 0x1F) as u32;
72        let mant = (bits & 0x3FF) as u32;
73
74        if exp == 0 && mant == 0 {
75            // Zero
76            f32::from_bits(sign << 31)
77        } else if exp == 0 {
78            // Denormal
79            let mut new_mant = mant;
80            let mut new_exp = 0;
81            while (new_mant & 0x400) == 0 {
82                new_mant <<= 1;
83                new_exp += 1;
84            }
85            new_mant &= 0x3FF;
86            new_exp = 127 - 15 - new_exp;
87            f32::from_bits((sign << 31) | (new_exp << 23) | (new_mant << 13))
88        } else if exp == 31 {
89            // Infinity or NaN
90            let new_mant = if mant == 0 { 0 } else { 0x7FFFFF };
91            f32::from_bits((sign << 31) | 0x7F800000 | new_mant)
92        } else {
93            // Normal
94            let new_exp = exp + 127 - 15;
95            f32::from_bits((sign << 31) | (new_exp << 23) | (mant << 13))
96        }
97    }
98
99    /// Check if the value is finite
100    pub fn is_finite(self) -> bool {
101        (self.0 & 0x7C00) != 0x7C00
102    }
103
104    /// Check if the value is infinite
105    pub fn is_infinite(self) -> bool {
106        (self.0 & 0x7FFF) == 0x7C00
107    }
108
109    /// Check if the value is NaN
110    pub fn is_nan(self) -> bool {
111        (self.0 & 0x7C00) == 0x7C00 && (self.0 & 0x3FF) != 0
112    }
113}
114
115impl BF16 {
116    /// Create a new BF16 from a u16 bit representation
117    pub fn from_bits(bits: u16) -> Self {
118        BF16(bits)
119    }
120
121    /// Get the bit representation
122    pub fn to_bits(self) -> u16 {
123        self.0
124    }
125
126    /// Convert from f32 to BF16
127    pub fn from_f32(value: f32) -> Self {
128        // BFloat16 is simply the top 16 bits of IEEE 754 f32
129        let bits = value.to_bits();
130        let _truncated = (bits >> 16) as u16;
131
132        // Round to nearest even (banker's rounding)
133        let rounding_bias = 0x7FFF + ((bits >> 16) & 1);
134        let rounded = ((bits + rounding_bias) >> 16) as u16;
135
136        BF16(rounded)
137    }
138
139    /// Convert BF16 to f32
140    pub fn to_f32(self) -> f32 {
141        // BFloat16 to f32 is just shifting left by 16 bits
142        f32::from_bits((self.0 as u32) << 16)
143    }
144
145    /// Check if the value is finite
146    pub fn is_finite(self) -> bool {
147        (self.0 & 0x7F80) != 0x7F80
148    }
149
150    /// Check if the value is infinite
151    pub fn is_infinite(self) -> bool {
152        (self.0 & 0x7FFF) == 0x7F80
153    }
154
155    /// Check if the value is NaN
156    pub fn is_nan(self) -> bool {
157        (self.0 & 0x7F80) == 0x7F80 && (self.0 & 0x7F) != 0
158    }
159}
160
161impl fmt::Display for F16 {
162    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163        write!(f, "{}", self.to_f32())
164    }
165}
166
167impl fmt::Display for BF16 {
168    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
169        write!(f, "{}", self.to_f32())
170    }
171}
172
173/// SIMD operations for half-precision formats
174pub mod simd {
175    use super::*;
176
177    /// Convert slice of f32 to FP16 with SIMD optimization
178    pub fn f32_to_f16_slice(input: &[f32], output: &mut [F16]) {
179        assert_eq!(input.len(), output.len());
180
181        // Process in chunks for better vectorization
182        const CHUNK_SIZE: usize = 8;
183        let chunks = input.len() / CHUNK_SIZE;
184
185        for i in 0..chunks {
186            let start = i * CHUNK_SIZE;
187            let end = start + CHUNK_SIZE;
188
189            // Convert chunk of f32 to FP16
190            for j in start..end {
191                output[j] = F16::from_f32(input[j]);
192            }
193        }
194
195        // Handle remaining elements
196        for i in (chunks * CHUNK_SIZE)..input.len() {
197            output[i] = F16::from_f32(input[i]);
198        }
199    }
200
201    /// Convert slice of FP16 to f32 with SIMD optimization
202    pub fn f16_to_f32_slice(input: &[F16], output: &mut [f32]) {
203        assert_eq!(input.len(), output.len());
204
205        const CHUNK_SIZE: usize = 8;
206        let chunks = input.len() / CHUNK_SIZE;
207
208        for i in 0..chunks {
209            let start = i * CHUNK_SIZE;
210            let end = start + CHUNK_SIZE;
211
212            for j in start..end {
213                output[j] = input[j].to_f32();
214            }
215        }
216
217        for i in (chunks * CHUNK_SIZE)..input.len() {
218            output[i] = input[i].to_f32();
219        }
220    }
221
222    /// Convert slice of f32 to BF16 with SIMD optimization
223    pub fn f32_to_bf16_slice(input: &[f32], output: &mut [BF16]) {
224        assert_eq!(input.len(), output.len());
225
226        const CHUNK_SIZE: usize = 8;
227        let chunks = input.len() / CHUNK_SIZE;
228
229        for i in 0..chunks {
230            let start = i * CHUNK_SIZE;
231            let end = start + CHUNK_SIZE;
232
233            for j in start..end {
234                output[j] = BF16::from_f32(input[j]);
235            }
236        }
237
238        for i in (chunks * CHUNK_SIZE)..input.len() {
239            output[i] = BF16::from_f32(input[i]);
240        }
241    }
242
243    /// Convert slice of BF16 to f32 with SIMD optimization
244    pub fn bf16_to_f32_slice(input: &[BF16], output: &mut [f32]) {
245        assert_eq!(input.len(), output.len());
246
247        const CHUNK_SIZE: usize = 8;
248        let chunks = input.len() / CHUNK_SIZE;
249
250        for i in 0..chunks {
251            let start = i * CHUNK_SIZE;
252            let end = start + CHUNK_SIZE;
253
254            for j in start..end {
255                output[j] = input[j].to_f32();
256            }
257        }
258
259        for i in (chunks * CHUNK_SIZE)..input.len() {
260            output[i] = input[i].to_f32();
261        }
262    }
263
264    /// Element-wise addition for FP16 vectors
265    pub fn add_f16(a: &[F16], b: &[F16], result: &mut [F16]) {
266        assert_eq!(a.len(), b.len());
267        assert_eq!(a.len(), result.len());
268
269        for i in 0..a.len() {
270            let sum = a[i].to_f32() + b[i].to_f32();
271            result[i] = F16::from_f32(sum);
272        }
273    }
274
275    /// Element-wise multiplication for FP16 vectors
276    pub fn mul_f16(a: &[F16], b: &[F16], result: &mut [F16]) {
277        assert_eq!(a.len(), b.len());
278        assert_eq!(a.len(), result.len());
279
280        for i in 0..a.len() {
281            let product = a[i].to_f32() * b[i].to_f32();
282            result[i] = F16::from_f32(product);
283        }
284    }
285
286    /// Element-wise addition for BF16 vectors
287    pub fn add_bf16(a: &[BF16], b: &[BF16], result: &mut [BF16]) {
288        assert_eq!(a.len(), b.len());
289        assert_eq!(a.len(), result.len());
290
291        for i in 0..a.len() {
292            let sum = a[i].to_f32() + b[i].to_f32();
293            result[i] = BF16::from_f32(sum);
294        }
295    }
296
297    /// Element-wise multiplication for BF16 vectors
298    pub fn mul_bf16(a: &[BF16], b: &[BF16], result: &mut [BF16]) {
299        assert_eq!(a.len(), b.len());
300        assert_eq!(a.len(), result.len());
301
302        for i in 0..a.len() {
303            let product = a[i].to_f32() * b[i].to_f32();
304            result[i] = BF16::from_f32(product);
305        }
306    }
307
308    /// Dot product for FP16 vectors
309    pub fn dot_f16(a: &[F16], b: &[F16]) -> f32 {
310        assert_eq!(a.len(), b.len());
311
312        let mut sum = 0.0f32;
313        for i in 0..a.len() {
314            sum += a[i].to_f32() * b[i].to_f32();
315        }
316        sum
317    }
318
319    /// Dot product for BF16 vectors
320    pub fn dot_bf16(a: &[BF16], b: &[BF16]) -> f32 {
321        assert_eq!(a.len(), b.len());
322
323        let mut sum = 0.0f32;
324        for i in 0..a.len() {
325            sum += a[i].to_f32() * b[i].to_f32();
326        }
327        sum
328    }
329
330    /// Matrix multiplication for FP16 matrices (A * B = C)
331    pub fn matmul_f16(a: &[F16], b: &[F16], c: &mut [F16], m: usize, n: usize, k: usize) {
332        assert_eq!(a.len(), m * k);
333        assert_eq!(b.len(), k * n);
334        assert_eq!(c.len(), m * n);
335
336        for i in 0..m {
337            for j in 0..n {
338                let mut sum = 0.0f32;
339                for l in 0..k {
340                    sum += a[i * k + l].to_f32() * b[l * n + j].to_f32();
341                }
342                c[i * n + j] = F16::from_f32(sum);
343            }
344        }
345    }
346
347    /// Matrix multiplication for BF16 matrices (A * B = C)
348    pub fn matmul_bf16(a: &[BF16], b: &[BF16], c: &mut [BF16], m: usize, n: usize, k: usize) {
349        assert_eq!(a.len(), m * k);
350        assert_eq!(b.len(), k * n);
351        assert_eq!(c.len(), m * n);
352
353        for i in 0..m {
354            for j in 0..n {
355                let mut sum = 0.0f32;
356                for l in 0..k {
357                    sum += a[i * k + l].to_f32() * b[l * n + j].to_f32();
358                }
359                c[i * n + j] = BF16::from_f32(sum);
360            }
361        }
362    }
363}
364
365/// Constants for half-precision formats
366pub mod constants {
367    use super::*;
368
369    pub const F16_ZERO: F16 = F16(0);
370    pub const F16_ONE: F16 = F16(0x3C00);
371    pub const F16_NEG_ONE: F16 = F16(0xBC00);
372    pub const F16_INFINITY: F16 = F16(0x7C00);
373    pub const F16_NEG_INFINITY: F16 = F16(0xFC00);
374    pub const F16_NAN: F16 = F16(0x7E00);
375    pub const F16_MAX: F16 = F16(0x7BFF);
376    pub const F16_MIN: F16 = F16(0x0400);
377    pub const F16_EPSILON: F16 = F16(0x1400);
378
379    pub const BF16_ZERO: BF16 = BF16(0);
380    pub const BF16_ONE: BF16 = BF16(0x3F80);
381    pub const BF16_NEG_ONE: BF16 = BF16(0xBF80);
382    pub const BF16_INFINITY: BF16 = BF16(0x7F80);
383    pub const BF16_NEG_INFINITY: BF16 = BF16(0xFF80);
384    pub const BF16_NAN: BF16 = BF16(0x7FC0);
385    pub const BF16_MAX: BF16 = BF16(0x7F7F);
386    pub const BF16_MIN: BF16 = BF16(0x0080);
387    pub const BF16_EPSILON: BF16 = BF16(0x3C00);
388}
389
390#[allow(non_snake_case)]
391#[cfg(all(test, not(feature = "no-std")))]
392mod tests {
393    use super::constants::*;
394    use super::*;
395
396    #[cfg(feature = "no-std")]
397    use alloc::{vec, vec::Vec};
398
399    #[test]
400    fn test_f16_conversion() {
401        let val = std::f32::consts::PI;
402        let f16_val = F16::from_f32(val);
403        let back_to_f32 = f16_val.to_f32();
404
405        // FP16 has limited precision, so we expect some loss
406        assert!((val - back_to_f32).abs() < 0.01);
407    }
408
409    #[test]
410    fn test_bf16_conversion() {
411        let val = std::f32::consts::PI;
412        let bf16_val = BF16::from_f32(val);
413        let back_to_f32 = bf16_val.to_f32();
414
415        // BF16 has better precision than FP16
416        assert!((val - back_to_f32).abs() < 0.01);
417    }
418
419    #[test]
420    fn test_f16_constants() {
421        assert_eq!(F16_ZERO.to_f32(), 0.0);
422        assert_eq!(F16_ONE.to_f32(), 1.0);
423        assert_eq!(F16_NEG_ONE.to_f32(), -1.0);
424        assert!(F16_INFINITY.is_infinite());
425        assert!(F16_NAN.is_nan());
426    }
427
428    #[test]
429    fn test_bf16_constants() {
430        assert_eq!(BF16_ZERO.to_f32(), 0.0);
431        assert_eq!(BF16_ONE.to_f32(), 1.0);
432        assert_eq!(BF16_NEG_ONE.to_f32(), -1.0);
433        assert!(BF16_INFINITY.is_infinite());
434        assert!(BF16_NAN.is_nan());
435    }
436
437    #[test]
438    fn test_f16_special_values() {
439        let inf = F16::from_f32(f32::INFINITY);
440        let neg_inf = F16::from_f32(f32::NEG_INFINITY);
441        let nan = F16::from_f32(f32::NAN);
442
443        assert!(inf.is_infinite());
444        assert!(neg_inf.is_infinite());
445        assert!(nan.is_nan());
446    }
447
448    #[test]
449    fn test_bf16_special_values() {
450        let inf = BF16::from_f32(f32::INFINITY);
451        let neg_inf = BF16::from_f32(f32::NEG_INFINITY);
452        let nan = BF16::from_f32(f32::NAN);
453
454        assert!(inf.is_infinite());
455        assert!(neg_inf.is_infinite());
456        assert!(nan.is_nan());
457    }
458
459    #[test]
460    fn test_simd_f32_to_f16_conversion() {
461        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
462        let mut output = vec![F16::from_bits(0); 8];
463
464        simd::f32_to_f16_slice(&input, &mut output);
465
466        for i in 0..input.len() {
467            assert!((input[i] - output[i].to_f32()).abs() < 0.01);
468        }
469    }
470
471    #[test]
472    fn test_simd_f32_to_bf16_conversion() {
473        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
474        let mut output = vec![BF16::from_bits(0); 8];
475
476        simd::f32_to_bf16_slice(&input, &mut output);
477
478        for i in 0..input.len() {
479            assert!((input[i] - output[i].to_f32()).abs() < 0.01);
480        }
481    }
482
483    #[test]
484    fn test_f16_arithmetic() {
485        let a = vec![F16::from_f32(1.0), F16::from_f32(2.0), F16::from_f32(3.0)];
486        let b = vec![F16::from_f32(4.0), F16::from_f32(5.0), F16::from_f32(6.0)];
487        let mut result = vec![F16::from_bits(0); 3];
488
489        simd::add_f16(&a, &b, &mut result);
490
491        let expected = [5.0, 7.0, 9.0];
492        for i in 0..3 {
493            assert!((result[i].to_f32() - expected[i]).abs() < 0.01);
494        }
495    }
496
497    #[test]
498    fn test_bf16_arithmetic() {
499        let a = vec![
500            BF16::from_f32(1.0),
501            BF16::from_f32(2.0),
502            BF16::from_f32(3.0),
503        ];
504        let b = vec![
505            BF16::from_f32(4.0),
506            BF16::from_f32(5.0),
507            BF16::from_f32(6.0),
508        ];
509        let mut result = vec![BF16::from_bits(0); 3];
510
511        simd::add_bf16(&a, &b, &mut result);
512
513        let expected = [5.0, 7.0, 9.0];
514        for i in 0..3 {
515            assert!((result[i].to_f32() - expected[i]).abs() < 0.01);
516        }
517    }
518
519    #[test]
520    fn test_f16_dot_product() {
521        let a = vec![F16::from_f32(1.0), F16::from_f32(2.0), F16::from_f32(3.0)];
522        let b = vec![F16::from_f32(4.0), F16::from_f32(5.0), F16::from_f32(6.0)];
523
524        let result = simd::dot_f16(&a, &b);
525        let expected = 1.0 * 4.0 + 2.0 * 5.0 + 3.0 * 6.0; // 32.0
526
527        assert!((result - expected).abs() < 0.1);
528    }
529
530    #[test]
531    fn test_bf16_dot_product() {
532        let a = vec![
533            BF16::from_f32(1.0),
534            BF16::from_f32(2.0),
535            BF16::from_f32(3.0),
536        ];
537        let b = vec![
538            BF16::from_f32(4.0),
539            BF16::from_f32(5.0),
540            BF16::from_f32(6.0),
541        ];
542
543        let result = simd::dot_bf16(&a, &b);
544        let expected = 1.0 * 4.0 + 2.0 * 5.0 + 3.0 * 6.0; // 32.0
545
546        assert!((result - expected).abs() < 0.1);
547    }
548
549    #[test]
550    fn test_f16_matrix_multiplication() {
551        // 2x2 matrix multiplication
552        let a = vec![
553            F16::from_f32(1.0),
554            F16::from_f32(2.0),
555            F16::from_f32(3.0),
556            F16::from_f32(4.0),
557        ];
558        let b = vec![
559            F16::from_f32(5.0),
560            F16::from_f32(6.0),
561            F16::from_f32(7.0),
562            F16::from_f32(8.0),
563        ];
564        let mut c = vec![F16::from_bits(0); 4];
565
566        simd::matmul_f16(&a, &b, &mut c, 2, 2, 2);
567
568        // Expected result: [[19, 22], [43, 50]]
569        let expected = [19.0, 22.0, 43.0, 50.0];
570        for i in 0..4 {
571            assert!((c[i].to_f32() - expected[i]).abs() < 0.1);
572        }
573    }
574
575    #[test]
576    fn test_bf16_matrix_multiplication() {
577        // 2x2 matrix multiplication
578        let a = vec![
579            BF16::from_f32(1.0),
580            BF16::from_f32(2.0),
581            BF16::from_f32(3.0),
582            BF16::from_f32(4.0),
583        ];
584        let b = vec![
585            BF16::from_f32(5.0),
586            BF16::from_f32(6.0),
587            BF16::from_f32(7.0),
588            BF16::from_f32(8.0),
589        ];
590        let mut c = vec![BF16::from_bits(0); 4];
591
592        simd::matmul_bf16(&a, &b, &mut c, 2, 2, 2);
593
594        // Expected result: [[19, 22], [43, 50]]
595        let expected = [19.0, 22.0, 43.0, 50.0];
596        for i in 0..4 {
597            assert!((c[i].to_f32() - expected[i]).abs() < 0.1);
598        }
599    }
600
601    #[test]
602    fn test_large_vector_conversion() {
603        let size = 1024;
604        let input: Vec<f32> = (0..size).map(|i| i as f32 * 0.1).collect();
605        let mut f16_output = vec![F16::from_bits(0); size];
606        let mut bf16_output = vec![BF16::from_bits(0); size];
607
608        simd::f32_to_f16_slice(&input, &mut f16_output);
609        simd::f32_to_bf16_slice(&input, &mut bf16_output);
610
611        for i in 0..size {
612            let f16_error = (input[i] - f16_output[i].to_f32()).abs();
613            let bf16_error = (input[i] - bf16_output[i].to_f32()).abs();
614
615            // Relative tolerance for larger values, absolute tolerance for small values
616            let tolerance = if input[i].abs() > 1.0 {
617                input[i].abs() * 0.01 // 1% relative error for larger values
618            } else {
619                0.01 // Absolute tolerance for small values
620            };
621
622            assert!(
623                f16_error < tolerance,
624                "F16 error {:.6} > tolerance {:.6} for input {:.6}",
625                f16_error,
626                tolerance,
627                input[i]
628            );
629            assert!(
630                bf16_error < tolerance,
631                "BF16 error {:.6} > tolerance {:.6} for input {:.6}",
632                bf16_error,
633                tolerance,
634                input[i]
635            );
636        }
637    }
638
639    #[test]
640    fn test_precision_comparison() {
641        let test_values = vec![
642            0.0,
643            1.0,
644            -1.0,
645            0.5,
646            -0.5,
647            std::f32::consts::PI,
648            std::f32::consts::E,
649            std::f32::consts::SQRT_2,
650            1.73205,
651            0.1,
652            0.01,
653            0.001,
654            0.0001,
655        ];
656
657        for &val in &test_values {
658            let f16_val = F16::from_f32(val);
659            let bf16_val = BF16::from_f32(val);
660
661            let f16_error = (val - f16_val.to_f32()).abs();
662            let bf16_error = (val - bf16_val.to_f32()).abs();
663
664            // Both should be reasonably close
665            assert!(f16_error < 0.01 || val.abs() < 0.01);
666            assert!(bf16_error < 0.01 || val.abs() < 0.01);
667        }
668    }
669}