Skip to main content

torsh_quantization/analysis/
config.rs

1//! Configuration and core types for quantization analysis
2
3use crate::QScheme;
4use std::collections::HashMap;
5
6/// Configuration parameters for sensitivity analysis
7#[derive(Debug, Clone)]
8pub struct AnalysisConfig {
9    /// Threshold for considering a layer as high sensitivity (default: 0.05)
10    pub sensitivity_threshold: f32,
11    /// Threshold for keeping layers in FP32 (default: 0.05)
12    pub fp32_threshold: f32,
13    /// Threshold for aggressive quantization candidates (default: 0.01)
14    pub aggressive_threshold: f32,
15    /// Maximum acceptable accuracy drop percentage (default: 5.0%)
16    pub max_accuracy_drop_percent: f32,
17    /// Weights for efficiency score calculation
18    pub efficiency_weights: EfficiencyWeights,
19    /// Normalization factors for efficiency score
20    pub normalization_factors: NormalizationFactors,
21}
22
23/// Weights for efficiency score calculation
24#[derive(Debug, Clone)]
25pub struct EfficiencyWeights {
26    /// Weight for accuracy in efficiency score (default: 0.5)
27    pub accuracy: f32,
28    /// Weight for size reduction in efficiency score (default: 0.3)
29    pub size: f32,
30    /// Weight for speed improvement in efficiency score (default: 0.2)
31    pub speed: f32,
32}
33
34/// Normalization factors for efficiency score
35#[derive(Debug, Clone)]
36pub struct NormalizationFactors {
37    /// Maximum expected size reduction factor (default: 8.0)
38    pub max_size_reduction: f32,
39    /// Maximum expected speed improvement factor (default: 10.0)
40    pub max_speed_improvement: f32,
41}
42
43impl Default for AnalysisConfig {
44    fn default() -> Self {
45        Self {
46            sensitivity_threshold: 0.05,
47            fp32_threshold: 0.05,
48            aggressive_threshold: 0.01,
49            max_accuracy_drop_percent: 5.0,
50            efficiency_weights: EfficiencyWeights::default(),
51            normalization_factors: NormalizationFactors::default(),
52        }
53    }
54}
55
56impl Default for EfficiencyWeights {
57    fn default() -> Self {
58        Self {
59            accuracy: 0.5,
60            size: 0.3,
61            speed: 0.2,
62        }
63    }
64}
65
66impl Default for NormalizationFactors {
67    fn default() -> Self {
68        Self {
69            max_size_reduction: 8.0,
70            max_speed_improvement: 10.0,
71        }
72    }
73}
74
75impl AnalysisConfig {
76    /// Create a new analysis configuration with custom sensitivity thresholds
77    pub fn with_sensitivity_thresholds(
78        sensitivity_threshold: f32,
79        fp32_threshold: f32,
80        aggressive_threshold: f32,
81    ) -> Self {
82        Self {
83            sensitivity_threshold,
84            fp32_threshold,
85            aggressive_threshold,
86            ..Default::default()
87        }
88    }
89
90    /// Create a new analysis configuration with custom efficiency weights
91    pub fn with_efficiency_weights(accuracy: f32, size: f32, speed: f32) -> Self {
92        Self {
93            efficiency_weights: EfficiencyWeights {
94                accuracy,
95                size,
96                speed,
97            },
98            ..Default::default()
99        }
100    }
101
102    /// Create a conservative analysis configuration (higher thresholds)
103    pub fn conservative() -> Self {
104        Self {
105            sensitivity_threshold: 0.02,
106            fp32_threshold: 0.02,
107            aggressive_threshold: 0.005,
108            max_accuracy_drop_percent: 2.0,
109            ..Default::default()
110        }
111    }
112
113    /// Create an aggressive analysis configuration (lower thresholds)
114    pub fn aggressive() -> Self {
115        Self {
116            sensitivity_threshold: 0.1,
117            fp32_threshold: 0.1,
118            aggressive_threshold: 0.05,
119            max_accuracy_drop_percent: 10.0,
120            ..Default::default()
121        }
122    }
123}
124
125/// Results of sensitivity analysis for a single layer
126#[derive(Debug, Clone)]
127pub struct LayerSensitivityResult {
128    /// Layer name or identifier
129    pub layer_name: String,
130    /// Original accuracy (before quantization)
131    pub original_accuracy: f32,
132    /// Accuracy after quantizing this layer
133    pub quantized_accuracy: f32,
134    /// Sensitivity score (accuracy drop)
135    pub sensitivity_score: f32,
136    /// Recommended quantization scheme for this layer
137    pub recommended_scheme: QScheme,
138    /// Whether this layer should be kept in full precision
139    pub keep_fp32: bool,
140}
141
142impl LayerSensitivityResult {
143    /// Create a new sensitivity result
144    pub fn new(layer_name: String, original_accuracy: f32, quantized_accuracy: f32) -> Self {
145        Self::new_with_config(
146            layer_name,
147            original_accuracy,
148            quantized_accuracy,
149            &AnalysisConfig::default(),
150        )
151    }
152
153    /// Create a new sensitivity result with custom analysis configuration
154    pub fn new_with_config(
155        layer_name: String,
156        original_accuracy: f32,
157        quantized_accuracy: f32,
158        config: &AnalysisConfig,
159    ) -> Self {
160        let sensitivity_score = original_accuracy - quantized_accuracy;
161        let keep_fp32 = sensitivity_score > config.fp32_threshold;
162        let recommended_scheme = Self::determine_recommended_scheme(sensitivity_score, config);
163
164        Self {
165            layer_name,
166            original_accuracy,
167            quantized_accuracy,
168            sensitivity_score,
169            recommended_scheme,
170            keep_fp32,
171        }
172    }
173
174    /// Determine the recommended quantization scheme based on sensitivity and configuration
175    fn determine_recommended_scheme(sensitivity_score: f32, config: &AnalysisConfig) -> QScheme {
176        if sensitivity_score > config.fp32_threshold {
177            // High sensitivity - use conservative quantization
178            QScheme::PerTensorAffine
179        } else if sensitivity_score > config.aggressive_threshold {
180            // Medium sensitivity - use per-channel for better accuracy
181            QScheme::PerChannelAffine
182        } else if sensitivity_score > config.aggressive_threshold / 2.0 {
183            // Low sensitivity - can use INT4
184            QScheme::Int4PerTensor
185        } else {
186            // Very low sensitivity - can use aggressive quantization
187            QScheme::Int4PerChannel
188        }
189    }
190
191    /// Get the accuracy drop percentage
192    pub fn accuracy_drop_percentage(&self) -> f32 {
193        (self.sensitivity_score / self.original_accuracy) * 100.0
194    }
195
196    /// Check if this layer is highly sensitive to quantization
197    pub fn is_high_sensitivity(&self) -> bool {
198        self.is_high_sensitivity_with_config(&AnalysisConfig::default())
199    }
200
201    /// Check if this layer is highly sensitive to quantization with custom config
202    pub fn is_high_sensitivity_with_config(&self, config: &AnalysisConfig) -> bool {
203        self.sensitivity_score > config.sensitivity_threshold
204            || self.accuracy_drop_percentage() > config.max_accuracy_drop_percent
205    }
206}
207
208/// Comprehensive sensitivity analysis results
209#[derive(Debug, Clone)]
210pub struct SensitivityAnalysisResults {
211    /// Results for individual layers
212    pub layer_results: Vec<LayerSensitivityResult>,
213    /// Overall model sensitivity summary
214    pub overall_sensitivity: f32,
215    /// Most sensitive layers (top 10% by sensitivity score)
216    pub most_sensitive_layers: Vec<String>,
217    /// Least sensitive layers (suitable for aggressive quantization)
218    pub least_sensitive_layers: Vec<String>,
219    /// Recommended mixed precision configuration
220    pub recommended_config: HashMap<String, QScheme>,
221}
222
223impl SensitivityAnalysisResults {
224    /// Create a new sensitivity analysis results
225    pub fn new(layer_results: Vec<LayerSensitivityResult>) -> Self {
226        let overall_sensitivity = if layer_results.is_empty() {
227            0.0
228        } else {
229            layer_results
230                .iter()
231                .map(|r| r.sensitivity_score)
232                .sum::<f32>()
233                / layer_results.len() as f32
234        };
235
236        // Sort layers by sensitivity
237        let mut sorted_results = layer_results.clone();
238        sorted_results.sort_by(|a, b| {
239            b.sensitivity_score
240                .partial_cmp(&a.sensitivity_score)
241                .expect("sensitivity scores should be comparable")
242        });
243
244        let num_layers = sorted_results.len();
245        let top_10_percent = (num_layers as f32 * 0.1).ceil() as usize;
246        let bottom_10_percent = (num_layers as f32 * 0.1).ceil() as usize;
247
248        let most_sensitive_layers = sorted_results
249            .iter()
250            .take(top_10_percent)
251            .map(|r| r.layer_name.clone())
252            .collect();
253
254        let least_sensitive_layers = sorted_results
255            .iter()
256            .rev()
257            .take(bottom_10_percent)
258            .map(|r| r.layer_name.clone())
259            .collect();
260
261        // Generate recommended configuration
262        let mut recommended_config = HashMap::new();
263        for result in &layer_results {
264            recommended_config.insert(result.layer_name.clone(), result.recommended_scheme);
265        }
266
267        Self {
268            layer_results,
269            overall_sensitivity,
270            most_sensitive_layers,
271            least_sensitive_layers,
272            recommended_config,
273        }
274    }
275
276    /// Get layers that should be kept in FP32
277    pub fn get_fp32_layers(&self) -> Vec<&String> {
278        self.layer_results
279            .iter()
280            .filter(|r| r.keep_fp32)
281            .map(|r| &r.layer_name)
282            .collect()
283    }
284
285    /// Get the average sensitivity score
286    pub fn average_sensitivity(&self) -> f32 {
287        self.overall_sensitivity
288    }
289
290    /// Get layers suitable for aggressive quantization (INT4 or lower)
291    pub fn get_aggressive_quantization_candidates(&self) -> Vec<&String> {
292        self.get_aggressive_quantization_candidates_with_config(&AnalysisConfig::default())
293    }
294
295    /// Get layers suitable for aggressive quantization with custom config
296    pub fn get_aggressive_quantization_candidates_with_config(
297        &self,
298        config: &AnalysisConfig,
299    ) -> Vec<&String> {
300        self.layer_results
301            .iter()
302            .filter(|r| r.sensitivity_score < config.aggressive_threshold)
303            .map(|r| &r.layer_name)
304            .collect()
305    }
306
307    /// Generate a summary report
308    pub fn summary_report(&self) -> String {
309        format!(
310            "Sensitivity Analysis Summary:\n\
311             - Total layers analyzed: {}\n\
312             - Average sensitivity: {:.4}\n\
313             - Most sensitive layers ({}):\n{}\n\
314             - Least sensitive layers ({}):\n{}\n\
315             - Layers recommended for FP32: {}",
316            self.layer_results.len(),
317            self.overall_sensitivity,
318            self.most_sensitive_layers.len(),
319            self.most_sensitive_layers
320                .iter()
321                .map(|name| format!("  - {}", name))
322                .collect::<Vec<_>>()
323                .join("\n"),
324            self.least_sensitive_layers.len(),
325            self.least_sensitive_layers
326                .iter()
327                .map(|name| format!("  - {}", name))
328                .collect::<Vec<_>>()
329                .join("\n"),
330            self.get_fp32_layers().len()
331        )
332    }
333}
334
335/// Accuracy comparison between quantized and original models
336#[derive(Debug, Clone)]
337pub struct AccuracyComparison {
338    /// Original model accuracy
339    pub original_accuracy: f32,
340    /// Quantized model accuracy
341    pub quantized_accuracy: f32,
342    /// Accuracy drop (original - quantized)
343    pub accuracy_drop: f32,
344    /// Accuracy drop as percentage
345    pub accuracy_drop_percentage: f32,
346    /// Whether the accuracy drop is acceptable
347    pub is_acceptable: bool,
348    /// Additional metrics for detailed comparison
349    pub detailed_metrics: HashMap<String, f32>,
350}
351
352impl AccuracyComparison {
353    /// Create a new accuracy comparison
354    pub fn new(original_accuracy: f32, quantized_accuracy: f32) -> Self {
355        Self::new_with_threshold(original_accuracy, quantized_accuracy, 5.0)
356    }
357
358    /// Create a new accuracy comparison with custom acceptable threshold
359    pub fn new_with_threshold(
360        original_accuracy: f32,
361        quantized_accuracy: f32,
362        acceptable_drop_percentage: f32,
363    ) -> Self {
364        let accuracy_drop = original_accuracy - quantized_accuracy;
365        let accuracy_drop_percentage = (accuracy_drop / original_accuracy) * 100.0;
366        let is_acceptable = accuracy_drop_percentage <= acceptable_drop_percentage;
367
368        Self {
369            original_accuracy,
370            quantized_accuracy,
371            accuracy_drop,
372            accuracy_drop_percentage,
373            is_acceptable,
374            detailed_metrics: HashMap::new(),
375        }
376    }
377
378    /// Add a detailed metric for comparison
379    pub fn add_metric(&mut self, name: String, value: f32) {
380        self.detailed_metrics.insert(name, value);
381    }
382
383    /// Get the efficiency score based on accuracy preservation
384    pub fn efficiency_score(&self) -> f32 {
385        if self.original_accuracy == 0.0 {
386            0.0
387        } else {
388            self.quantized_accuracy / self.original_accuracy
389        }
390    }
391
392    /// Check if quantization is recommended based on accuracy
393    pub fn is_quantization_recommended(&self) -> bool {
394        self.is_acceptable && self.efficiency_score() > 0.95
395    }
396
397    /// Generate a comparison report
398    pub fn report(&self) -> String {
399        let mut report = format!(
400            "Accuracy Comparison Report:\n\
401             - Original Accuracy: {:.4} ({:.2}%)\n\
402             - Quantized Accuracy: {:.4} ({:.2}%)\n\
403             - Accuracy Drop: {:.4} ({:.2}%)\n\
404             - Efficiency Score: {:.4}\n\
405             - Acceptable: {}\n\
406             - Quantization Recommended: {}",
407            self.original_accuracy,
408            self.original_accuracy * 100.0,
409            self.quantized_accuracy,
410            self.quantized_accuracy * 100.0,
411            self.accuracy_drop,
412            self.accuracy_drop_percentage,
413            self.efficiency_score(),
414            if self.is_acceptable { "Yes" } else { "No" },
415            if self.is_quantization_recommended() {
416                "Yes"
417            } else {
418                "No"
419            }
420        );
421
422        if !self.detailed_metrics.is_empty() {
423            report.push_str("\n\nDetailed Metrics:");
424            for (name, value) in &self.detailed_metrics {
425                report.push_str(&format!("\n  - {}: {:.4}", name, value));
426            }
427        }
428
429        report
430    }
431}