Skip to main content

torsh_quantization/analysis/
visualization.rs

1//! Visualization tools for quantization analysis
2
3use crate::analysis::config::SensitivityAnalysisResults;
4use crate::analysis::size::SizeAnalyzer;
5use crate::QScheme;
6use crate::TorshResult;
7use std::collections::HashMap;
8use torsh_core::TorshError;
9use torsh_tensor::Tensor;
10
11/// Visualization tools for quantization analysis
12pub struct VisualizationTool;
13
14impl VisualizationTool {
15    /// Generate a text-based bar chart for sensitivity scores
16    pub fn render_sensitivity_bar_chart(
17        results: &SensitivityAnalysisResults,
18        width: usize,
19    ) -> String {
20        let mut chart = String::new();
21        chart.push_str("Layer Sensitivity Analysis\n");
22        chart.push_str(&"=".repeat(width));
23        chart.push('\n');
24
25        // Sort layers by sensitivity for better visualization
26        let mut sorted_results = results.layer_results.clone();
27        sorted_results.sort_by(|a, b| {
28            b.sensitivity_score
29                .partial_cmp(&a.sensitivity_score)
30                .expect("sensitivity scores should be comparable")
31        });
32
33        let max_sensitivity = sorted_results
34            .first()
35            .map(|r| r.sensitivity_score)
36            .unwrap_or(1.0);
37
38        for result in &sorted_results {
39            let bar_length =
40                ((result.sensitivity_score / max_sensitivity) * (width - 20) as f32) as usize;
41            let bar = "█".repeat(bar_length);
42
43            chart.push_str(&format!(
44                "{:15} |{:<width$}| {:.3}\n",
45                Self::truncate_string(&result.layer_name, 15),
46                bar,
47                result.sensitivity_score,
48                width = width - 20
49            ));
50        }
51
52        chart.push('\n');
53        chart.push_str(&format!(
54            "Overall Sensitivity: {:.3}\n",
55            results.overall_sensitivity
56        ));
57        chart
58    }
59
60    /// Generate a text-based comparison table for quantization schemes
61    pub fn render_quantization_comparison_table(
62        num_parameters: usize,
63        baseline_accuracy: f32,
64        sensitivity_results: &SensitivityAnalysisResults,
65    ) -> String {
66        let mut table = String::new();
67        table.push_str("Quantization Scheme Comparison\n");
68        table.push_str(&"=".repeat(80));
69        table.push('\n');
70
71        table.push_str(&format!(
72            "{:<20} | {:>10} | {:>10} | {:>15} | {:>15}\n",
73            "Scheme", "Size (MB)", "Reduction", "Speed Factor", "Avg Accuracy"
74        ));
75        table.push_str(&"-".repeat(80));
76        table.push('\n');
77
78        let schemes = vec![
79            ("FP32 (Baseline)", QScheme::MixedPrecision, 1.0),
80            ("INT8 PerTensor", QScheme::PerTensorAffine, 1.0),
81            ("INT8 PerChannel", QScheme::PerChannelAffine, 1.0),
82            ("INT4", QScheme::Int4PerTensor, 1.0),
83            ("Binary", QScheme::Binary, 1.0),
84            ("Ternary", QScheme::Ternary, 1.0),
85            ("Group-wise", QScheme::GroupWise, 1.0),
86        ];
87
88        for (name, scheme, accuracy_modifier) in schemes {
89            let size_mb = SizeAnalyzer::calculate_model_size(num_parameters, scheme);
90            let reduction_factor = SizeAnalyzer::size_reduction_factor(
91                QScheme::MixedPrecision,
92                scheme,
93                num_parameters,
94            );
95            let speed_factor = Self::estimate_speed_improvement(scheme);
96            let avg_accuracy =
97                baseline_accuracy * accuracy_modifier - sensitivity_results.overall_sensitivity;
98
99            table.push_str(&format!(
100                "{name:<20} | {size_mb:>8.1} | {reduction_factor:>8.1}x | {speed_factor:>13.1}x | {avg_accuracy:>13.3}\n"
101            ));
102        }
103
104        table
105    }
106
107    /// Generate histogram of quantization errors
108    pub fn render_error_histogram(
109        original: &Tensor,
110        quantized: &Tensor,
111        bins: usize,
112        width: usize,
113    ) -> TorshResult<String> {
114        if original.shape() != quantized.shape() {
115            return Err(TorshError::InvalidArgument(
116                "Tensors must have the same shape".to_string(),
117            ));
118        }
119
120        let original_data = original.data()?;
121        let quantized_data = quantized.data()?;
122
123        // Calculate errors
124        let errors: Vec<f32> = original_data
125            .iter()
126            .zip(quantized_data.iter())
127            .map(|(orig, quant)| orig - quant)
128            .collect();
129
130        let min_error = errors.iter().fold(f32::INFINITY, |a, &b| a.min(b));
131        let max_error = errors.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
132
133        if min_error == max_error {
134            return Ok("All errors are identical (perfect quantization)\n".to_string());
135        }
136
137        // Create histogram bins
138        let mut histogram = vec![0; bins];
139        let bin_width = (max_error - min_error) / bins as f32;
140
141        for &error in &errors {
142            let bin_index = ((error - min_error) / bin_width) as usize;
143            let bin_index = bin_index.min(bins - 1);
144            histogram[bin_index] += 1;
145        }
146
147        // Render histogram
148        let mut chart = String::new();
149        chart.push_str("Quantization Error Distribution\n");
150        chart.push_str(&"=".repeat(width));
151        chart.push('\n');
152
153        let max_count = *histogram.iter().max().unwrap_or(&1);
154
155        for (i, &count) in histogram.iter().enumerate() {
156            let bin_start = min_error + i as f32 * bin_width;
157            let bin_end = bin_start + bin_width;
158            let bar_length = (count as f32 / max_count as f32 * (width - 25) as f32) as usize;
159            let bar = "█".repeat(bar_length);
160
161            chart.push_str(&format!(
162                "[{:7.3}, {:7.3}) |{:<width$}| {:>5}\n",
163                bin_start,
164                bin_end,
165                bar,
166                count,
167                width = width - 25
168            ));
169        }
170
171        chart.push_str(&format!(
172            "\nMean Error: {:.6}\n",
173            errors.iter().sum::<f32>() / errors.len() as f32
174        ));
175        chart.push_str(&format!(
176            "Std Error:  {:.6}\n",
177            Self::calculate_std(&errors)
178        ));
179
180        Ok(chart)
181    }
182
183    /// Export data for external visualization tools
184    pub fn export_sensitivity_data(
185        results: &SensitivityAnalysisResults,
186    ) -> HashMap<String, Vec<f32>> {
187        let mut data = HashMap::new();
188
189        let sensitivity_scores: Vec<f32> = results
190            .layer_results
191            .iter()
192            .map(|r| r.sensitivity_score)
193            .collect();
194        let accuracy_drops: Vec<f32> = results
195            .layer_results
196            .iter()
197            .map(|r| r.accuracy_drop_percentage())
198            .collect();
199
200        data.insert("sensitivity_scores".to_string(), sensitivity_scores);
201        data.insert("accuracy_drops".to_string(), accuracy_drops);
202
203        data
204    }
205
206    // Helper functions
207    fn truncate_string(s: &str, max_len: usize) -> String {
208        if s.len() <= max_len {
209            s.to_string()
210        } else {
211            format!("{}...", &s[..max_len.saturating_sub(3)])
212        }
213    }
214
215    fn calculate_std(values: &[f32]) -> f32 {
216        let mean = values.iter().sum::<f32>() / values.len() as f32;
217        let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / values.len() as f32;
218        variance.sqrt()
219    }
220
221    fn estimate_speed_improvement(scheme: QScheme) -> f32 {
222        match scheme {
223            QScheme::Binary => 8.0,
224            QScheme::Ternary => 6.0,
225            QScheme::Int4PerTensor | QScheme::Int4PerChannel => 4.0,
226            QScheme::PerTensorAffine | QScheme::PerChannelAffine => 2.0,
227            QScheme::PerTensorSymmetric | QScheme::PerChannelSymmetric => 2.0,
228            QScheme::MixedPrecision => 1.5,
229            QScheme::GroupWise => 2.5,
230        }
231    }
232}