torsh_quantization/analysis/
statistical.rs1use crate::QScheme;
4use std::collections::HashMap;
5
6pub struct AdvancedStatisticalAnalyzer {
8 pub sample_size: usize,
10 pub confidence_level: f32,
12}
13
14impl Default for AdvancedStatisticalAnalyzer {
15 fn default() -> Self {
16 Self {
17 sample_size: 1000,
18 confidence_level: 0.95,
19 }
20 }
21}
22
23impl AdvancedStatisticalAnalyzer {
24 pub fn new(sample_size: usize, confidence_level: f32) -> Self {
26 Self {
27 sample_size,
28 confidence_level,
29 }
30 }
31
32 pub fn test_significance(
34 &self,
35 baseline: &[f32],
36 quantized: &[f32],
37 ) -> StatisticalSignificance {
38 let baseline_mean = Self::calculate_mean(baseline);
39 let quantized_mean = Self::calculate_mean(quantized);
40 let baseline_std = Self::calculate_std_dev(baseline, baseline_mean);
41 let quantized_std = Self::calculate_std_dev(quantized, quantized_mean);
42
43 let pooled_std = ((baseline_std.powi(2) + quantized_std.powi(2)) / 2.0).sqrt();
45 let t_statistic = (baseline_mean - quantized_mean)
46 / (pooled_std
47 * ((1.0 / baseline.len() as f32) + (1.0 / quantized.len() as f32)).sqrt());
48
49 let p_value = Self::calculate_p_value(t_statistic.abs());
50 let is_significant = p_value < (1.0 - self.confidence_level);
51
52 StatisticalSignificance {
53 t_statistic,
54 p_value,
55 is_significant,
56 confidence_level: self.confidence_level,
57 effect_size: (baseline_mean - quantized_mean).abs() / pooled_std,
58 }
59 }
60
61 pub fn generate_comprehensive_report(
63 &self,
64 baseline_accuracy: &[f32],
65 quantized_accuracy: &[f32],
66 schemes: &[QScheme],
67 ) -> ComprehensiveStatisticalReport {
68 let mut scheme_analysis = HashMap::new();
69
70 for &scheme in schemes {
71 let significance = self.test_significance(baseline_accuracy, quantized_accuracy);
72 let risk_level = self.assess_risk_level(&significance);
73
74 scheme_analysis.insert(scheme, (significance, risk_level));
75 }
76
77 let overall_mean_baseline = Self::calculate_mean(baseline_accuracy);
78 let overall_mean_quantized = Self::calculate_mean(quantized_accuracy);
79 let overall_variance_baseline = Self::calculate_variance(baseline_accuracy);
80 let overall_variance_quantized = Self::calculate_variance(quantized_accuracy);
81
82 ComprehensiveStatisticalReport {
83 overall_mean_baseline,
84 overall_mean_quantized,
85 overall_variance_baseline,
86 overall_variance_quantized,
87 scheme_analysis,
88 sample_size: self.sample_size,
89 confidence_level: self.confidence_level,
90 }
91 }
92
93 pub fn assess_risk_level(&self, significance: &StatisticalSignificance) -> RiskLevel {
95 if !significance.is_significant {
96 RiskLevel::Low
97 } else if significance.effect_size < 0.2 {
98 RiskLevel::Low
99 } else if significance.effect_size < 0.5 {
100 RiskLevel::Medium
101 } else if significance.effect_size < 0.8 {
102 RiskLevel::High
103 } else {
104 RiskLevel::Critical
105 }
106 }
107
108 fn calculate_mean(values: &[f32]) -> f32 {
110 values.iter().sum::<f32>() / values.len() as f32
111 }
112
113 fn calculate_std_dev(values: &[f32], mean: f32) -> f32 {
114 let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / values.len() as f32;
115 variance.sqrt()
116 }
117
118 fn calculate_variance(values: &[f32]) -> f32 {
119 let mean = Self::calculate_mean(values);
120 values.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / values.len() as f32
121 }
122
123 fn calculate_p_value(t_stat: f32) -> f32 {
124 if t_stat < 1.96 {
126 0.05
127 } else if t_stat < 2.58 {
128 0.01
129 } else {
130 0.001
131 }
132 }
133}
134
135#[derive(Debug, Clone)]
137pub struct StatisticalSignificance {
138 pub t_statistic: f32,
140 pub p_value: f32,
142 pub is_significant: bool,
144 pub confidence_level: f32,
146 pub effect_size: f32,
148}
149
150#[derive(Debug, Clone)]
152pub struct ComprehensiveStatisticalReport {
153 pub overall_mean_baseline: f32,
155 pub overall_mean_quantized: f32,
157 pub overall_variance_baseline: f32,
159 pub overall_variance_quantized: f32,
161 pub scheme_analysis: HashMap<QScheme, (StatisticalSignificance, RiskLevel)>,
163 pub sample_size: usize,
165 pub confidence_level: f32,
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq)]
171pub enum RiskLevel {
172 Low,
174 Medium,
176 High,
178 Critical,
180}