torsh_quantization/analysis/
config.rs1use crate::QScheme;
4use std::collections::HashMap;
5
6#[derive(Debug, Clone)]
8pub struct AnalysisConfig {
9 pub sensitivity_threshold: f32,
11 pub fp32_threshold: f32,
13 pub aggressive_threshold: f32,
15 pub max_accuracy_drop_percent: f32,
17 pub efficiency_weights: EfficiencyWeights,
19 pub normalization_factors: NormalizationFactors,
21}
22
23#[derive(Debug, Clone)]
25pub struct EfficiencyWeights {
26 pub accuracy: f32,
28 pub size: f32,
30 pub speed: f32,
32}
33
34#[derive(Debug, Clone)]
36pub struct NormalizationFactors {
37 pub max_size_reduction: f32,
39 pub max_speed_improvement: f32,
41}
42
43impl Default for AnalysisConfig {
44 fn default() -> Self {
45 Self {
46 sensitivity_threshold: 0.05,
47 fp32_threshold: 0.05,
48 aggressive_threshold: 0.01,
49 max_accuracy_drop_percent: 5.0,
50 efficiency_weights: EfficiencyWeights::default(),
51 normalization_factors: NormalizationFactors::default(),
52 }
53 }
54}
55
56impl Default for EfficiencyWeights {
57 fn default() -> Self {
58 Self {
59 accuracy: 0.5,
60 size: 0.3,
61 speed: 0.2,
62 }
63 }
64}
65
66impl Default for NormalizationFactors {
67 fn default() -> Self {
68 Self {
69 max_size_reduction: 8.0,
70 max_speed_improvement: 10.0,
71 }
72 }
73}
74
75impl AnalysisConfig {
76 pub fn with_sensitivity_thresholds(
78 sensitivity_threshold: f32,
79 fp32_threshold: f32,
80 aggressive_threshold: f32,
81 ) -> Self {
82 Self {
83 sensitivity_threshold,
84 fp32_threshold,
85 aggressive_threshold,
86 ..Default::default()
87 }
88 }
89
90 pub fn with_efficiency_weights(accuracy: f32, size: f32, speed: f32) -> Self {
92 Self {
93 efficiency_weights: EfficiencyWeights {
94 accuracy,
95 size,
96 speed,
97 },
98 ..Default::default()
99 }
100 }
101
102 pub fn conservative() -> Self {
104 Self {
105 sensitivity_threshold: 0.02,
106 fp32_threshold: 0.02,
107 aggressive_threshold: 0.005,
108 max_accuracy_drop_percent: 2.0,
109 ..Default::default()
110 }
111 }
112
113 pub fn aggressive() -> Self {
115 Self {
116 sensitivity_threshold: 0.1,
117 fp32_threshold: 0.1,
118 aggressive_threshold: 0.05,
119 max_accuracy_drop_percent: 10.0,
120 ..Default::default()
121 }
122 }
123}
124
125#[derive(Debug, Clone)]
127pub struct LayerSensitivityResult {
128 pub layer_name: String,
130 pub original_accuracy: f32,
132 pub quantized_accuracy: f32,
134 pub sensitivity_score: f32,
136 pub recommended_scheme: QScheme,
138 pub keep_fp32: bool,
140}
141
142impl LayerSensitivityResult {
143 pub fn new(layer_name: String, original_accuracy: f32, quantized_accuracy: f32) -> Self {
145 Self::new_with_config(
146 layer_name,
147 original_accuracy,
148 quantized_accuracy,
149 &AnalysisConfig::default(),
150 )
151 }
152
153 pub fn new_with_config(
155 layer_name: String,
156 original_accuracy: f32,
157 quantized_accuracy: f32,
158 config: &AnalysisConfig,
159 ) -> Self {
160 let sensitivity_score = original_accuracy - quantized_accuracy;
161 let keep_fp32 = sensitivity_score > config.fp32_threshold;
162 let recommended_scheme = Self::determine_recommended_scheme(sensitivity_score, config);
163
164 Self {
165 layer_name,
166 original_accuracy,
167 quantized_accuracy,
168 sensitivity_score,
169 recommended_scheme,
170 keep_fp32,
171 }
172 }
173
174 fn determine_recommended_scheme(sensitivity_score: f32, config: &AnalysisConfig) -> QScheme {
176 if sensitivity_score > config.fp32_threshold {
177 QScheme::PerTensorAffine
179 } else if sensitivity_score > config.aggressive_threshold {
180 QScheme::PerChannelAffine
182 } else if sensitivity_score > config.aggressive_threshold / 2.0 {
183 QScheme::Int4PerTensor
185 } else {
186 QScheme::Int4PerChannel
188 }
189 }
190
191 pub fn accuracy_drop_percentage(&self) -> f32 {
193 (self.sensitivity_score / self.original_accuracy) * 100.0
194 }
195
196 pub fn is_high_sensitivity(&self) -> bool {
198 self.is_high_sensitivity_with_config(&AnalysisConfig::default())
199 }
200
201 pub fn is_high_sensitivity_with_config(&self, config: &AnalysisConfig) -> bool {
203 self.sensitivity_score > config.sensitivity_threshold
204 || self.accuracy_drop_percentage() > config.max_accuracy_drop_percent
205 }
206}
207
208#[derive(Debug, Clone)]
210pub struct SensitivityAnalysisResults {
211 pub layer_results: Vec<LayerSensitivityResult>,
213 pub overall_sensitivity: f32,
215 pub most_sensitive_layers: Vec<String>,
217 pub least_sensitive_layers: Vec<String>,
219 pub recommended_config: HashMap<String, QScheme>,
221}
222
223impl SensitivityAnalysisResults {
224 pub fn new(layer_results: Vec<LayerSensitivityResult>) -> Self {
226 let overall_sensitivity = if layer_results.is_empty() {
227 0.0
228 } else {
229 layer_results
230 .iter()
231 .map(|r| r.sensitivity_score)
232 .sum::<f32>()
233 / layer_results.len() as f32
234 };
235
236 let mut sorted_results = layer_results.clone();
238 sorted_results.sort_by(|a, b| {
239 b.sensitivity_score
240 .partial_cmp(&a.sensitivity_score)
241 .expect("sensitivity scores should be comparable")
242 });
243
244 let num_layers = sorted_results.len();
245 let top_10_percent = (num_layers as f32 * 0.1).ceil() as usize;
246 let bottom_10_percent = (num_layers as f32 * 0.1).ceil() as usize;
247
248 let most_sensitive_layers = sorted_results
249 .iter()
250 .take(top_10_percent)
251 .map(|r| r.layer_name.clone())
252 .collect();
253
254 let least_sensitive_layers = sorted_results
255 .iter()
256 .rev()
257 .take(bottom_10_percent)
258 .map(|r| r.layer_name.clone())
259 .collect();
260
261 let mut recommended_config = HashMap::new();
263 for result in &layer_results {
264 recommended_config.insert(result.layer_name.clone(), result.recommended_scheme);
265 }
266
267 Self {
268 layer_results,
269 overall_sensitivity,
270 most_sensitive_layers,
271 least_sensitive_layers,
272 recommended_config,
273 }
274 }
275
276 pub fn get_fp32_layers(&self) -> Vec<&String> {
278 self.layer_results
279 .iter()
280 .filter(|r| r.keep_fp32)
281 .map(|r| &r.layer_name)
282 .collect()
283 }
284
285 pub fn average_sensitivity(&self) -> f32 {
287 self.overall_sensitivity
288 }
289
290 pub fn get_aggressive_quantization_candidates(&self) -> Vec<&String> {
292 self.get_aggressive_quantization_candidates_with_config(&AnalysisConfig::default())
293 }
294
295 pub fn get_aggressive_quantization_candidates_with_config(
297 &self,
298 config: &AnalysisConfig,
299 ) -> Vec<&String> {
300 self.layer_results
301 .iter()
302 .filter(|r| r.sensitivity_score < config.aggressive_threshold)
303 .map(|r| &r.layer_name)
304 .collect()
305 }
306
307 pub fn summary_report(&self) -> String {
309 format!(
310 "Sensitivity Analysis Summary:\n\
311 - Total layers analyzed: {}\n\
312 - Average sensitivity: {:.4}\n\
313 - Most sensitive layers ({}):\n{}\n\
314 - Least sensitive layers ({}):\n{}\n\
315 - Layers recommended for FP32: {}",
316 self.layer_results.len(),
317 self.overall_sensitivity,
318 self.most_sensitive_layers.len(),
319 self.most_sensitive_layers
320 .iter()
321 .map(|name| format!(" - {}", name))
322 .collect::<Vec<_>>()
323 .join("\n"),
324 self.least_sensitive_layers.len(),
325 self.least_sensitive_layers
326 .iter()
327 .map(|name| format!(" - {}", name))
328 .collect::<Vec<_>>()
329 .join("\n"),
330 self.get_fp32_layers().len()
331 )
332 }
333}
334
335#[derive(Debug, Clone)]
337pub struct AccuracyComparison {
338 pub original_accuracy: f32,
340 pub quantized_accuracy: f32,
342 pub accuracy_drop: f32,
344 pub accuracy_drop_percentage: f32,
346 pub is_acceptable: bool,
348 pub detailed_metrics: HashMap<String, f32>,
350}
351
352impl AccuracyComparison {
353 pub fn new(original_accuracy: f32, quantized_accuracy: f32) -> Self {
355 Self::new_with_threshold(original_accuracy, quantized_accuracy, 5.0)
356 }
357
358 pub fn new_with_threshold(
360 original_accuracy: f32,
361 quantized_accuracy: f32,
362 acceptable_drop_percentage: f32,
363 ) -> Self {
364 let accuracy_drop = original_accuracy - quantized_accuracy;
365 let accuracy_drop_percentage = (accuracy_drop / original_accuracy) * 100.0;
366 let is_acceptable = accuracy_drop_percentage <= acceptable_drop_percentage;
367
368 Self {
369 original_accuracy,
370 quantized_accuracy,
371 accuracy_drop,
372 accuracy_drop_percentage,
373 is_acceptable,
374 detailed_metrics: HashMap::new(),
375 }
376 }
377
378 pub fn add_metric(&mut self, name: String, value: f32) {
380 self.detailed_metrics.insert(name, value);
381 }
382
383 pub fn efficiency_score(&self) -> f32 {
385 if self.original_accuracy == 0.0 {
386 0.0
387 } else {
388 self.quantized_accuracy / self.original_accuracy
389 }
390 }
391
392 pub fn is_quantization_recommended(&self) -> bool {
394 self.is_acceptable && self.efficiency_score() > 0.95
395 }
396
397 pub fn report(&self) -> String {
399 let mut report = format!(
400 "Accuracy Comparison Report:\n\
401 - Original Accuracy: {:.4} ({:.2}%)\n\
402 - Quantized Accuracy: {:.4} ({:.2}%)\n\
403 - Accuracy Drop: {:.4} ({:.2}%)\n\
404 - Efficiency Score: {:.4}\n\
405 - Acceptable: {}\n\
406 - Quantization Recommended: {}",
407 self.original_accuracy,
408 self.original_accuracy * 100.0,
409 self.quantized_accuracy,
410 self.quantized_accuracy * 100.0,
411 self.accuracy_drop,
412 self.accuracy_drop_percentage,
413 self.efficiency_score(),
414 if self.is_acceptable { "Yes" } else { "No" },
415 if self.is_quantization_recommended() {
416 "Yes"
417 } else {
418 "No"
419 }
420 );
421
422 if !self.detailed_metrics.is_empty() {
423 report.push_str("\n\nDetailed Metrics:");
424 for (name, value) in &self.detailed_metrics {
425 report.push_str(&format!("\n - {}: {:.4}", name, value));
426 }
427 }
428
429 report
430 }
431}