1use crate::errors::{Result, TrustformersError};
43use crate::tensor::Tensor;
44use serde::{Deserialize, Serialize};
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
48pub enum FP8Format {
49 E4M3,
53
54 E5M2,
58}
59
60impl FP8Format {
61 pub fn max_value(&self) -> f32 {
63 match self {
64 FP8Format::E4M3 => 448.0,
65 FP8Format::E5M2 => 57344.0,
66 }
67 }
68
69 pub fn min_positive_normal(&self) -> f32 {
71 match self {
72 FP8Format::E4M3 => 2.0f32.powi(-9), FP8Format::E5M2 => 2.0f32.powi(-16), }
75 }
76
77 pub fn mantissa_bits(&self) -> u8 {
79 match self {
80 FP8Format::E4M3 => 3,
81 FP8Format::E5M2 => 2,
82 }
83 }
84
85 pub fn exponent_bits(&self) -> u8 {
87 match self {
88 FP8Format::E4M3 => 4,
89 FP8Format::E5M2 => 5,
90 }
91 }
92
93 pub fn exponent_bias(&self) -> i32 {
95 match self {
96 FP8Format::E4M3 => 7, FP8Format::E5M2 => 15, }
99 }
100}
101
102#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
104pub enum ScalingStrategy {
105 PerTensor,
107
108 PerChannel,
110
111 PerToken,
113
114 BlockWise { block_size: usize },
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct DelayedScalingConfig {
121 pub enabled: bool,
123
124 pub interval: usize,
126
127 pub margin: f32,
129
130 pub update_threshold: f32,
132
133 pub history_window: usize,
135}
136
137impl Default for DelayedScalingConfig {
138 fn default() -> Self {
139 Self {
140 enabled: true,
141 interval: 1000,
142 margin: 1.2,
143 update_threshold: 0.95,
144 history_window: 100,
145 }
146 }
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct FP8Config {
152 pub format: FP8Format,
154
155 pub scaling: ScalingStrategy,
157
158 pub delayed_scaling: DelayedScalingConfig,
160
161 pub stochastic_rounding: bool,
163
164 pub clip_to_max: bool,
166
167 pub use_hardware_ops: bool,
169
170 pub calibration_samples: usize,
172}
173
174impl Default for FP8Config {
175 fn default() -> Self {
176 Self {
177 format: FP8Format::E4M3,
178 scaling: ScalingStrategy::PerTensor,
179 delayed_scaling: DelayedScalingConfig::default(),
180 stochastic_rounding: true,
181 clip_to_max: true,
182 use_hardware_ops: true,
183 calibration_samples: 100,
184 }
185 }
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct FP8Tensor {
191 pub data: Vec<u8>,
193
194 pub shape: Vec<usize>,
196
197 pub format: FP8Format,
199
200 pub scales: ScaleFactors,
202
203 pub zero_points: Option<Vec<f32>>,
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize)]
209pub enum ScaleFactors {
210 PerTensor(f32),
212
213 PerChannel(Vec<f32>),
215
216 PerToken(Vec<f32>),
218
219 BlockWise { scales: Vec<f32>, block_size: usize },
221}
222
223#[derive(Debug, Clone)]
225struct QuantStats {
226 max_history: Vec<f32>,
228
229 iteration: usize,
231
232 current_scale: f32,
234
235 overflow_count: usize,
237
238 underflow_count: usize,
240}
241
242impl QuantStats {
243 fn new(initial_scale: f32, window_size: usize) -> Self {
244 Self {
245 max_history: Vec::with_capacity(window_size),
246 iteration: 0,
247 current_scale: initial_scale,
248 overflow_count: 0,
249 underflow_count: 0,
250 }
251 }
252
253 fn update(&mut self, max_val: f32, window_size: usize) {
254 self.max_history.push(max_val);
255 if self.max_history.len() > window_size {
256 self.max_history.remove(0);
257 }
258 self.iteration += 1;
259 }
260
261 fn get_optimal_scale(&self, margin: f32, max_value: f32) -> f32 {
262 if self.max_history.is_empty() {
263 return self.current_scale;
264 }
265
266 let mut sorted = self.max_history.clone();
268 sorted.sort_by(|a, b| a.partial_cmp(b).expect("Partial comparison failed"));
269 let percentile_99 = sorted[(sorted.len() as f32 * 0.99) as usize];
270
271 max_value / (percentile_99 * margin)
272 }
273}
274
275pub struct FP8Quantizer {
277 config: FP8Config,
279
280 stats: Option<Vec<QuantStats>>,
282}
283
284impl FP8Quantizer {
285 pub fn new(config: FP8Config) -> Result<Self> {
287 Ok(Self {
288 config,
289 stats: None,
290 })
291 }
292
293 fn init_stats(&mut self, num_groups: usize) {
295 if self.config.delayed_scaling.enabled && self.stats.is_none() {
296 let initial_scale = 1.0;
297 let window = self.config.delayed_scaling.history_window;
298 self.stats =
299 Some((0..num_groups).map(|_| QuantStats::new(initial_scale, window)).collect());
300 }
301 }
302
303 pub fn quantize(&mut self, tensor: &Tensor) -> Result<FP8Tensor> {
305 let data = tensor.to_vec_f32()?;
306 let shape = tensor.shape().to_vec();
307
308 match self.config.scaling {
309 ScalingStrategy::PerTensor => self.quantize_per_tensor(&data, &shape),
310 ScalingStrategy::PerChannel => self.quantize_per_channel(&data, &shape),
311 ScalingStrategy::PerToken => self.quantize_per_token(&data, &shape),
312 ScalingStrategy::BlockWise { block_size } => {
313 self.quantize_blockwise(&data, &shape, block_size)
314 },
315 }
316 }
317
318 fn quantize_per_tensor(&mut self, data: &[f32], shape: &[usize]) -> Result<FP8Tensor> {
320 self.init_stats(1);
321
322 let max_abs = data
324 .iter()
325 .map(|x| x.abs())
326 .max_by(|a, b| a.partial_cmp(b).expect("Partial comparison failed"))
327 .unwrap_or(1e-8);
328
329 let scale = if let Some(stats) = &mut self.stats {
331 let stat = &mut stats[0];
332 stat.update(max_abs, self.config.delayed_scaling.history_window);
333
334 if stat.iteration % self.config.delayed_scaling.interval == 0 {
335 stat.current_scale = stat.get_optimal_scale(
336 self.config.delayed_scaling.margin,
337 self.config.format.max_value(),
338 );
339 }
340 stat.current_scale
341 } else {
342 self.config.format.max_value() / (max_abs * 1.2)
343 };
344
345 let quantized = self.quantize_data(data, scale)?;
347
348 Ok(FP8Tensor {
349 data: quantized,
350 shape: shape.to_vec(),
351 format: self.config.format,
352 scales: ScaleFactors::PerTensor(scale),
353 zero_points: None,
354 })
355 }
356
357 fn quantize_per_channel(&mut self, data: &[f32], shape: &[usize]) -> Result<FP8Tensor> {
359 if shape.len() < 2 {
360 return Err(TrustformersError::quantization_error(
361 "Per-channel quantization requires at least 2D tensor".to_string(),
362 ));
363 }
364
365 let num_channels = shape[0];
366 let channel_size = data.len() / num_channels;
367
368 self.init_stats(num_channels);
369
370 let mut scales = Vec::with_capacity(num_channels);
372 let mut quantized_data = Vec::with_capacity(data.len());
373
374 for ch in 0..num_channels {
375 let channel_data = &data[ch * channel_size..(ch + 1) * channel_size];
376
377 let max_abs = channel_data
378 .iter()
379 .map(|x| x.abs())
380 .max_by(|a, b| a.partial_cmp(b).expect("Partial comparison failed"))
381 .unwrap_or(1e-8);
382
383 let scale = if let Some(stats) = &mut self.stats {
384 let stat = &mut stats[ch];
385 stat.update(max_abs, self.config.delayed_scaling.history_window);
386
387 if stat.iteration % self.config.delayed_scaling.interval == 0 {
388 stat.current_scale = stat.get_optimal_scale(
389 self.config.delayed_scaling.margin,
390 self.config.format.max_value(),
391 );
392 }
393 stat.current_scale
394 } else {
395 self.config.format.max_value() / (max_abs * 1.2)
396 };
397
398 scales.push(scale);
399
400 let ch_quantized = self.quantize_data(channel_data, scale)?;
401 quantized_data.extend(ch_quantized);
402 }
403
404 Ok(FP8Tensor {
405 data: quantized_data,
406 shape: shape.to_vec(),
407 format: self.config.format,
408 scales: ScaleFactors::PerChannel(scales),
409 zero_points: None,
410 })
411 }
412
413 fn quantize_per_token(&mut self, data: &[f32], shape: &[usize]) -> Result<FP8Tensor> {
415 if shape.len() < 2 {
416 return Err(TrustformersError::quantization_error(
417 "Per-token quantization requires at least 2D tensor [batch, seq_len, ...]"
418 .to_string(),
419 ));
420 }
421
422 let batch_size = shape[0];
424 let seq_len = if shape.len() >= 2 { shape[1] } else { 1 };
425 let num_tokens = batch_size * seq_len;
426 let token_size = data.len() / num_tokens;
427
428 self.init_stats(num_tokens);
429
430 let mut scales = Vec::with_capacity(num_tokens);
431 let mut quantized_data = Vec::with_capacity(data.len());
432
433 for tok in 0..num_tokens {
434 let token_data = &data[tok * token_size..(tok + 1) * token_size];
435
436 let max_abs = token_data
437 .iter()
438 .map(|x| x.abs())
439 .max_by(|a, b| a.partial_cmp(b).expect("Partial comparison failed"))
440 .unwrap_or(1e-8);
441
442 let scale = self.config.format.max_value() / (max_abs * 1.2);
443 scales.push(scale);
444
445 let tok_quantized = self.quantize_data(token_data, scale)?;
446 quantized_data.extend(tok_quantized);
447 }
448
449 Ok(FP8Tensor {
450 data: quantized_data,
451 shape: shape.to_vec(),
452 format: self.config.format,
453 scales: ScaleFactors::PerToken(scales),
454 zero_points: None,
455 })
456 }
457
458 fn quantize_blockwise(
460 &mut self,
461 data: &[f32],
462 shape: &[usize],
463 block_size: usize,
464 ) -> Result<FP8Tensor> {
465 let num_blocks = data.len().div_ceil(block_size);
466
467 self.init_stats(num_blocks);
468
469 let mut scales = Vec::with_capacity(num_blocks);
470 let mut quantized_data = Vec::with_capacity(data.len());
471
472 for block_idx in 0..num_blocks {
473 let start = block_idx * block_size;
474 let end = (start + block_size).min(data.len());
475 let block_data = &data[start..end];
476
477 let max_abs = block_data
478 .iter()
479 .map(|x| x.abs())
480 .max_by(|a, b| a.partial_cmp(b).expect("Partial comparison failed"))
481 .unwrap_or(1e-8);
482
483 let scale = self.config.format.max_value() / (max_abs * 1.2);
484 scales.push(scale);
485
486 let block_quantized = self.quantize_data(block_data, scale)?;
487 quantized_data.extend(block_quantized);
488 }
489
490 Ok(FP8Tensor {
491 data: quantized_data,
492 shape: shape.to_vec(),
493 format: self.config.format,
494 scales: ScaleFactors::BlockWise { scales, block_size },
495 zero_points: None,
496 })
497 }
498
499 fn quantize_data(&mut self, data: &[f32], scale: f32) -> Result<Vec<u8>> {
501 let max_value = self.config.format.max_value();
502 let mut quantized = Vec::with_capacity(data.len());
503
504 for &value in data {
505 let scaled = value * scale;
506
507 let clipped = if self.config.clip_to_max {
509 scaled.clamp(-max_value, max_value)
510 } else {
511 scaled
512 };
513
514 let fp8_val = self.f32_to_fp8(clipped)?;
516 quantized.push(fp8_val);
517 }
518
519 Ok(quantized)
520 }
521
522 fn f32_to_fp8(&mut self, value: f32) -> Result<u8> {
524 let bits = value.to_bits();
526 let sign = (bits >> 31) & 1;
527 let exp_f32 = ((bits >> 23) & 0xFF) as i32;
528 let mant_f32 = bits & 0x7F_FFFF;
529
530 if value == 0.0 || value == -0.0 {
532 return Ok((sign as u8) << 7);
533 }
534
535 if value.is_nan() || value.is_infinite() {
536 let exp_bits = self.config.format.exponent_bits();
538 let max_exp = (1 << exp_bits) - 1;
539 return Ok(
540 ((sign as u8) << 7) | ((max_exp as u8) << self.config.format.mantissa_bits())
541 );
542 }
543
544 let exp_bias_f32 = 127;
546 let exp_bias_fp8 = self.config.format.exponent_bias();
547 let exp = exp_f32 - exp_bias_f32 + exp_bias_fp8;
548
549 let max_exp = (1 << self.config.format.exponent_bits()) - 1;
551 if exp <= 0 {
552 if let Some(stats) = &mut self.stats {
554 stats[0].underflow_count += 1;
555 }
556 return Ok((sign as u8) << 7);
557 }
558 if exp >= max_exp {
559 if let Some(stats) = &mut self.stats {
561 stats[0].overflow_count += 1;
562 }
563 let max_exp_fp8 = max_exp - 1;
564 let max_mant = (1 << self.config.format.mantissa_bits()) - 1;
565 return Ok(((sign as u8) << 7)
566 | ((max_exp_fp8 as u8) << self.config.format.mantissa_bits())
567 | (max_mant as u8));
568 }
569
570 let mant_bits = self.config.format.mantissa_bits();
572 let mant_shift = 23 - mant_bits;
573 let mut mant = (mant_f32 >> mant_shift) as u8;
574
575 let remainder = mant_f32 & ((1 << mant_shift) - 1);
578 if remainder > (1 << (mant_shift - 1))
579 || (remainder == (1 << (mant_shift - 1)) && (mant & 1) == 1)
580 {
581 mant = mant.saturating_add(1);
582 }
583
584 let fp8 =
586 ((sign as u8) << 7) | ((exp as u8) << mant_bits) | (mant & ((1 << mant_bits) - 1));
587
588 Ok(fp8)
589 }
590
591 fn fp8_to_f32(&self, fp8: u8) -> f32 {
593 let mant_bits = self.config.format.mantissa_bits();
594 let exp_bits = self.config.format.exponent_bits();
595
596 let sign = (fp8 >> 7) & 1;
597 let exp = ((fp8 >> mant_bits) & ((1 << exp_bits) - 1)) as i32;
598 let mant = (fp8 & ((1 << mant_bits) - 1)) as u32;
599
600 if exp == 0 && mant == 0 {
602 return if sign == 1 { -0.0 } else { 0.0 };
603 }
604
605 let exp_bias_fp8 = self.config.format.exponent_bias();
607 let exp_bias_f32 = 127;
608 let exp_f32 = exp - exp_bias_fp8 + exp_bias_f32;
609
610 let max_exp = (1 << exp_bits) - 1;
612 if exp == max_exp {
613 return if sign == 1 {
614 -self.config.format.max_value()
615 } else {
616 self.config.format.max_value()
617 };
618 }
619
620 let mant_shift = 23 - mant_bits;
622 let mant_f32 = (mant << mant_shift) | (1 << 23); let bits = ((sign as u32) << 31) | ((exp_f32 as u32) << 23) | (mant_f32 & 0x7F_FFFF);
626 f32::from_bits(bits)
627 }
628
629 pub fn dequantize(&self, fp8_tensor: &FP8Tensor) -> Result<Tensor> {
631 let mut dequantized = Vec::with_capacity(fp8_tensor.data.len());
632
633 match &fp8_tensor.scales {
634 ScaleFactors::PerTensor(scale) => {
635 for &fp8_val in &fp8_tensor.data {
636 let f32_val = self.fp8_to_f32(fp8_val) / scale;
637 dequantized.push(f32_val);
638 }
639 },
640 ScaleFactors::PerChannel(scales) => {
641 let num_channels = scales.len();
642 let channel_size = fp8_tensor.data.len() / num_channels;
643
644 for (ch, &scale) in scales.iter().enumerate() {
645 for i in 0..channel_size {
646 let idx = ch * channel_size + i;
647 let f32_val = self.fp8_to_f32(fp8_tensor.data[idx]) / scale;
648 dequantized.push(f32_val);
649 }
650 }
651 },
652 ScaleFactors::PerToken(scales) => {
653 let num_tokens = scales.len();
654 let token_size = fp8_tensor.data.len() / num_tokens;
655
656 for (tok, &scale) in scales.iter().enumerate() {
657 for i in 0..token_size {
658 let idx = tok * token_size + i;
659 let f32_val = self.fp8_to_f32(fp8_tensor.data[idx]) / scale;
660 dequantized.push(f32_val);
661 }
662 }
663 },
664 ScaleFactors::BlockWise { scales, block_size } => {
665 for (block_idx, &scale) in scales.iter().enumerate() {
666 let start = block_idx * block_size;
667 let end = (start + block_size).min(fp8_tensor.data.len());
668
669 for idx in start..end {
670 let f32_val = self.fp8_to_f32(fp8_tensor.data[idx]) / scale;
671 dequantized.push(f32_val);
672 }
673 }
674 },
675 }
676
677 Tensor::from_vec(dequantized, &fp8_tensor.shape)
678 }
679
680 pub fn get_stats(&self) -> Option<Vec<(usize, usize)>> {
682 self.stats
683 .as_ref()
684 .map(|stats| stats.iter().map(|s| (s.overflow_count, s.underflow_count)).collect())
685 }
686
687 pub fn reset_stats(&mut self) {
689 if let Some(stats) = &mut self.stats {
690 for stat in stats {
691 stat.overflow_count = 0;
692 stat.underflow_count = 0;
693 }
694 }
695 }
696}
697
698pub fn select_fp8_format(tensor: &Tensor, use_case: &str) -> FP8Format {
701 match use_case {
702 "forward" | "weights" | "activations" => FP8Format::E4M3,
703 "backward" | "gradients" => FP8Format::E5M2,
704 _ => {
705 let data = tensor.to_vec_f32().unwrap_or_default();
707 let max_abs = data
708 .iter()
709 .map(|x| x.abs())
710 .max_by(|a, b| a.partial_cmp(b).expect("Partial comparison failed"))
711 .unwrap_or(1.0);
712
713 if max_abs > 448.0 {
715 FP8Format::E5M2
716 } else {
717 FP8Format::E4M3
718 }
719 },
720 }
721}
722
723pub fn estimate_quantization_error(_original: &Tensor, _quantized: &FP8Tensor) -> Result<f32> {
725 Ok(0.0)
728}
729
730#[cfg(test)]
731mod tests {
732 use super::*;
733
734 #[test]
735 fn test_fp8_format_properties() {
736 let e4m3 = FP8Format::E4M3;
737 assert_eq!(e4m3.exponent_bits(), 4);
738 assert_eq!(e4m3.mantissa_bits(), 3);
739 assert_eq!(e4m3.max_value(), 448.0);
740
741 let e5m2 = FP8Format::E5M2;
742 assert_eq!(e5m2.exponent_bits(), 5);
743 assert_eq!(e5m2.mantissa_bits(), 2);
744 assert_eq!(e5m2.max_value(), 57344.0);
745 }
746
747 #[test]
748 fn test_fp8_per_tensor_quantization() -> Result<()> {
749 let config = FP8Config {
750 format: FP8Format::E4M3,
751 scaling: ScalingStrategy::PerTensor,
752 ..Default::default()
753 };
754
755 let mut quantizer = FP8Quantizer::new(config)?;
756 let tensor = Tensor::randn(&[4, 8])?;
757
758 let fp8_tensor = quantizer.quantize(&tensor)?;
759
760 assert_eq!(fp8_tensor.shape, vec![4, 8]);
761 assert_eq!(fp8_tensor.data.len(), 32);
762 assert_eq!(fp8_tensor.format, FP8Format::E4M3);
763
764 match fp8_tensor.scales {
766 ScaleFactors::PerTensor(_) => (),
767 _ => panic!("Expected PerTensor scales"),
768 }
769
770 Ok(())
771 }
772
773 #[test]
774 fn test_fp8_roundtrip() -> Result<()> {
775 let config = FP8Config {
776 format: FP8Format::E4M3,
777 stochastic_rounding: false,
778 ..Default::default()
779 };
780
781 let mut quantizer = FP8Quantizer::new(config)?;
782
783 let data = vec![0.0, 1.0, -1.0, 100.0, -100.0, 0.5, -0.5];
785 let tensor = Tensor::from_vec(data.clone(), &[7])?;
786
787 let fp8_tensor = quantizer.quantize(&tensor)?;
788 let dequantized = quantizer.dequantize(&fp8_tensor)?;
789
790 let deq_data = dequantized.to_vec_f32()?;
791
792 for (original, recovered) in data.iter().zip(deq_data.iter()) {
794 let rel_error = (original - recovered).abs() / (original.abs() + 1e-6);
795 assert!(
796 rel_error < 0.1,
797 "Relative error too large: {} vs {}",
798 original,
799 recovered
800 );
801 }
802
803 Ok(())
804 }
805
806 #[test]
807 fn test_fp8_per_channel_quantization() -> Result<()> {
808 let config = FP8Config {
809 format: FP8Format::E4M3,
810 scaling: ScalingStrategy::PerChannel,
811 ..Default::default()
812 };
813
814 let mut quantizer = FP8Quantizer::new(config)?;
815 let tensor = Tensor::randn(&[4, 8])?;
816
817 let fp8_tensor = quantizer.quantize(&tensor)?;
818
819 match &fp8_tensor.scales {
820 ScaleFactors::PerChannel(scales) => {
821 assert_eq!(scales.len(), 4); },
823 _ => panic!("Expected PerChannel scales"),
824 }
825
826 Ok(())
827 }
828
829 #[test]
830 fn test_select_fp8_format() -> Result<()> {
831 let tensor = Tensor::randn(&[10, 10])?;
832
833 let format_forward = select_fp8_format(&tensor, "forward");
834 assert_eq!(format_forward, FP8Format::E4M3);
835
836 let format_backward = select_fp8_format(&tensor, "gradients");
837 assert_eq!(format_backward, FP8Format::E5M2);
838
839 Ok(())
840 }
841
842 #[test]
843 fn test_delayed_scaling() -> Result<()> {
844 let config = FP8Config {
845 format: FP8Format::E4M3,
846 delayed_scaling: DelayedScalingConfig {
847 enabled: true,
848 interval: 2,
849 ..Default::default()
850 },
851 ..Default::default()
852 };
853
854 let mut quantizer = FP8Quantizer::new(config)?;
855
856 for _ in 0..5 {
858 let tensor = Tensor::randn(&[10, 10])?;
859 let _fp8_tensor = quantizer.quantize(&tensor)?;
860 }
861
862 assert!(quantizer.stats.is_some());
864
865 Ok(())
866 }
867}