scirs2_neural/visualization/
training.rs

1//! Training metrics and curve visualization for neural networks
2//!
3//! This module provides comprehensive tools for visualizing training progress
4//! including loss curves, accuracy metrics, learning rate schedules, and system performance.
5
6use super::config::{DownsamplingStrategy, VisualizationConfig};
7use crate::error::{NeuralError, Result};
8
9use num_traits::Float;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::fmt::Debug;
13use std::fs;
14use std::path::PathBuf;
15
16/// Training metrics visualizer
17#[allow(dead_code)]
18pub struct TrainingVisualizer<F: Float + Debug> {
19    /// Training history
20    metrics_history: Vec<TrainingMetrics<F>>,
21    /// Visualization configuration
22    config: VisualizationConfig,
23    /// Active plots
24    active_plots: HashMap<String, PlotConfig>,
25}
26
27/// Training metrics for a single epoch/step
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct TrainingMetrics<F: Float + Debug> {
30    /// Epoch number
31    pub epoch: usize,
32    /// Step number within epoch
33    pub step: usize,
34    /// Timestamp
35    pub timestamp: String,
36    /// Loss values
37    pub losses: HashMap<String, F>,
38    /// Accuracy metrics
39    pub accuracies: HashMap<String, F>,
40    /// Learning rate
41    pub learning_rate: F,
42    /// Other custom metrics
43    pub custom_metrics: HashMap<String, F>,
44    /// System metrics
45    pub system_metrics: SystemMetrics,
46}
47
48/// System performance metrics during training
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct SystemMetrics {
51    /// Memory usage in MB
52    pub memory_usage_mb: f64,
53    /// GPU memory usage in MB (if available)
54    pub gpu_memory_mb: Option<f64>,
55    /// CPU utilization percentage
56    pub cpu_utilization: f64,
57    /// GPU utilization percentage (if available)
58    pub gpu_utilization: Option<f64>,
59    /// Training step duration in milliseconds
60    pub step_duration_ms: f64,
61    /// Samples processed per second
62    pub samples_per_second: f64,
63}
64
65/// Plot configuration
66#[derive(Debug, Clone, Serialize)]
67pub struct PlotConfig {
68    /// Plot title
69    pub title: String,
70    /// X-axis configuration
71    pub x_axis: AxisConfig,
72    /// Y-axis configuration
73    pub y_axis: AxisConfig,
74    /// Series to plot
75    pub series: Vec<SeriesConfig>,
76    /// Plot type
77    pub plot_type: PlotType,
78    /// Update mode
79    pub update_mode: UpdateMode,
80}
81
82/// Axis configuration
83#[derive(Debug, Clone, Serialize)]
84pub struct AxisConfig {
85    /// Axis label
86    pub label: String,
87    /// Axis scale
88    pub scale: AxisScale,
89    /// Range (None for auto)
90    pub range: Option<(f64, f64)>,
91    /// Show grid lines
92    pub show_grid: bool,
93    /// Tick configuration
94    pub ticks: TickConfig,
95}
96
97/// Axis scale type
98#[derive(Debug, Clone, PartialEq, Serialize)]
99pub enum AxisScale {
100    /// Linear scale
101    Linear,
102    /// Logarithmic scale
103    Log,
104    /// Square root scale
105    Sqrt,
106    /// Custom scale
107    Custom(String),
108}
109
110/// Tick configuration
111#[derive(Debug, Clone, Serialize)]
112pub struct TickConfig {
113    /// Tick interval (None for auto)
114    pub interval: Option<f64>,
115    /// Tick format
116    pub format: TickFormat,
117    /// Show tick labels
118    pub show_labels: bool,
119    /// Tick rotation angle
120    pub rotation: f32,
121}
122
123/// Tick format options
124#[derive(Debug, Clone, PartialEq, Serialize)]
125pub enum TickFormat {
126    /// Automatic formatting
127    Auto,
128    /// Fixed decimal places
129    Fixed(u32),
130    /// Scientific notation
131    Scientific,
132    /// Percentage
133    Percentage,
134    /// Custom format string
135    Custom(String),
136}
137
138/// Data series configuration
139#[derive(Debug, Clone, Serialize)]
140pub struct SeriesConfig {
141    /// Series name
142    pub name: String,
143    /// Data source (metric name)
144    pub data_source: String,
145    /// Line style
146    pub style: LineStyleConfig,
147    /// Marker style
148    pub markers: MarkerConfig,
149    /// Series color
150    pub color: String,
151    /// Series opacity
152    pub opacity: f32,
153}
154
155/// Line style configuration for series
156#[derive(Debug, Clone, Serialize)]
157pub struct LineStyleConfig {
158    /// Line style
159    pub style: LineStyle,
160    /// Line width
161    pub width: f32,
162    /// Smoothing enabled
163    pub smoothing: bool,
164    /// Smoothing window size
165    pub smoothing_window: usize,
166}
167
168/// Line style options (re-exported from network module)
169#[derive(Debug, Clone, PartialEq, Serialize)]
170pub enum LineStyle {
171    /// Solid line
172    Solid,
173    /// Dashed line
174    Dashed,
175    /// Dotted line
176    Dotted,
177    /// Dash-dot line
178    DashDot,
179}
180
181/// Marker configuration for data points
182#[derive(Debug, Clone, Serialize)]
183pub struct MarkerConfig {
184    /// Show markers
185    pub show: bool,
186    /// Marker shape
187    pub shape: MarkerShape,
188    /// Marker size
189    pub size: f32,
190    /// Marker fill color
191    pub fill_color: String,
192    /// Marker border color
193    pub border_color: String,
194}
195
196/// Marker shape options
197#[derive(Debug, Clone, PartialEq, Serialize)]
198pub enum MarkerShape {
199    /// Circle marker
200    Circle,
201    /// Square marker
202    Square,
203    /// Triangle marker
204    Triangle,
205    /// Diamond marker
206    Diamond,
207    /// Cross marker
208    Cross,
209    /// Plus marker
210    Plus,
211}
212
213/// Plot type options
214#[derive(Debug, Clone, PartialEq, Serialize)]
215pub enum PlotType {
216    /// Line plot
217    Line,
218    /// Scatter plot
219    Scatter,
220    /// Bar plot
221    Bar,
222    /// Area plot
223    Area,
224    /// Histogram
225    Histogram,
226    /// Box plot
227    Box,
228    /// Heatmap
229    Heatmap,
230}
231
232/// Update mode for plots
233#[derive(Debug, Clone, PartialEq, Serialize)]
234pub enum UpdateMode {
235    /// Append new data
236    Append,
237    /// Replace all data
238    Replace,
239    /// Rolling window
240    Rolling(usize),
241}
242
243// Implementation for TrainingVisualizer
244
245impl<F: Float + Debug + 'static + num_traits::FromPrimitive + Send + Sync> TrainingVisualizer<F> {
246    /// Create a new training visualizer
247    pub fn new(config: VisualizationConfig) -> Self {
248        Self {
249            metrics_history: Vec::new(),
250            config,
251            active_plots: HashMap::new(),
252        }
253    }
254
255    /// Add training metrics for visualization
256    pub fn add_metrics(&mut self, metrics: TrainingMetrics<F>) {
257        self.metrics_history.push(metrics);
258
259        // Apply downsampling if needed
260        if self.metrics_history.len() > self.config.performance.max_points_per_plot
261            && self.config.performance.enable_downsampling
262        {
263            self.downsample_metrics();
264        }
265    }
266
267    /// Generate training curves visualization
268    pub fn visualize_training_curves(&self) -> Result<Vec<PathBuf>> {
269        let mut output_files = Vec::new();
270
271        // Generate loss curves
272        if let Some(loss_plot) = self.create_loss_plot()? {
273            let loss_path = self.config.output_dir.join("training_loss.html");
274            fs::write(&loss_path, loss_plot)
275                .map_err(|e| NeuralError::IOError(format!("Failed to write loss plot: {}", e)))?;
276            output_files.push(loss_path);
277        }
278
279        // Generate accuracy curves
280        if let Some(accuracy_plot) = self.create_accuracy_plot()? {
281            let accuracy_path = self.config.output_dir.join("training_accuracy.html");
282            fs::write(&accuracy_path, accuracy_plot).map_err(|e| {
283                NeuralError::IOError(format!("Failed to write accuracy plot: {}", e))
284            })?;
285            output_files.push(accuracy_path);
286        }
287
288        // Generate learning rate plot
289        if let Some(lr_plot) = self.create_learning_rate_plot()? {
290            let lr_path = self.config.output_dir.join("learning_rate.html");
291            fs::write(&lr_path, lr_plot).map_err(|e| {
292                NeuralError::IOError(format!("Failed to write learning rate plot: {}", e))
293            })?;
294            output_files.push(lr_path);
295        }
296
297        // Generate system metrics plot
298        if let Some(system_plot) = self.create_system_metrics_plot()? {
299            let system_path = self.config.output_dir.join("system_metrics.html");
300            fs::write(&system_path, system_plot).map_err(|e| {
301                NeuralError::IOError(format!("Failed to write system metrics plot: {}", e))
302            })?;
303            output_files.push(system_path);
304        }
305
306        Ok(output_files)
307    }
308
309    /// Get the current metrics history
310    pub fn get_metrics_history(&self) -> &[TrainingMetrics<F>] {
311        &self.metrics_history
312    }
313
314    /// Clear the metrics history
315    pub fn clear_history(&mut self) {
316        self.metrics_history.clear();
317    }
318
319    /// Add a custom plot configuration
320    pub fn add_plot(&mut self, name: String, config: PlotConfig) {
321        self.active_plots.insert(name, config);
322    }
323
324    /// Remove a plot configuration
325    pub fn remove_plot(&mut self, name: &str) -> Option<PlotConfig> {
326        self.active_plots.remove(name)
327    }
328
329    /// Update the visualization configuration
330    pub fn update_config(&mut self, config: VisualizationConfig) {
331        self.config = config;
332    }
333
334    fn downsample_metrics(&mut self) {
335        // TODO: Implement downsampling based on strategy
336        match self.config.performance.downsampling_strategy {
337            DownsamplingStrategy::Uniform => {
338                // Keep every nth point
339                let step = self.metrics_history.len() / self.config.performance.max_points_per_plot;
340                if step > 1 {
341                    let mut downsampled = Vec::new();
342                    for (i, metric) in self.metrics_history.iter().enumerate() {
343                        if i % step == 0 {
344                            downsampled.push(metric.clone());
345                        }
346                    }
347                    self.metrics_history = downsampled;
348                }
349            }
350            _ => {
351                // For now, just truncate to max size
352                if self.metrics_history.len() > self.config.performance.max_points_per_plot {
353                    let start =
354                        self.metrics_history.len() - self.config.performance.max_points_per_plot;
355                    self.metrics_history.drain(0..start);
356                }
357            }
358        }
359    }
360
361    fn create_loss_plot(&self) -> Result<Option<String>> {
362        if self.metrics_history.is_empty() {
363            return Ok(None);
364        }
365
366        // TODO: Implement actual plotting library integration
367        // For now, return a placeholder HTML
368        let plot_html = r#"
369<!DOCTYPE html>
370<html>
371<head>
372    <title>Training Loss</title>
373    <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
374</head>
375<body>
376    <div id="lossPlot" style="width:100%;height:500px;"></div>
377    <script>
378        // TODO: Implement actual loss curve plotting
379        var trace = {
380            x: [1, 2, 3, 4],
381            y: [0.8, 0.6, 0.4, 0.3],
382            type: 'scatter',
383            name: 'Training Loss'
384        };
385        
386        var layout = {
387            title: 'Training Loss Over Time',
388            xaxis: { title: 'Epoch' },
389            yaxis: { title: 'Loss' }
390        };
391        
392        Plotly.newPlot('lossPlot', [trace], layout);
393    </script>
394</body>
395</html>"#;
396
397        Ok(Some(plot_html.to_string()))
398    }
399
400    fn create_accuracy_plot(&self) -> Result<Option<String>> {
401        if self.metrics_history.is_empty() {
402            return Ok(None);
403        }
404
405        // TODO: Implement accuracy plotting
406        Ok(None)
407    }
408
409    fn create_learning_rate_plot(&self) -> Result<Option<String>> {
410        if self.metrics_history.is_empty() {
411            return Ok(None);
412        }
413
414        // TODO: Implement learning rate plotting
415        Ok(None)
416    }
417
418    fn create_system_metrics_plot(&self) -> Result<Option<String>> {
419        if self.metrics_history.is_empty() {
420            return Ok(None);
421        }
422
423        // TODO: Implement system metrics plotting
424        Ok(None)
425    }
426}
427
428// Default implementations for configuration types
429
430impl Default for PlotConfig {
431    fn default() -> Self {
432        Self {
433            title: "Training Metrics".to_string(),
434            x_axis: AxisConfig::default(),
435            y_axis: AxisConfig::default(),
436            series: Vec::new(),
437            plot_type: PlotType::Line,
438            update_mode: UpdateMode::Append,
439        }
440    }
441}
442
443impl Default for AxisConfig {
444    fn default() -> Self {
445        Self {
446            label: "".to_string(),
447            scale: AxisScale::Linear,
448            range: None,
449            show_grid: true,
450            ticks: TickConfig::default(),
451        }
452    }
453}
454
455impl Default for TickConfig {
456    fn default() -> Self {
457        Self {
458            interval: None,
459            format: TickFormat::Auto,
460            show_labels: true,
461            rotation: 0.0,
462        }
463    }
464}
465
466impl Default for SeriesConfig {
467    fn default() -> Self {
468        Self {
469            name: "Series".to_string(),
470            data_source: "".to_string(),
471            style: LineStyleConfig::default(),
472            markers: MarkerConfig::default(),
473            color: "#1f77b4".to_string(), // Default blue
474            opacity: 1.0,
475        }
476    }
477}
478
479impl Default for LineStyleConfig {
480    fn default() -> Self {
481        Self {
482            style: LineStyle::Solid,
483            width: 2.0,
484            smoothing: false,
485            smoothing_window: 5,
486        }
487    }
488}
489
490impl Default for MarkerConfig {
491    fn default() -> Self {
492        Self {
493            show: false,
494            shape: MarkerShape::Circle,
495            size: 6.0,
496            fill_color: "#1f77b4".to_string(),
497            border_color: "#1f77b4".to_string(),
498        }
499    }
500}
501
502impl Default for SystemMetrics {
503    fn default() -> Self {
504        Self {
505            memory_usage_mb: 0.0,
506            gpu_memory_mb: None,
507            cpu_utilization: 0.0,
508            gpu_utilization: None,
509            step_duration_ms: 0.0,
510            samples_per_second: 0.0,
511        }
512    }
513}
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518
519    #[test]
520    fn test_training_visualizer_creation() {
521        let config = VisualizationConfig::default();
522        let visualizer = TrainingVisualizer::<f32>::new(config);
523
524        assert!(visualizer.metrics_history.is_empty());
525        assert!(visualizer.active_plots.is_empty());
526    }
527
528    #[test]
529    fn test_add_metrics() {
530        let config = VisualizationConfig::default();
531        let mut visualizer = TrainingVisualizer::<f32>::new(config);
532
533        let metrics = TrainingMetrics {
534            epoch: 1,
535            step: 100,
536            timestamp: "2024-01-01T00:00:00Z".to_string(),
537            losses: HashMap::from([("train_loss".to_string(), 0.5)]),
538            accuracies: HashMap::from([("train_acc".to_string(), 0.8)]),
539            learning_rate: 0.001,
540            custom_metrics: HashMap::new(),
541            system_metrics: SystemMetrics::default(),
542        };
543
544        visualizer.add_metrics(metrics);
545        assert_eq!(visualizer.metrics_history.len(), 1);
546    }
547
548    #[test]
549    fn test_plot_config_defaults() {
550        let config = PlotConfig::default();
551        assert_eq!(config.title, "Training Metrics");
552        assert_eq!(config.plot_type, PlotType::Line);
553        assert_eq!(config.update_mode, UpdateMode::Append);
554    }
555
556    #[test]
557    fn test_axis_scale_variants() {
558        assert_eq!(AxisScale::Linear, AxisScale::Linear);
559        assert_eq!(AxisScale::Log, AxisScale::Log);
560        assert_eq!(AxisScale::Sqrt, AxisScale::Sqrt);
561
562        let custom = AxisScale::Custom("symlog".to_string());
563        match custom {
564            AxisScale::Custom(name) => assert_eq!(name, "symlog"),
565            _ => panic!("Expected custom scale"),
566        }
567    }
568
569    #[test]
570    fn test_marker_shapes() {
571        let shapes = [
572            MarkerShape::Circle,
573            MarkerShape::Square,
574            MarkerShape::Triangle,
575            MarkerShape::Diamond,
576            MarkerShape::Cross,
577            MarkerShape::Plus,
578        ];
579
580        assert_eq!(shapes.len(), 6);
581        assert_eq!(shapes[0], MarkerShape::Circle);
582    }
583
584    #[test]
585    fn test_plot_types() {
586        let types = [
587            PlotType::Line,
588            PlotType::Scatter,
589            PlotType::Bar,
590            PlotType::Area,
591            PlotType::Histogram,
592            PlotType::Box,
593            PlotType::Heatmap,
594        ];
595
596        assert_eq!(types.len(), 7);
597        assert_eq!(types[0], PlotType::Line);
598    }
599
600    #[test]
601    fn test_update_modes() {
602        let append = UpdateMode::Append;
603        let replace = UpdateMode::Replace;
604        let rolling = UpdateMode::Rolling(100);
605
606        assert_eq!(append, UpdateMode::Append);
607        assert_eq!(replace, UpdateMode::Replace);
608
609        match rolling {
610            UpdateMode::Rolling(size) => assert_eq!(size, 100),
611            _ => panic!("Expected rolling update mode"),
612        }
613    }
614
615    #[test]
616    fn test_clear_history() {
617        let config = VisualizationConfig::default();
618        let mut visualizer = TrainingVisualizer::<f32>::new(config);
619
620        let metrics = TrainingMetrics {
621            epoch: 1,
622            step: 100,
623            timestamp: "2024-01-01T00:00:00Z".to_string(),
624            losses: HashMap::from([("train_loss".to_string(), 0.5)]),
625            accuracies: HashMap::from([("train_acc".to_string(), 0.8)]),
626            learning_rate: 0.001,
627            custom_metrics: HashMap::new(),
628            system_metrics: SystemMetrics::default(),
629        };
630
631        visualizer.add_metrics(metrics);
632        assert_eq!(visualizer.metrics_history.len(), 1);
633
634        visualizer.clear_history();
635        assert!(visualizer.metrics_history.is_empty());
636    }
637
638    #[test]
639    fn test_plot_management() {
640        let config = VisualizationConfig::default();
641        let mut visualizer = TrainingVisualizer::<f32>::new(config);
642
643        let plot_config = PlotConfig::default();
644        visualizer.add_plot("test_plot".to_string(), plot_config);
645
646        assert!(visualizer.active_plots.contains_key("test_plot"));
647
648        let removed = visualizer.remove_plot("test_plot");
649        assert!(removed.is_some());
650        assert!(!visualizer.active_plots.contains_key("test_plot"));
651    }
652}