Skip to main content

torsh_quantization/analysis/
statistical.rs

1//! Advanced statistical analysis for quantization
2
3use crate::QScheme;
4use std::collections::HashMap;
5
6/// Advanced statistical analysis for quantization
7pub struct AdvancedStatisticalAnalyzer {
8    /// Sample size for statistical tests
9    pub sample_size: usize,
10    /// Confidence level for statistical tests (default: 0.95)
11    pub confidence_level: f32,
12}
13
14impl Default for AdvancedStatisticalAnalyzer {
15    fn default() -> Self {
16        Self {
17            sample_size: 1000,
18            confidence_level: 0.95,
19        }
20    }
21}
22
23impl AdvancedStatisticalAnalyzer {
24    /// Create a new statistical analyzer
25    pub fn new(sample_size: usize, confidence_level: f32) -> Self {
26        Self {
27            sample_size,
28            confidence_level,
29        }
30    }
31
32    /// Perform statistical significance test
33    pub fn test_significance(
34        &self,
35        baseline: &[f32],
36        quantized: &[f32],
37    ) -> StatisticalSignificance {
38        let baseline_mean = Self::calculate_mean(baseline);
39        let quantized_mean = Self::calculate_mean(quantized);
40        let baseline_std = Self::calculate_std_dev(baseline, baseline_mean);
41        let quantized_std = Self::calculate_std_dev(quantized, quantized_mean);
42
43        // Simplified t-test calculation
44        let pooled_std = ((baseline_std.powi(2) + quantized_std.powi(2)) / 2.0).sqrt();
45        let t_statistic = (baseline_mean - quantized_mean)
46            / (pooled_std
47                * ((1.0 / baseline.len() as f32) + (1.0 / quantized.len() as f32)).sqrt());
48
49        let p_value = Self::calculate_p_value(t_statistic.abs());
50        let is_significant = p_value < (1.0 - self.confidence_level);
51
52        StatisticalSignificance {
53            t_statistic,
54            p_value,
55            is_significant,
56            confidence_level: self.confidence_level,
57            effect_size: (baseline_mean - quantized_mean).abs() / pooled_std,
58        }
59    }
60
61    /// Generate comprehensive statistical report
62    pub fn generate_comprehensive_report(
63        &self,
64        baseline_accuracy: &[f32],
65        quantized_accuracy: &[f32],
66        schemes: &[QScheme],
67    ) -> ComprehensiveStatisticalReport {
68        let mut scheme_analysis = HashMap::new();
69
70        for &scheme in schemes {
71            let significance = self.test_significance(baseline_accuracy, quantized_accuracy);
72            let risk_level = self.assess_risk_level(&significance);
73
74            scheme_analysis.insert(scheme, (significance, risk_level));
75        }
76
77        let overall_mean_baseline = Self::calculate_mean(baseline_accuracy);
78        let overall_mean_quantized = Self::calculate_mean(quantized_accuracy);
79        let overall_variance_baseline = Self::calculate_variance(baseline_accuracy);
80        let overall_variance_quantized = Self::calculate_variance(quantized_accuracy);
81
82        ComprehensiveStatisticalReport {
83            overall_mean_baseline,
84            overall_mean_quantized,
85            overall_variance_baseline,
86            overall_variance_quantized,
87            scheme_analysis,
88            sample_size: self.sample_size,
89            confidence_level: self.confidence_level,
90        }
91    }
92
93    /// Assess risk level based on statistical significance
94    pub fn assess_risk_level(&self, significance: &StatisticalSignificance) -> RiskLevel {
95        if !significance.is_significant {
96            RiskLevel::Low
97        } else if significance.effect_size < 0.2 {
98            RiskLevel::Low
99        } else if significance.effect_size < 0.5 {
100            RiskLevel::Medium
101        } else if significance.effect_size < 0.8 {
102            RiskLevel::High
103        } else {
104            RiskLevel::Critical
105        }
106    }
107
108    // Helper methods
109    fn calculate_mean(values: &[f32]) -> f32 {
110        values.iter().sum::<f32>() / values.len() as f32
111    }
112
113    fn calculate_std_dev(values: &[f32], mean: f32) -> f32 {
114        let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / values.len() as f32;
115        variance.sqrt()
116    }
117
118    fn calculate_variance(values: &[f32]) -> f32 {
119        let mean = Self::calculate_mean(values);
120        values.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / values.len() as f32
121    }
122
123    fn calculate_p_value(t_stat: f32) -> f32 {
124        // Simplified p-value calculation (approximation)
125        if t_stat < 1.96 {
126            0.05
127        } else if t_stat < 2.58 {
128            0.01
129        } else {
130            0.001
131        }
132    }
133}
134
135/// Statistical significance test results
136#[derive(Debug, Clone)]
137pub struct StatisticalSignificance {
138    /// T-statistic value
139    pub t_statistic: f32,
140    /// P-value of the test
141    pub p_value: f32,
142    /// Whether the difference is statistically significant
143    pub is_significant: bool,
144    /// Confidence level used for the test
145    pub confidence_level: f32,
146    /// Effect size (Cohen's d)
147    pub effect_size: f32,
148}
149
150/// Comprehensive statistical report
151#[derive(Debug, Clone)]
152pub struct ComprehensiveStatisticalReport {
153    /// Overall mean baseline accuracy
154    pub overall_mean_baseline: f32,
155    /// Overall mean quantized accuracy
156    pub overall_mean_quantized: f32,
157    /// Overall variance in baseline accuracy
158    pub overall_variance_baseline: f32,
159    /// Overall variance in quantized accuracy
160    pub overall_variance_quantized: f32,
161    /// Analysis for each quantization scheme
162    pub scheme_analysis: HashMap<QScheme, (StatisticalSignificance, RiskLevel)>,
163    /// Sample size used for analysis
164    pub sample_size: usize,
165    /// Confidence level used
166    pub confidence_level: f32,
167}
168
169/// Risk level assessment for quantization
170#[derive(Debug, Clone, Copy, PartialEq, Eq)]
171pub enum RiskLevel {
172    /// Low risk - minimal impact on accuracy
173    Low,
174    /// Medium risk - moderate impact on accuracy
175    Medium,
176    /// High risk - significant impact on accuracy
177    High,
178    /// Critical risk - severe impact on accuracy
179    Critical,
180}