1use crate::errors::{Result, TrustformersError};
7use crate::quantization::{ActivationQuantScheme, QuantizationScheme};
8use crate::tensor::Tensor;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct QATConfig {
15 pub weight_scheme: QuantizationScheme,
17 pub activation_scheme: ActivationQuantScheme,
19 pub symmetric: bool,
21 pub warmup_epochs: usize,
23 pub schedule: QATSchedule,
25 pub quantize_first_last: bool,
27 pub observer_config: ObserverConfig,
29 pub use_ste: bool,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub enum QATSchedule {
36 Immediate,
38 Gradual {
40 start_epoch: usize,
41 end_epoch: usize,
42 weight_schedule: GradualSchedule,
43 activation_schedule: GradualSchedule,
44 },
45 LayerWise {
47 schedule: HashMap<String, LayerSchedule>,
48 },
49 Progressive {
51 start_bits: u8,
52 end_bits: u8,
53 reduction_epochs: Vec<usize>,
54 },
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub enum GradualSchedule {
60 Linear,
62 Cosine,
64 Exponential { base: f64 },
66 Step { steps: Vec<usize> },
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct LayerSchedule {
73 pub start_epoch: usize,
74 pub enable_weights: bool,
75 pub enable_activations: bool,
76 pub bits: Option<u8>,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct ObserverConfig {
82 pub momentum: f64,
84 pub use_percentile: bool,
86 pub percentile: f64,
88 pub min_observations: usize,
90 pub freeze_after_warmup: bool,
92}
93
94#[derive(Debug)]
96pub struct FakeQuantLayer {
97 pub bits: u8,
99 pub enabled: bool,
101 pub scheme: QuantizationScheme,
103 pub observer: MovingAverageObserver,
105 pub scale: Option<f32>,
107 pub zero_point: Option<i32>,
108 pub config: QATConfig,
110 pub current_epoch: usize,
112}
113
114#[derive(Debug, Clone)]
116pub struct MovingAverageObserver {
117 pub min_val: f32,
119 pub max_val: f32,
121 pub momentum: f64,
123 pub num_observations: usize,
125 pub frozen: bool,
127 pub config: ObserverConfig,
129}
130
131#[derive(Debug)]
133pub struct QATTrainer {
134 pub config: QATConfig,
136 pub fake_quant_layers: HashMap<String, FakeQuantLayer>,
138 pub current_epoch: usize,
140 pub stats: QATStats,
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct QATStats {
147 pub quantization_ratio: f64,
149 pub quantized_layers: usize,
151 pub total_layers: usize,
153 pub average_bits: f64,
155 pub size_reduction: f64,
157 pub training_loss: f64,
159 pub validation_accuracy: f64,
161}
162
163impl Default for QATConfig {
164 fn default() -> Self {
165 Self {
166 weight_scheme: QuantizationScheme::Dynamic,
167 activation_scheme: ActivationQuantScheme::Int8,
168 symmetric: false,
169 warmup_epochs: 5,
170 schedule: QATSchedule::Gradual {
171 start_epoch: 5,
172 end_epoch: 20,
173 weight_schedule: GradualSchedule::Linear,
174 activation_schedule: GradualSchedule::Linear,
175 },
176 quantize_first_last: false,
177 observer_config: ObserverConfig::default(),
178 use_ste: true,
179 }
180 }
181}
182
183impl Default for ObserverConfig {
184 fn default() -> Self {
185 Self {
186 momentum: 0.01,
187 use_percentile: true,
188 percentile: 0.999,
189 min_observations: 100,
190 freeze_after_warmup: true,
191 }
192 }
193}
194
195impl MovingAverageObserver {
196 pub fn new(config: ObserverConfig) -> Self {
198 Self {
199 min_val: f32::INFINITY,
200 max_val: f32::NEG_INFINITY,
201 momentum: config.momentum,
202 num_observations: 0,
203 frozen: false,
204 config,
205 }
206 }
207
208 pub fn update(&mut self, tensor: &Tensor) -> Result<()> {
210 if self.frozen {
211 return Ok(());
212 }
213
214 match tensor {
215 Tensor::F32(arr) => {
216 for &val in arr.iter() {
217 if !val.is_finite() {
218 continue;
219 }
220
221 if self.num_observations == 0 {
222 self.min_val = val;
223 self.max_val = val;
224 } else {
225 if val < self.min_val {
227 self.min_val = val;
228 }
229 if val > self.max_val {
230 self.max_val = val;
231 }
232 }
233 self.num_observations += 1;
234 }
235 },
236 _ => {
237 return Err(TrustformersError::quantization_error(
238 "Unsupported tensor type for observer".into(),
239 ))
240 },
241 }
242
243 Ok(())
244 }
245
246 pub fn get_quantization_params(&self, bits: u8, symmetric: bool) -> Result<(f32, i32)> {
248 if self.num_observations < self.config.min_observations {
249 return Err(TrustformersError::quantization_error(
250 "Insufficient observations for quantization".into(),
251 ));
252 }
253
254 let q_min = if symmetric { -(1 << (bits - 1)) } else { 0 };
255 let q_max = if symmetric { (1 << (bits - 1)) - 1 } else { (1 << bits) - 1 };
256
257 let (scale, zero_point) = if symmetric {
258 let abs_max = self.max_val.abs().max(self.min_val.abs());
259 if abs_max == 0.0 {
260 return Ok((1.0, 0));
261 }
262 let scale = abs_max / (q_max - q_min) as f32;
263 (scale, 0)
264 } else {
265 if self.max_val == self.min_val {
266 return Ok((1.0, q_min));
267 }
268 let scale = (self.max_val - self.min_val) / (q_max - q_min) as f32;
269 let zero_point = q_min - (self.min_val / scale).round() as i32;
270 let zero_point = zero_point.clamp(q_min, q_max);
271 (scale, zero_point)
272 };
273
274 Ok((scale, zero_point))
275 }
276
277 pub fn freeze(&mut self) {
279 self.frozen = true;
280 }
281
282 pub fn is_ready(&self) -> bool {
284 self.num_observations >= self.config.min_observations
285 }
286}
287
288impl FakeQuantLayer {
289 pub fn new(bits: u8, scheme: QuantizationScheme, config: QATConfig) -> Self {
291 Self {
292 bits,
293 enabled: false,
294 scheme,
295 observer: MovingAverageObserver::new(config.observer_config.clone()),
296 scale: None,
297 zero_point: None,
298 config,
299 current_epoch: 0,
300 }
301 }
302
303 pub fn update_epoch(&mut self, epoch: usize) {
305 self.current_epoch = epoch;
306
307 match &self.config.schedule {
309 QATSchedule::Immediate => {
310 if epoch >= self.config.warmup_epochs {
311 self.enabled = true;
312 }
313 },
314 QATSchedule::Gradual { start_epoch, .. } => {
315 if epoch >= *start_epoch {
316 self.enabled = true;
317 }
318 },
319 QATSchedule::LayerWise { .. } => {
320 self.enabled = epoch >= self.config.warmup_epochs;
322 },
323 QATSchedule::Progressive {
324 start_bits,
325 end_bits,
326 reduction_epochs,
327 } => {
328 self.enabled = epoch >= self.config.warmup_epochs;
329
330 for (i, &reduction_epoch) in reduction_epochs.iter().enumerate() {
332 if epoch >= reduction_epoch {
333 let bits_reduction = (start_bits - end_bits) / reduction_epochs.len() as u8;
334 self.bits = (*start_bits - (i as u8 + 1) * bits_reduction).max(*end_bits);
335 }
336 }
337 },
338 }
339
340 if self.config.observer_config.freeze_after_warmup && epoch > self.config.warmup_epochs {
342 self.observer.freeze();
343 }
344 }
345
346 pub fn forward(&mut self, tensor: &Tensor, training: bool) -> Result<Tensor> {
348 if training {
349 self.observer.update(tensor)?;
351 }
352
353 if !self.enabled || !self.observer.is_ready() {
354 return Ok(tensor.clone());
355 }
356
357 if self.scale.is_none() || self.zero_point.is_none() {
359 let (scale, zero_point) =
360 self.observer.get_quantization_params(self.bits, self.config.symmetric)?;
361 self.scale = Some(scale);
362 self.zero_point = Some(zero_point);
363 }
364
365 let scale = self.scale.expect("scale should be set after observer initialization");
367 let zero_point =
368 self.zero_point.expect("zero_point should be set after observer initialization");
369
370 self.fake_quantize(tensor, scale, zero_point)
372 }
373
374 fn fake_quantize(&self, tensor: &Tensor, scale: f32, zero_point: i32) -> Result<Tensor> {
376 match tensor {
377 Tensor::F32(arr) => {
378 let q_min = if self.config.symmetric { -(1 << (self.bits - 1)) } else { 0 };
379 let q_max = if self.config.symmetric {
380 (1 << (self.bits - 1)) - 1
381 } else {
382 (1 << self.bits) - 1
383 };
384
385 let fake_quantized_data: Vec<f32> = arr
386 .iter()
387 .map(|&val| {
388 if self.config.use_ste {
389 let q_val =
391 ((val / scale).round() as i32 + zero_point).clamp(q_min, q_max);
392 (q_val - zero_point) as f32 * scale
393 } else {
394 let q_val =
396 ((val / scale).round() as i32 + zero_point).clamp(q_min, q_max);
397 (q_val - zero_point) as f32 * scale
398 }
399 })
400 .collect();
401
402 Tensor::from_vec(fake_quantized_data, arr.shape())
403 },
404 _ => Err(TrustformersError::quantization_error(
405 "Unsupported tensor type for fake quantization".into(),
406 )),
407 }
408 }
409
410 pub fn get_params(&self) -> Option<(f32, i32)> {
412 if let (Some(scale), Some(zero_point)) = (self.scale, self.zero_point) {
413 Some((scale, zero_point))
414 } else {
415 None
416 }
417 }
418}
419
420impl QATTrainer {
421 pub fn new(config: QATConfig) -> Self {
423 Self {
424 config,
425 fake_quant_layers: HashMap::new(),
426 current_epoch: 0,
427 stats: QATStats::default(),
428 }
429 }
430
431 pub fn add_layer(&mut self, name: String, bits: u8, scheme: QuantizationScheme) {
433 let layer = FakeQuantLayer::new(bits, scheme, self.config.clone());
434 self.fake_quant_layers.insert(name, layer);
435 self.update_stats();
436 }
437
438 pub fn update_epoch(&mut self, epoch: usize) {
440 self.current_epoch = epoch;
441
442 for layer in self.fake_quant_layers.values_mut() {
443 layer.update_epoch(epoch);
444 }
445
446 self.update_stats();
447 }
448
449 pub fn quantize_layer(
451 &mut self,
452 layer_name: &str,
453 tensor: &Tensor,
454 training: bool,
455 ) -> Result<Tensor> {
456 if let Some(layer) = self.fake_quant_layers.get_mut(layer_name) {
457 layer.forward(tensor, training)
458 } else {
459 Ok(tensor.clone())
460 }
461 }
462
463 pub fn get_schedule_value(
465 &self,
466 schedule: &GradualSchedule,
467 start_epoch: usize,
468 end_epoch: usize,
469 ) -> f64 {
470 if self.current_epoch < start_epoch {
471 return 0.0;
472 }
473 if self.current_epoch >= end_epoch {
474 return 1.0;
475 }
476
477 let progress = (self.current_epoch - start_epoch) as f64 / (end_epoch - start_epoch) as f64;
478
479 match schedule {
480 GradualSchedule::Linear => progress,
481 GradualSchedule::Cosine => 0.5 * (1.0 - (std::f64::consts::PI * progress).cos()),
482 GradualSchedule::Exponential { base } => 1.0 - base.powf(progress),
483 GradualSchedule::Step { steps } => {
484 let current_step =
485 steps.iter().position(|&step| self.current_epoch < step).unwrap_or(steps.len());
486 current_step as f64 / steps.len() as f64
487 },
488 }
489 }
490
491 fn update_stats(&mut self) {
493 let total_layers = self.fake_quant_layers.len();
494 let quantized_layers =
495 self.fake_quant_layers.values().filter(|layer| layer.enabled).count();
496
497 let average_bits = if total_layers > 0 {
498 self.fake_quant_layers.values().map(|layer| layer.bits as f64).sum::<f64>()
499 / total_layers as f64
500 } else {
501 0.0
502 };
503
504 let quantization_ratio = if total_layers > 0 {
505 quantized_layers as f64 / total_layers as f64
506 } else {
507 0.0
508 };
509
510 let size_reduction = match average_bits as u8 {
512 8 => 0.75, 16 => 0.5, 4 => 0.875, _ => 0.0,
516 } * quantization_ratio;
517
518 self.stats = QATStats {
519 quantization_ratio,
520 quantized_layers,
521 total_layers,
522 average_bits,
523 size_reduction,
524 training_loss: self.stats.training_loss, validation_accuracy: self.stats.validation_accuracy,
526 };
527 }
528
529 pub fn update_metrics(&mut self, training_loss: f64, validation_accuracy: f64) {
531 self.stats.training_loss = training_loss;
532 self.stats.validation_accuracy = validation_accuracy;
533 }
534
535 pub fn get_stats(&self) -> &QATStats {
537 &self.stats
538 }
539
540 pub fn is_ready(&self) -> bool {
542 self.fake_quant_layers.values().all(|layer| layer.observer.is_ready())
543 }
544
545 pub fn export_quantized_config(&self) -> HashMap<String, (f32, i32, u8)> {
547 self.fake_quant_layers
548 .iter()
549 .filter_map(|(name, layer)| {
550 if let Some((scale, zero_point)) = layer.get_params() {
551 Some((name.clone(), (scale, zero_point, layer.bits)))
552 } else {
553 None
554 }
555 })
556 .collect()
557 }
558
559 pub fn save_state(&self, path: &str) -> Result<()> {
561 let state = QATState {
562 config: self.config.clone(),
563 current_epoch: self.current_epoch,
564 stats: self.stats.clone(),
565 layer_configs: self.export_quantized_config(),
566 };
567
568 let json_data = serde_json::to_string_pretty(&state).map_err(|e| {
569 TrustformersError::quantization_error(format!("Failed to serialize QAT state: {}", e))
570 })?;
571
572 std::fs::write(path, json_data).map_err(|e| {
573 TrustformersError::quantization_error(format!("Failed to write file: {}", e))
574 })?;
575
576 Ok(())
577 }
578
579 pub fn load_state(&mut self, path: &str) -> Result<()> {
581 let json_data = std::fs::read_to_string(path).map_err(|e| {
582 TrustformersError::quantization_error(format!("Failed to read file: {}", e))
583 })?;
584
585 let state: QATState = serde_json::from_str(&json_data).map_err(|e| {
586 TrustformersError::quantization_error(format!("Failed to deserialize QAT state: {}", e))
587 })?;
588
589 self.config = state.config;
590 self.current_epoch = state.current_epoch;
591 self.stats = state.stats;
592
593 for (name, (scale, zero_point, bits)) in state.layer_configs {
595 if let Some(layer) = self.fake_quant_layers.get_mut(&name) {
596 layer.scale = Some(scale);
597 layer.zero_point = Some(zero_point);
598 layer.bits = bits;
599 }
600 }
601
602 Ok(())
603 }
604}
605
606impl Default for QATStats {
607 fn default() -> Self {
608 Self {
609 quantization_ratio: 0.0,
610 quantized_layers: 0,
611 total_layers: 0,
612 average_bits: 32.0,
613 size_reduction: 0.0,
614 training_loss: 0.0,
615 validation_accuracy: 0.0,
616 }
617 }
618}
619
620#[derive(Debug, Clone, Serialize, Deserialize)]
622pub struct QATState {
623 pub config: QATConfig,
624 pub current_epoch: usize,
625 pub stats: QATStats,
626 pub layer_configs: HashMap<String, (f32, i32, u8)>, }
628
629pub struct QATUtils;
631
632impl QATUtils {
633 pub fn create_progressive_schedule(
635 warmup_epochs: usize,
636 total_epochs: usize,
637 start_bits: u8,
638 end_bits: u8,
639 ) -> QATSchedule {
640 let reduction_steps = (start_bits - end_bits) as usize;
641 let epochs_per_step = (total_epochs - warmup_epochs) / reduction_steps.max(1);
642
643 let reduction_epochs: Vec<usize> = (1..=reduction_steps)
644 .map(|step| warmup_epochs + step * epochs_per_step)
645 .collect();
646
647 QATSchedule::Progressive {
648 start_bits,
649 end_bits,
650 reduction_epochs,
651 }
652 }
653
654 pub fn create_layerwise_schedule(
656 layer_names: &[String],
657 start_epoch: usize,
658 epochs_between_layers: usize,
659 ) -> QATSchedule {
660 let mut schedule = HashMap::new();
661
662 for (i, name) in layer_names.iter().enumerate() {
663 let layer_start_epoch = start_epoch + i * epochs_between_layers;
664 schedule.insert(
665 name.clone(),
666 LayerSchedule {
667 start_epoch: layer_start_epoch,
668 enable_weights: true,
669 enable_activations: true,
670 bits: Some(8),
671 },
672 );
673 }
674
675 QATSchedule::LayerWise { schedule }
676 }
677
678 pub fn estimate_size_reduction(
680 original_bits: u8,
681 quantized_bits: u8,
682 quantization_ratio: f64,
683 ) -> f64 {
684 let bit_reduction = 1.0 - (quantized_bits as f64 / original_bits as f64);
685 bit_reduction * quantization_ratio
686 }
687
688 pub fn calculate_quantization_noise(original: &Tensor, quantized: &Tensor) -> Result<f64> {
690 match (original, quantized) {
691 (Tensor::F32(orig_arr), Tensor::F32(quant_arr)) => {
692 if orig_arr.len() != quant_arr.len() {
693 return Err(TrustformersError::quantization_error(
694 "Tensor sizes don't match".into(),
695 ));
696 }
697
698 let mse: f64 = orig_arr
699 .iter()
700 .zip(quant_arr.iter())
701 .map(|(&orig, &quant)| (orig - quant).powi(2) as f64)
702 .sum::<f64>()
703 / orig_arr.len() as f64;
704
705 Ok(mse.sqrt()) },
707 _ => Err(TrustformersError::quantization_error(
708 "Unsupported tensor types for noise calculation".into(),
709 )),
710 }
711 }
712}
713
714#[cfg(test)]
715mod tests {
716 use super::*;
717
718 #[test]
719 fn test_qat_config_default() {
720 let config = QATConfig::default();
721 assert_eq!(config.warmup_epochs, 5);
722 assert!(!config.quantize_first_last);
723 assert!(config.use_ste);
724 }
725
726 #[test]
727 fn test_moving_average_observer() {
728 let config = ObserverConfig::default();
729 let mut observer = MovingAverageObserver::new(config);
730
731 let tensor =
732 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).expect("Tensor from_vec failed");
733 observer.update(&tensor).expect("tensor operation failed");
734
735 assert_eq!(observer.num_observations, 4);
736 assert!(observer.min_val <= 1.0);
737 assert!(observer.max_val >= 4.0);
738 }
739
740 #[test]
741 fn test_fake_quant_layer() {
742 let mut config = QATConfig::default();
743 config.observer_config.freeze_after_warmup = false; let mut layer = FakeQuantLayer::new(8, QuantizationScheme::DynamicINT8, config);
745
746 assert!(!layer.enabled);
748
749 layer.update_epoch(10);
751
752 let tensor =
753 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).expect("Tensor from_vec failed");
754
755 for _ in 0..100 {
757 layer.forward(&tensor, true).expect("Forward pass failed");
758 }
759
760 assert!(layer.enabled);
762 assert!(layer.observer.is_ready());
763 }
764
765 #[test]
766 fn test_qat_trainer() {
767 let config = QATConfig::default();
768 let mut trainer = QATTrainer::new(config);
769
770 trainer.add_layer("conv1".to_string(), 8, QuantizationScheme::DynamicINT8);
771 trainer.add_layer("conv2".to_string(), 8, QuantizationScheme::DynamicINT8);
772
773 let stats = trainer.get_stats();
774 assert_eq!(stats.total_layers, 2);
775 assert_eq!(stats.quantized_layers, 0); trainer.update_epoch(10);
778 let stats = trainer.get_stats();
779 assert_eq!(stats.quantized_layers, 2); }
781
782 #[test]
783 fn test_gradual_schedule() {
784 let config = QATConfig::default();
785 let trainer = QATTrainer::new(config);
786
787 let schedule = GradualSchedule::Linear;
788 let value = trainer.get_schedule_value(&schedule, 5, 15);
789 assert!((0.0..=1.0).contains(&value));
791 }
792
793 #[test]
794 fn test_qat_utils_progressive_schedule() {
795 let schedule = QATUtils::create_progressive_schedule(5, 25, 16, 8);
796
797 match schedule {
798 QATSchedule::Progressive {
799 start_bits,
800 end_bits,
801 reduction_epochs,
802 } => {
803 assert_eq!(start_bits, 16);
804 assert_eq!(end_bits, 8);
805 assert!(!reduction_epochs.is_empty());
806 },
807 _ => panic!("Expected progressive schedule"),
808 }
809 }
810
811 #[test]
812 fn test_size_reduction_estimation() {
813 let reduction = QATUtils::estimate_size_reduction(32, 8, 1.0);
814 assert_eq!(reduction, 0.75); }
816
817 #[test]
818 fn test_quantization_noise_calculation() {
819 let original =
820 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).expect("Tensor from_vec failed");
821 let quantized =
822 Tensor::from_vec(vec![1.1, 1.9, 3.1, 3.9], &[4]).expect("Tensor from_vec failed");
823
824 let noise = QATUtils::calculate_quantization_noise(&original, &quantized)
825 .expect("operation failed in test");
826 assert!(noise > 0.0);
827 assert!(noise < 1.0); }
829
830 #[test]
831 fn test_layer_wise_schedule() {
832 let layer_names = vec!["conv1".to_string(), "conv2".to_string(), "fc1".to_string()];
833 let schedule = QATUtils::create_layerwise_schedule(&layer_names, 5, 2);
834
835 match schedule {
836 QATSchedule::LayerWise { schedule } => {
837 assert_eq!(schedule.len(), 3);
838 assert!(schedule.contains_key("conv1"));
839 assert!(schedule.contains_key("conv2"));
840 assert!(schedule.contains_key("fc1"));
841 },
842 _ => panic!("Expected layer-wise schedule"),
843 }
844 }
845}