Skip to main content

trustformers_core/quantization/
qat.rs

1//! Quantization-Aware Training (QAT) Infrastructure for TrustformeRS
2//!
3//! This module provides comprehensive quantization-aware training capabilities,
4//! including fake quantization layers, QAT schedulers, and training utilities.
5
6use crate::errors::{Result, TrustformersError};
7use crate::quantization::{ActivationQuantScheme, QuantizationScheme};
8use crate::tensor::Tensor;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12/// QAT training configuration
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct QATConfig {
15    /// Weight quantization scheme
16    pub weight_scheme: QuantizationScheme,
17    /// Activation quantization scheme
18    pub activation_scheme: ActivationQuantScheme,
19    /// Whether to use symmetric quantization
20    pub symmetric: bool,
21    /// Number of warmup epochs before enabling quantization
22    pub warmup_epochs: usize,
23    /// QAT schedule for gradual quantization introduction
24    pub schedule: QATSchedule,
25    /// Whether to quantize first and last layers
26    pub quantize_first_last: bool,
27    /// Observer configuration
28    pub observer_config: ObserverConfig,
29    /// Whether to use straight-through estimator
30    pub use_ste: bool,
31}
32
33/// QAT training schedule
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub enum QATSchedule {
36    /// Immediate quantization from start
37    Immediate,
38    /// Gradual introduction over epochs
39    Gradual {
40        start_epoch: usize,
41        end_epoch: usize,
42        weight_schedule: GradualSchedule,
43        activation_schedule: GradualSchedule,
44    },
45    /// Custom layer-by-layer schedule
46    LayerWise {
47        schedule: HashMap<String, LayerSchedule>,
48    },
49    /// Progressive bit reduction
50    Progressive {
51        start_bits: u8,
52        end_bits: u8,
53        reduction_epochs: Vec<usize>,
54    },
55}
56
57/// Gradual quantization schedule
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub enum GradualSchedule {
60    /// Linear introduction
61    Linear,
62    /// Cosine schedule
63    Cosine,
64    /// Exponential schedule
65    Exponential { base: f64 },
66    /// Step-wise introduction
67    Step { steps: Vec<usize> },
68}
69
70/// Layer-specific QAT schedule
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct LayerSchedule {
73    pub start_epoch: usize,
74    pub enable_weights: bool,
75    pub enable_activations: bool,
76    pub bits: Option<u8>,
77}
78
79/// Observer configuration for calibration
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct ObserverConfig {
82    /// Moving average momentum for statistics
83    pub momentum: f64,
84    /// Whether to use percentile clipping
85    pub use_percentile: bool,
86    /// Percentile value for clipping (e.g., 0.999)
87    pub percentile: f64,
88    /// Minimum number of observations before quantization
89    pub min_observations: usize,
90    /// Whether to freeze observer after warmup
91    pub freeze_after_warmup: bool,
92}
93
94/// Fake quantization layer for QAT
95#[derive(Debug)]
96pub struct FakeQuantLayer {
97    /// Current bit width
98    pub bits: u8,
99    /// Whether quantization is enabled
100    pub enabled: bool,
101    /// Quantization scheme
102    pub scheme: QuantizationScheme,
103    /// Observer for collecting statistics
104    pub observer: MovingAverageObserver,
105    /// Quantization parameters
106    pub scale: Option<f32>,
107    pub zero_point: Option<i32>,
108    /// Configuration
109    pub config: QATConfig,
110    /// Current epoch for schedule tracking
111    pub current_epoch: usize,
112}
113
114/// Moving average observer for QAT
115#[derive(Debug, Clone)]
116pub struct MovingAverageObserver {
117    /// Running minimum
118    pub min_val: f32,
119    /// Running maximum
120    pub max_val: f32,
121    /// Moving average momentum
122    pub momentum: f64,
123    /// Number of observations
124    pub num_observations: usize,
125    /// Whether observer is frozen
126    pub frozen: bool,
127    /// Configuration
128    pub config: ObserverConfig,
129}
130
131/// QAT trainer for managing the training process
132#[derive(Debug)]
133pub struct QATTrainer {
134    /// QAT configuration
135    pub config: QATConfig,
136    /// Fake quantization layers
137    pub fake_quant_layers: HashMap<String, FakeQuantLayer>,
138    /// Current training epoch
139    pub current_epoch: usize,
140    /// Training statistics
141    pub stats: QATStats,
142}
143
144/// QAT training statistics
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct QATStats {
147    /// Current quantization ratio (0.0 to 1.0)
148    pub quantization_ratio: f64,
149    /// Number of quantized layers
150    pub quantized_layers: usize,
151    /// Total number of layers
152    pub total_layers: usize,
153    /// Average bit width across layers
154    pub average_bits: f64,
155    /// Model size reduction ratio
156    pub size_reduction: f64,
157    /// Current training loss
158    pub training_loss: f64,
159    /// Validation accuracy
160    pub validation_accuracy: f64,
161}
162
163impl Default for QATConfig {
164    fn default() -> Self {
165        Self {
166            weight_scheme: QuantizationScheme::Dynamic,
167            activation_scheme: ActivationQuantScheme::Int8,
168            symmetric: false,
169            warmup_epochs: 5,
170            schedule: QATSchedule::Gradual {
171                start_epoch: 5,
172                end_epoch: 20,
173                weight_schedule: GradualSchedule::Linear,
174                activation_schedule: GradualSchedule::Linear,
175            },
176            quantize_first_last: false,
177            observer_config: ObserverConfig::default(),
178            use_ste: true,
179        }
180    }
181}
182
183impl Default for ObserverConfig {
184    fn default() -> Self {
185        Self {
186            momentum: 0.01,
187            use_percentile: true,
188            percentile: 0.999,
189            min_observations: 100,
190            freeze_after_warmup: true,
191        }
192    }
193}
194
195impl MovingAverageObserver {
196    /// Create new observer
197    pub fn new(config: ObserverConfig) -> Self {
198        Self {
199            min_val: f32::INFINITY,
200            max_val: f32::NEG_INFINITY,
201            momentum: config.momentum,
202            num_observations: 0,
203            frozen: false,
204            config,
205        }
206    }
207
208    /// Update observer with new tensor
209    pub fn update(&mut self, tensor: &Tensor) -> Result<()> {
210        if self.frozen {
211            return Ok(());
212        }
213
214        match tensor {
215            Tensor::F32(arr) => {
216                for &val in arr.iter() {
217                    if !val.is_finite() {
218                        continue;
219                    }
220
221                    if self.num_observations == 0 {
222                        self.min_val = val;
223                        self.max_val = val;
224                    } else {
225                        // Track actual min/max values
226                        if val < self.min_val {
227                            self.min_val = val;
228                        }
229                        if val > self.max_val {
230                            self.max_val = val;
231                        }
232                    }
233                    self.num_observations += 1;
234                }
235            },
236            _ => {
237                return Err(TrustformersError::quantization_error(
238                    "Unsupported tensor type for observer".into(),
239                ))
240            },
241        }
242
243        Ok(())
244    }
245
246    /// Get quantization parameters
247    pub fn get_quantization_params(&self, bits: u8, symmetric: bool) -> Result<(f32, i32)> {
248        if self.num_observations < self.config.min_observations {
249            return Err(TrustformersError::quantization_error(
250                "Insufficient observations for quantization".into(),
251            ));
252        }
253
254        let q_min = if symmetric { -(1 << (bits - 1)) } else { 0 };
255        let q_max = if symmetric { (1 << (bits - 1)) - 1 } else { (1 << bits) - 1 };
256
257        let (scale, zero_point) = if symmetric {
258            let abs_max = self.max_val.abs().max(self.min_val.abs());
259            if abs_max == 0.0 {
260                return Ok((1.0, 0));
261            }
262            let scale = abs_max / (q_max - q_min) as f32;
263            (scale, 0)
264        } else {
265            if self.max_val == self.min_val {
266                return Ok((1.0, q_min));
267            }
268            let scale = (self.max_val - self.min_val) / (q_max - q_min) as f32;
269            let zero_point = q_min - (self.min_val / scale).round() as i32;
270            let zero_point = zero_point.clamp(q_min, q_max);
271            (scale, zero_point)
272        };
273
274        Ok((scale, zero_point))
275    }
276
277    /// Freeze observer
278    pub fn freeze(&mut self) {
279        self.frozen = true;
280    }
281
282    /// Check if observer is ready
283    pub fn is_ready(&self) -> bool {
284        self.num_observations >= self.config.min_observations
285    }
286}
287
288impl FakeQuantLayer {
289    /// Create new fake quantization layer
290    pub fn new(bits: u8, scheme: QuantizationScheme, config: QATConfig) -> Self {
291        Self {
292            bits,
293            enabled: false,
294            scheme,
295            observer: MovingAverageObserver::new(config.observer_config.clone()),
296            scale: None,
297            zero_point: None,
298            config,
299            current_epoch: 0,
300        }
301    }
302
303    /// Update layer for current epoch
304    pub fn update_epoch(&mut self, epoch: usize) {
305        self.current_epoch = epoch;
306
307        // Update quantization state based on schedule
308        match &self.config.schedule {
309            QATSchedule::Immediate => {
310                if epoch >= self.config.warmup_epochs {
311                    self.enabled = true;
312                }
313            },
314            QATSchedule::Gradual { start_epoch, .. } => {
315                if epoch >= *start_epoch {
316                    self.enabled = true;
317                }
318            },
319            QATSchedule::LayerWise { .. } => {
320                // Layer-specific logic would be implemented here
321                self.enabled = epoch >= self.config.warmup_epochs;
322            },
323            QATSchedule::Progressive {
324                start_bits,
325                end_bits,
326                reduction_epochs,
327            } => {
328                self.enabled = epoch >= self.config.warmup_epochs;
329
330                // Progressive bit reduction
331                for (i, &reduction_epoch) in reduction_epochs.iter().enumerate() {
332                    if epoch >= reduction_epoch {
333                        let bits_reduction = (start_bits - end_bits) / reduction_epochs.len() as u8;
334                        self.bits = (*start_bits - (i as u8 + 1) * bits_reduction).max(*end_bits);
335                    }
336                }
337            },
338        }
339
340        // Freeze observer after warmup if configured
341        if self.config.observer_config.freeze_after_warmup && epoch > self.config.warmup_epochs {
342            self.observer.freeze();
343        }
344    }
345
346    /// Apply fake quantization to tensor
347    pub fn forward(&mut self, tensor: &Tensor, training: bool) -> Result<Tensor> {
348        if training {
349            // Update observer during training
350            self.observer.update(tensor)?;
351        }
352
353        if !self.enabled || !self.observer.is_ready() {
354            return Ok(tensor.clone());
355        }
356
357        // Get quantization parameters
358        if self.scale.is_none() || self.zero_point.is_none() {
359            let (scale, zero_point) =
360                self.observer.get_quantization_params(self.bits, self.config.symmetric)?;
361            self.scale = Some(scale);
362            self.zero_point = Some(zero_point);
363        }
364
365        // Safe: scale and zero_point are set in the block above if None
366        let scale = self.scale.expect("scale should be set after observer initialization");
367        let zero_point =
368            self.zero_point.expect("zero_point should be set after observer initialization");
369
370        // Apply fake quantization with straight-through estimator
371        self.fake_quantize(tensor, scale, zero_point)
372    }
373
374    /// Fake quantization with straight-through estimator
375    fn fake_quantize(&self, tensor: &Tensor, scale: f32, zero_point: i32) -> Result<Tensor> {
376        match tensor {
377            Tensor::F32(arr) => {
378                let q_min = if self.config.symmetric { -(1 << (self.bits - 1)) } else { 0 };
379                let q_max = if self.config.symmetric {
380                    (1 << (self.bits - 1)) - 1
381                } else {
382                    (1 << self.bits) - 1
383                };
384
385                let fake_quantized_data: Vec<f32> = arr
386                    .iter()
387                    .map(|&val| {
388                        if self.config.use_ste {
389                            // Straight-through estimator: forward pass quantized, backward pass identity
390                            let q_val =
391                                ((val / scale).round() as i32 + zero_point).clamp(q_min, q_max);
392                            (q_val - zero_point) as f32 * scale
393                        } else {
394                            // Standard fake quantization
395                            let q_val =
396                                ((val / scale).round() as i32 + zero_point).clamp(q_min, q_max);
397                            (q_val - zero_point) as f32 * scale
398                        }
399                    })
400                    .collect();
401
402                Tensor::from_vec(fake_quantized_data, arr.shape())
403            },
404            _ => Err(TrustformersError::quantization_error(
405                "Unsupported tensor type for fake quantization".into(),
406            )),
407        }
408    }
409
410    /// Get current quantization parameters
411    pub fn get_params(&self) -> Option<(f32, i32)> {
412        if let (Some(scale), Some(zero_point)) = (self.scale, self.zero_point) {
413            Some((scale, zero_point))
414        } else {
415            None
416        }
417    }
418}
419
420impl QATTrainer {
421    /// Create new QAT trainer
422    pub fn new(config: QATConfig) -> Self {
423        Self {
424            config,
425            fake_quant_layers: HashMap::new(),
426            current_epoch: 0,
427            stats: QATStats::default(),
428        }
429    }
430
431    /// Add fake quantization layer
432    pub fn add_layer(&mut self, name: String, bits: u8, scheme: QuantizationScheme) {
433        let layer = FakeQuantLayer::new(bits, scheme, self.config.clone());
434        self.fake_quant_layers.insert(name, layer);
435        self.update_stats();
436    }
437
438    /// Update epoch for all layers
439    pub fn update_epoch(&mut self, epoch: usize) {
440        self.current_epoch = epoch;
441
442        for layer in self.fake_quant_layers.values_mut() {
443            layer.update_epoch(epoch);
444        }
445
446        self.update_stats();
447    }
448
449    /// Apply fake quantization to tensor for specific layer
450    pub fn quantize_layer(
451        &mut self,
452        layer_name: &str,
453        tensor: &Tensor,
454        training: bool,
455    ) -> Result<Tensor> {
456        if let Some(layer) = self.fake_quant_layers.get_mut(layer_name) {
457            layer.forward(tensor, training)
458        } else {
459            Ok(tensor.clone())
460        }
461    }
462
463    /// Get current quantization schedule value
464    pub fn get_schedule_value(
465        &self,
466        schedule: &GradualSchedule,
467        start_epoch: usize,
468        end_epoch: usize,
469    ) -> f64 {
470        if self.current_epoch < start_epoch {
471            return 0.0;
472        }
473        if self.current_epoch >= end_epoch {
474            return 1.0;
475        }
476
477        let progress = (self.current_epoch - start_epoch) as f64 / (end_epoch - start_epoch) as f64;
478
479        match schedule {
480            GradualSchedule::Linear => progress,
481            GradualSchedule::Cosine => 0.5 * (1.0 - (std::f64::consts::PI * progress).cos()),
482            GradualSchedule::Exponential { base } => 1.0 - base.powf(progress),
483            GradualSchedule::Step { steps } => {
484                let current_step =
485                    steps.iter().position(|&step| self.current_epoch < step).unwrap_or(steps.len());
486                current_step as f64 / steps.len() as f64
487            },
488        }
489    }
490
491    /// Update training statistics
492    fn update_stats(&mut self) {
493        let total_layers = self.fake_quant_layers.len();
494        let quantized_layers =
495            self.fake_quant_layers.values().filter(|layer| layer.enabled).count();
496
497        let average_bits = if total_layers > 0 {
498            self.fake_quant_layers.values().map(|layer| layer.bits as f64).sum::<f64>()
499                / total_layers as f64
500        } else {
501            0.0
502        };
503
504        let quantization_ratio = if total_layers > 0 {
505            quantized_layers as f64 / total_layers as f64
506        } else {
507            0.0
508        };
509
510        // Estimate size reduction (simplified)
511        let size_reduction = match average_bits as u8 {
512            8 => 0.75,  // 32-bit to 8-bit
513            16 => 0.5,  // 32-bit to 16-bit
514            4 => 0.875, // 32-bit to 4-bit
515            _ => 0.0,
516        } * quantization_ratio;
517
518        self.stats = QATStats {
519            quantization_ratio,
520            quantized_layers,
521            total_layers,
522            average_bits,
523            size_reduction,
524            training_loss: self.stats.training_loss, // Preserve current values
525            validation_accuracy: self.stats.validation_accuracy,
526        };
527    }
528
529    /// Update training metrics
530    pub fn update_metrics(&mut self, training_loss: f64, validation_accuracy: f64) {
531        self.stats.training_loss = training_loss;
532        self.stats.validation_accuracy = validation_accuracy;
533    }
534
535    /// Get current statistics
536    pub fn get_stats(&self) -> &QATStats {
537        &self.stats
538    }
539
540    /// Check if QAT is ready (all observers have enough data)
541    pub fn is_ready(&self) -> bool {
542        self.fake_quant_layers.values().all(|layer| layer.observer.is_ready())
543    }
544
545    /// Export quantized model configuration
546    pub fn export_quantized_config(&self) -> HashMap<String, (f32, i32, u8)> {
547        self.fake_quant_layers
548            .iter()
549            .filter_map(|(name, layer)| {
550                if let Some((scale, zero_point)) = layer.get_params() {
551                    Some((name.clone(), (scale, zero_point, layer.bits)))
552                } else {
553                    None
554                }
555            })
556            .collect()
557    }
558
559    /// Save QAT state
560    pub fn save_state(&self, path: &str) -> Result<()> {
561        let state = QATState {
562            config: self.config.clone(),
563            current_epoch: self.current_epoch,
564            stats: self.stats.clone(),
565            layer_configs: self.export_quantized_config(),
566        };
567
568        let json_data = serde_json::to_string_pretty(&state).map_err(|e| {
569            TrustformersError::quantization_error(format!("Failed to serialize QAT state: {}", e))
570        })?;
571
572        std::fs::write(path, json_data).map_err(|e| {
573            TrustformersError::quantization_error(format!("Failed to write file: {}", e))
574        })?;
575
576        Ok(())
577    }
578
579    /// Load QAT state
580    pub fn load_state(&mut self, path: &str) -> Result<()> {
581        let json_data = std::fs::read_to_string(path).map_err(|e| {
582            TrustformersError::quantization_error(format!("Failed to read file: {}", e))
583        })?;
584
585        let state: QATState = serde_json::from_str(&json_data).map_err(|e| {
586            TrustformersError::quantization_error(format!("Failed to deserialize QAT state: {}", e))
587        })?;
588
589        self.config = state.config;
590        self.current_epoch = state.current_epoch;
591        self.stats = state.stats;
592
593        // Restore layer configurations
594        for (name, (scale, zero_point, bits)) in state.layer_configs {
595            if let Some(layer) = self.fake_quant_layers.get_mut(&name) {
596                layer.scale = Some(scale);
597                layer.zero_point = Some(zero_point);
598                layer.bits = bits;
599            }
600        }
601
602        Ok(())
603    }
604}
605
606impl Default for QATStats {
607    fn default() -> Self {
608        Self {
609            quantization_ratio: 0.0,
610            quantized_layers: 0,
611            total_layers: 0,
612            average_bits: 32.0,
613            size_reduction: 0.0,
614            training_loss: 0.0,
615            validation_accuracy: 0.0,
616        }
617    }
618}
619
620/// Serializable QAT state
621#[derive(Debug, Clone, Serialize, Deserialize)]
622pub struct QATState {
623    pub config: QATConfig,
624    pub current_epoch: usize,
625    pub stats: QATStats,
626    pub layer_configs: HashMap<String, (f32, i32, u8)>, // (scale, zero_point, bits)
627}
628
629/// QAT utilities
630pub struct QATUtils;
631
632impl QATUtils {
633    /// Create a progressive QAT schedule
634    pub fn create_progressive_schedule(
635        warmup_epochs: usize,
636        total_epochs: usize,
637        start_bits: u8,
638        end_bits: u8,
639    ) -> QATSchedule {
640        let reduction_steps = (start_bits - end_bits) as usize;
641        let epochs_per_step = (total_epochs - warmup_epochs) / reduction_steps.max(1);
642
643        let reduction_epochs: Vec<usize> = (1..=reduction_steps)
644            .map(|step| warmup_epochs + step * epochs_per_step)
645            .collect();
646
647        QATSchedule::Progressive {
648            start_bits,
649            end_bits,
650            reduction_epochs,
651        }
652    }
653
654    /// Create layer-wise schedule
655    pub fn create_layerwise_schedule(
656        layer_names: &[String],
657        start_epoch: usize,
658        epochs_between_layers: usize,
659    ) -> QATSchedule {
660        let mut schedule = HashMap::new();
661
662        for (i, name) in layer_names.iter().enumerate() {
663            let layer_start_epoch = start_epoch + i * epochs_between_layers;
664            schedule.insert(
665                name.clone(),
666                LayerSchedule {
667                    start_epoch: layer_start_epoch,
668                    enable_weights: true,
669                    enable_activations: true,
670                    bits: Some(8),
671                },
672            );
673        }
674
675        QATSchedule::LayerWise { schedule }
676    }
677
678    /// Estimate model size reduction
679    pub fn estimate_size_reduction(
680        original_bits: u8,
681        quantized_bits: u8,
682        quantization_ratio: f64,
683    ) -> f64 {
684        let bit_reduction = 1.0 - (quantized_bits as f64 / original_bits as f64);
685        bit_reduction * quantization_ratio
686    }
687
688    /// Calculate quantization noise
689    pub fn calculate_quantization_noise(original: &Tensor, quantized: &Tensor) -> Result<f64> {
690        match (original, quantized) {
691            (Tensor::F32(orig_arr), Tensor::F32(quant_arr)) => {
692                if orig_arr.len() != quant_arr.len() {
693                    return Err(TrustformersError::quantization_error(
694                        "Tensor sizes don't match".into(),
695                    ));
696                }
697
698                let mse: f64 = orig_arr
699                    .iter()
700                    .zip(quant_arr.iter())
701                    .map(|(&orig, &quant)| (orig - quant).powi(2) as f64)
702                    .sum::<f64>()
703                    / orig_arr.len() as f64;
704
705                Ok(mse.sqrt()) // RMSE
706            },
707            _ => Err(TrustformersError::quantization_error(
708                "Unsupported tensor types for noise calculation".into(),
709            )),
710        }
711    }
712}
713
714#[cfg(test)]
715mod tests {
716    use super::*;
717
718    #[test]
719    fn test_qat_config_default() {
720        let config = QATConfig::default();
721        assert_eq!(config.warmup_epochs, 5);
722        assert!(!config.quantize_first_last);
723        assert!(config.use_ste);
724    }
725
726    #[test]
727    fn test_moving_average_observer() {
728        let config = ObserverConfig::default();
729        let mut observer = MovingAverageObserver::new(config);
730
731        let tensor =
732            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).expect("Tensor from_vec failed");
733        observer.update(&tensor).expect("tensor operation failed");
734
735        assert_eq!(observer.num_observations, 4);
736        assert!(observer.min_val <= 1.0);
737        assert!(observer.max_val >= 4.0);
738    }
739
740    #[test]
741    fn test_fake_quant_layer() {
742        let mut config = QATConfig::default();
743        config.observer_config.freeze_after_warmup = false; // Don't freeze observer prematurely
744        let mut layer = FakeQuantLayer::new(8, QuantizationScheme::DynamicINT8, config);
745
746        // Should not be enabled initially
747        assert!(!layer.enabled);
748
749        // Update to after warmup
750        layer.update_epoch(10);
751
752        let tensor =
753            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).expect("Tensor from_vec failed");
754
755        // Simulate training to collect statistics
756        for _ in 0..100 {
757            layer.forward(&tensor, true).expect("Forward pass failed");
758        }
759
760        // Should be enabled and ready now
761        assert!(layer.enabled);
762        assert!(layer.observer.is_ready());
763    }
764
765    #[test]
766    fn test_qat_trainer() {
767        let config = QATConfig::default();
768        let mut trainer = QATTrainer::new(config);
769
770        trainer.add_layer("conv1".to_string(), 8, QuantizationScheme::DynamicINT8);
771        trainer.add_layer("conv2".to_string(), 8, QuantizationScheme::DynamicINT8);
772
773        let stats = trainer.get_stats();
774        assert_eq!(stats.total_layers, 2);
775        assert_eq!(stats.quantized_layers, 0); // Not enabled yet
776
777        trainer.update_epoch(10);
778        let stats = trainer.get_stats();
779        assert_eq!(stats.quantized_layers, 2); // Should be enabled now
780    }
781
782    #[test]
783    fn test_gradual_schedule() {
784        let config = QATConfig::default();
785        let trainer = QATTrainer::new(config);
786
787        let schedule = GradualSchedule::Linear;
788        let value = trainer.get_schedule_value(&schedule, 5, 15);
789        // Should be between 0 and 1 for linear schedule
790        assert!((0.0..=1.0).contains(&value));
791    }
792
793    #[test]
794    fn test_qat_utils_progressive_schedule() {
795        let schedule = QATUtils::create_progressive_schedule(5, 25, 16, 8);
796
797        match schedule {
798            QATSchedule::Progressive {
799                start_bits,
800                end_bits,
801                reduction_epochs,
802            } => {
803                assert_eq!(start_bits, 16);
804                assert_eq!(end_bits, 8);
805                assert!(!reduction_epochs.is_empty());
806            },
807            _ => panic!("Expected progressive schedule"),
808        }
809    }
810
811    #[test]
812    fn test_size_reduction_estimation() {
813        let reduction = QATUtils::estimate_size_reduction(32, 8, 1.0);
814        assert_eq!(reduction, 0.75); // 75% reduction from 32-bit to 8-bit
815    }
816
817    #[test]
818    fn test_quantization_noise_calculation() {
819        let original =
820            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).expect("Tensor from_vec failed");
821        let quantized =
822            Tensor::from_vec(vec![1.1, 1.9, 3.1, 3.9], &[4]).expect("Tensor from_vec failed");
823
824        let noise = QATUtils::calculate_quantization_noise(&original, &quantized)
825            .expect("operation failed in test");
826        assert!(noise > 0.0);
827        assert!(noise < 1.0); // Should be small for close values
828    }
829
830    #[test]
831    fn test_layer_wise_schedule() {
832        let layer_names = vec!["conv1".to_string(), "conv2".to_string(), "fc1".to_string()];
833        let schedule = QATUtils::create_layerwise_schedule(&layer_names, 5, 2);
834
835        match schedule {
836            QATSchedule::LayerWise { schedule } => {
837                assert_eq!(schedule.len(), 3);
838                assert!(schedule.contains_key("conv1"));
839                assert!(schedule.contains_key("conv2"));
840                assert!(schedule.contains_key("fc1"));
841            },
842            _ => panic!("Expected layer-wise schedule"),
843        }
844    }
845}