Skip to main content

torsh_quantization/analysis/
sensitivity.rs

1//! Sensitivity analysis for quantization
2
3use crate::analysis::config::{LayerSensitivityResult, SensitivityAnalysisResults};
4use crate::{QScheme, QuantConfig, TorshResult};
5use torsh_core::TorshError;
6use torsh_tensor::Tensor;
7
8/// Quantization sensitivity analyzer
9pub struct SensitivityAnalyzer {
10    /// Test dataset for evaluation
11    #[allow(dead_code)]
12    test_data: Vec<(Tensor, Tensor)>, // (input, expected_output) pairs
13    /// Tolerance for accuracy comparison
14    tolerance: f32,
15}
16
17impl SensitivityAnalyzer {
18    /// Create a new sensitivity analyzer
19    pub fn new(test_data: Vec<(Tensor, Tensor)>) -> Self {
20        Self {
21            test_data,
22            tolerance: 1e-6,
23        }
24    }
25
26    /// Set tolerance for accuracy comparison
27    pub fn set_tolerance(&mut self, tolerance: f32) {
28        self.tolerance = tolerance;
29    }
30
31    /// Perform sensitivity analysis on a model's layers
32    pub fn analyze_layer_sensitivity(
33        &self,
34        layer_names: &[String],
35        evaluation_fn: impl Fn(&str, &QuantConfig) -> TorshResult<f32>,
36    ) -> TorshResult<SensitivityAnalysisResults> {
37        let mut layer_results = Vec::new();
38
39        // Get baseline accuracy (no quantization)
40        let baseline_accuracy = evaluation_fn("", &QuantConfig::default())?;
41
42        for layer_name in layer_names {
43            // Test different quantization schemes for this layer
44            let mut best_accuracy = 0.0;
45            let mut _best_scheme = QScheme::PerTensorAffine;
46
47            let schemes_to_test = vec![
48                QScheme::PerTensorAffine,
49                QScheme::PerTensorSymmetric,
50                QScheme::PerChannelAffine,
51                QScheme::Int4PerTensor,
52                QScheme::Binary,
53            ];
54
55            for &scheme in &schemes_to_test {
56                let config = QuantConfig::new().with_scheme(scheme);
57
58                match evaluation_fn(layer_name, &config) {
59                    Ok(accuracy) => {
60                        if accuracy > best_accuracy {
61                            best_accuracy = accuracy;
62                            _best_scheme = scheme;
63                        }
64                    }
65                    Err(_) => {
66                        // Skip this scheme if it fails
67                        continue;
68                    }
69                }
70            }
71
72            let result =
73                LayerSensitivityResult::new(layer_name.clone(), baseline_accuracy, best_accuracy);
74            layer_results.push(result);
75        }
76
77        Ok(SensitivityAnalysisResults::new(layer_results))
78    }
79
80    /// Perform heuristic sensitivity analysis based on layer types
81    pub fn heuristic_sensitivity_analysis(
82        &self,
83        layer_names: &[String],
84    ) -> TorshResult<SensitivityAnalysisResults> {
85        let mut layer_results = Vec::new();
86
87        for layer_name in layer_names {
88            let (sensitivity_score, _recommended_scheme) =
89                self.estimate_layer_sensitivity(layer_name);
90
91            let baseline_accuracy = 0.95; // Assumed baseline
92            let quantized_accuracy = baseline_accuracy - sensitivity_score;
93
94            let result = LayerSensitivityResult::new(
95                layer_name.clone(),
96                baseline_accuracy,
97                quantized_accuracy,
98            );
99            layer_results.push(result);
100        }
101
102        Ok(SensitivityAnalysisResults::new(layer_results))
103    }
104
105    /// Estimate layer sensitivity based on layer type and name patterns
106    fn estimate_layer_sensitivity(&self, layer_name: &str) -> (f32, QScheme) {
107        let layer_name_lower = layer_name.to_lowercase();
108
109        // Different layer types have different sensitivity levels
110        if layer_name_lower.contains("embedding") {
111            (0.08, QScheme::PerTensorAffine) // High sensitivity
112        } else if layer_name_lower.contains("attention") || layer_name_lower.contains("self_attn") {
113            (0.06, QScheme::PerChannelAffine) // Medium-high sensitivity
114        } else if layer_name_lower.contains("output") || layer_name_lower.contains("classifier") {
115            (0.05, QScheme::PerTensorAffine) // Medium sensitivity
116        } else if layer_name_lower.contains("layer_norm") || layer_name_lower.contains("batch_norm")
117        {
118            (0.02, QScheme::Int4PerTensor) // Low sensitivity
119        } else if layer_name_lower.contains("conv") && layer_name_lower.contains("1x1") {
120            (0.01, QScheme::Int4PerTensor) // Very low sensitivity
121        } else if layer_name_lower.contains("conv") {
122            (0.03, QScheme::PerChannelAffine) // Low-medium sensitivity
123        } else if layer_name_lower.contains("linear") || layer_name_lower.contains("dense") {
124            (0.025, QScheme::PerTensorAffine) // Low sensitivity
125        } else {
126            (0.03, QScheme::PerTensorAffine) // Default medium sensitivity
127        }
128    }
129
130    /// Compare accuracy between original and quantized tensors
131    pub fn compare_tensor_accuracy(
132        &self,
133        original: &Tensor,
134        quantized: &Tensor,
135    ) -> TorshResult<f32> {
136        if original.shape() != quantized.shape() {
137            return Err(TorshError::InvalidArgument(
138                "Tensors must have the same shape for accuracy comparison".to_string(),
139            ));
140        }
141
142        let original_data = original.data()?;
143        let quantized_data = quantized.data()?;
144
145        let mut correct_predictions = 0;
146        let total_predictions = original_data.len();
147
148        for (orig, quant) in original_data.iter().zip(quantized_data.iter()) {
149            if (orig - quant).abs() <= self.tolerance {
150                correct_predictions += 1;
151            }
152        }
153
154        Ok(correct_predictions as f32 / total_predictions as f32)
155    }
156
157    /// Calculate Mean Squared Error between tensors
158    pub fn calculate_mse(&self, original: &Tensor, quantized: &Tensor) -> TorshResult<f32> {
159        if original.shape() != quantized.shape() {
160            return Err(TorshError::InvalidArgument(
161                "Tensors must have the same shape for MSE calculation".to_string(),
162            ));
163        }
164
165        let original_data = original.data()?;
166        let quantized_data = quantized.data()?;
167
168        let mse = original_data
169            .iter()
170            .zip(quantized_data.iter())
171            .map(|(orig, quant)| (orig - quant).powi(2))
172            .sum::<f32>()
173            / original_data.len() as f32;
174
175        Ok(mse)
176    }
177
178    /// Calculate Signal-to-Noise Ratio
179    pub fn calculate_snr(&self, original: &Tensor, quantized: &Tensor) -> TorshResult<f32> {
180        let mse = self.calculate_mse(original, quantized)?;
181
182        if mse == 0.0 {
183            return Ok(f32::INFINITY); // Perfect reconstruction
184        }
185
186        let original_data = original.data()?;
187        let signal_power =
188            original_data.iter().map(|&x| x.powi(2)).sum::<f32>() / original_data.len() as f32;
189
190        let snr_db = 10.0 * (signal_power / mse).log10();
191        Ok(snr_db)
192    }
193}