torsh_quantization/analysis/
visualization.rs1use crate::analysis::config::SensitivityAnalysisResults;
4use crate::analysis::size::SizeAnalyzer;
5use crate::QScheme;
6use crate::TorshResult;
7use std::collections::HashMap;
8use torsh_core::TorshError;
9use torsh_tensor::Tensor;
10
11pub struct VisualizationTool;
13
14impl VisualizationTool {
15 pub fn render_sensitivity_bar_chart(
17 results: &SensitivityAnalysisResults,
18 width: usize,
19 ) -> String {
20 let mut chart = String::new();
21 chart.push_str("Layer Sensitivity Analysis\n");
22 chart.push_str(&"=".repeat(width));
23 chart.push('\n');
24
25 let mut sorted_results = results.layer_results.clone();
27 sorted_results.sort_by(|a, b| {
28 b.sensitivity_score
29 .partial_cmp(&a.sensitivity_score)
30 .expect("sensitivity scores should be comparable")
31 });
32
33 let max_sensitivity = sorted_results
34 .first()
35 .map(|r| r.sensitivity_score)
36 .unwrap_or(1.0);
37
38 for result in &sorted_results {
39 let bar_length =
40 ((result.sensitivity_score / max_sensitivity) * (width - 20) as f32) as usize;
41 let bar = "█".repeat(bar_length);
42
43 chart.push_str(&format!(
44 "{:15} |{:<width$}| {:.3}\n",
45 Self::truncate_string(&result.layer_name, 15),
46 bar,
47 result.sensitivity_score,
48 width = width - 20
49 ));
50 }
51
52 chart.push('\n');
53 chart.push_str(&format!(
54 "Overall Sensitivity: {:.3}\n",
55 results.overall_sensitivity
56 ));
57 chart
58 }
59
60 pub fn render_quantization_comparison_table(
62 num_parameters: usize,
63 baseline_accuracy: f32,
64 sensitivity_results: &SensitivityAnalysisResults,
65 ) -> String {
66 let mut table = String::new();
67 table.push_str("Quantization Scheme Comparison\n");
68 table.push_str(&"=".repeat(80));
69 table.push('\n');
70
71 table.push_str(&format!(
72 "{:<20} | {:>10} | {:>10} | {:>15} | {:>15}\n",
73 "Scheme", "Size (MB)", "Reduction", "Speed Factor", "Avg Accuracy"
74 ));
75 table.push_str(&"-".repeat(80));
76 table.push('\n');
77
78 let schemes = vec![
79 ("FP32 (Baseline)", QScheme::MixedPrecision, 1.0),
80 ("INT8 PerTensor", QScheme::PerTensorAffine, 1.0),
81 ("INT8 PerChannel", QScheme::PerChannelAffine, 1.0),
82 ("INT4", QScheme::Int4PerTensor, 1.0),
83 ("Binary", QScheme::Binary, 1.0),
84 ("Ternary", QScheme::Ternary, 1.0),
85 ("Group-wise", QScheme::GroupWise, 1.0),
86 ];
87
88 for (name, scheme, accuracy_modifier) in schemes {
89 let size_mb = SizeAnalyzer::calculate_model_size(num_parameters, scheme);
90 let reduction_factor = SizeAnalyzer::size_reduction_factor(
91 QScheme::MixedPrecision,
92 scheme,
93 num_parameters,
94 );
95 let speed_factor = Self::estimate_speed_improvement(scheme);
96 let avg_accuracy =
97 baseline_accuracy * accuracy_modifier - sensitivity_results.overall_sensitivity;
98
99 table.push_str(&format!(
100 "{name:<20} | {size_mb:>8.1} | {reduction_factor:>8.1}x | {speed_factor:>13.1}x | {avg_accuracy:>13.3}\n"
101 ));
102 }
103
104 table
105 }
106
107 pub fn render_error_histogram(
109 original: &Tensor,
110 quantized: &Tensor,
111 bins: usize,
112 width: usize,
113 ) -> TorshResult<String> {
114 if original.shape() != quantized.shape() {
115 return Err(TorshError::InvalidArgument(
116 "Tensors must have the same shape".to_string(),
117 ));
118 }
119
120 let original_data = original.data()?;
121 let quantized_data = quantized.data()?;
122
123 let errors: Vec<f32> = original_data
125 .iter()
126 .zip(quantized_data.iter())
127 .map(|(orig, quant)| orig - quant)
128 .collect();
129
130 let min_error = errors.iter().fold(f32::INFINITY, |a, &b| a.min(b));
131 let max_error = errors.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
132
133 if min_error == max_error {
134 return Ok("All errors are identical (perfect quantization)\n".to_string());
135 }
136
137 let mut histogram = vec![0; bins];
139 let bin_width = (max_error - min_error) / bins as f32;
140
141 for &error in &errors {
142 let bin_index = ((error - min_error) / bin_width) as usize;
143 let bin_index = bin_index.min(bins - 1);
144 histogram[bin_index] += 1;
145 }
146
147 let mut chart = String::new();
149 chart.push_str("Quantization Error Distribution\n");
150 chart.push_str(&"=".repeat(width));
151 chart.push('\n');
152
153 let max_count = *histogram.iter().max().unwrap_or(&1);
154
155 for (i, &count) in histogram.iter().enumerate() {
156 let bin_start = min_error + i as f32 * bin_width;
157 let bin_end = bin_start + bin_width;
158 let bar_length = (count as f32 / max_count as f32 * (width - 25) as f32) as usize;
159 let bar = "█".repeat(bar_length);
160
161 chart.push_str(&format!(
162 "[{:7.3}, {:7.3}) |{:<width$}| {:>5}\n",
163 bin_start,
164 bin_end,
165 bar,
166 count,
167 width = width - 25
168 ));
169 }
170
171 chart.push_str(&format!(
172 "\nMean Error: {:.6}\n",
173 errors.iter().sum::<f32>() / errors.len() as f32
174 ));
175 chart.push_str(&format!(
176 "Std Error: {:.6}\n",
177 Self::calculate_std(&errors)
178 ));
179
180 Ok(chart)
181 }
182
183 pub fn export_sensitivity_data(
185 results: &SensitivityAnalysisResults,
186 ) -> HashMap<String, Vec<f32>> {
187 let mut data = HashMap::new();
188
189 let sensitivity_scores: Vec<f32> = results
190 .layer_results
191 .iter()
192 .map(|r| r.sensitivity_score)
193 .collect();
194 let accuracy_drops: Vec<f32> = results
195 .layer_results
196 .iter()
197 .map(|r| r.accuracy_drop_percentage())
198 .collect();
199
200 data.insert("sensitivity_scores".to_string(), sensitivity_scores);
201 data.insert("accuracy_drops".to_string(), accuracy_drops);
202
203 data
204 }
205
206 fn truncate_string(s: &str, max_len: usize) -> String {
208 if s.len() <= max_len {
209 s.to_string()
210 } else {
211 format!("{}...", &s[..max_len.saturating_sub(3)])
212 }
213 }
214
215 fn calculate_std(values: &[f32]) -> f32 {
216 let mean = values.iter().sum::<f32>() / values.len() as f32;
217 let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / values.len() as f32;
218 variance.sqrt()
219 }
220
221 fn estimate_speed_improvement(scheme: QScheme) -> f32 {
222 match scheme {
223 QScheme::Binary => 8.0,
224 QScheme::Ternary => 6.0,
225 QScheme::Int4PerTensor | QScheme::Int4PerChannel => 4.0,
226 QScheme::PerTensorAffine | QScheme::PerChannelAffine => 2.0,
227 QScheme::PerTensorSymmetric | QScheme::PerChannelSymmetric => 2.0,
228 QScheme::MixedPrecision => 1.5,
229 QScheme::GroupWise => 2.5,
230 }
231 }
232}