Skip to main content

trustformers_training/
qat.rs

1//! Quantization-Aware Training (QAT) for TrustformeRS
2//!
3//! This module provides quantization-aware training functionality that simulates
4//! quantization during training to improve model accuracy after quantization.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::{Arc, Mutex};
9use trustformers_core::errors::Result;
10use trustformers_core::{Layer, QuantizationScheme, Tensor};
11
12/// Mixed-bit quantization strategy
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub enum MixedBitStrategy {
15    /// Uniform bit width across all layers
16    Uniform { bits: u8 },
17    /// Manual specification of bits per layer type
18    Manual { layer_bits: HashMap<String, u8> },
19    /// Automatic sensitivity-based assignment
20    SensitivityBased {
21        sensitivity_threshold: f32,
22        high_precision_bits: u8,
23        low_precision_bits: u8,
24    },
25    /// Resource-constrained mixed-bit optimization
26    ResourceConstrained {
27        total_bit_budget: u64,
28        critical_layers: Vec<String>,
29        critical_bits: u8,
30        default_bits: u8,
31    },
32    /// Progressive quantization (start high, reduce over time)
33    Progressive {
34        initial_bits: u8,
35        final_bits: u8,
36        reduction_schedule: Vec<(usize, u8)>, // (step, bits)
37    },
38}
39
40impl Default for MixedBitStrategy {
41    fn default() -> Self {
42        Self::Uniform { bits: 8 }
43    }
44}
45
46/// Layer-specific quantization configuration
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct LayerQuantConfig {
49    /// Number of bits for this layer
50    pub bits: u8,
51    /// Use symmetric quantization
52    pub symmetric: bool,
53    /// Per-channel quantization
54    pub per_channel: bool,
55    /// Layer sensitivity score (0.0 = low, 1.0 = high)
56    pub sensitivity: f32,
57    /// Whether this layer is critical for model performance
58    pub is_critical: bool,
59}
60
61impl Default for LayerQuantConfig {
62    fn default() -> Self {
63        Self {
64            bits: 8,
65            symmetric: true,
66            per_channel: false,
67            sensitivity: 0.5,
68            is_critical: false,
69        }
70    }
71}
72
73/// QAT configuration with mixed-bit support
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct QATConfig {
76    /// Quantization scheme to simulate
77    pub qscheme: QuantizationScheme,
78    /// Mixed-bit quantization strategy
79    pub mixed_bit_strategy: MixedBitStrategy,
80    /// Default number of bits (used as fallback)
81    pub default_bits: u8,
82    /// Use symmetric quantization (default)
83    pub symmetric: bool,
84    /// Per-channel quantization (default)
85    pub per_channel: bool,
86    /// Start QAT after this many steps
87    pub start_step: usize,
88    /// Freeze quantization parameters after this many steps
89    pub freeze_step: Option<usize>,
90    /// Use learned step size
91    pub learnable_step_size: bool,
92    /// Observer momentum for running statistics
93    pub observer_momentum: f32,
94    /// Layer-specific configurations
95    pub layer_configs: HashMap<String, LayerQuantConfig>,
96    /// Enable activation quantization
97    pub quantize_activations: bool,
98    /// Activation quantization bits
99    pub activation_bits: u8,
100    /// Enable mixed-bit optimization
101    pub enable_mixed_bit_optimization: bool,
102    /// Bit allocation budget for resource-constrained scenarios
103    pub bit_allocation_budget: Option<u64>,
104}
105
106impl Default for QATConfig {
107    fn default() -> Self {
108        Self {
109            qscheme: QuantizationScheme::Int8,
110            mixed_bit_strategy: MixedBitStrategy::default(),
111            default_bits: 8,
112            symmetric: true,
113            per_channel: false,
114            start_step: 1000,
115            freeze_step: None,
116            learnable_step_size: false,
117            observer_momentum: 0.99,
118            layer_configs: HashMap::new(),
119            quantize_activations: false,
120            activation_bits: 8,
121            enable_mixed_bit_optimization: false,
122            bit_allocation_budget: None,
123        }
124    }
125}
126
127impl QATConfig {
128    /// Get bits for a specific layer
129    pub fn get_layer_bits(&self, layer_name: &str, current_step: usize) -> u8 {
130        // First check layer-specific configuration
131        if let Some(layer_config) = self.layer_configs.get(layer_name) {
132            return layer_config.bits;
133        }
134
135        // Then check mixed-bit strategy
136        match &self.mixed_bit_strategy {
137            MixedBitStrategy::Uniform { bits } => *bits,
138            MixedBitStrategy::Manual { layer_bits } => {
139                // Try exact match first, then partial match
140                layer_bits
141                    .get(layer_name)
142                    .or_else(|| {
143                        // Try to match by layer type (e.g., "linear", "conv2d")
144                        layer_bits
145                            .iter()
146                            .find(|(key, _)| layer_name.contains(key.as_str()))
147                            .map(|(_, bits)| bits)
148                    })
149                    .copied()
150                    .unwrap_or(self.default_bits)
151            },
152            MixedBitStrategy::SensitivityBased {
153                sensitivity_threshold,
154                high_precision_bits,
155                low_precision_bits,
156            } => {
157                if let Some(layer_config) = self.layer_configs.get(layer_name) {
158                    if layer_config.sensitivity > *sensitivity_threshold {
159                        *high_precision_bits
160                    } else {
161                        *low_precision_bits
162                    }
163                } else {
164                    self.default_bits
165                }
166            },
167            MixedBitStrategy::ResourceConstrained {
168                critical_layers,
169                critical_bits,
170                default_bits,
171                ..
172            } => {
173                if critical_layers.iter().any(|layer| layer_name.contains(layer)) {
174                    *critical_bits
175                } else {
176                    *default_bits
177                }
178            },
179            MixedBitStrategy::Progressive {
180                initial_bits,
181                final_bits,
182                reduction_schedule,
183            } => {
184                // Find the appropriate bits based on current step
185                for (step, bits) in reduction_schedule.iter().rev() {
186                    if current_step >= *step {
187                        return *bits;
188                    }
189                }
190                // If before all scheduled reductions, use initial bits
191                if current_step < reduction_schedule.first().map(|(s, _)| *s).unwrap_or(0) {
192                    *initial_bits
193                } else {
194                    *final_bits
195                }
196            },
197        }
198    }
199
200    /// Set layer-specific configuration
201    pub fn set_layer_config(&mut self, layer_name: String, config: LayerQuantConfig) {
202        self.layer_configs.insert(layer_name, config);
203    }
204
205    /// Automatically configure layers based on sensitivity analysis
206    pub fn auto_configure_sensitivity(&mut self, layer_sensitivities: HashMap<String, f32>) {
207        let sensitivity_threshold = match &self.mixed_bit_strategy {
208            MixedBitStrategy::SensitivityBased {
209                sensitivity_threshold,
210                ..
211            } => *sensitivity_threshold,
212            _ => 0.7, // default threshold
213        };
214
215        for (layer_name, sensitivity) in layer_sensitivities {
216            let is_critical = sensitivity > sensitivity_threshold;
217            let config = LayerQuantConfig {
218                bits: if is_critical { 8 } else { 4 },
219                sensitivity,
220                is_critical,
221                ..LayerQuantConfig::default()
222            };
223            self.layer_configs.insert(layer_name, config);
224        }
225
226        // Update strategy to sensitivity-based if not already set
227        if matches!(self.mixed_bit_strategy, MixedBitStrategy::Uniform { .. }) {
228            self.mixed_bit_strategy = MixedBitStrategy::SensitivityBased {
229                sensitivity_threshold,
230                high_precision_bits: 8,
231                low_precision_bits: 4,
232            };
233        }
234    }
235
236    /// Optimize bit allocation under resource constraints
237    pub fn optimize_bit_allocation(&mut self, model_size_info: HashMap<String, u64>) -> Result<()> {
238        if let Some(budget) = self.bit_allocation_budget {
239            let _total_params: u64 = model_size_info.values().sum();
240
241            // Sort layers by importance (sensitivity * size)
242            let mut layer_importance: Vec<(String, f64)> = model_size_info
243                .iter()
244                .map(|(name, size)| {
245                    let sensitivity =
246                        self.layer_configs.get(name).map(|c| c.sensitivity as f64).unwrap_or(0.5);
247                    let importance = sensitivity * (*size as f64);
248                    (name.clone(), importance)
249                })
250                .collect();
251
252            layer_importance
253                .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
254
255            // Allocate bits greedily based on importance
256            let mut remaining_budget = budget;
257
258            for (layer_name, _) in layer_importance {
259                let layer_size = model_size_info[&layer_name];
260                let min_bits = 4; // minimum quantization
261                let max_bits = 8; // maximum practical bits
262
263                // Calculate how many bits we can afford for this layer
264                let affordable_bits = std::cmp::min(
265                    max_bits,
266                    std::cmp::max(min_bits, (remaining_budget / layer_size) as u8),
267                );
268
269                // Update layer configuration
270                let mut config = self.layer_configs.get(&layer_name).cloned().unwrap_or_default();
271                config.bits = affordable_bits;
272                self.layer_configs.insert(layer_name.clone(), config);
273
274                // Update remaining budget
275                remaining_budget =
276                    remaining_budget.saturating_sub(layer_size * affordable_bits as u64);
277            }
278
279            println!(
280                "🎯 Optimized bit allocation under budget constraint: {} bits",
281                budget
282            );
283            println!("📊 Remaining budget: {} bits", remaining_budget);
284        }
285
286        Ok(())
287    }
288
289    /// Get total bit consumption estimate
290    pub fn estimate_bit_consumption(&self, model_size_info: &HashMap<String, u64>) -> u64 {
291        model_size_info
292            .iter()
293            .map(|(layer_name, size)| {
294                let bits = self.get_layer_bits(layer_name, 0); // Use step 0 for estimation
295                size * bits as u64
296            })
297            .sum()
298    }
299
300    /// Create a configuration for common mixed-bit scenarios
301    pub fn create_common_config(scenario: &str) -> Self {
302        match scenario {
303            "edge_deployment" => Self {
304                mixed_bit_strategy: MixedBitStrategy::ResourceConstrained {
305                    total_bit_budget: 1024 * 1024, // 1MB bit budget
306                    critical_layers: vec!["attention".to_string(), "output".to_string()],
307                    critical_bits: 8,
308                    default_bits: 4,
309                },
310                quantize_activations: true,
311                activation_bits: 8,
312                enable_mixed_bit_optimization: true,
313                ..Self::default()
314            },
315            "high_accuracy" => Self {
316                mixed_bit_strategy: MixedBitStrategy::SensitivityBased {
317                    sensitivity_threshold: 0.6,
318                    high_precision_bits: 8,
319                    low_precision_bits: 6,
320                },
321                quantize_activations: false, // Keep activations in full precision
322                enable_mixed_bit_optimization: true,
323                ..Self::default()
324            },
325            "aggressive_compression" => Self {
326                mixed_bit_strategy: MixedBitStrategy::SensitivityBased {
327                    sensitivity_threshold: 0.8,
328                    high_precision_bits: 6,
329                    low_precision_bits: 3,
330                },
331                quantize_activations: true,
332                activation_bits: 4,
333                enable_mixed_bit_optimization: true,
334                ..Self::default()
335            },
336            _ => Self::default(),
337        }
338    }
339}
340
341/// Quantization parameters that can be learned
342#[derive(Debug, Clone)]
343pub struct QuantizationParams {
344    /// Scale factor for quantization
345    pub scale: Tensor,
346    /// Zero point for asymmetric quantization
347    pub zero_point: Option<Tensor>,
348    /// Running min value
349    pub running_min: Tensor,
350    /// Running max value
351    pub running_max: Tensor,
352    /// Number of observations
353    pub num_observations: usize,
354}
355
356impl QuantizationParams {
357    pub fn new(shape: &[usize], symmetric: bool) -> Self {
358        Self {
359            scale: Tensor::ones(shape).expect("Failed to create scale"),
360            zero_point: if symmetric {
361                None
362            } else {
363                Some(Tensor::zeros(shape).expect("Failed to create zero point"))
364            },
365            running_min: Tensor::full(f32::INFINITY, shape.to_vec()).expect("Failed to create min"),
366            running_max: Tensor::full(f32::NEG_INFINITY, shape.to_vec())
367                .expect("Failed to create max"),
368            num_observations: 0,
369        }
370    }
371
372    /// Update running statistics
373    pub fn update_stats(&mut self, tensor: &Tensor, momentum: f32) -> Result<()> {
374        let (current_min_val, current_max_val) = tensor.min_max()?;
375        let current_min = Tensor::scalar(current_min_val)?;
376        let current_max = Tensor::scalar(current_max_val)?;
377
378        if self.num_observations == 0 {
379            self.running_min = current_min;
380            self.running_max = current_max;
381        } else {
382            // Exponential moving average
383            self.running_min = self
384                .running_min
385                .mul_scalar(momentum)?
386                .add(&current_min.mul_scalar(1.0 - momentum)?)?;
387            self.running_max = self
388                .running_max
389                .mul_scalar(momentum)?
390                .add(&current_max.mul_scalar(1.0 - momentum)?)?;
391        }
392
393        self.num_observations += 1;
394        Ok(())
395    }
396
397    /// Compute scale and zero point from statistics
398    pub fn compute_params(&mut self, bits: u8, symmetric: bool) -> Result<()> {
399        let q_min = if symmetric { -(1 << (bits - 1)) } else { 0 } as f32;
400        let q_max = if symmetric { (1 << (bits - 1)) - 1 } else { (1 << bits) - 1 } as f32;
401
402        if symmetric {
403            let abs_running_max = self.running_max.abs()?;
404            let abs_running_min = self.running_min.abs()?;
405            let (_, max_abs_max) = abs_running_max.min_max()?;
406            let (_, max_abs_min) = abs_running_min.min_max()?;
407            // For symmetric quantization, we need the maximum of the absolute values
408            let abs_max = Tensor::scalar(max_abs_max.max(max_abs_min))?;
409            self.scale = abs_max.div_scalar(q_max)?;
410        } else {
411            let range = self.running_max.sub(&self.running_min)?;
412            self.scale = range.div_scalar(q_max - q_min)?;
413
414            if let Some(zp) = &mut self.zero_point {
415                *zp = self.running_min.div(&self.scale)?.neg()?.add_scalar(q_min)?;
416                *zp = zp.clamp(q_min, q_max)?;
417            }
418        }
419
420        Ok(())
421    }
422}
423
424/// QAT Linear layer
425pub struct QATLinear {
426    /// Original linear layer
427    linear: Arc<dyn Layer<Input = Tensor, Output = Tensor>>,
428    /// QAT configuration
429    config: QATConfig,
430    /// Quantization parameters
431    quant_params: Arc<Mutex<QuantizationParams>>,
432    /// Current training step
433    step: Arc<Mutex<usize>>,
434    /// Whether QAT is enabled
435    enabled: bool,
436}
437
438impl QATLinear {
439    pub fn new(linear: Arc<dyn Layer<Input = Tensor, Output = Tensor>>, config: QATConfig) -> Self {
440        // Initialize quantization parameters based on weight shape
441        let weight_shape = vec![1]; // Simplified - would get from linear layer
442        let quant_params = QuantizationParams::new(&weight_shape, config.symmetric);
443
444        Self {
445            linear,
446            config,
447            quant_params: Arc::new(Mutex::new(quant_params)),
448            step: Arc::new(Mutex::new(0)),
449            enabled: true,
450        }
451    }
452
453    /// Enable or disable QAT
454    pub fn set_enabled(&mut self, enabled: bool) {
455        self.enabled = enabled;
456    }
457
458    /// Get current quantization parameters
459    pub fn get_quant_params(&self) -> Arc<Mutex<QuantizationParams>> {
460        Arc::clone(&self.quant_params)
461    }
462
463    /// Extract weight tensor from the wrapped linear layer
464    fn get_layer_weights(&self) -> Result<Tensor> {
465        // Since we're working with a trait object, we simulate weight extraction
466        // In a production implementation, this would use a WeightAccessor trait
467        // or downcast to concrete layer types to extract actual weights
468
469        // Use typical transformer layer dimensions for weight simulation
470        let weight_shape = vec![768, 768]; // Standard hidden_size for many models
471
472        // Initialize with Xavier/Glorot uniform initialization for realistic weights
473        let fan_in = weight_shape[0] as f32;
474        let fan_out = weight_shape[1] as f32;
475        let limit = (6.0 / (fan_in + fan_out)).sqrt();
476
477        // Generate weight data with proper initialization distribution
478        let total_elements = weight_shape.iter().product::<usize>();
479        let weight_data: Vec<f32> = (0..total_elements)
480            .map(|_| {
481                let uniform_val = fastrand::f32(); // [0.0, 1.0)
482                (uniform_val - 0.5) * 2.0 * limit // Scale to [-limit, limit]
483            })
484            .collect();
485
486        Tensor::from_vec(weight_data, &weight_shape)
487    }
488}
489
490impl Layer for QATLinear {
491    type Input = Tensor;
492    type Output = Tensor;
493    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
494        let mut step = self.step.lock().expect("lock should not be poisoned");
495        *step += 1;
496        let current_step = *step;
497        drop(step);
498
499        // Check if QAT should be active
500        if !self.enabled || current_step < self.config.start_step {
501            // Regular forward pass without quantization
502            return self.linear.forward(input);
503        }
504
505        // Get weight tensor from the linear layer
506        let weight = self.get_layer_weights()?;
507
508        // Update statistics if not frozen
509        if self.config.freeze_step.is_none()
510            || current_step
511                < self.config.freeze_step.expect("freeze_step checked as Some in condition")
512        {
513            let mut params = self.quant_params.lock().expect("lock should not be poisoned");
514            params.update_stats(&weight, self.config.observer_momentum)?;
515            params.compute_params(self.config.default_bits, self.config.symmetric)?;
516        }
517
518        // Simulate quantization on weights
519        let params = self.quant_params.lock().expect("lock should not be poisoned");
520        let _quantized_weight = fake_quantize(
521            &weight,
522            &params.scale,
523            params.zero_point.as_ref(),
524            self.config.default_bits,
525            self.config.symmetric,
526        )?;
527        drop(params);
528
529        // Forward with quantized weights
530        // In practice, this would use the quantized weights in the linear operation
531        self.linear.forward(input)
532    }
533}
534
535/// QAT Convolution layer
536pub struct QATConv2d {
537    /// Original convolution layer
538    conv: Arc<dyn Layer<Input = Tensor, Output = Tensor>>,
539    /// QAT configuration
540    config: QATConfig,
541    /// Weight quantization parameters
542    #[allow(dead_code)]
543    weight_params: Arc<Mutex<QuantizationParams>>,
544    /// Activation quantization parameters (optional)
545    activation_params: Option<Arc<Mutex<QuantizationParams>>>,
546    /// Current training step
547    step: Arc<Mutex<usize>>,
548}
549
550impl QATConv2d {
551    pub fn new(
552        conv: Arc<dyn Layer<Input = Tensor, Output = Tensor>>,
553        config: QATConfig,
554        quantize_activations: bool,
555    ) -> Self {
556        let weight_shape = vec![1]; // Simplified
557        let weight_params = QuantizationParams::new(&weight_shape, config.symmetric);
558
559        let activation_params = if quantize_activations {
560            Some(Arc::new(Mutex::new(QuantizationParams::new(
561                &[1],
562                config.symmetric,
563            ))))
564        } else {
565            None
566        };
567
568        Self {
569            conv,
570            config,
571            weight_params: Arc::new(Mutex::new(weight_params)),
572            activation_params,
573            step: Arc::new(Mutex::new(0)),
574        }
575    }
576}
577
578impl Layer for QATConv2d {
579    type Input = Tensor;
580    type Output = Tensor;
581    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
582        let mut step = self.step.lock().expect("lock should not be poisoned");
583        *step += 1;
584        let current_step = *step;
585        drop(step);
586
587        if current_step < self.config.start_step {
588            return self.conv.forward(input);
589        }
590
591        // Quantize input activations if configured
592        let quantized_input = if let Some(act_params) = &self.activation_params {
593            if self.config.freeze_step.is_none()
594                || current_step
595                    < self.config.freeze_step.expect("freeze_step checked as Some in condition")
596            {
597                let mut params = act_params.lock().expect("lock should not be poisoned");
598                params.update_stats(&input, self.config.observer_momentum)?;
599                params.compute_params(self.config.default_bits, self.config.symmetric)?;
600            }
601
602            let params = act_params.lock().expect("lock should not be poisoned");
603            fake_quantize(
604                &input,
605                &params.scale,
606                params.zero_point.as_ref(),
607                self.config.default_bits,
608                self.config.symmetric,
609            )?
610        } else {
611            input.clone()
612        };
613
614        // Apply convolution with quantized weights (simplified)
615        self.conv.forward(quantized_input)
616    }
617}
618
619/// Fake quantize operation for QAT
620/// Activation quantizer for mixed-bit quantization
621#[derive(Debug, Clone)]
622pub struct ActivationQuantizer {
623    pub params: QuantizationParams,
624    pub bits: u8,
625    pub symmetric: bool,
626    pub calibrated: bool,
627}
628
629impl ActivationQuantizer {
630    pub fn new(shape: &[usize], bits: u8, symmetric: bool) -> Self {
631        Self {
632            params: QuantizationParams::new(shape, symmetric),
633            bits,
634            symmetric,
635            calibrated: false,
636        }
637    }
638
639    /// Calibrate the quantizer with calibration data
640    pub fn calibrate(&mut self, calibration_data: &[Tensor], momentum: f32) -> Result<()> {
641        for tensor in calibration_data {
642            self.params.update_stats(tensor, momentum)?;
643        }
644        self.params.compute_params(self.bits, self.symmetric)?;
645        self.calibrated = true;
646        Ok(())
647    }
648
649    /// Quantize activation tensor
650    pub fn quantize(&self, tensor: &Tensor) -> Result<Tensor> {
651        if !self.calibrated {
652            // If not calibrated, just pass through with warning
653            println!("⚠️ Warning: Activation quantizer not calibrated, using full precision");
654            return Ok(tensor.clone());
655        }
656
657        fake_quantize(
658            tensor,
659            &self.params.scale,
660            self.params.zero_point.as_ref(),
661            self.bits,
662            self.symmetric,
663        )
664    }
665
666    /// Update parameters during training
667    pub fn update(&mut self, tensor: &Tensor, momentum: f32) -> Result<()> {
668        self.params.update_stats(tensor, momentum)?;
669        if self.calibrated {
670            self.params.compute_params(self.bits, self.symmetric)?;
671        }
672        Ok(())
673    }
674}
675
676/// Mixed-bit fake quantization with layer-aware bit selection
677pub fn fake_quantize_mixed_bit(
678    tensor: &Tensor,
679    scale: &Tensor,
680    zero_point: Option<&Tensor>,
681    config: &QATConfig,
682    layer_name: &str,
683    current_step: usize,
684) -> Result<Tensor> {
685    let bits = config.get_layer_bits(layer_name, current_step);
686    fake_quantize(tensor, scale, zero_point, bits, config.symmetric)
687}
688
689/// Enhanced fake quantization with better numerical stability
690pub fn fake_quantize(
691    tensor: &Tensor,
692    scale: &Tensor,
693    zero_point: Option<&Tensor>,
694    bits: u8,
695    symmetric: bool,
696) -> Result<Tensor> {
697    let q_min = if symmetric { -(1 << (bits - 1)) } else { 0 } as f32;
698    let q_max = if symmetric { (1 << (bits - 1)) - 1 } else { (1 << bits) - 1 } as f32;
699
700    // Get scalar values for scale and zero_point
701    let scale_val = scale.get_float(0)?;
702    let zero_point_val = if let Some(zp) = zero_point { zp.get_float(0)? } else { 0.0 };
703
704    // Manual broadcasting: apply operations element-wise
705    let tensor_data = tensor.data()?;
706    let result_data: Vec<f32> = tensor_data
707        .iter()
708        .map(|&x| {
709            // Scale
710            let scaled = x / scale_val;
711
712            // Add zero point if asymmetric
713            let shifted = if zero_point.is_some() { scaled + zero_point_val } else { scaled };
714
715            // Round and clamp
716            let quantized = shifted.round().clamp(q_min, q_max);
717
718            // Dequantize back
719            if zero_point.is_some() {
720                (quantized - zero_point_val) * scale_val
721            } else {
722                quantized * scale_val
723            }
724        })
725        .collect();
726
727    // Straight-through estimator: forward uses dequantized values,
728    // backward passes gradients through unchanged
729    Tensor::from_vec(result_data, &tensor.shape())
730}
731
732/// QAT model wrapper
733pub struct QATModel {
734    /// Original model
735    model: Arc<dyn Layer<Input = Tensor, Output = Tensor>>,
736    /// QAT layers mapping
737    qat_layers: HashMap<String, Arc<Mutex<dyn Layer<Input = Tensor, Output = Tensor>>>>,
738    /// Global QAT configuration
739    config: QATConfig,
740}
741
742impl QATModel {
743    pub fn new(model: Arc<dyn Layer<Input = Tensor, Output = Tensor>>, config: QATConfig) -> Self {
744        Self {
745            model,
746            qat_layers: HashMap::new(),
747            config,
748        }
749    }
750
751    /// Replace a layer with QAT version
752    pub fn add_qat_layer(
753        &mut self,
754        name: String,
755        layer: Arc<Mutex<dyn Layer<Input = Tensor, Output = Tensor>>>,
756    ) {
757        self.qat_layers.insert(name, layer);
758    }
759
760    /// Prepare model for QAT
761    pub fn prepare(&mut self) -> Result<()> {
762        // In practice, this would traverse the model and replace
763        // linear/conv layers with QAT versions
764        Ok(())
765    }
766
767    /// Convert to quantized model
768    pub fn convert(&self) -> Result<QuantizedModel> {
769        // Extract learned quantization parameters and create
770        // a fully quantized model
771        let quantized_layers = HashMap::new();
772
773        Ok(QuantizedModel {
774            layers: quantized_layers,
775            config: self.config.clone(),
776        })
777    }
778
779    /// Get quantization statistics
780    pub fn get_statistics(&self) -> HashMap<String, QuantStats> {
781        let mut stats = HashMap::new();
782
783        // Collect statistics from all QAT layers
784        for name in self.qat_layers.keys() {
785            stats.insert(
786                name.clone(),
787                QuantStats {
788                    min_val: 0.0,
789                    max_val: 0.0,
790                    mean_val: 0.0,
791                    scale: 1.0,
792                },
793            );
794        }
795
796        stats
797    }
798}
799
800/// Quantized model after QAT
801#[allow(dead_code)]
802pub struct QuantizedModel {
803    #[allow(dead_code)]
804    layers: HashMap<String, QuantizedLayer>,
805    config: QATConfig,
806}
807
808/// Quantized layer representation
809#[allow(dead_code)]
810pub struct QuantizedLayer {
811    #[allow(dead_code)]
812    weights: Vec<u8>,
813    scale: Vec<f32>,
814    zero_point: Vec<i32>,
815}
816
817/// Quantization statistics
818#[derive(Debug, Clone)]
819pub struct QuantStats {
820    pub min_val: f32,
821    pub max_val: f32,
822    pub mean_val: f32,
823    pub scale: f32,
824}
825
826/// Mixed-bit QAT training manager
827pub struct MixedBitQATTrainer {
828    /// QAT configuration with mixed-bit settings
829    pub config: QATConfig,
830    /// Layer-specific quantization parameters
831    pub layer_params: HashMap<String, QuantizationParams>,
832    /// Activation quantizers for each layer
833    pub activation_quantizers: HashMap<String, ActivationQuantizer>,
834    /// Learning rate for quantization parameters
835    pub quant_lr: f32,
836    /// Weight decay for quantization parameters
837    pub quant_weight_decay: f32,
838    /// Current training step
839    pub current_step: usize,
840    /// Sensitivity analysis results
841    pub layer_sensitivities: HashMap<String, f32>,
842    /// Model size information for bit allocation
843    pub model_size_info: HashMap<String, u64>,
844}
845
846impl MixedBitQATTrainer {
847    pub fn new(config: QATConfig, quant_lr: f32, quant_weight_decay: f32) -> Self {
848        Self {
849            config,
850            layer_params: HashMap::new(),
851            activation_quantizers: HashMap::new(),
852            quant_lr,
853            quant_weight_decay,
854            current_step: 0,
855            layer_sensitivities: HashMap::new(),
856            model_size_info: HashMap::new(),
857        }
858    }
859
860    /// Initialize quantization parameters for a layer
861    pub fn init_layer(&mut self, layer_name: String, param_shape: &[usize]) -> Result<()> {
862        // Get layer-specific configuration
863        let layer_config = self.config.layer_configs.get(&layer_name).cloned().unwrap_or_default();
864
865        // Initialize weight quantization parameters
866        let params = QuantizationParams::new(param_shape, layer_config.symmetric);
867        self.layer_params.insert(layer_name.clone(), params);
868
869        // Initialize activation quantizer if enabled
870        if self.config.quantize_activations {
871            let activation_bits = if self.config.enable_mixed_bit_optimization {
872                layer_config.bits
873            } else {
874                self.config.activation_bits
875            };
876
877            let act_quantizer =
878                ActivationQuantizer::new(param_shape, activation_bits, layer_config.symmetric);
879            self.activation_quantizers.insert(layer_name.clone(), act_quantizer);
880        }
881
882        println!(
883            "🔧 Initialized mixed-bit QAT for layer: {} ({}bits)",
884            layer_name,
885            self.config.get_layer_bits(&layer_name, self.current_step)
886        );
887
888        Ok(())
889    }
890
891    /// Perform sensitivity analysis on layers
892    pub fn analyze_sensitivity(
893        &mut self,
894        model_outputs: HashMap<String, Vec<Tensor>>,
895    ) -> Result<()> {
896        println!("🔍 Performing layer sensitivity analysis for mixed-bit optimization...");
897
898        for (layer_name, outputs) in model_outputs {
899            // Calculate sensitivity based on activation variance and gradient magnitudes
900            let mut total_variance = 0.0;
901            let mut total_magnitude = 0.0;
902
903            for output in &outputs {
904                // Calculate activation variance as a proxy for sensitivity
905                let data = output.data()?;
906                let mean = data.iter().sum::<f32>() / data.len() as f32;
907                let variance =
908                    data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
909
910                total_variance += variance;
911
912                // Calculate magnitude (as proxy for importance)
913                let magnitude = data.iter().map(|x| x.abs()).sum::<f32>() / data.len() as f32;
914                total_magnitude += magnitude;
915            }
916
917            // Normalize sensitivity score
918            let avg_variance = total_variance / outputs.len() as f32;
919            let avg_magnitude = total_magnitude / outputs.len() as f32;
920            let sensitivity = (avg_variance * avg_magnitude).sqrt().min(1.0);
921
922            self.layer_sensitivities.insert(layer_name.clone(), sensitivity);
923
924            println!("📊 Layer {} sensitivity: {:.3}", layer_name, sensitivity);
925        }
926
927        // Auto-configure based on sensitivity analysis
928        self.config.auto_configure_sensitivity(self.layer_sensitivities.clone());
929
930        Ok(())
931    }
932
933    /// Update model size information for bit allocation
934    pub fn update_model_info(&mut self, model_info: HashMap<String, u64>) {
935        self.model_size_info = model_info;
936
937        // Optimize bit allocation if enabled
938        if self.config.enable_mixed_bit_optimization {
939            if let Err(e) = self.config.optimize_bit_allocation(self.model_size_info.clone()) {
940                println!("⚠️ Warning: Failed to optimize bit allocation: {}", e);
941            }
942        }
943    }
944
945    /// Quantize layer weights with mixed-bit support
946    pub fn quantize_layer_weights(&mut self, layer_name: &str, weights: &Tensor) -> Result<Tensor> {
947        // Get or initialize parameters for this layer
948        if !self.layer_params.contains_key(layer_name) {
949            self.init_layer(layer_name.to_string(), &weights.shape())?;
950        }
951
952        let params = self
953            .layer_params
954            .get_mut(layer_name)
955            .expect("layer_params entry exists after initialization check");
956
957        // Update statistics if we're in the calibration phase
958        if self.current_step < self.config.start_step {
959            params.update_stats(weights, self.config.observer_momentum)?;
960            params.compute_params(
961                self.config.get_layer_bits(layer_name, self.current_step),
962                self.config.symmetric,
963            )?;
964        }
965
966        // Apply mixed-bit fake quantization
967        fake_quantize_mixed_bit(
968            weights,
969            &params.scale,
970            params.zero_point.as_ref(),
971            &self.config,
972            layer_name,
973            self.current_step,
974        )
975    }
976
977    /// Quantize layer activations
978    pub fn quantize_layer_activations(
979        &mut self,
980        layer_name: &str,
981        activations: &Tensor,
982    ) -> Result<Tensor> {
983        if !self.config.quantize_activations {
984            return Ok(activations.clone());
985        }
986
987        if let Some(quantizer) = self.activation_quantizers.get_mut(layer_name) {
988            // Update quantizer during training
989            quantizer.update(activations, self.config.observer_momentum)?;
990            quantizer.quantize(activations)
991        } else {
992            // Initialize if not present
993            let layer_config =
994                self.config.layer_configs.get(layer_name).cloned().unwrap_or_default();
995
996            let mut quantizer = ActivationQuantizer::new(
997                &activations.shape(),
998                layer_config.bits,
999                layer_config.symmetric,
1000            );
1001
1002            quantizer.update(activations, self.config.observer_momentum)?;
1003            let result = quantizer.quantize(activations)?;
1004
1005            self.activation_quantizers.insert(layer_name.to_string(), quantizer);
1006            Ok(result)
1007        }
1008    }
1009
1010    /// Step the trainer (increment step counter and update progressive quantization)
1011    pub fn step(&mut self) {
1012        self.current_step += 1;
1013
1014        // Handle progressive quantization
1015        if let MixedBitStrategy::Progressive { .. } = &self.config.mixed_bit_strategy {
1016            // Bit widths will be automatically updated via get_layer_bits
1017            if self.current_step % 1000 == 0 {
1018                println!(
1019                    "📈 Progressive quantization step {}: updating bit allocations",
1020                    self.current_step
1021                );
1022            }
1023        }
1024
1025        // Freeze quantization parameters if specified
1026        if let Some(freeze_step) = self.config.freeze_step {
1027            if self.current_step == freeze_step {
1028                println!(
1029                    "🔒 Freezing quantization parameters at step {}",
1030                    freeze_step
1031                );
1032            }
1033        }
1034    }
1035
1036    /// Get current quantization statistics
1037    pub fn get_quantization_stats(&self) -> HashMap<String, (u8, f32)> {
1038        let mut stats = HashMap::new();
1039
1040        for layer_name in self.layer_params.keys() {
1041            let bits = self.config.get_layer_bits(layer_name, self.current_step);
1042            let sensitivity = self.layer_sensitivities.get(layer_name).copied().unwrap_or(0.0);
1043            stats.insert(layer_name.clone(), (bits, sensitivity));
1044        }
1045
1046        stats
1047    }
1048
1049    /// Estimate total memory/compute savings from mixed-bit quantization
1050    pub fn estimate_savings(&self) -> (f64, f64) {
1051        if self.model_size_info.is_empty() {
1052            return (0.0, 0.0);
1053        }
1054
1055        let total_params: u64 = self.model_size_info.values().sum();
1056        let _baseline_bits = 32.0; // fp32 baseline
1057
1058        let mut total_quantized_bits = 0u64;
1059        for (layer_name, param_count) in &self.model_size_info {
1060            let bits = self.config.get_layer_bits(layer_name, self.current_step) as u64;
1061            total_quantized_bits += param_count * bits;
1062        }
1063
1064        let baseline_total_bits = total_params * 32;
1065
1066        let memory_savings = 1.0 - (total_quantized_bits as f64) / (baseline_total_bits as f64);
1067        let compute_savings = memory_savings * 0.8; // Approximation: compute scales with memory
1068
1069        (memory_savings, compute_savings)
1070    }
1071
1072    /// Create a summary report of the mixed-bit configuration
1073    pub fn summary_report(&self) -> String {
1074        let mut report = String::from("📊 Mixed-Bit QAT Summary Report\n");
1075        report.push_str("====================================\n\n");
1076
1077        // Strategy summary
1078        report.push_str(&format!(
1079            "🎯 Strategy: {:?}\n",
1080            self.config.mixed_bit_strategy
1081        ));
1082        report.push_str(&format!(
1083            "📋 Total layers configured: {}\n",
1084            self.layer_params.len()
1085        ));
1086        report.push_str(&format!(
1087            "📈 Current training step: {}\n",
1088            self.current_step
1089        ));
1090
1091        if self.config.quantize_activations {
1092            report.push_str(&format!(
1093                "⚡ Activation quantization: {} bits\n",
1094                self.config.activation_bits
1095            ));
1096        }
1097
1098        report.push('\n');
1099
1100        // Per-layer breakdown
1101        report.push_str("🔍 Per-Layer Configuration:\n");
1102        for layer_name in self.layer_params.keys() {
1103            let bits = self.config.get_layer_bits(layer_name, self.current_step);
1104            let sensitivity = self.layer_sensitivities.get(layer_name).copied().unwrap_or(0.0);
1105            let size = self.model_size_info.get(layer_name).copied().unwrap_or(0);
1106
1107            report.push_str(&format!(
1108                "  {} | {} bits | sensitivity: {:.3} | params: {}\n",
1109                layer_name, bits, sensitivity, size
1110            ));
1111        }
1112
1113        // Savings estimate
1114        let (memory_savings, compute_savings) = self.estimate_savings();
1115        report.push('\n');
1116        report.push_str(&format!(
1117            "💾 Estimated memory savings: {:.1}%\n",
1118            memory_savings * 100.0
1119        ));
1120        report.push_str(&format!(
1121            "⚡ Estimated compute savings: {:.1}%\n",
1122            compute_savings * 100.0
1123        ));
1124
1125        // Bit consumption
1126        if !self.model_size_info.is_empty() {
1127            let total_bits = self.config.estimate_bit_consumption(&self.model_size_info);
1128            report.push_str(&format!("📊 Total bit consumption: {} bits\n", total_bits));
1129
1130            if let Some(budget) = self.config.bit_allocation_budget {
1131                let usage_pct = (total_bits as f64) / (budget as f64) * 100.0;
1132                report.push_str(&format!(
1133                    "💰 Budget usage: {:.1}% ({}/{})\n",
1134                    usage_pct, total_bits, budget
1135                ));
1136            }
1137        }
1138
1139        report
1140    }
1141}
1142
1143/// Traditional QAT training utilities (for backward compatibility)
1144pub struct QATTrainer {
1145    /// Learning rate for quantization parameters
1146    pub quant_lr: f32,
1147    /// Weight decay for quantization parameters
1148    pub quant_weight_decay: f32,
1149}
1150
1151impl QATTrainer {
1152    pub fn new(quant_lr: f32, quant_weight_decay: f32) -> Self {
1153        Self {
1154            quant_lr,
1155            quant_weight_decay,
1156        }
1157    }
1158
1159    /// Update quantization parameters with gradients
1160    pub fn update_quant_params(
1161        &self,
1162        params: &mut QuantizationParams,
1163        grads: &QuantizationGradients,
1164    ) -> Result<()> {
1165        // Update scale with gradient descent
1166        if let Some(scale_grad) = &grads.scale_grad {
1167            params.scale = params.scale.sub(&scale_grad.mul_scalar(self.quant_lr)?)?;
1168        }
1169
1170        // Update zero point if present
1171        if let (Some(zp), Some(zp_grad)) = (&mut params.zero_point, &grads.zero_point_grad) {
1172            *zp = zp.sub(&zp_grad.mul_scalar(self.quant_lr)?)?;
1173        }
1174
1175        Ok(())
1176    }
1177}
1178
1179/// Gradients for quantization parameters
1180pub struct QuantizationGradients {
1181    pub scale_grad: Option<Tensor>,
1182    pub zero_point_grad: Option<Tensor>,
1183}
1184
1185/// Calibration dataset for QAT
1186pub struct CalibrationDataset {
1187    samples: Vec<Tensor>,
1188    labels: Vec<Tensor>,
1189}
1190
1191impl CalibrationDataset {
1192    pub fn new(samples: Vec<Tensor>, labels: Vec<Tensor>) -> Self {
1193        Self { samples, labels }
1194    }
1195
1196    /// Run calibration to initialize quantization parameters
1197    pub fn calibrate(&self, model: &mut QATModel) -> Result<()> {
1198        // Disable gradient computation during calibration
1199        for (sample, _label) in self.samples.iter().zip(&self.labels) {
1200            let _ = model.model.forward(sample.clone())?;
1201        }
1202
1203        Ok(())
1204    }
1205}
1206
1207/// QAT-specific loss function that includes quantization error
1208pub fn qat_loss(
1209    predictions: &Tensor,
1210    targets: &Tensor,
1211    quant_error: f32,
1212    alpha: f32,
1213) -> Result<Tensor> {
1214    // Regular loss (e.g., cross-entropy)
1215    let task_loss = compute_task_loss(predictions, targets)?;
1216
1217    // Add quantization error penalty
1218    let total_loss = task_loss.add_scalar(alpha * quant_error)?;
1219
1220    Ok(total_loss)
1221}
1222
1223fn compute_task_loss(predictions: &Tensor, targets: &Tensor) -> Result<Tensor> {
1224    // Placeholder for actual loss computation
1225    predictions.sub(targets)?.pow(2.0)?.mean()
1226}
1227
1228#[cfg(test)]
1229mod tests {
1230    use super::*;
1231
1232    #[test]
1233    fn test_fake_quantize() {
1234        let tensor =
1235            Tensor::from_vec(vec![-1.0, 0.0, 1.0, 2.0], &[4]).expect("tensor operation failed");
1236        let scale = Tensor::from_vec(vec![0.1], &[1]).expect("tensor operation failed");
1237        let zero_point =
1238            Some(Tensor::from_vec(vec![128.0], &[1]).expect("tensor operation failed"));
1239
1240        let quantized = fake_quantize(&tensor, &scale, zero_point.as_ref(), 8, false)
1241            .expect("tensor operation failed");
1242        assert_eq!(quantized.shape(), tensor.shape());
1243    }
1244
1245    #[test]
1246    fn test_quantization_params() {
1247        let mut params = QuantizationParams::new(&[1], true);
1248
1249        let tensor1 =
1250            Tensor::from_vec(vec![-1.0, 0.0, 1.0], &[3]).expect("tensor operation failed");
1251        params.update_stats(&tensor1, 0.9).expect("tensor operation failed");
1252
1253        assert!(params.num_observations == 1);
1254        params.compute_params(8, true).expect("operation failed in test");
1255    }
1256
1257    #[test]
1258    fn test_qat_config() {
1259        let config = QATConfig::default();
1260        assert_eq!(config.default_bits, 8);
1261        assert!(config.symmetric);
1262        assert_eq!(config.start_step, 1000);
1263    }
1264}