Skip to main content

trustformers_debug/visualization/
mod.rs

1//! Visualization module for TrustformeRS debugging tools
2//!
3//! This module has been refactored into focused submodules to comply with the
4//! 2000-line policy. The original visualization.rs (2843 lines) has been split into:
5//!
6//! - `types` - Basic visualization types, enums, and data structures
7//! - Additional modules to be created as needed for terminal, video, etc.
8
9pub mod modern_plotting;
10pub mod types;
11
12// Re-export main types for backward compatibility
13pub use modern_plotting::*;
14pub use types::*;
15
16use anyhow::Result;
17use std::path::Path;
18
19/// Main debug visualizer (simplified version)
20#[derive(Debug)]
21pub struct DebugVisualizer {
22    config: VisualizationConfig,
23}
24
25impl DebugVisualizer {
26    pub fn new(config: VisualizationConfig) -> Self {
27        Self { config }
28    }
29
30    pub fn with_default_config() -> Self {
31        Self::new(VisualizationConfig::default())
32    }
33
34    /// Create a simple line plot
35    pub fn create_line_plot(&self, data: &PlotData) -> Result<String> {
36        // Placeholder implementation
37        Ok(format!("Line plot '{}' created successfully", data.title))
38    }
39
40    /// Create a heatmap visualization
41    pub fn create_heatmap(&self, data: &HeatmapData) -> Result<String> {
42        // Placeholder implementation
43        Ok(format!("Heatmap '{}' created successfully", data.title))
44    }
45
46    /// Create a histogram
47    pub fn create_histogram(&self, data: &HistogramData) -> Result<String> {
48        // Placeholder implementation
49        Ok(format!("Histogram '{}' created successfully", data.title))
50    }
51
52    /// Plot tensor distribution
53    pub fn plot_tensor_distribution(
54        &self,
55        name: &str,
56        values: &[f64],
57        bins: usize,
58    ) -> Result<String> {
59        let data = HistogramData {
60            values: values.to_vec(),
61            bins,
62            title: format!("{} Distribution", name),
63            x_label: "Value".to_string(),
64            y_label: "Frequency".to_string(),
65            density: false,
66        };
67        self.create_histogram(&data)
68    }
69
70    /// Plot training metrics
71    pub fn plot_training_metrics(
72        &mut self,
73        steps: &[f64],
74        losses: &[f64],
75        accuracies: Option<&[f64]>,
76    ) -> Result<String> {
77        let mut plot_data = PlotData {
78            x_values: steps.to_vec(),
79            y_values: losses.to_vec(),
80            labels: vec!["Loss".to_string()],
81            title: "Training Metrics".to_string(),
82            x_label: "Steps".to_string(),
83            y_label: "Value".to_string(),
84        };
85
86        if let Some(acc) = accuracies {
87            plot_data.y_values.extend_from_slice(acc);
88            plot_data.labels.push("Accuracy".to_string());
89        }
90
91        self.create_line_plot(&plot_data)
92    }
93
94    /// Plot gradient flow
95    pub fn plot_gradient_flow(
96        &self,
97        layer_name: &str,
98        steps: &[f64],
99        gradient_norms: &[f64],
100    ) -> Result<String> {
101        let data = PlotData {
102            x_values: steps.to_vec(),
103            y_values: gradient_norms.to_vec(),
104            labels: vec![format!("{} Gradient Flow", layer_name)],
105            title: format!("Gradient Flow - {}", layer_name),
106            x_label: "Steps".to_string(),
107            y_label: "Gradient Norm".to_string(),
108        };
109        self.create_line_plot(&data)
110    }
111
112    /// Plot tensor heatmap
113    pub fn plot_tensor_heatmap(&self, name: &str, values: &[Vec<f64>]) -> Result<String> {
114        let data = HeatmapData {
115            values: values.to_vec(),
116            x_labels: (0..values.first().map_or(0, |row| row.len()))
117                .map(|i| i.to_string())
118                .collect(),
119            y_labels: (0..values.len()).map(|i| i.to_string()).collect(),
120            title: format!("{} Heatmap", name),
121            color_bar_label: "Value".to_string(),
122        };
123        self.create_heatmap(&data)
124    }
125
126    /// Plot activation patterns
127    pub fn plot_activation_patterns(
128        &self,
129        layer_name: &str,
130        inputs: &[f64],
131        outputs: &[f64],
132    ) -> Result<String> {
133        let data = PlotData {
134            x_values: inputs.to_vec(),
135            y_values: outputs.to_vec(),
136            labels: vec![format!("{} Activation", layer_name)],
137            title: format!("Activation Pattern - {}", layer_name),
138            x_label: "Input".to_string(),
139            y_label: "Output".to_string(),
140        };
141        self.create_line_plot(&data)
142    }
143
144    /// Get plot names
145    pub fn get_plot_names(&self) -> Vec<String> {
146        // Return some default plot names for demonstration
147        vec![
148            "tensor_distribution".to_string(),
149            "training_metrics".to_string(),
150            "gradient_flow".to_string(),
151            "activation_patterns".to_string(),
152        ]
153    }
154
155    /// Create dashboard
156    pub fn create_dashboard(&mut self, plot_names: &[String]) -> Result<String> {
157        let dashboard_path = Path::new(&self.config.output_directory).join("dashboard.html");
158        std::fs::create_dir_all(&self.config.output_directory)?;
159
160        let mut html_content =
161            String::from("<html><head><title>Debug Dashboard</title></head><body>");
162        html_content.push_str("<h1>TrustformeRS Debug Dashboard</h1>");
163
164        for plot_name in plot_names {
165            html_content.push_str(&format!(
166                "<div><h2>{}</h2><p>Plot: {}</p></div>",
167                plot_name, plot_name
168            ));
169        }
170
171        html_content.push_str("</body></html>");
172        std::fs::write(&dashboard_path, html_content)?;
173
174        Ok(dashboard_path.to_string_lossy().to_string())
175    }
176
177    /// Export plot data
178    pub fn export_plot_data(&self, plot_name: &str, export_path: &Path) -> Result<()> {
179        std::fs::create_dir_all(export_path.parent().unwrap_or(Path::new(".")))?;
180        let data = format!("Plot data for: {}", plot_name);
181        std::fs::write(export_path, data)?;
182        Ok(())
183    }
184
185    /// Save visualization to file
186    pub fn save_to_file(&self, filename: &str) -> Result<()> {
187        // Placeholder implementation
188        let output_path = Path::new(&self.config.output_directory).join(filename);
189        std::fs::create_dir_all(&self.config.output_directory)?;
190        std::fs::write(output_path, "placeholder visualization content")?;
191        Ok(())
192    }
193}
194
195/// Simple terminal-based visualizer
196pub struct TerminalVisualizer;
197
198impl TerminalVisualizer {
199    pub fn new() -> Self {
200        Self
201    }
202
203    /// Display simple text-based histogram in terminal
204    pub fn display_histogram(&self, data: &HistogramData) -> Result<()> {
205        println!("Terminal Histogram: {}", data.title);
206        println!("Data points: {}", data.values.len());
207        if data.values.is_empty() {
208            return Ok(());
209        }
210        let bins = if data.bins == 0 { 10 } else { data.bins };
211        let rendered = self.ascii_histogram(&data.values, bins);
212        if !data.x_label.is_empty() || !data.y_label.is_empty() {
213            println!("{} vs {}", data.y_label, data.x_label);
214        }
215        print!("{}", rendered);
216        Ok(())
217    }
218
219    /// Display simple text-based statistics
220    pub fn display_statistics(&self, label: &str, values: &[f64]) -> Result<()> {
221        if values.is_empty() {
222            println!("{}: No data", label);
223            return Ok(());
224        }
225
226        let mean = values.iter().sum::<f64>() / values.len() as f64;
227        let min = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
228        let max = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
229
230        println!(
231            "{}: mean={:.3}, min={:.3}, max={:.3}",
232            label, mean, min, max
233        );
234        Ok(())
235    }
236
237    /// ASCII histogram display
238    pub fn ascii_histogram(&self, values: &[f64], bins: usize) -> String {
239        if values.is_empty() {
240            return "No data for histogram".to_string();
241        }
242
243        let min_val = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
244        let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
245
246        if (max_val - min_val).abs() < f64::EPSILON {
247            return format!("All values are {:.3}", min_val);
248        }
249
250        let mut histogram = vec![0; bins];
251        let bin_width = (max_val - min_val) / bins as f64;
252
253        for &value in values {
254            let bin_index = ((value - min_val) / bin_width).floor() as usize;
255            let bin_index = bin_index.min(bins - 1);
256            histogram[bin_index] += 1;
257        }
258
259        let max_count = histogram.iter().max().unwrap_or(&0);
260        let scale = if *max_count > 0 { 40.0 / *max_count as f64 } else { 1.0 };
261
262        let mut result = String::new();
263        for (i, &count) in histogram.iter().enumerate() {
264            let bin_start = min_val + i as f64 * bin_width;
265            let bin_end = bin_start + bin_width;
266            let bar_length = (count as f64 * scale) as usize;
267            let bar = "█".repeat(bar_length);
268            result.push_str(&format!(
269                "[{:.2}-{:.2}): {} ({})\n",
270                bin_start, bin_end, bar, count
271            ));
272        }
273
274        result
275    }
276
277    /// ASCII line plot display
278    pub fn ascii_line_plot(&self, x_values: &[f64], y_values: &[f64], title: &str) -> String {
279        if x_values.is_empty() || y_values.is_empty() || x_values.len() != y_values.len() {
280            return "Invalid data for line plot".to_string();
281        }
282
283        let min_y = y_values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
284        let max_y = y_values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
285
286        let mut result = format!("{}\n", title);
287        result.push_str("═".repeat(title.len()).as_str());
288        result.push('\n');
289
290        if (max_y - min_y).abs() < f64::EPSILON {
291            result.push_str(&format!("Constant value: {:.3}\n", min_y));
292            return result;
293        }
294
295        let height = 20;
296        let width = x_values.len().min(80);
297
298        // Sample data if too many points
299        let step = if x_values.len() > width { x_values.len() / width } else { 1 };
300
301        for row in (0..height).rev() {
302            let y_threshold = min_y + (max_y - min_y) * row as f64 / (height - 1) as f64;
303            let mut line = String::new();
304
305            for i in (0..x_values.len()).step_by(step).take(width) {
306                if y_values[i] >= y_threshold {
307                    line.push('*');
308                } else {
309                    line.push(' ');
310                }
311            }
312            result.push_str(&format!("{:8.2} |{}\n", y_threshold, line));
313        }
314
315        result.push_str(&format!("{:8} +{}\n", "", "─".repeat(width)));
316        result
317    }
318}
319
320impl Default for TerminalVisualizer {
321    fn default() -> Self {
322        Self::new()
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329
330    #[test]
331    fn test_display_histogram_empty_returns_ok() {
332        let viz = TerminalVisualizer::new();
333        let data = HistogramData {
334            values: vec![],
335            bins: 10,
336            title: "empty".to_string(),
337            x_label: String::new(),
338            y_label: String::new(),
339            density: false,
340        };
341        assert!(viz.display_histogram(&data).is_ok());
342    }
343
344    #[test]
345    fn test_display_histogram_with_values_returns_ok() {
346        let viz = TerminalVisualizer::new();
347        let data = HistogramData {
348            values: (0..50).map(|i| i as f64).collect(),
349            bins: 5,
350            title: "ramp".to_string(),
351            x_label: "value".to_string(),
352            y_label: "count".to_string(),
353            density: false,
354        };
355        assert!(viz.display_histogram(&data).is_ok());
356    }
357
358    #[test]
359    fn test_display_histogram_zero_bins_falls_back_to_default() {
360        let viz = TerminalVisualizer::new();
361        let data = HistogramData {
362            values: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
363            bins: 0,
364            title: "fallback".to_string(),
365            x_label: String::new(),
366            y_label: String::new(),
367            density: false,
368        };
369        // Should not panic with zero bins.
370        assert!(viz.display_histogram(&data).is_ok());
371    }
372}