1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::{Arc, Mutex};
9use trustformers_core::errors::Result;
10use trustformers_core::{Layer, QuantizationScheme, Tensor};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub enum MixedBitStrategy {
15 Uniform { bits: u8 },
17 Manual { layer_bits: HashMap<String, u8> },
19 SensitivityBased {
21 sensitivity_threshold: f32,
22 high_precision_bits: u8,
23 low_precision_bits: u8,
24 },
25 ResourceConstrained {
27 total_bit_budget: u64,
28 critical_layers: Vec<String>,
29 critical_bits: u8,
30 default_bits: u8,
31 },
32 Progressive {
34 initial_bits: u8,
35 final_bits: u8,
36 reduction_schedule: Vec<(usize, u8)>, },
38}
39
40impl Default for MixedBitStrategy {
41 fn default() -> Self {
42 Self::Uniform { bits: 8 }
43 }
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct LayerQuantConfig {
49 pub bits: u8,
51 pub symmetric: bool,
53 pub per_channel: bool,
55 pub sensitivity: f32,
57 pub is_critical: bool,
59}
60
61impl Default for LayerQuantConfig {
62 fn default() -> Self {
63 Self {
64 bits: 8,
65 symmetric: true,
66 per_channel: false,
67 sensitivity: 0.5,
68 is_critical: false,
69 }
70 }
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct QATConfig {
76 pub qscheme: QuantizationScheme,
78 pub mixed_bit_strategy: MixedBitStrategy,
80 pub default_bits: u8,
82 pub symmetric: bool,
84 pub per_channel: bool,
86 pub start_step: usize,
88 pub freeze_step: Option<usize>,
90 pub learnable_step_size: bool,
92 pub observer_momentum: f32,
94 pub layer_configs: HashMap<String, LayerQuantConfig>,
96 pub quantize_activations: bool,
98 pub activation_bits: u8,
100 pub enable_mixed_bit_optimization: bool,
102 pub bit_allocation_budget: Option<u64>,
104}
105
106impl Default for QATConfig {
107 fn default() -> Self {
108 Self {
109 qscheme: QuantizationScheme::Int8,
110 mixed_bit_strategy: MixedBitStrategy::default(),
111 default_bits: 8,
112 symmetric: true,
113 per_channel: false,
114 start_step: 1000,
115 freeze_step: None,
116 learnable_step_size: false,
117 observer_momentum: 0.99,
118 layer_configs: HashMap::new(),
119 quantize_activations: false,
120 activation_bits: 8,
121 enable_mixed_bit_optimization: false,
122 bit_allocation_budget: None,
123 }
124 }
125}
126
127impl QATConfig {
128 pub fn get_layer_bits(&self, layer_name: &str, current_step: usize) -> u8 {
130 if let Some(layer_config) = self.layer_configs.get(layer_name) {
132 return layer_config.bits;
133 }
134
135 match &self.mixed_bit_strategy {
137 MixedBitStrategy::Uniform { bits } => *bits,
138 MixedBitStrategy::Manual { layer_bits } => {
139 layer_bits
141 .get(layer_name)
142 .or_else(|| {
143 layer_bits
145 .iter()
146 .find(|(key, _)| layer_name.contains(key.as_str()))
147 .map(|(_, bits)| bits)
148 })
149 .copied()
150 .unwrap_or(self.default_bits)
151 },
152 MixedBitStrategy::SensitivityBased {
153 sensitivity_threshold,
154 high_precision_bits,
155 low_precision_bits,
156 } => {
157 if let Some(layer_config) = self.layer_configs.get(layer_name) {
158 if layer_config.sensitivity > *sensitivity_threshold {
159 *high_precision_bits
160 } else {
161 *low_precision_bits
162 }
163 } else {
164 self.default_bits
165 }
166 },
167 MixedBitStrategy::ResourceConstrained {
168 critical_layers,
169 critical_bits,
170 default_bits,
171 ..
172 } => {
173 if critical_layers.iter().any(|layer| layer_name.contains(layer)) {
174 *critical_bits
175 } else {
176 *default_bits
177 }
178 },
179 MixedBitStrategy::Progressive {
180 initial_bits,
181 final_bits,
182 reduction_schedule,
183 } => {
184 for (step, bits) in reduction_schedule.iter().rev() {
186 if current_step >= *step {
187 return *bits;
188 }
189 }
190 if current_step < reduction_schedule.first().map(|(s, _)| *s).unwrap_or(0) {
192 *initial_bits
193 } else {
194 *final_bits
195 }
196 },
197 }
198 }
199
200 pub fn set_layer_config(&mut self, layer_name: String, config: LayerQuantConfig) {
202 self.layer_configs.insert(layer_name, config);
203 }
204
205 pub fn auto_configure_sensitivity(&mut self, layer_sensitivities: HashMap<String, f32>) {
207 let sensitivity_threshold = match &self.mixed_bit_strategy {
208 MixedBitStrategy::SensitivityBased {
209 sensitivity_threshold,
210 ..
211 } => *sensitivity_threshold,
212 _ => 0.7, };
214
215 for (layer_name, sensitivity) in layer_sensitivities {
216 let is_critical = sensitivity > sensitivity_threshold;
217 let config = LayerQuantConfig {
218 bits: if is_critical { 8 } else { 4 },
219 sensitivity,
220 is_critical,
221 ..LayerQuantConfig::default()
222 };
223 self.layer_configs.insert(layer_name, config);
224 }
225
226 if matches!(self.mixed_bit_strategy, MixedBitStrategy::Uniform { .. }) {
228 self.mixed_bit_strategy = MixedBitStrategy::SensitivityBased {
229 sensitivity_threshold,
230 high_precision_bits: 8,
231 low_precision_bits: 4,
232 };
233 }
234 }
235
236 pub fn optimize_bit_allocation(&mut self, model_size_info: HashMap<String, u64>) -> Result<()> {
238 if let Some(budget) = self.bit_allocation_budget {
239 let _total_params: u64 = model_size_info.values().sum();
240
241 let mut layer_importance: Vec<(String, f64)> = model_size_info
243 .iter()
244 .map(|(name, size)| {
245 let sensitivity =
246 self.layer_configs.get(name).map(|c| c.sensitivity as f64).unwrap_or(0.5);
247 let importance = sensitivity * (*size as f64);
248 (name.clone(), importance)
249 })
250 .collect();
251
252 layer_importance
253 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
254
255 let mut remaining_budget = budget;
257
258 for (layer_name, _) in layer_importance {
259 let layer_size = model_size_info[&layer_name];
260 let min_bits = 4; let max_bits = 8; let affordable_bits = std::cmp::min(
265 max_bits,
266 std::cmp::max(min_bits, (remaining_budget / layer_size) as u8),
267 );
268
269 let mut config = self.layer_configs.get(&layer_name).cloned().unwrap_or_default();
271 config.bits = affordable_bits;
272 self.layer_configs.insert(layer_name.clone(), config);
273
274 remaining_budget =
276 remaining_budget.saturating_sub(layer_size * affordable_bits as u64);
277 }
278
279 println!(
280 "🎯 Optimized bit allocation under budget constraint: {} bits",
281 budget
282 );
283 println!("📊 Remaining budget: {} bits", remaining_budget);
284 }
285
286 Ok(())
287 }
288
289 pub fn estimate_bit_consumption(&self, model_size_info: &HashMap<String, u64>) -> u64 {
291 model_size_info
292 .iter()
293 .map(|(layer_name, size)| {
294 let bits = self.get_layer_bits(layer_name, 0); size * bits as u64
296 })
297 .sum()
298 }
299
300 pub fn create_common_config(scenario: &str) -> Self {
302 match scenario {
303 "edge_deployment" => Self {
304 mixed_bit_strategy: MixedBitStrategy::ResourceConstrained {
305 total_bit_budget: 1024 * 1024, critical_layers: vec!["attention".to_string(), "output".to_string()],
307 critical_bits: 8,
308 default_bits: 4,
309 },
310 quantize_activations: true,
311 activation_bits: 8,
312 enable_mixed_bit_optimization: true,
313 ..Self::default()
314 },
315 "high_accuracy" => Self {
316 mixed_bit_strategy: MixedBitStrategy::SensitivityBased {
317 sensitivity_threshold: 0.6,
318 high_precision_bits: 8,
319 low_precision_bits: 6,
320 },
321 quantize_activations: false, enable_mixed_bit_optimization: true,
323 ..Self::default()
324 },
325 "aggressive_compression" => Self {
326 mixed_bit_strategy: MixedBitStrategy::SensitivityBased {
327 sensitivity_threshold: 0.8,
328 high_precision_bits: 6,
329 low_precision_bits: 3,
330 },
331 quantize_activations: true,
332 activation_bits: 4,
333 enable_mixed_bit_optimization: true,
334 ..Self::default()
335 },
336 _ => Self::default(),
337 }
338 }
339}
340
341#[derive(Debug, Clone)]
343pub struct QuantizationParams {
344 pub scale: Tensor,
346 pub zero_point: Option<Tensor>,
348 pub running_min: Tensor,
350 pub running_max: Tensor,
352 pub num_observations: usize,
354}
355
356impl QuantizationParams {
357 pub fn new(shape: &[usize], symmetric: bool) -> Self {
358 Self {
359 scale: Tensor::ones(shape).expect("Failed to create scale"),
360 zero_point: if symmetric {
361 None
362 } else {
363 Some(Tensor::zeros(shape).expect("Failed to create zero point"))
364 },
365 running_min: Tensor::full(f32::INFINITY, shape.to_vec()).expect("Failed to create min"),
366 running_max: Tensor::full(f32::NEG_INFINITY, shape.to_vec())
367 .expect("Failed to create max"),
368 num_observations: 0,
369 }
370 }
371
372 pub fn update_stats(&mut self, tensor: &Tensor, momentum: f32) -> Result<()> {
374 let (current_min_val, current_max_val) = tensor.min_max()?;
375 let current_min = Tensor::scalar(current_min_val)?;
376 let current_max = Tensor::scalar(current_max_val)?;
377
378 if self.num_observations == 0 {
379 self.running_min = current_min;
380 self.running_max = current_max;
381 } else {
382 self.running_min = self
384 .running_min
385 .mul_scalar(momentum)?
386 .add(¤t_min.mul_scalar(1.0 - momentum)?)?;
387 self.running_max = self
388 .running_max
389 .mul_scalar(momentum)?
390 .add(¤t_max.mul_scalar(1.0 - momentum)?)?;
391 }
392
393 self.num_observations += 1;
394 Ok(())
395 }
396
397 pub fn compute_params(&mut self, bits: u8, symmetric: bool) -> Result<()> {
399 let q_min = if symmetric { -(1 << (bits - 1)) } else { 0 } as f32;
400 let q_max = if symmetric { (1 << (bits - 1)) - 1 } else { (1 << bits) - 1 } as f32;
401
402 if symmetric {
403 let abs_running_max = self.running_max.abs()?;
404 let abs_running_min = self.running_min.abs()?;
405 let (_, max_abs_max) = abs_running_max.min_max()?;
406 let (_, max_abs_min) = abs_running_min.min_max()?;
407 let abs_max = Tensor::scalar(max_abs_max.max(max_abs_min))?;
409 self.scale = abs_max.div_scalar(q_max)?;
410 } else {
411 let range = self.running_max.sub(&self.running_min)?;
412 self.scale = range.div_scalar(q_max - q_min)?;
413
414 if let Some(zp) = &mut self.zero_point {
415 *zp = self.running_min.div(&self.scale)?.neg()?.add_scalar(q_min)?;
416 *zp = zp.clamp(q_min, q_max)?;
417 }
418 }
419
420 Ok(())
421 }
422}
423
424pub struct QATLinear {
426 linear: Arc<dyn Layer<Input = Tensor, Output = Tensor>>,
428 config: QATConfig,
430 quant_params: Arc<Mutex<QuantizationParams>>,
432 step: Arc<Mutex<usize>>,
434 enabled: bool,
436}
437
438impl QATLinear {
439 pub fn new(linear: Arc<dyn Layer<Input = Tensor, Output = Tensor>>, config: QATConfig) -> Self {
440 let weight_shape = vec![1]; let quant_params = QuantizationParams::new(&weight_shape, config.symmetric);
443
444 Self {
445 linear,
446 config,
447 quant_params: Arc::new(Mutex::new(quant_params)),
448 step: Arc::new(Mutex::new(0)),
449 enabled: true,
450 }
451 }
452
453 pub fn set_enabled(&mut self, enabled: bool) {
455 self.enabled = enabled;
456 }
457
458 pub fn get_quant_params(&self) -> Arc<Mutex<QuantizationParams>> {
460 Arc::clone(&self.quant_params)
461 }
462
463 fn get_layer_weights(&self) -> Result<Tensor> {
465 let weight_shape = vec![768, 768]; let fan_in = weight_shape[0] as f32;
474 let fan_out = weight_shape[1] as f32;
475 let limit = (6.0 / (fan_in + fan_out)).sqrt();
476
477 let total_elements = weight_shape.iter().product::<usize>();
479 let weight_data: Vec<f32> = (0..total_elements)
480 .map(|_| {
481 let uniform_val = fastrand::f32(); (uniform_val - 0.5) * 2.0 * limit })
484 .collect();
485
486 Tensor::from_vec(weight_data, &weight_shape)
487 }
488}
489
490impl Layer for QATLinear {
491 type Input = Tensor;
492 type Output = Tensor;
493 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
494 let mut step = self.step.lock().expect("lock should not be poisoned");
495 *step += 1;
496 let current_step = *step;
497 drop(step);
498
499 if !self.enabled || current_step < self.config.start_step {
501 return self.linear.forward(input);
503 }
504
505 let weight = self.get_layer_weights()?;
507
508 if self.config.freeze_step.is_none()
510 || current_step
511 < self.config.freeze_step.expect("freeze_step checked as Some in condition")
512 {
513 let mut params = self.quant_params.lock().expect("lock should not be poisoned");
514 params.update_stats(&weight, self.config.observer_momentum)?;
515 params.compute_params(self.config.default_bits, self.config.symmetric)?;
516 }
517
518 let params = self.quant_params.lock().expect("lock should not be poisoned");
520 let _quantized_weight = fake_quantize(
521 &weight,
522 ¶ms.scale,
523 params.zero_point.as_ref(),
524 self.config.default_bits,
525 self.config.symmetric,
526 )?;
527 drop(params);
528
529 self.linear.forward(input)
532 }
533}
534
535pub struct QATConv2d {
537 conv: Arc<dyn Layer<Input = Tensor, Output = Tensor>>,
539 config: QATConfig,
541 #[allow(dead_code)]
543 weight_params: Arc<Mutex<QuantizationParams>>,
544 activation_params: Option<Arc<Mutex<QuantizationParams>>>,
546 step: Arc<Mutex<usize>>,
548}
549
550impl QATConv2d {
551 pub fn new(
552 conv: Arc<dyn Layer<Input = Tensor, Output = Tensor>>,
553 config: QATConfig,
554 quantize_activations: bool,
555 ) -> Self {
556 let weight_shape = vec![1]; let weight_params = QuantizationParams::new(&weight_shape, config.symmetric);
558
559 let activation_params = if quantize_activations {
560 Some(Arc::new(Mutex::new(QuantizationParams::new(
561 &[1],
562 config.symmetric,
563 ))))
564 } else {
565 None
566 };
567
568 Self {
569 conv,
570 config,
571 weight_params: Arc::new(Mutex::new(weight_params)),
572 activation_params,
573 step: Arc::new(Mutex::new(0)),
574 }
575 }
576}
577
578impl Layer for QATConv2d {
579 type Input = Tensor;
580 type Output = Tensor;
581 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
582 let mut step = self.step.lock().expect("lock should not be poisoned");
583 *step += 1;
584 let current_step = *step;
585 drop(step);
586
587 if current_step < self.config.start_step {
588 return self.conv.forward(input);
589 }
590
591 let quantized_input = if let Some(act_params) = &self.activation_params {
593 if self.config.freeze_step.is_none()
594 || current_step
595 < self.config.freeze_step.expect("freeze_step checked as Some in condition")
596 {
597 let mut params = act_params.lock().expect("lock should not be poisoned");
598 params.update_stats(&input, self.config.observer_momentum)?;
599 params.compute_params(self.config.default_bits, self.config.symmetric)?;
600 }
601
602 let params = act_params.lock().expect("lock should not be poisoned");
603 fake_quantize(
604 &input,
605 ¶ms.scale,
606 params.zero_point.as_ref(),
607 self.config.default_bits,
608 self.config.symmetric,
609 )?
610 } else {
611 input.clone()
612 };
613
614 self.conv.forward(quantized_input)
616 }
617}
618
619#[derive(Debug, Clone)]
622pub struct ActivationQuantizer {
623 pub params: QuantizationParams,
624 pub bits: u8,
625 pub symmetric: bool,
626 pub calibrated: bool,
627}
628
629impl ActivationQuantizer {
630 pub fn new(shape: &[usize], bits: u8, symmetric: bool) -> Self {
631 Self {
632 params: QuantizationParams::new(shape, symmetric),
633 bits,
634 symmetric,
635 calibrated: false,
636 }
637 }
638
639 pub fn calibrate(&mut self, calibration_data: &[Tensor], momentum: f32) -> Result<()> {
641 for tensor in calibration_data {
642 self.params.update_stats(tensor, momentum)?;
643 }
644 self.params.compute_params(self.bits, self.symmetric)?;
645 self.calibrated = true;
646 Ok(())
647 }
648
649 pub fn quantize(&self, tensor: &Tensor) -> Result<Tensor> {
651 if !self.calibrated {
652 println!("⚠️ Warning: Activation quantizer not calibrated, using full precision");
654 return Ok(tensor.clone());
655 }
656
657 fake_quantize(
658 tensor,
659 &self.params.scale,
660 self.params.zero_point.as_ref(),
661 self.bits,
662 self.symmetric,
663 )
664 }
665
666 pub fn update(&mut self, tensor: &Tensor, momentum: f32) -> Result<()> {
668 self.params.update_stats(tensor, momentum)?;
669 if self.calibrated {
670 self.params.compute_params(self.bits, self.symmetric)?;
671 }
672 Ok(())
673 }
674}
675
676pub fn fake_quantize_mixed_bit(
678 tensor: &Tensor,
679 scale: &Tensor,
680 zero_point: Option<&Tensor>,
681 config: &QATConfig,
682 layer_name: &str,
683 current_step: usize,
684) -> Result<Tensor> {
685 let bits = config.get_layer_bits(layer_name, current_step);
686 fake_quantize(tensor, scale, zero_point, bits, config.symmetric)
687}
688
689pub fn fake_quantize(
691 tensor: &Tensor,
692 scale: &Tensor,
693 zero_point: Option<&Tensor>,
694 bits: u8,
695 symmetric: bool,
696) -> Result<Tensor> {
697 let q_min = if symmetric { -(1 << (bits - 1)) } else { 0 } as f32;
698 let q_max = if symmetric { (1 << (bits - 1)) - 1 } else { (1 << bits) - 1 } as f32;
699
700 let scale_val = scale.get_float(0)?;
702 let zero_point_val = if let Some(zp) = zero_point { zp.get_float(0)? } else { 0.0 };
703
704 let tensor_data = tensor.data()?;
706 let result_data: Vec<f32> = tensor_data
707 .iter()
708 .map(|&x| {
709 let scaled = x / scale_val;
711
712 let shifted = if zero_point.is_some() { scaled + zero_point_val } else { scaled };
714
715 let quantized = shifted.round().clamp(q_min, q_max);
717
718 if zero_point.is_some() {
720 (quantized - zero_point_val) * scale_val
721 } else {
722 quantized * scale_val
723 }
724 })
725 .collect();
726
727 Tensor::from_vec(result_data, &tensor.shape())
730}
731
732pub struct QATModel {
734 model: Arc<dyn Layer<Input = Tensor, Output = Tensor>>,
736 qat_layers: HashMap<String, Arc<Mutex<dyn Layer<Input = Tensor, Output = Tensor>>>>,
738 config: QATConfig,
740}
741
742impl QATModel {
743 pub fn new(model: Arc<dyn Layer<Input = Tensor, Output = Tensor>>, config: QATConfig) -> Self {
744 Self {
745 model,
746 qat_layers: HashMap::new(),
747 config,
748 }
749 }
750
751 pub fn add_qat_layer(
753 &mut self,
754 name: String,
755 layer: Arc<Mutex<dyn Layer<Input = Tensor, Output = Tensor>>>,
756 ) {
757 self.qat_layers.insert(name, layer);
758 }
759
760 pub fn prepare(&mut self) -> Result<()> {
762 Ok(())
765 }
766
767 pub fn convert(&self) -> Result<QuantizedModel> {
769 let quantized_layers = HashMap::new();
772
773 Ok(QuantizedModel {
774 layers: quantized_layers,
775 config: self.config.clone(),
776 })
777 }
778
779 pub fn get_statistics(&self) -> HashMap<String, QuantStats> {
781 let mut stats = HashMap::new();
782
783 for name in self.qat_layers.keys() {
785 stats.insert(
786 name.clone(),
787 QuantStats {
788 min_val: 0.0,
789 max_val: 0.0,
790 mean_val: 0.0,
791 scale: 1.0,
792 },
793 );
794 }
795
796 stats
797 }
798}
799
800#[allow(dead_code)]
802pub struct QuantizedModel {
803 #[allow(dead_code)]
804 layers: HashMap<String, QuantizedLayer>,
805 config: QATConfig,
806}
807
808#[allow(dead_code)]
810pub struct QuantizedLayer {
811 #[allow(dead_code)]
812 weights: Vec<u8>,
813 scale: Vec<f32>,
814 zero_point: Vec<i32>,
815}
816
817#[derive(Debug, Clone)]
819pub struct QuantStats {
820 pub min_val: f32,
821 pub max_val: f32,
822 pub mean_val: f32,
823 pub scale: f32,
824}
825
826pub struct MixedBitQATTrainer {
828 pub config: QATConfig,
830 pub layer_params: HashMap<String, QuantizationParams>,
832 pub activation_quantizers: HashMap<String, ActivationQuantizer>,
834 pub quant_lr: f32,
836 pub quant_weight_decay: f32,
838 pub current_step: usize,
840 pub layer_sensitivities: HashMap<String, f32>,
842 pub model_size_info: HashMap<String, u64>,
844}
845
846impl MixedBitQATTrainer {
847 pub fn new(config: QATConfig, quant_lr: f32, quant_weight_decay: f32) -> Self {
848 Self {
849 config,
850 layer_params: HashMap::new(),
851 activation_quantizers: HashMap::new(),
852 quant_lr,
853 quant_weight_decay,
854 current_step: 0,
855 layer_sensitivities: HashMap::new(),
856 model_size_info: HashMap::new(),
857 }
858 }
859
860 pub fn init_layer(&mut self, layer_name: String, param_shape: &[usize]) -> Result<()> {
862 let layer_config = self.config.layer_configs.get(&layer_name).cloned().unwrap_or_default();
864
865 let params = QuantizationParams::new(param_shape, layer_config.symmetric);
867 self.layer_params.insert(layer_name.clone(), params);
868
869 if self.config.quantize_activations {
871 let activation_bits = if self.config.enable_mixed_bit_optimization {
872 layer_config.bits
873 } else {
874 self.config.activation_bits
875 };
876
877 let act_quantizer =
878 ActivationQuantizer::new(param_shape, activation_bits, layer_config.symmetric);
879 self.activation_quantizers.insert(layer_name.clone(), act_quantizer);
880 }
881
882 println!(
883 "🔧 Initialized mixed-bit QAT for layer: {} ({}bits)",
884 layer_name,
885 self.config.get_layer_bits(&layer_name, self.current_step)
886 );
887
888 Ok(())
889 }
890
891 pub fn analyze_sensitivity(
893 &mut self,
894 model_outputs: HashMap<String, Vec<Tensor>>,
895 ) -> Result<()> {
896 println!("🔍 Performing layer sensitivity analysis for mixed-bit optimization...");
897
898 for (layer_name, outputs) in model_outputs {
899 let mut total_variance = 0.0;
901 let mut total_magnitude = 0.0;
902
903 for output in &outputs {
904 let data = output.data()?;
906 let mean = data.iter().sum::<f32>() / data.len() as f32;
907 let variance =
908 data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
909
910 total_variance += variance;
911
912 let magnitude = data.iter().map(|x| x.abs()).sum::<f32>() / data.len() as f32;
914 total_magnitude += magnitude;
915 }
916
917 let avg_variance = total_variance / outputs.len() as f32;
919 let avg_magnitude = total_magnitude / outputs.len() as f32;
920 let sensitivity = (avg_variance * avg_magnitude).sqrt().min(1.0);
921
922 self.layer_sensitivities.insert(layer_name.clone(), sensitivity);
923
924 println!("📊 Layer {} sensitivity: {:.3}", layer_name, sensitivity);
925 }
926
927 self.config.auto_configure_sensitivity(self.layer_sensitivities.clone());
929
930 Ok(())
931 }
932
933 pub fn update_model_info(&mut self, model_info: HashMap<String, u64>) {
935 self.model_size_info = model_info;
936
937 if self.config.enable_mixed_bit_optimization {
939 if let Err(e) = self.config.optimize_bit_allocation(self.model_size_info.clone()) {
940 println!("⚠️ Warning: Failed to optimize bit allocation: {}", e);
941 }
942 }
943 }
944
945 pub fn quantize_layer_weights(&mut self, layer_name: &str, weights: &Tensor) -> Result<Tensor> {
947 if !self.layer_params.contains_key(layer_name) {
949 self.init_layer(layer_name.to_string(), &weights.shape())?;
950 }
951
952 let params = self
953 .layer_params
954 .get_mut(layer_name)
955 .expect("layer_params entry exists after initialization check");
956
957 if self.current_step < self.config.start_step {
959 params.update_stats(weights, self.config.observer_momentum)?;
960 params.compute_params(
961 self.config.get_layer_bits(layer_name, self.current_step),
962 self.config.symmetric,
963 )?;
964 }
965
966 fake_quantize_mixed_bit(
968 weights,
969 ¶ms.scale,
970 params.zero_point.as_ref(),
971 &self.config,
972 layer_name,
973 self.current_step,
974 )
975 }
976
977 pub fn quantize_layer_activations(
979 &mut self,
980 layer_name: &str,
981 activations: &Tensor,
982 ) -> Result<Tensor> {
983 if !self.config.quantize_activations {
984 return Ok(activations.clone());
985 }
986
987 if let Some(quantizer) = self.activation_quantizers.get_mut(layer_name) {
988 quantizer.update(activations, self.config.observer_momentum)?;
990 quantizer.quantize(activations)
991 } else {
992 let layer_config =
994 self.config.layer_configs.get(layer_name).cloned().unwrap_or_default();
995
996 let mut quantizer = ActivationQuantizer::new(
997 &activations.shape(),
998 layer_config.bits,
999 layer_config.symmetric,
1000 );
1001
1002 quantizer.update(activations, self.config.observer_momentum)?;
1003 let result = quantizer.quantize(activations)?;
1004
1005 self.activation_quantizers.insert(layer_name.to_string(), quantizer);
1006 Ok(result)
1007 }
1008 }
1009
1010 pub fn step(&mut self) {
1012 self.current_step += 1;
1013
1014 if let MixedBitStrategy::Progressive { .. } = &self.config.mixed_bit_strategy {
1016 if self.current_step % 1000 == 0 {
1018 println!(
1019 "📈 Progressive quantization step {}: updating bit allocations",
1020 self.current_step
1021 );
1022 }
1023 }
1024
1025 if let Some(freeze_step) = self.config.freeze_step {
1027 if self.current_step == freeze_step {
1028 println!(
1029 "🔒 Freezing quantization parameters at step {}",
1030 freeze_step
1031 );
1032 }
1033 }
1034 }
1035
1036 pub fn get_quantization_stats(&self) -> HashMap<String, (u8, f32)> {
1038 let mut stats = HashMap::new();
1039
1040 for layer_name in self.layer_params.keys() {
1041 let bits = self.config.get_layer_bits(layer_name, self.current_step);
1042 let sensitivity = self.layer_sensitivities.get(layer_name).copied().unwrap_or(0.0);
1043 stats.insert(layer_name.clone(), (bits, sensitivity));
1044 }
1045
1046 stats
1047 }
1048
1049 pub fn estimate_savings(&self) -> (f64, f64) {
1051 if self.model_size_info.is_empty() {
1052 return (0.0, 0.0);
1053 }
1054
1055 let total_params: u64 = self.model_size_info.values().sum();
1056 let _baseline_bits = 32.0; let mut total_quantized_bits = 0u64;
1059 for (layer_name, param_count) in &self.model_size_info {
1060 let bits = self.config.get_layer_bits(layer_name, self.current_step) as u64;
1061 total_quantized_bits += param_count * bits;
1062 }
1063
1064 let baseline_total_bits = total_params * 32;
1065
1066 let memory_savings = 1.0 - (total_quantized_bits as f64) / (baseline_total_bits as f64);
1067 let compute_savings = memory_savings * 0.8; (memory_savings, compute_savings)
1070 }
1071
1072 pub fn summary_report(&self) -> String {
1074 let mut report = String::from("📊 Mixed-Bit QAT Summary Report\n");
1075 report.push_str("====================================\n\n");
1076
1077 report.push_str(&format!(
1079 "🎯 Strategy: {:?}\n",
1080 self.config.mixed_bit_strategy
1081 ));
1082 report.push_str(&format!(
1083 "📋 Total layers configured: {}\n",
1084 self.layer_params.len()
1085 ));
1086 report.push_str(&format!(
1087 "📈 Current training step: {}\n",
1088 self.current_step
1089 ));
1090
1091 if self.config.quantize_activations {
1092 report.push_str(&format!(
1093 "⚡ Activation quantization: {} bits\n",
1094 self.config.activation_bits
1095 ));
1096 }
1097
1098 report.push('\n');
1099
1100 report.push_str("🔍 Per-Layer Configuration:\n");
1102 for layer_name in self.layer_params.keys() {
1103 let bits = self.config.get_layer_bits(layer_name, self.current_step);
1104 let sensitivity = self.layer_sensitivities.get(layer_name).copied().unwrap_or(0.0);
1105 let size = self.model_size_info.get(layer_name).copied().unwrap_or(0);
1106
1107 report.push_str(&format!(
1108 " {} | {} bits | sensitivity: {:.3} | params: {}\n",
1109 layer_name, bits, sensitivity, size
1110 ));
1111 }
1112
1113 let (memory_savings, compute_savings) = self.estimate_savings();
1115 report.push('\n');
1116 report.push_str(&format!(
1117 "💾 Estimated memory savings: {:.1}%\n",
1118 memory_savings * 100.0
1119 ));
1120 report.push_str(&format!(
1121 "⚡ Estimated compute savings: {:.1}%\n",
1122 compute_savings * 100.0
1123 ));
1124
1125 if !self.model_size_info.is_empty() {
1127 let total_bits = self.config.estimate_bit_consumption(&self.model_size_info);
1128 report.push_str(&format!("📊 Total bit consumption: {} bits\n", total_bits));
1129
1130 if let Some(budget) = self.config.bit_allocation_budget {
1131 let usage_pct = (total_bits as f64) / (budget as f64) * 100.0;
1132 report.push_str(&format!(
1133 "💰 Budget usage: {:.1}% ({}/{})\n",
1134 usage_pct, total_bits, budget
1135 ));
1136 }
1137 }
1138
1139 report
1140 }
1141}
1142
1143pub struct QATTrainer {
1145 pub quant_lr: f32,
1147 pub quant_weight_decay: f32,
1149}
1150
1151impl QATTrainer {
1152 pub fn new(quant_lr: f32, quant_weight_decay: f32) -> Self {
1153 Self {
1154 quant_lr,
1155 quant_weight_decay,
1156 }
1157 }
1158
1159 pub fn update_quant_params(
1161 &self,
1162 params: &mut QuantizationParams,
1163 grads: &QuantizationGradients,
1164 ) -> Result<()> {
1165 if let Some(scale_grad) = &grads.scale_grad {
1167 params.scale = params.scale.sub(&scale_grad.mul_scalar(self.quant_lr)?)?;
1168 }
1169
1170 if let (Some(zp), Some(zp_grad)) = (&mut params.zero_point, &grads.zero_point_grad) {
1172 *zp = zp.sub(&zp_grad.mul_scalar(self.quant_lr)?)?;
1173 }
1174
1175 Ok(())
1176 }
1177}
1178
1179pub struct QuantizationGradients {
1181 pub scale_grad: Option<Tensor>,
1182 pub zero_point_grad: Option<Tensor>,
1183}
1184
1185pub struct CalibrationDataset {
1187 samples: Vec<Tensor>,
1188 labels: Vec<Tensor>,
1189}
1190
1191impl CalibrationDataset {
1192 pub fn new(samples: Vec<Tensor>, labels: Vec<Tensor>) -> Self {
1193 Self { samples, labels }
1194 }
1195
1196 pub fn calibrate(&self, model: &mut QATModel) -> Result<()> {
1198 for (sample, _label) in self.samples.iter().zip(&self.labels) {
1200 let _ = model.model.forward(sample.clone())?;
1201 }
1202
1203 Ok(())
1204 }
1205}
1206
1207pub fn qat_loss(
1209 predictions: &Tensor,
1210 targets: &Tensor,
1211 quant_error: f32,
1212 alpha: f32,
1213) -> Result<Tensor> {
1214 let task_loss = compute_task_loss(predictions, targets)?;
1216
1217 let total_loss = task_loss.add_scalar(alpha * quant_error)?;
1219
1220 Ok(total_loss)
1221}
1222
1223fn compute_task_loss(predictions: &Tensor, targets: &Tensor) -> Result<Tensor> {
1224 predictions.sub(targets)?.pow(2.0)?.mean()
1226}
1227
1228#[cfg(test)]
1229mod tests {
1230 use super::*;
1231
1232 #[test]
1233 fn test_fake_quantize() {
1234 let tensor =
1235 Tensor::from_vec(vec![-1.0, 0.0, 1.0, 2.0], &[4]).expect("tensor operation failed");
1236 let scale = Tensor::from_vec(vec![0.1], &[1]).expect("tensor operation failed");
1237 let zero_point =
1238 Some(Tensor::from_vec(vec![128.0], &[1]).expect("tensor operation failed"));
1239
1240 let quantized = fake_quantize(&tensor, &scale, zero_point.as_ref(), 8, false)
1241 .expect("tensor operation failed");
1242 assert_eq!(quantized.shape(), tensor.shape());
1243 }
1244
1245 #[test]
1246 fn test_quantization_params() {
1247 let mut params = QuantizationParams::new(&[1], true);
1248
1249 let tensor1 =
1250 Tensor::from_vec(vec![-1.0, 0.0, 1.0], &[3]).expect("tensor operation failed");
1251 params.update_stats(&tensor1, 0.9).expect("tensor operation failed");
1252
1253 assert!(params.num_observations == 1);
1254 params.compute_params(8, true).expect("operation failed in test");
1255 }
1256
1257 #[test]
1258 fn test_qat_config() {
1259 let config = QATConfig::default();
1260 assert_eq!(config.default_bits, 8);
1261 assert!(config.symmetric);
1262 assert_eq!(config.start_step, 1000);
1263 }
1264}