torsh_quantization/analysis/
sensitivity.rs1use crate::analysis::config::{LayerSensitivityResult, SensitivityAnalysisResults};
4use crate::{QScheme, QuantConfig, TorshResult};
5use torsh_core::TorshError;
6use torsh_tensor::Tensor;
7
8pub struct SensitivityAnalyzer {
10 #[allow(dead_code)]
12 test_data: Vec<(Tensor, Tensor)>, tolerance: f32,
15}
16
17impl SensitivityAnalyzer {
18 pub fn new(test_data: Vec<(Tensor, Tensor)>) -> Self {
20 Self {
21 test_data,
22 tolerance: 1e-6,
23 }
24 }
25
26 pub fn set_tolerance(&mut self, tolerance: f32) {
28 self.tolerance = tolerance;
29 }
30
31 pub fn analyze_layer_sensitivity(
33 &self,
34 layer_names: &[String],
35 evaluation_fn: impl Fn(&str, &QuantConfig) -> TorshResult<f32>,
36 ) -> TorshResult<SensitivityAnalysisResults> {
37 let mut layer_results = Vec::new();
38
39 let baseline_accuracy = evaluation_fn("", &QuantConfig::default())?;
41
42 for layer_name in layer_names {
43 let mut best_accuracy = 0.0;
45 let mut _best_scheme = QScheme::PerTensorAffine;
46
47 let schemes_to_test = vec![
48 QScheme::PerTensorAffine,
49 QScheme::PerTensorSymmetric,
50 QScheme::PerChannelAffine,
51 QScheme::Int4PerTensor,
52 QScheme::Binary,
53 ];
54
55 for &scheme in &schemes_to_test {
56 let config = QuantConfig::new().with_scheme(scheme);
57
58 match evaluation_fn(layer_name, &config) {
59 Ok(accuracy) => {
60 if accuracy > best_accuracy {
61 best_accuracy = accuracy;
62 _best_scheme = scheme;
63 }
64 }
65 Err(_) => {
66 continue;
68 }
69 }
70 }
71
72 let result =
73 LayerSensitivityResult::new(layer_name.clone(), baseline_accuracy, best_accuracy);
74 layer_results.push(result);
75 }
76
77 Ok(SensitivityAnalysisResults::new(layer_results))
78 }
79
80 pub fn heuristic_sensitivity_analysis(
82 &self,
83 layer_names: &[String],
84 ) -> TorshResult<SensitivityAnalysisResults> {
85 let mut layer_results = Vec::new();
86
87 for layer_name in layer_names {
88 let (sensitivity_score, _recommended_scheme) =
89 self.estimate_layer_sensitivity(layer_name);
90
91 let baseline_accuracy = 0.95; let quantized_accuracy = baseline_accuracy - sensitivity_score;
93
94 let result = LayerSensitivityResult::new(
95 layer_name.clone(),
96 baseline_accuracy,
97 quantized_accuracy,
98 );
99 layer_results.push(result);
100 }
101
102 Ok(SensitivityAnalysisResults::new(layer_results))
103 }
104
105 fn estimate_layer_sensitivity(&self, layer_name: &str) -> (f32, QScheme) {
107 let layer_name_lower = layer_name.to_lowercase();
108
109 if layer_name_lower.contains("embedding") {
111 (0.08, QScheme::PerTensorAffine) } else if layer_name_lower.contains("attention") || layer_name_lower.contains("self_attn") {
113 (0.06, QScheme::PerChannelAffine) } else if layer_name_lower.contains("output") || layer_name_lower.contains("classifier") {
115 (0.05, QScheme::PerTensorAffine) } else if layer_name_lower.contains("layer_norm") || layer_name_lower.contains("batch_norm")
117 {
118 (0.02, QScheme::Int4PerTensor) } else if layer_name_lower.contains("conv") && layer_name_lower.contains("1x1") {
120 (0.01, QScheme::Int4PerTensor) } else if layer_name_lower.contains("conv") {
122 (0.03, QScheme::PerChannelAffine) } else if layer_name_lower.contains("linear") || layer_name_lower.contains("dense") {
124 (0.025, QScheme::PerTensorAffine) } else {
126 (0.03, QScheme::PerTensorAffine) }
128 }
129
130 pub fn compare_tensor_accuracy(
132 &self,
133 original: &Tensor,
134 quantized: &Tensor,
135 ) -> TorshResult<f32> {
136 if original.shape() != quantized.shape() {
137 return Err(TorshError::InvalidArgument(
138 "Tensors must have the same shape for accuracy comparison".to_string(),
139 ));
140 }
141
142 let original_data = original.data()?;
143 let quantized_data = quantized.data()?;
144
145 let mut correct_predictions = 0;
146 let total_predictions = original_data.len();
147
148 for (orig, quant) in original_data.iter().zip(quantized_data.iter()) {
149 if (orig - quant).abs() <= self.tolerance {
150 correct_predictions += 1;
151 }
152 }
153
154 Ok(correct_predictions as f32 / total_predictions as f32)
155 }
156
157 pub fn calculate_mse(&self, original: &Tensor, quantized: &Tensor) -> TorshResult<f32> {
159 if original.shape() != quantized.shape() {
160 return Err(TorshError::InvalidArgument(
161 "Tensors must have the same shape for MSE calculation".to_string(),
162 ));
163 }
164
165 let original_data = original.data()?;
166 let quantized_data = quantized.data()?;
167
168 let mse = original_data
169 .iter()
170 .zip(quantized_data.iter())
171 .map(|(orig, quant)| (orig - quant).powi(2))
172 .sum::<f32>()
173 / original_data.len() as f32;
174
175 Ok(mse)
176 }
177
178 pub fn calculate_snr(&self, original: &Tensor, quantized: &Tensor) -> TorshResult<f32> {
180 let mse = self.calculate_mse(original, quantized)?;
181
182 if mse == 0.0 {
183 return Ok(f32::INFINITY); }
185
186 let original_data = original.data()?;
187 let signal_power =
188 original_data.iter().map(|&x| x.powi(2)).sum::<f32>() / original_data.len() as f32;
189
190 let snr_db = 10.0 * (signal_power / mse).log10();
191 Ok(snr_db)
192 }
193}