Skip to main content

trustformers_debug/visualization/
mod.rs

1//! Visualization module for TrustformeRS debugging tools
2//!
3//! This module has been refactored into focused submodules to comply with the
4//! 2000-line policy. The original visualization.rs (2843 lines) has been split into:
5//!
6//! - `types` - Basic visualization types, enums, and data structures
7//! - Additional modules to be created as needed for terminal, video, etc.
8
9pub mod modern_plotting;
10pub mod types;
11
12// Re-export main types for backward compatibility
13pub use modern_plotting::*;
14pub use types::*;
15
16use anyhow::Result;
17use std::path::Path;
18
19/// Main debug visualizer (simplified version)
20#[derive(Debug)]
21pub struct DebugVisualizer {
22    config: VisualizationConfig,
23}
24
25impl DebugVisualizer {
26    pub fn new(config: VisualizationConfig) -> Self {
27        Self { config }
28    }
29
30    pub fn with_default_config() -> Self {
31        Self::new(VisualizationConfig::default())
32    }
33
34    /// Create a simple line plot
35    pub fn create_line_plot(&self, data: &PlotData) -> Result<String> {
36        // Placeholder implementation
37        Ok(format!("Line plot '{}' created successfully", data.title))
38    }
39
40    /// Create a heatmap visualization
41    pub fn create_heatmap(&self, data: &HeatmapData) -> Result<String> {
42        // Placeholder implementation
43        Ok(format!("Heatmap '{}' created successfully", data.title))
44    }
45
46    /// Create a histogram
47    pub fn create_histogram(&self, data: &HistogramData) -> Result<String> {
48        // Placeholder implementation
49        Ok(format!("Histogram '{}' created successfully", data.title))
50    }
51
52    /// Plot tensor distribution
53    pub fn plot_tensor_distribution(
54        &self,
55        name: &str,
56        values: &[f64],
57        bins: usize,
58    ) -> Result<String> {
59        let data = HistogramData {
60            values: values.to_vec(),
61            bins,
62            title: format!("{} Distribution", name),
63            x_label: "Value".to_string(),
64            y_label: "Frequency".to_string(),
65            density: false,
66        };
67        self.create_histogram(&data)
68    }
69
70    /// Plot training metrics
71    pub fn plot_training_metrics(
72        &mut self,
73        steps: &[f64],
74        losses: &[f64],
75        accuracies: Option<&[f64]>,
76    ) -> Result<String> {
77        let mut plot_data = PlotData {
78            x_values: steps.to_vec(),
79            y_values: losses.to_vec(),
80            labels: vec!["Loss".to_string()],
81            title: "Training Metrics".to_string(),
82            x_label: "Steps".to_string(),
83            y_label: "Value".to_string(),
84        };
85
86        if let Some(acc) = accuracies {
87            plot_data.y_values.extend_from_slice(acc);
88            plot_data.labels.push("Accuracy".to_string());
89        }
90
91        self.create_line_plot(&plot_data)
92    }
93
94    /// Plot gradient flow
95    pub fn plot_gradient_flow(
96        &self,
97        layer_name: &str,
98        steps: &[f64],
99        gradient_norms: &[f64],
100    ) -> Result<String> {
101        let data = PlotData {
102            x_values: steps.to_vec(),
103            y_values: gradient_norms.to_vec(),
104            labels: vec![format!("{} Gradient Flow", layer_name)],
105            title: format!("Gradient Flow - {}", layer_name),
106            x_label: "Steps".to_string(),
107            y_label: "Gradient Norm".to_string(),
108        };
109        self.create_line_plot(&data)
110    }
111
112    /// Plot tensor heatmap
113    pub fn plot_tensor_heatmap(&self, name: &str, values: &[Vec<f64>]) -> Result<String> {
114        let data = HeatmapData {
115            values: values.to_vec(),
116            x_labels: (0..values.first().map_or(0, |row| row.len()))
117                .map(|i| i.to_string())
118                .collect(),
119            y_labels: (0..values.len()).map(|i| i.to_string()).collect(),
120            title: format!("{} Heatmap", name),
121            color_bar_label: "Value".to_string(),
122        };
123        self.create_heatmap(&data)
124    }
125
126    /// Plot activation patterns
127    pub fn plot_activation_patterns(
128        &self,
129        layer_name: &str,
130        inputs: &[f64],
131        outputs: &[f64],
132    ) -> Result<String> {
133        let data = PlotData {
134            x_values: inputs.to_vec(),
135            y_values: outputs.to_vec(),
136            labels: vec![format!("{} Activation", layer_name)],
137            title: format!("Activation Pattern - {}", layer_name),
138            x_label: "Input".to_string(),
139            y_label: "Output".to_string(),
140        };
141        self.create_line_plot(&data)
142    }
143
144    /// Get plot names
145    pub fn get_plot_names(&self) -> Vec<String> {
146        // Return some default plot names for demonstration
147        vec![
148            "tensor_distribution".to_string(),
149            "training_metrics".to_string(),
150            "gradient_flow".to_string(),
151            "activation_patterns".to_string(),
152        ]
153    }
154
155    /// Create dashboard
156    pub fn create_dashboard(&mut self, plot_names: &[String]) -> Result<String> {
157        let dashboard_path = Path::new(&self.config.output_directory).join("dashboard.html");
158        std::fs::create_dir_all(&self.config.output_directory)?;
159
160        let mut html_content =
161            String::from("<html><head><title>Debug Dashboard</title></head><body>");
162        html_content.push_str("<h1>TrustformeRS Debug Dashboard</h1>");
163
164        for plot_name in plot_names {
165            html_content.push_str(&format!(
166                "<div><h2>{}</h2><p>Plot: {}</p></div>",
167                plot_name, plot_name
168            ));
169        }
170
171        html_content.push_str("</body></html>");
172        std::fs::write(&dashboard_path, html_content)?;
173
174        Ok(dashboard_path.to_string_lossy().to_string())
175    }
176
177    /// Export plot data
178    pub fn export_plot_data(&self, plot_name: &str, export_path: &Path) -> Result<()> {
179        std::fs::create_dir_all(export_path.parent().unwrap_or(Path::new(".")))?;
180        let data = format!("Plot data for: {}", plot_name);
181        std::fs::write(export_path, data)?;
182        Ok(())
183    }
184
185    /// Save visualization to file
186    pub fn save_to_file(&self, filename: &str) -> Result<()> {
187        // Placeholder implementation
188        let output_path = Path::new(&self.config.output_directory).join(filename);
189        std::fs::create_dir_all(&self.config.output_directory)?;
190        std::fs::write(output_path, "placeholder visualization content")?;
191        Ok(())
192    }
193}
194
195/// Simple terminal-based visualizer
196pub struct TerminalVisualizer;
197
198impl TerminalVisualizer {
199    pub fn new() -> Self {
200        Self
201    }
202
203    /// Display simple text-based histogram in terminal
204    pub fn display_histogram(&self, data: &HistogramData) -> Result<()> {
205        println!("Terminal Histogram: {}", data.title);
206        println!("Data points: {}", data.values.len());
207        // Placeholder for actual terminal histogram
208        Ok(())
209    }
210
211    /// Display simple text-based statistics
212    pub fn display_statistics(&self, label: &str, values: &[f64]) -> Result<()> {
213        if values.is_empty() {
214            println!("{}: No data", label);
215            return Ok(());
216        }
217
218        let mean = values.iter().sum::<f64>() / values.len() as f64;
219        let min = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
220        let max = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
221
222        println!(
223            "{}: mean={:.3}, min={:.3}, max={:.3}",
224            label, mean, min, max
225        );
226        Ok(())
227    }
228
229    /// ASCII histogram display
230    pub fn ascii_histogram(&self, values: &[f64], bins: usize) -> String {
231        if values.is_empty() {
232            return "No data for histogram".to_string();
233        }
234
235        let min_val = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
236        let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
237
238        if (max_val - min_val).abs() < f64::EPSILON {
239            return format!("All values are {:.3}", min_val);
240        }
241
242        let mut histogram = vec![0; bins];
243        let bin_width = (max_val - min_val) / bins as f64;
244
245        for &value in values {
246            let bin_index = ((value - min_val) / bin_width).floor() as usize;
247            let bin_index = bin_index.min(bins - 1);
248            histogram[bin_index] += 1;
249        }
250
251        let max_count = histogram.iter().max().unwrap_or(&0);
252        let scale = if *max_count > 0 { 40.0 / *max_count as f64 } else { 1.0 };
253
254        let mut result = String::new();
255        for (i, &count) in histogram.iter().enumerate() {
256            let bin_start = min_val + i as f64 * bin_width;
257            let bin_end = bin_start + bin_width;
258            let bar_length = (count as f64 * scale) as usize;
259            let bar = "█".repeat(bar_length);
260            result.push_str(&format!(
261                "[{:.2}-{:.2}): {} ({})\n",
262                bin_start, bin_end, bar, count
263            ));
264        }
265
266        result
267    }
268
269    /// ASCII line plot display
270    pub fn ascii_line_plot(&self, x_values: &[f64], y_values: &[f64], title: &str) -> String {
271        if x_values.is_empty() || y_values.is_empty() || x_values.len() != y_values.len() {
272            return "Invalid data for line plot".to_string();
273        }
274
275        let min_y = y_values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
276        let max_y = y_values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
277
278        let mut result = format!("{}\n", title);
279        result.push_str("═".repeat(title.len()).as_str());
280        result.push('\n');
281
282        if (max_y - min_y).abs() < f64::EPSILON {
283            result.push_str(&format!("Constant value: {:.3}\n", min_y));
284            return result;
285        }
286
287        let height = 20;
288        let width = x_values.len().min(80);
289
290        // Sample data if too many points
291        let step = if x_values.len() > width { x_values.len() / width } else { 1 };
292
293        for row in (0..height).rev() {
294            let y_threshold = min_y + (max_y - min_y) * row as f64 / (height - 1) as f64;
295            let mut line = String::new();
296
297            for i in (0..x_values.len()).step_by(step).take(width) {
298                if y_values[i] >= y_threshold {
299                    line.push('*');
300                } else {
301                    line.push(' ');
302                }
303            }
304            result.push_str(&format!("{:8.2} |{}\n", y_threshold, line));
305        }
306
307        result.push_str(&format!("{:8} +{}\n", "", "─".repeat(width)));
308        result
309    }
310}
311
312impl Default for TerminalVisualizer {
313    fn default() -> Self {
314        Self::new()
315    }
316}