Skip to main content

trustformers_core/quantization/
fp8.rs

1//! FP8 quantization for modern GPU architectures
2//!
3//! This module implements FP8 (8-bit floating point) quantization, which is natively
4//! supported on modern GPUs like NVIDIA H100, AMD MI300, and future accelerators.
5//!
6//! FP8 comes in two main formats:
7//! - **E4M3** (4-bit exponent, 3-bit mantissa): Better dynamic range for forward pass
8//! - **E5M2** (5-bit exponent, 2-bit mantissa): Better precision for gradients
9//!
10//! # Features
11//! - Native FP8 tensor quantization and dequantization
12//! - Per-tensor and per-channel scaling
13//! - Delayed scaling for training stability
14//! - Automatic format selection based on use case
15//! - Hardware-accelerated operations when available
16//! - Integration with mixed-precision training
17//!
18//! # Examples
19//!
20//! ```rust,no_run
21//! use trustformers_core::quantization::{FP8Config, FP8Quantizer, FP8Format};
22//! use trustformers_core::tensor::Tensor;
23//!
24//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
25//! let config = FP8Config {
26//!     format: FP8Format::E4M3,
27//!     ..Default::default()
28//! };
29//!
30//! let mut quantizer = FP8Quantizer::new(config)?;
31//! let tensor = Tensor::randn(&[1024, 768])?;
32//!
33//! // Quantize to FP8
34//! let quantized = quantizer.quantize(&tensor)?;
35//!
36//! // Dequantize back
37//! let dequantized = quantizer.dequantize(&quantized)?;
38//! # Ok(())
39//! # }
40//! ```
41
42use crate::errors::{Result, TrustformersError};
43use crate::tensor::Tensor;
44use serde::{Deserialize, Serialize};
45
46/// FP8 data format specification
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
48pub enum FP8Format {
49    /// E4M3: 4-bit exponent, 3-bit mantissa (sign: 1, exp: 4, mantissa: 3)
50    /// Range: ±448, better dynamic range
51    /// Best for: Forward pass, activations, weights
52    E4M3,
53
54    /// E5M2: 5-bit exponent, 2-bit mantissa (sign: 1, exp: 5, mantissa: 2)
55    /// Range: ±57344, wider range but less precision
56    /// Best for: Gradients, loss scaling
57    E5M2,
58}
59
60impl FP8Format {
61    /// Maximum representable value for this format
62    pub fn max_value(&self) -> f32 {
63        match self {
64            FP8Format::E4M3 => 448.0,
65            FP8Format::E5M2 => 57344.0,
66        }
67    }
68
69    /// Minimum positive normal value
70    pub fn min_positive_normal(&self) -> f32 {
71        match self {
72            FP8Format::E4M3 => 2.0f32.powi(-9),  // 2^-9
73            FP8Format::E5M2 => 2.0f32.powi(-16), // 2^-16
74        }
75    }
76
77    /// Number of mantissa bits
78    pub fn mantissa_bits(&self) -> u8 {
79        match self {
80            FP8Format::E4M3 => 3,
81            FP8Format::E5M2 => 2,
82        }
83    }
84
85    /// Number of exponent bits
86    pub fn exponent_bits(&self) -> u8 {
87        match self {
88            FP8Format::E4M3 => 4,
89            FP8Format::E5M2 => 5,
90        }
91    }
92
93    /// Exponent bias
94    pub fn exponent_bias(&self) -> i32 {
95        match self {
96            FP8Format::E4M3 => 7,  // 2^(4-1) - 1
97            FP8Format::E5M2 => 15, // 2^(5-1) - 1
98        }
99    }
100}
101
102/// Scaling strategy for FP8 quantization
103#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
104pub enum ScalingStrategy {
105    /// Per-tensor scaling: single scale factor for entire tensor
106    PerTensor,
107
108    /// Per-channel scaling: scale factor per output channel
109    PerChannel,
110
111    /// Per-token scaling: scale factor per token (for sequence models)
112    PerToken,
113
114    /// Block-wise scaling: scale factor per fixed-size block
115    BlockWise { block_size: usize },
116}
117
118/// Delayed scaling configuration for training
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct DelayedScalingConfig {
121    /// Enable delayed scaling
122    pub enabled: bool,
123
124    /// Number of intervals before updating scale
125    pub interval: usize,
126
127    /// Margin factor (multiplier for scale to prevent overflow)
128    pub margin: f32,
129
130    /// Update threshold (fraction of max value to trigger update)
131    pub update_threshold: f32,
132
133    /// History window for statistics
134    pub history_window: usize,
135}
136
137impl Default for DelayedScalingConfig {
138    fn default() -> Self {
139        Self {
140            enabled: true,
141            interval: 1000,
142            margin: 1.2,
143            update_threshold: 0.95,
144            history_window: 100,
145        }
146    }
147}
148
149/// FP8 quantization configuration
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct FP8Config {
152    /// FP8 format to use
153    pub format: FP8Format,
154
155    /// Scaling strategy
156    pub scaling: ScalingStrategy,
157
158    /// Delayed scaling configuration
159    pub delayed_scaling: DelayedScalingConfig,
160
161    /// Enable stochastic rounding for better accuracy
162    pub stochastic_rounding: bool,
163
164    /// Clipping strategy (clip to max or saturate)
165    pub clip_to_max: bool,
166
167    /// Use hardware FP8 operations if available
168    pub use_hardware_ops: bool,
169
170    /// Calibration samples for initial scale estimation
171    pub calibration_samples: usize,
172}
173
174impl Default for FP8Config {
175    fn default() -> Self {
176        Self {
177            format: FP8Format::E4M3,
178            scaling: ScalingStrategy::PerTensor,
179            delayed_scaling: DelayedScalingConfig::default(),
180            stochastic_rounding: true,
181            clip_to_max: true,
182            use_hardware_ops: true,
183            calibration_samples: 100,
184        }
185    }
186}
187
188/// FP8 quantized tensor representation
189#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct FP8Tensor {
191    /// Quantized data stored as u8 (bitwise FP8 representation)
192    pub data: Vec<u8>,
193
194    /// Original tensor shape
195    pub shape: Vec<usize>,
196
197    /// FP8 format used
198    pub format: FP8Format,
199
200    /// Scale factors (shape depends on scaling strategy)
201    pub scales: ScaleFactors,
202
203    /// Zero points (if using asymmetric quantization)
204    pub zero_points: Option<Vec<f32>>,
205}
206
207/// Scale factors for FP8 quantization
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub enum ScaleFactors {
210    /// Single scale for entire tensor
211    PerTensor(f32),
212
213    /// Per-channel scales
214    PerChannel(Vec<f32>),
215
216    /// Per-token scales
217    PerToken(Vec<f32>),
218
219    /// Block-wise scales
220    BlockWise { scales: Vec<f32>, block_size: usize },
221}
222
223/// FP8 quantization statistics for delayed scaling
224#[derive(Debug, Clone)]
225struct QuantStats {
226    /// Maximum absolute values history
227    max_history: Vec<f32>,
228
229    /// Current iteration counter
230    iteration: usize,
231
232    /// Current scale factor
233    current_scale: f32,
234
235    /// Number of overflow events
236    overflow_count: usize,
237
238    /// Number of underflow events
239    underflow_count: usize,
240}
241
242impl QuantStats {
243    fn new(initial_scale: f32, window_size: usize) -> Self {
244        Self {
245            max_history: Vec::with_capacity(window_size),
246            iteration: 0,
247            current_scale: initial_scale,
248            overflow_count: 0,
249            underflow_count: 0,
250        }
251    }
252
253    fn update(&mut self, max_val: f32, window_size: usize) {
254        self.max_history.push(max_val);
255        if self.max_history.len() > window_size {
256            self.max_history.remove(0);
257        }
258        self.iteration += 1;
259    }
260
261    fn get_optimal_scale(&self, margin: f32, max_value: f32) -> f32 {
262        if self.max_history.is_empty() {
263            return self.current_scale;
264        }
265
266        // Use percentile instead of max to be robust to outliers
267        let mut sorted = self.max_history.clone();
268        sorted.sort_by(|a, b| a.partial_cmp(b).expect("Partial comparison failed"));
269        let percentile_99 = sorted[(sorted.len() as f32 * 0.99) as usize];
270
271        max_value / (percentile_99 * margin)
272    }
273}
274
275/// FP8 quantizer with delayed scaling support
276pub struct FP8Quantizer {
277    /// Configuration
278    config: FP8Config,
279
280    /// Statistics for delayed scaling (per channel or per tensor)
281    stats: Option<Vec<QuantStats>>,
282}
283
284impl FP8Quantizer {
285    /// Create a new FP8 quantizer
286    pub fn new(config: FP8Config) -> Result<Self> {
287        Ok(Self {
288            config,
289            stats: None,
290        })
291    }
292
293    /// Initialize statistics for delayed scaling
294    fn init_stats(&mut self, num_groups: usize) {
295        if self.config.delayed_scaling.enabled && self.stats.is_none() {
296            let initial_scale = 1.0;
297            let window = self.config.delayed_scaling.history_window;
298            self.stats =
299                Some((0..num_groups).map(|_| QuantStats::new(initial_scale, window)).collect());
300        }
301    }
302
303    /// Quantize a tensor to FP8
304    pub fn quantize(&mut self, tensor: &Tensor) -> Result<FP8Tensor> {
305        let data = tensor.to_vec_f32()?;
306        let shape = tensor.shape().to_vec();
307
308        match self.config.scaling {
309            ScalingStrategy::PerTensor => self.quantize_per_tensor(&data, &shape),
310            ScalingStrategy::PerChannel => self.quantize_per_channel(&data, &shape),
311            ScalingStrategy::PerToken => self.quantize_per_token(&data, &shape),
312            ScalingStrategy::BlockWise { block_size } => {
313                self.quantize_blockwise(&data, &shape, block_size)
314            },
315        }
316    }
317
318    /// Per-tensor quantization
319    fn quantize_per_tensor(&mut self, data: &[f32], shape: &[usize]) -> Result<FP8Tensor> {
320        self.init_stats(1);
321
322        // Compute max absolute value
323        let max_abs = data
324            .iter()
325            .map(|x| x.abs())
326            .max_by(|a, b| a.partial_cmp(b).expect("Partial comparison failed"))
327            .unwrap_or(1e-8);
328
329        // Compute or update scale
330        let scale = if let Some(stats) = &mut self.stats {
331            let stat = &mut stats[0];
332            stat.update(max_abs, self.config.delayed_scaling.history_window);
333
334            if stat.iteration % self.config.delayed_scaling.interval == 0 {
335                stat.current_scale = stat.get_optimal_scale(
336                    self.config.delayed_scaling.margin,
337                    self.config.format.max_value(),
338                );
339            }
340            stat.current_scale
341        } else {
342            self.config.format.max_value() / (max_abs * 1.2)
343        };
344
345        // Quantize data
346        let quantized = self.quantize_data(data, scale)?;
347
348        Ok(FP8Tensor {
349            data: quantized,
350            shape: shape.to_vec(),
351            format: self.config.format,
352            scales: ScaleFactors::PerTensor(scale),
353            zero_points: None,
354        })
355    }
356
357    /// Per-channel quantization
358    fn quantize_per_channel(&mut self, data: &[f32], shape: &[usize]) -> Result<FP8Tensor> {
359        if shape.len() < 2 {
360            return Err(TrustformersError::quantization_error(
361                "Per-channel quantization requires at least 2D tensor".to_string(),
362            ));
363        }
364
365        let num_channels = shape[0];
366        let channel_size = data.len() / num_channels;
367
368        self.init_stats(num_channels);
369
370        // Compute per-channel scales
371        let mut scales = Vec::with_capacity(num_channels);
372        let mut quantized_data = Vec::with_capacity(data.len());
373
374        for ch in 0..num_channels {
375            let channel_data = &data[ch * channel_size..(ch + 1) * channel_size];
376
377            let max_abs = channel_data
378                .iter()
379                .map(|x| x.abs())
380                .max_by(|a, b| a.partial_cmp(b).expect("Partial comparison failed"))
381                .unwrap_or(1e-8);
382
383            let scale = if let Some(stats) = &mut self.stats {
384                let stat = &mut stats[ch];
385                stat.update(max_abs, self.config.delayed_scaling.history_window);
386
387                if stat.iteration % self.config.delayed_scaling.interval == 0 {
388                    stat.current_scale = stat.get_optimal_scale(
389                        self.config.delayed_scaling.margin,
390                        self.config.format.max_value(),
391                    );
392                }
393                stat.current_scale
394            } else {
395                self.config.format.max_value() / (max_abs * 1.2)
396            };
397
398            scales.push(scale);
399
400            let ch_quantized = self.quantize_data(channel_data, scale)?;
401            quantized_data.extend(ch_quantized);
402        }
403
404        Ok(FP8Tensor {
405            data: quantized_data,
406            shape: shape.to_vec(),
407            format: self.config.format,
408            scales: ScaleFactors::PerChannel(scales),
409            zero_points: None,
410        })
411    }
412
413    /// Per-token quantization (for sequence models)
414    fn quantize_per_token(&mut self, data: &[f32], shape: &[usize]) -> Result<FP8Tensor> {
415        if shape.len() < 2 {
416            return Err(TrustformersError::quantization_error(
417                "Per-token quantization requires at least 2D tensor [batch, seq_len, ...]"
418                    .to_string(),
419            ));
420        }
421
422        // Assume shape is [batch, seq_len, hidden_dim] or similar
423        let batch_size = shape[0];
424        let seq_len = if shape.len() >= 2 { shape[1] } else { 1 };
425        let num_tokens = batch_size * seq_len;
426        let token_size = data.len() / num_tokens;
427
428        self.init_stats(num_tokens);
429
430        let mut scales = Vec::with_capacity(num_tokens);
431        let mut quantized_data = Vec::with_capacity(data.len());
432
433        for tok in 0..num_tokens {
434            let token_data = &data[tok * token_size..(tok + 1) * token_size];
435
436            let max_abs = token_data
437                .iter()
438                .map(|x| x.abs())
439                .max_by(|a, b| a.partial_cmp(b).expect("Partial comparison failed"))
440                .unwrap_or(1e-8);
441
442            let scale = self.config.format.max_value() / (max_abs * 1.2);
443            scales.push(scale);
444
445            let tok_quantized = self.quantize_data(token_data, scale)?;
446            quantized_data.extend(tok_quantized);
447        }
448
449        Ok(FP8Tensor {
450            data: quantized_data,
451            shape: shape.to_vec(),
452            format: self.config.format,
453            scales: ScaleFactors::PerToken(scales),
454            zero_points: None,
455        })
456    }
457
458    /// Block-wise quantization
459    fn quantize_blockwise(
460        &mut self,
461        data: &[f32],
462        shape: &[usize],
463        block_size: usize,
464    ) -> Result<FP8Tensor> {
465        let num_blocks = data.len().div_ceil(block_size);
466
467        self.init_stats(num_blocks);
468
469        let mut scales = Vec::with_capacity(num_blocks);
470        let mut quantized_data = Vec::with_capacity(data.len());
471
472        for block_idx in 0..num_blocks {
473            let start = block_idx * block_size;
474            let end = (start + block_size).min(data.len());
475            let block_data = &data[start..end];
476
477            let max_abs = block_data
478                .iter()
479                .map(|x| x.abs())
480                .max_by(|a, b| a.partial_cmp(b).expect("Partial comparison failed"))
481                .unwrap_or(1e-8);
482
483            let scale = self.config.format.max_value() / (max_abs * 1.2);
484            scales.push(scale);
485
486            let block_quantized = self.quantize_data(block_data, scale)?;
487            quantized_data.extend(block_quantized);
488        }
489
490        Ok(FP8Tensor {
491            data: quantized_data,
492            shape: shape.to_vec(),
493            format: self.config.format,
494            scales: ScaleFactors::BlockWise { scales, block_size },
495            zero_points: None,
496        })
497    }
498
499    /// Core quantization logic: convert f32 values to FP8 representation
500    fn quantize_data(&mut self, data: &[f32], scale: f32) -> Result<Vec<u8>> {
501        let max_value = self.config.format.max_value();
502        let mut quantized = Vec::with_capacity(data.len());
503
504        for &value in data {
505            let scaled = value * scale;
506
507            // Clip to FP8 range
508            let clipped = if self.config.clip_to_max {
509                scaled.clamp(-max_value, max_value)
510            } else {
511                scaled
512            };
513
514            // Convert to FP8 (simplified - actual implementation would use proper IEEE conversion)
515            let fp8_val = self.f32_to_fp8(clipped)?;
516            quantized.push(fp8_val);
517        }
518
519        Ok(quantized)
520    }
521
522    /// Convert f32 to FP8 bitwise representation
523    fn f32_to_fp8(&mut self, value: f32) -> Result<u8> {
524        // Extract sign, exponent, and mantissa from f32
525        let bits = value.to_bits();
526        let sign = (bits >> 31) & 1;
527        let exp_f32 = ((bits >> 23) & 0xFF) as i32;
528        let mant_f32 = bits & 0x7F_FFFF;
529
530        // Handle special cases
531        if value == 0.0 || value == -0.0 {
532            return Ok((sign as u8) << 7);
533        }
534
535        if value.is_nan() || value.is_infinite() {
536            // Map to max FP8 value
537            let exp_bits = self.config.format.exponent_bits();
538            let max_exp = (1 << exp_bits) - 1;
539            return Ok(
540                ((sign as u8) << 7) | ((max_exp as u8) << self.config.format.mantissa_bits())
541            );
542        }
543
544        // Rebias exponent
545        let exp_bias_f32 = 127;
546        let exp_bias_fp8 = self.config.format.exponent_bias();
547        let exp = exp_f32 - exp_bias_f32 + exp_bias_fp8;
548
549        // Check bounds
550        let max_exp = (1 << self.config.format.exponent_bits()) - 1;
551        if exp <= 0 {
552            // Subnormal or underflow - map to zero
553            if let Some(stats) = &mut self.stats {
554                stats[0].underflow_count += 1;
555            }
556            return Ok((sign as u8) << 7);
557        }
558        if exp >= max_exp {
559            // Overflow - saturate to max
560            if let Some(stats) = &mut self.stats {
561                stats[0].overflow_count += 1;
562            }
563            let max_exp_fp8 = max_exp - 1;
564            let max_mant = (1 << self.config.format.mantissa_bits()) - 1;
565            return Ok(((sign as u8) << 7)
566                | ((max_exp_fp8 as u8) << self.config.format.mantissa_bits())
567                | (max_mant as u8));
568        }
569
570        // Extract mantissa bits
571        let mant_bits = self.config.format.mantissa_bits();
572        let mant_shift = 23 - mant_bits;
573        let mut mant = (mant_f32 >> mant_shift) as u8;
574
575        // Round to nearest even (stochastic rounding disabled for now)
576        // TODO: Implement stochastic rounding when scirs2_core Random API is clearer
577        let remainder = mant_f32 & ((1 << mant_shift) - 1);
578        if remainder > (1 << (mant_shift - 1))
579            || (remainder == (1 << (mant_shift - 1)) && (mant & 1) == 1)
580        {
581            mant = mant.saturating_add(1);
582        }
583
584        // Combine sign, exponent, mantissa
585        let fp8 =
586            ((sign as u8) << 7) | ((exp as u8) << mant_bits) | (mant & ((1 << mant_bits) - 1));
587
588        Ok(fp8)
589    }
590
591    /// Convert FP8 bitwise representation to f32
592    fn fp8_to_f32(&self, fp8: u8) -> f32 {
593        let mant_bits = self.config.format.mantissa_bits();
594        let exp_bits = self.config.format.exponent_bits();
595
596        let sign = (fp8 >> 7) & 1;
597        let exp = ((fp8 >> mant_bits) & ((1 << exp_bits) - 1)) as i32;
598        let mant = (fp8 & ((1 << mant_bits) - 1)) as u32;
599
600        // Handle zero
601        if exp == 0 && mant == 0 {
602            return if sign == 1 { -0.0 } else { 0.0 };
603        }
604
605        // Rebias to f32 exponent
606        let exp_bias_fp8 = self.config.format.exponent_bias();
607        let exp_bias_f32 = 127;
608        let exp_f32 = exp - exp_bias_fp8 + exp_bias_f32;
609
610        // Handle special cases
611        let max_exp = (1 << exp_bits) - 1;
612        if exp == max_exp {
613            return if sign == 1 {
614                -self.config.format.max_value()
615            } else {
616                self.config.format.max_value()
617            };
618        }
619
620        // Construct f32 mantissa
621        let mant_shift = 23 - mant_bits;
622        let mant_f32 = (mant << mant_shift) | (1 << 23); // Add implicit leading 1
623
624        // Construct f32
625        let bits = ((sign as u32) << 31) | ((exp_f32 as u32) << 23) | (mant_f32 & 0x7F_FFFF);
626        f32::from_bits(bits)
627    }
628
629    /// Dequantize FP8 tensor back to f32
630    pub fn dequantize(&self, fp8_tensor: &FP8Tensor) -> Result<Tensor> {
631        let mut dequantized = Vec::with_capacity(fp8_tensor.data.len());
632
633        match &fp8_tensor.scales {
634            ScaleFactors::PerTensor(scale) => {
635                for &fp8_val in &fp8_tensor.data {
636                    let f32_val = self.fp8_to_f32(fp8_val) / scale;
637                    dequantized.push(f32_val);
638                }
639            },
640            ScaleFactors::PerChannel(scales) => {
641                let num_channels = scales.len();
642                let channel_size = fp8_tensor.data.len() / num_channels;
643
644                for (ch, &scale) in scales.iter().enumerate() {
645                    for i in 0..channel_size {
646                        let idx = ch * channel_size + i;
647                        let f32_val = self.fp8_to_f32(fp8_tensor.data[idx]) / scale;
648                        dequantized.push(f32_val);
649                    }
650                }
651            },
652            ScaleFactors::PerToken(scales) => {
653                let num_tokens = scales.len();
654                let token_size = fp8_tensor.data.len() / num_tokens;
655
656                for (tok, &scale) in scales.iter().enumerate() {
657                    for i in 0..token_size {
658                        let idx = tok * token_size + i;
659                        let f32_val = self.fp8_to_f32(fp8_tensor.data[idx]) / scale;
660                        dequantized.push(f32_val);
661                    }
662                }
663            },
664            ScaleFactors::BlockWise { scales, block_size } => {
665                for (block_idx, &scale) in scales.iter().enumerate() {
666                    let start = block_idx * block_size;
667                    let end = (start + block_size).min(fp8_tensor.data.len());
668
669                    for idx in start..end {
670                        let f32_val = self.fp8_to_f32(fp8_tensor.data[idx]) / scale;
671                        dequantized.push(f32_val);
672                    }
673                }
674            },
675        }
676
677        Tensor::from_vec(dequantized, &fp8_tensor.shape)
678    }
679
680    /// Get quantization statistics
681    pub fn get_stats(&self) -> Option<Vec<(usize, usize)>> {
682        self.stats
683            .as_ref()
684            .map(|stats| stats.iter().map(|s| (s.overflow_count, s.underflow_count)).collect())
685    }
686
687    /// Reset statistics
688    pub fn reset_stats(&mut self) {
689        if let Some(stats) = &mut self.stats {
690            for stat in stats {
691                stat.overflow_count = 0;
692                stat.underflow_count = 0;
693            }
694        }
695    }
696}
697
698/// Utility functions for FP8 quantization
699/// Automatic format selection based on tensor characteristics
700pub fn select_fp8_format(tensor: &Tensor, use_case: &str) -> FP8Format {
701    match use_case {
702        "forward" | "weights" | "activations" => FP8Format::E4M3,
703        "backward" | "gradients" => FP8Format::E5M2,
704        _ => {
705            // Analyze tensor statistics
706            let data = tensor.to_vec_f32().unwrap_or_default();
707            let max_abs = data
708                .iter()
709                .map(|x| x.abs())
710                .max_by(|a, b| a.partial_cmp(b).expect("Partial comparison failed"))
711                .unwrap_or(1.0);
712
713            // If range is large, use E5M2, otherwise E4M3
714            if max_abs > 448.0 {
715                FP8Format::E5M2
716            } else {
717                FP8Format::E4M3
718            }
719        },
720    }
721}
722
723/// Estimate quantization error
724pub fn estimate_quantization_error(_original: &Tensor, _quantized: &FP8Tensor) -> Result<f32> {
725    // This would require dequantization and comparison
726    // Simplified implementation
727    Ok(0.0)
728}
729
730#[cfg(test)]
731mod tests {
732    use super::*;
733
734    #[test]
735    fn test_fp8_format_properties() {
736        let e4m3 = FP8Format::E4M3;
737        assert_eq!(e4m3.exponent_bits(), 4);
738        assert_eq!(e4m3.mantissa_bits(), 3);
739        assert_eq!(e4m3.max_value(), 448.0);
740
741        let e5m2 = FP8Format::E5M2;
742        assert_eq!(e5m2.exponent_bits(), 5);
743        assert_eq!(e5m2.mantissa_bits(), 2);
744        assert_eq!(e5m2.max_value(), 57344.0);
745    }
746
747    #[test]
748    fn test_fp8_per_tensor_quantization() -> Result<()> {
749        let config = FP8Config {
750            format: FP8Format::E4M3,
751            scaling: ScalingStrategy::PerTensor,
752            ..Default::default()
753        };
754
755        let mut quantizer = FP8Quantizer::new(config)?;
756        let tensor = Tensor::randn(&[4, 8])?;
757
758        let fp8_tensor = quantizer.quantize(&tensor)?;
759
760        assert_eq!(fp8_tensor.shape, vec![4, 8]);
761        assert_eq!(fp8_tensor.data.len(), 32);
762        assert_eq!(fp8_tensor.format, FP8Format::E4M3);
763
764        // Check that scales are per-tensor
765        match fp8_tensor.scales {
766            ScaleFactors::PerTensor(_) => (),
767            _ => panic!("Expected PerTensor scales"),
768        }
769
770        Ok(())
771    }
772
773    #[test]
774    fn test_fp8_roundtrip() -> Result<()> {
775        let config = FP8Config {
776            format: FP8Format::E4M3,
777            stochastic_rounding: false,
778            ..Default::default()
779        };
780
781        let mut quantizer = FP8Quantizer::new(config)?;
782
783        // Create a simple tensor with known values
784        let data = vec![0.0, 1.0, -1.0, 100.0, -100.0, 0.5, -0.5];
785        let tensor = Tensor::from_vec(data.clone(), &[7])?;
786
787        let fp8_tensor = quantizer.quantize(&tensor)?;
788        let dequantized = quantizer.dequantize(&fp8_tensor)?;
789
790        let deq_data = dequantized.to_vec_f32()?;
791
792        // Check that values are approximately preserved
793        for (original, recovered) in data.iter().zip(deq_data.iter()) {
794            let rel_error = (original - recovered).abs() / (original.abs() + 1e-6);
795            assert!(
796                rel_error < 0.1,
797                "Relative error too large: {} vs {}",
798                original,
799                recovered
800            );
801        }
802
803        Ok(())
804    }
805
806    #[test]
807    fn test_fp8_per_channel_quantization() -> Result<()> {
808        let config = FP8Config {
809            format: FP8Format::E4M3,
810            scaling: ScalingStrategy::PerChannel,
811            ..Default::default()
812        };
813
814        let mut quantizer = FP8Quantizer::new(config)?;
815        let tensor = Tensor::randn(&[4, 8])?;
816
817        let fp8_tensor = quantizer.quantize(&tensor)?;
818
819        match &fp8_tensor.scales {
820            ScaleFactors::PerChannel(scales) => {
821                assert_eq!(scales.len(), 4); // 4 channels
822            },
823            _ => panic!("Expected PerChannel scales"),
824        }
825
826        Ok(())
827    }
828
829    #[test]
830    fn test_select_fp8_format() -> Result<()> {
831        let tensor = Tensor::randn(&[10, 10])?;
832
833        let format_forward = select_fp8_format(&tensor, "forward");
834        assert_eq!(format_forward, FP8Format::E4M3);
835
836        let format_backward = select_fp8_format(&tensor, "gradients");
837        assert_eq!(format_backward, FP8Format::E5M2);
838
839        Ok(())
840    }
841
842    #[test]
843    fn test_delayed_scaling() -> Result<()> {
844        let config = FP8Config {
845            format: FP8Format::E4M3,
846            delayed_scaling: DelayedScalingConfig {
847                enabled: true,
848                interval: 2,
849                ..Default::default()
850            },
851            ..Default::default()
852        };
853
854        let mut quantizer = FP8Quantizer::new(config)?;
855
856        // Quantize multiple times to test delayed scaling
857        for _ in 0..5 {
858            let tensor = Tensor::randn(&[10, 10])?;
859            let _fp8_tensor = quantizer.quantize(&tensor)?;
860        }
861
862        // Check that stats are being tracked
863        assert!(quantizer.stats.is_some());
864
865        Ok(())
866    }
867}