1use super::config::{DownsamplingStrategy, VisualizationConfig};
7use crate::error::{NeuralError, Result};
8use scirs2_core::numeric::Float;
9use scirs2_core::NumAssign;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::fmt::Debug;
13use std::fs;
14use std::path::PathBuf;
15#[allow(dead_code)]
17pub struct TrainingVisualizer<F: Float + Debug + NumAssign> {
18 metrics_history: Vec<TrainingMetrics<F>>,
20 config: VisualizationConfig,
22 active_plots: HashMap<String, PlotConfig>,
24}
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct TrainingMetrics<F: Float + Debug + NumAssign> {
28 pub epoch: usize,
30 pub step: usize,
32 pub timestamp: String,
34 pub losses: HashMap<String, F>,
36 pub accuracies: HashMap<String, F>,
38 pub learning_rate: F,
40 pub custom_metrics: HashMap<String, F>,
42 pub system_metrics: SystemMetrics,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct SystemMetrics {
49 pub memory_usage_mb: f64,
51 pub gpu_memory_mb: Option<f64>,
53 pub cpu_utilization: f64,
55 pub gpu_utilization: Option<f64>,
57 pub step_duration_ms: f64,
59 pub samples_per_second: f64,
61}
62
63#[derive(Debug, Clone, Serialize)]
65pub struct PlotConfig {
66 pub title: String,
68 pub x_axis: AxisConfig,
70 pub y_axis: AxisConfig,
72 pub series: Vec<SeriesConfig>,
74 pub plot_type: PlotType,
76 pub update_mode: UpdateMode,
78}
79
80#[derive(Debug, Clone, Serialize)]
82pub struct AxisConfig {
83 pub label: String,
85 pub scale: AxisScale,
87 pub range: Option<(f64, f64)>,
89 pub show_grid: bool,
91 pub ticks: TickConfig,
93}
94
95#[derive(Debug, Clone, PartialEq, Serialize)]
97pub enum AxisScale {
98 Linear,
100 Log,
102 Sqrt,
104 Custom(String),
106}
107
108#[derive(Debug, Clone, Serialize)]
110pub struct TickConfig {
111 pub interval: Option<f64>,
113 pub format: TickFormat,
115 pub show_labels: bool,
117 pub rotation: f32,
119}
120
121#[derive(Debug, Clone, Serialize)]
123pub enum TickFormat {
124 Auto,
126 Fixed(u32),
128 Scientific,
130 Percentage,
132 Custom(String),
134}
135
136#[derive(Debug, Clone, Serialize)]
138pub struct SeriesConfig {
139 pub name: String,
141 pub data_source: String,
143 pub style: LineStyleConfig,
145 pub markers: MarkerConfig,
147 pub color: String,
149 pub opacity: f32,
151}
152
153#[derive(Debug, Clone, Serialize)]
155pub struct LineStyleConfig {
156 pub style: LineStyle,
157 pub width: f32,
159 pub smoothing: bool,
161 pub smoothing_window: usize,
163}
164
165#[derive(Debug, Clone, PartialEq, Serialize)]
167pub enum LineStyle {
168 Solid,
170 Dashed,
172 Dotted,
174 DashDot,
176}
177
178#[derive(Debug, Clone, Serialize)]
180pub struct MarkerConfig {
181 pub show: bool,
183 pub shape: MarkerShape,
185 pub size: f32,
187 pub fill_color: String,
189 pub border_color: String,
191}
192
193#[derive(Debug, Clone, PartialEq, Serialize)]
195pub enum MarkerShape {
196 Circle,
198 Square,
200 Triangle,
202 Diamond,
204 Cross,
206 Plus,
208}
209
210#[derive(Debug, Clone, PartialEq, Serialize)]
212pub enum PlotType {
213 Line,
215 Scatter,
217 Bar,
219 Area,
221 Histogram,
223 Box,
225 Heatmap,
227}
228
229#[derive(Debug, Clone, PartialEq, Serialize)]
231pub enum UpdateMode {
232 Append,
234 Replace,
236 Rolling(usize),
238}
239
240impl<
242 F: Float + Debug + NumAssign + 'static + scirs2_core::numeric::FromPrimitive + Send + Sync,
243 > TrainingVisualizer<F>
244{
245 pub fn new(config: VisualizationConfig) -> Self {
247 Self {
248 metrics_history: Vec::new(),
249 config,
250 active_plots: HashMap::new(),
251 }
252 }
253 pub fn add_metrics(&mut self, metrics: TrainingMetrics<F>) {
255 self.metrics_history.push(metrics);
256 if self.metrics_history.len() > self.config.performance.max_points_per_plot
258 && self.config.performance.enable_downsampling
259 {
260 self.downsample_metrics();
261 }
262 }
263
264 pub fn visualize_training_curves(&self) -> Result<Vec<PathBuf>> {
266 let mut output_files = Vec::new();
267 if let Some(loss_plot) = self.create_loss_plot()? {
269 let loss_path = self.config.output_dir.join("training_loss.html");
270 fs::write(&loss_path, loss_plot)
271 .map_err(|e| NeuralError::IOError(format!("Failed to write loss plot: {}", e)))?;
272 output_files.push(loss_path);
273 }
274
275 if let Some(accuracy_plot) = self.create_accuracy_plot()? {
277 let accuracy_path = self.config.output_dir.join("training_accuracy.html");
278 fs::write(&accuracy_path, accuracy_plot).map_err(|e| {
279 NeuralError::IOError(format!("Failed to write accuracy plot: {}", e))
280 })?;
281 output_files.push(accuracy_path);
282 }
283
284 if let Some(lr_plot) = self.create_learning_rate_plot()? {
286 let lr_path = self.config.output_dir.join("learning_rate.html");
287 fs::write(&lr_path, lr_plot).map_err(|e| {
288 NeuralError::IOError(format!("Failed to write learning rate plot: {}", e))
289 })?;
290 output_files.push(lr_path);
291 }
292
293 if let Some(system_plot) = self.create_system_metrics_plot()? {
295 let system_path = self.config.output_dir.join("system_metrics.html");
296 fs::write(&system_path, system_plot).map_err(|e| {
297 NeuralError::IOError(format!("Failed to write system metrics plot: {}", e))
298 })?;
299 output_files.push(system_path);
300 }
301
302 Ok(output_files)
303 }
304
305 pub fn get_metrics_history(&self) -> &[TrainingMetrics<F>] {
307 &self.metrics_history
308 }
309
310 pub fn clear_history(&mut self) {
312 self.metrics_history.clear();
313 }
314 pub fn add_plot(&mut self, name: String, config: PlotConfig) {
316 self.active_plots.insert(name, config);
317 }
318
319 pub fn remove_plot(&mut self, name: &str) -> Option<PlotConfig> {
321 self.active_plots.remove(name)
322 }
323
324 pub fn update_config(&mut self, config: VisualizationConfig) {
326 self.config = config;
327 }
328
329 fn downsample_metrics(&mut self) {
330 if self.metrics_history.len() <= self.config.performance.max_points_per_plot {
332 return; }
334
335 match self.config.performance.downsampling_strategy {
336 DownsamplingStrategy::Uniform => {
337 let step = self.metrics_history.len() / self.config.performance.max_points_per_plot;
339 if step > 1 {
340 let mut downsampled = Vec::new();
341 for (i, metric) in self.metrics_history.iter().enumerate() {
342 if i % step == 0 {
343 downsampled.push(metric.clone());
344 }
345 }
346 self.metrics_history = downsampled;
347 }
348 }
349 DownsamplingStrategy::LTTB => {
350 self.downsample_lttb();
352 }
353 DownsamplingStrategy::MinMax => {
354 self.downsample_minmax();
356 }
357 DownsamplingStrategy::Statistical => {
358 self.downsample_statistical();
360 }
361 }
362 }
363
364 fn downsample_lttb(&mut self) {
366 let target_points = self.config.performance.max_points_per_plot;
367 if self.metrics_history.len() <= target_points {
368 return;
369 }
370 let bucket_size = self.metrics_history.len() as f64 / target_points as f64;
371 let mut downsampled = Vec::new();
372 downsampled.push(self.metrics_history[0].clone());
374 for bucket in 1..(target_points - 1) {
376 let bucket_start = (bucket as f64 * bucket_size) as usize;
377 let bucket_end =
378 ((bucket + 1) as f64 * bucket_size).min(self.metrics_history.len() as f64) as usize;
379 let next_bucket_start = bucket_end;
381 let next_bucket_end =
382 ((bucket + 2) as f64 * bucket_size).min(self.metrics_history.len() as f64) as usize;
383 let avg_epoch = if next_bucket_end > next_bucket_start {
384 let sum: usize = (next_bucket_start..next_bucket_end)
385 .map(|i| self.metrics_history[i].epoch)
386 .sum();
387 sum as f64 / (next_bucket_end - next_bucket_start) as f64
388 } else {
389 self.metrics_history[self.metrics_history.len() - 1].epoch as f64
390 };
391 let mut max_area = 0.0f64;
393 let mut selected_idx = bucket_start;
394 let prev_epoch = downsampled.last().expect("Operation failed").epoch as f64;
395 for i in bucket_start..bucket_end {
396 let curr_epoch = self.metrics_history[i].epoch as f64;
397 let area = ((prev_epoch - avg_epoch) * (curr_epoch - prev_epoch)).abs();
399 if area > max_area {
400 max_area = area;
401 selected_idx = i;
402 }
403 }
404
405 downsampled.push(self.metrics_history[selected_idx].clone());
406 }
407
408 downsampled.push(self.metrics_history[self.metrics_history.len() - 1].clone());
410 self.metrics_history = downsampled;
411 }
412
413 fn downsample_minmax(&mut self) {
415 let target_points = self.config.performance.max_points_per_plot;
416 if self.metrics_history.len() <= target_points {
417 return;
418 }
419
420 let mut downsampled = Vec::new();
421 let bucket_size = self.metrics_history.len() / (target_points / 2); if bucket_size == 0 {
423 return;
424 }
425
426 for chunk in self.metrics_history.chunks(bucket_size) {
427 if chunk.is_empty() {
428 continue;
429 }
430
431 let mut min_metric = &chunk[0];
433 let mut max_metric = &chunk[0];
434
435 for metric in chunk {
436 let current_value = metric
438 .losses
439 .values()
440 .next()
441 .map(|v| v.to_f64().unwrap_or(0.0))
442 .unwrap_or(metric.epoch as f64);
443 let min_value = min_metric
444 .losses
445 .values()
446 .next()
447 .map(|v| v.to_f64().unwrap_or(0.0))
448 .unwrap_or(min_metric.epoch as f64);
449 let max_value = max_metric
450 .losses
451 .values()
452 .next()
453 .map(|v| v.to_f64().unwrap_or(0.0))
454 .unwrap_or(max_metric.epoch as f64);
455
456 if current_value < min_value {
457 min_metric = metric;
458 }
459 if current_value > max_value {
460 max_metric = metric;
461 }
462 }
463
464 if min_metric.epoch <= max_metric.epoch {
466 downsampled.push(min_metric.clone());
467 if min_metric.epoch != max_metric.epoch {
468 downsampled.push(max_metric.clone());
469 }
470 } else {
471 downsampled.push(max_metric.clone());
472 }
473 }
474 downsampled.sort_by_key(|m| m.epoch);
476
477 if downsampled.len() > target_points {
479 let step = downsampled.len() / target_points;
480 let mut final_downsampled = Vec::new();
481 for (i, metric) in downsampled.iter().enumerate() {
482 if i % step == 0 {
483 final_downsampled.push(metric.clone());
484 }
485 }
486 self.metrics_history = final_downsampled;
487 } else {
488 self.metrics_history = downsampled;
489 }
490 }
491
492 fn downsample_statistical(&mut self) {
494 let target_points = self.config.performance.max_points_per_plot;
495 if self.metrics_history.len() <= target_points {
496 return;
497 }
498
499 let mut downsampled = Vec::new();
500 let mut importance_scores: Vec<(usize, f64)> = Vec::new();
502 for (i, metric) in self.metrics_history.iter().enumerate() {
503 let mut score = 0.0f64;
504 if i > 0 && i < self.metrics_history.len() - 1 {
506 let prev_metric = &self.metrics_history[i - 1];
507 let next_metric = &self.metrics_history[i + 1];
508 for (loss_name, &loss_value) in &metric.losses {
510 if let (Some(&prev_loss), Some(&next_loss)) = (
511 prev_metric.losses.get(loss_name),
512 next_metric.losses.get(loss_name),
513 ) {
514 let prev_val = prev_loss.to_f64().unwrap_or(0.0);
515 let curr_val = loss_value.to_f64().unwrap_or(0.0);
516 let next_val = next_loss.to_f64().unwrap_or(0.0);
517 let curvature = ((next_val - curr_val) - (curr_val - prev_val)).abs();
519 score += curvature;
520 }
521 }
522 }
523
524 if i == 0 || i == self.metrics_history.len() - 1 {
526 score += 1000.0; }
528
529 importance_scores.push((i, score));
530 }
531 importance_scores
533 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
534
535 let mut selected_indices: Vec<usize> = importance_scores
537 .iter()
538 .take(target_points)
539 .map(|(idx, _)| *idx)
540 .collect();
541 selected_indices.sort();
542
543 for &idx in &selected_indices {
544 downsampled.push(self.metrics_history[idx].clone());
545 }
546
547 self.metrics_history = downsampled;
548 }
549
550 fn create_loss_plot(&self) -> Result<Option<String>> {
551 if self.metrics_history.is_empty() {
552 return Ok(None);
553 }
554
555 let mut loss_data = std::collections::HashMap::new();
557 let mut epochs = Vec::new();
558
559 for metric in &self.metrics_history {
560 epochs.push(metric.epoch);
561 for (loss_name, loss_value) in &metric.losses {
562 loss_data
563 .entry(loss_name.clone())
564 .or_insert_with(Vec::new)
565 .push(loss_value.to_f64().unwrap_or(0.0));
566 }
567 }
568
569 if loss_data.is_empty() {
570 return Ok(None);
571 }
572
573 let mut traces = Vec::new();
575 let colors = [
576 "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b",
577 ];
578
579 for (i, (loss_name, values)) in loss_data.iter().enumerate() {
580 let color = colors[i % colors.len()];
581 let epochs_json = serde_json::to_string(&epochs).unwrap_or_default();
582 let values_json = serde_json::to_string(values).unwrap_or_default();
583
584 traces.push(format!(
585 r#"{{
586 x: {},
587 y: {},
588 type: 'scatter',
589 mode: 'lines+markers',
590 name: '{}',
591 line: {{ color: '{}', width: 2 }},
592 marker: {{ size: 6, color: '{}' }}
593 }}"#,
594 epochs_json, values_json, loss_name, color, color
595 ));
596 }
597
598 let traces_str = traces.join(",\n ");
599 let plot_html = format!(
600 r#"
601<!DOCTYPE html>
602<html>
603<head>
604 <title>Training Loss</title>
605 <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
606 <style>
607 body {{ font-family: Arial, sans-serif; margin: 20px; }}
608 .plot-container {{ width: 100%; height: 600px; }}
609 </style>
610</head>
611<body>
612 <h2>Training Loss Curves</h2>
613 <div id="lossPlot" class="plot-container"></div>
614 <script>
615 var traces = [
616 {}
617
618 var layout = {{
619 title: {{
620 text: 'Training Loss Over Time',
621 font: {{ size: 18 }}
622 }},
623 xaxis: {{
624 title: 'Epoch',
625 showgrid: true,
626 gridcolor: '#e0e0e0'
627 yaxis: {{
628 title: 'Loss',
629 hovermode: 'x unified',
630 legend: {{
631 x: 1,
632 y: 1,
633 bgcolor: 'rgba(255,255,255,0.8)',
634 bordercolor: '#000',
635 borderwidth: 1
636 plot_bgcolor: '#ffffff',
637 paper_bgcolor: '#ffffff'
638 }};
639 var config = {{
640 responsive: true,
641 displayModeBar: true,
642 modeBarButtonsToRemove: ['pan2d', 'lasso2d', 'select2d']
643 }};
644
645 Plotly.newPlot('lossPlot', traces, layout, config);
646 </script>
647</body>
648</html>"#,
649 traces_str
650 );
651
652 Ok(Some(plot_html))
653 }
654
655 fn create_accuracy_plot(&self) -> Result<Option<String>> {
656 if self.metrics_history.is_empty() {
657 return Ok(None);
658 }
659
660 let mut accuracy_data = std::collections::HashMap::new();
662 let mut epochs = Vec::new();
663
664 for metric in &self.metrics_history {
665 epochs.push(metric.epoch);
666 for (acc_name, acc_value) in &metric.accuracies {
667 accuracy_data
668 .entry(acc_name.clone())
669 .or_insert_with(Vec::new)
670 .push(acc_value.to_f64().unwrap_or(0.0));
671 }
672 }
673
674 if accuracy_data.is_empty() {
675 return Ok(None);
676 }
677
678 let mut traces = Vec::new();
680 let colors = [
681 "#2ca02c", "#ff7f0e", "#1f77b4", "#d62728", "#9467bd", "#8c564b",
682 ];
683
684 for (i, (acc_name, values)) in accuracy_data.iter().enumerate() {
685 let color = colors[i % colors.len()];
686 let epochs_json = serde_json::to_string(&epochs).unwrap_or_default();
687 let values_json = serde_json::to_string(values).unwrap_or_default();
688
689 traces.push(format!(
690 r#"{{
691 x: {},
692 y: {},
693 type: 'scatter',
694 mode: 'lines+markers',
695 name: '{}',
696 line: {{ color: '{}', width: 2 }},
697 marker: {{ size: 6, color: '{}' }}
698 }}"#,
699 epochs_json, values_json, acc_name, color, color
700 ));
701 }
702
703 let traces_str = traces.join(",\n ");
704
705 let plot_html = format!(
706 r#"
707<!DOCTYPE html>
708<html>
709<head>
710 <title>Training Accuracy</title>
711 <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
712 <style>
713 body {{ font-family: Arial, sans-serif; margin: 20px; }}
714 .plot-container {{ width: 100%; height: 600px; }}
715 </style>
716</head>
717<body>
718 <h2>Training Accuracy Curves</h2>
719 <div id="accuracyPlot" class="plot-container"></div>
720 <script>
721 var traces = [
722 {}
723 ];
724
725 var layout = {{
726 title: {{
727 text: "Training Accuracy Over Time",
728 font: {{ size: 18 }}
729 }},
730 xaxis: {{
731 title: 'Epoch',
732 showgrid: true,
733 gridcolor: '#e0e0e0'
734 }},
735 yaxis: {{
736 title: 'Accuracy',
737 showgrid: true,
738 gridcolor: '#e0e0e0',
739 range: [0, 1]
740 }},
741 hovermode: 'x unified',
742 legend: {{
743 x: 1,
744 y: 0,
745 bgcolor: 'rgba(255,255,255,0.8)',
746 bordercolor: '#000',
747 borderwidth: 1
748 }},
749 plot_bgcolor: '#ffffff',
750 paper_bgcolor: '#ffffff'
751 }};
752
753 var config = {{
754 responsive: true,
755 displayModeBar: true,
756 modeBarButtonsToRemove: ['pan2d', 'lasso2d', 'select2d']
757 }};
758
759 Plotly.newPlot('accuracyPlot', traces, layout, config);
760 </script>
761</body>
762</html>"#,
763 traces_str
764 );
765
766 Ok(Some(plot_html))
767 }
768 fn create_learning_rate_plot(&self) -> Result<Option<String>> {
769 if self.metrics_history.is_empty() {
770 return Ok(None);
771 }
772
773 let mut learning_rates = Vec::new();
775 let mut epochs = Vec::new();
776
777 for metric in &self.metrics_history {
778 epochs.push(metric.epoch);
779 learning_rates.push(metric.learning_rate.to_f64().unwrap_or(0.0));
780 }
781
782 if learning_rates.is_empty() {
783 return Ok(None);
784 }
785
786 let epochs_json = serde_json::to_string(&epochs).unwrap_or_default();
787 let lr_json = serde_json::to_string(&learning_rates).unwrap_or_default();
788
789 let plot_html = format!(
790 r#"
791<!DOCTYPE html>
792<html>
793<head>
794 <title>Learning Rate Schedule</title>
795 <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
796 <style>
797 body {{ font-family: Arial, sans-serif; margin: 20px; }}
798 .plot-container {{ width: 100%; height: 600px; }}
799 </style>
800</head>
801<body>
802 <h2>Learning Rate Schedule</h2>
803 <div id="lrPlot" class="plot-container"></div>
804 <script>
805 var trace = {{
806 x: {},
807 y: {},
808 type: 'scatter',
809 mode: 'lines+markers',
810 name: 'Learning Rate',
811 line: {{ color: '#d62728', width: 3 }},
812 marker: {{ size: 8, color: '#d62728' }}
813 }};
814
815 var layout = {{
816 title: {{
817 text: "Learning Rate Over Time",
818 font: {{ size: 18 }}
819 }},
820 xaxis: {{
821 title: 'Epoch',
822 showgrid: true,
823 gridcolor: '#e0e0e0'
824 }},
825 yaxis: {{
826 title: 'Learning Rate',
827 showgrid: true,
828 gridcolor: '#e0e0e0',
829 type: 'log'
830 }},
831 hovermode: 'x unified',
832 legend: {{
833 x: 1,
834 y: 1,
835 bgcolor: 'rgba(255,255,255,0.8)',
836 bordercolor: '#000',
837 borderwidth: 1
838 }},
839 plot_bgcolor: '#ffffff',
840 paper_bgcolor: '#ffffff'
841 }};
842
843 var config = {{
844 responsive: true,
845 displayModeBar: true,
846 modeBarButtonsToRemove: ['pan2d', 'lasso2d', 'select2d']
847 }};
848
849 Plotly.newPlot('lrPlot', [trace], layout, config);
850 </script>
851</body>
852</html>"#,
853 epochs_json, lr_json
854 );
855
856 Ok(Some(plot_html))
857 }
858 fn create_system_metrics_plot(&self) -> Result<Option<String>> {
859 if self.metrics_history.is_empty() {
860 return Ok(None);
861 }
862
863 let mut memory_usage = Vec::new();
865 let mut cpu_utilization = Vec::new();
866 let mut gpu_utilization = Vec::new();
867 let mut samples_per_second = Vec::new();
868 let mut epochs = Vec::new();
869
870 for metric in &self.metrics_history {
871 epochs.push(metric.epoch);
872 memory_usage.push(metric.system_metrics.memory_usage_mb);
873 cpu_utilization.push(metric.system_metrics.cpu_utilization);
874 if let Some(gpu_util) = metric.system_metrics.gpu_utilization {
875 gpu_utilization.push(gpu_util);
876 }
877 samples_per_second.push(metric.system_metrics.samples_per_second);
878 }
879
880 let epochs_json = serde_json::to_string(&epochs).unwrap_or_default();
881 let memory_json = serde_json::to_string(&memory_usage).unwrap_or_default();
882 let cpu_json = serde_json::to_string(&cpu_utilization).unwrap_or_default();
883 let gpu_json = if !gpu_utilization.is_empty() {
884 serde_json::to_string(&gpu_utilization).unwrap_or_default()
885 } else {
886 "[]".to_string()
887 };
888 let sps_json = serde_json::to_string(&samples_per_second).unwrap_or_default();
889
890 let plot_html = format!(
891 r#"
892<!DOCTYPE html>
893<html>
894<head>
895 <title>System Metrics</title>
896 <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
897 <style>
898 body {{ font-family: Arial, sans-serif; margin: 20px; }}
899 .plot-container {{ width: 100%; height: 400px; margin-bottom: 20px; }}
900 </style>
901</head>
902<body>
903 <h2>System Performance Metrics</h2>
904
905 <h3>Memory Usage</h3>
906 <div id="memoryPlot" class="plot-container"></div>
907
908 <h3>CPU & GPU Utilization</h3>
909 <div id="utilizationPlot" class="plot-container"></div>
910
911 <h3>Training Throughput</h3>
912 <div id="throughputPlot" class="plot-container"></div>
913
914 <script>
915 var epochs = {};
916
917 // Memory usage plot
918 var memoryTrace = {{
919 x: epochs,
920 y: {},
921 type: 'scatter',
922 mode: 'lines+markers',
923 name: 'Memory Usage (MB)',
924 line: {{ color: '#ff7f0e', width: 2 }},
925 marker: {{ size: 6, color: '#ff7f0e' }}
926 }};
927
928 var memoryLayout = {{
929 title: "Memory Usage Over Time",
930 xaxis: {{ title: 'Epoch' }},
931 yaxis: {{ title: 'Memory (MB)' }},
932 showlegend: false
933 }};
934
935 var config = {{
936 responsive: true,
937 displayModeBar: true,
938 modeBarButtonsToRemove: ['pan2d', 'lasso2d', 'select2d']
939 }};
940
941 Plotly.newPlot('memoryPlot', [memoryTrace], memoryLayout, config);
942
943 // CPU and GPU utilization plot
944 var traces = [{{
945 x: epochs,
946 y: {},
947 type: 'scatter',
948 mode: 'lines+markers',
949 name: 'CPU Utilization (%)',
950 line: {{ color: '#1f77b4', width: 2 }},
951 marker: {{ size: 6, color: '#1f77b4' }}
952 }}];
953
954 if ({}.length > 0) {{
955 traces.push({{
956 x: epochs,
957 y: {},
958 type: 'scatter',
959 mode: 'lines+markers',
960 name: 'GPU Utilization (%)',
961 line: {{ color: '#2ca02c', width: 2 }},
962 marker: {{ size: 6, color: '#2ca02c' }}
963 }});
964 }}
965
966 var utilizationLayout = {{
967 title: "CPU & GPU Utilization",
968 xaxis: {{ title: 'Epoch' }},
969 yaxis: {{ title: 'Utilization (%)', range: [0, 100] }}
970 }};
971
972 Plotly.newPlot('utilizationPlot', traces, utilizationLayout, config);
973
974 // Throughput plot
975 var throughputTrace = {{
976 x: epochs,
977 y: {},
978 type: 'scatter',
979 mode: 'lines+markers',
980 name: 'Samples/Second',
981 line: {{ color: '#9467bd', width: 2 }},
982 marker: {{ size: 6, color: '#9467bd' }}
983 }};
984
985 var throughputLayout = {{
986 title: "Training Throughput",
987 xaxis: {{ title: 'Epoch' }},
988 yaxis: {{ title: 'Samples per Second' }},
989 showlegend: false
990 }};
991
992 Plotly.newPlot('throughputPlot', [throughputTrace], throughputLayout, config);
993 </script>
994</body>
995</html>"#,
996 epochs_json, memory_json, cpu_json, gpu_json, gpu_json, sps_json
997 );
998
999 Ok(Some(plot_html))
1000 }
1001}
1002
1003impl Default for PlotConfig {
1005 fn default() -> Self {
1006 Self {
1007 title: "Training Metrics".to_string(),
1008 x_axis: AxisConfig::default(),
1009 y_axis: AxisConfig::default(),
1010 series: Vec::new(),
1011 plot_type: PlotType::Line,
1012 update_mode: UpdateMode::Append,
1013 }
1014 }
1015}
1016
1017impl Default for AxisConfig {
1018 fn default() -> Self {
1019 Self {
1020 label: "".to_string(),
1021 scale: AxisScale::Linear,
1022 range: None,
1023 show_grid: true,
1024 ticks: TickConfig::default(),
1025 }
1026 }
1027}
1028
1029impl Default for TickConfig {
1030 fn default() -> Self {
1031 Self {
1032 interval: None,
1033 format: TickFormat::Auto,
1034 show_labels: true,
1035 rotation: 0.0,
1036 }
1037 }
1038}
1039
1040impl Default for SeriesConfig {
1041 fn default() -> Self {
1042 Self {
1043 name: "Series".to_string(),
1044 data_source: "".to_string(),
1045 style: LineStyleConfig::default(),
1046 markers: MarkerConfig::default(),
1047 color: "#1f77b4".to_string(), opacity: 1.0,
1049 }
1050 }
1051}
1052
1053impl Default for LineStyleConfig {
1054 fn default() -> Self {
1055 Self {
1056 style: LineStyle::Solid,
1057 width: 2.0,
1058 smoothing: false,
1059 smoothing_window: 5,
1060 }
1061 }
1062}
1063
1064impl Default for MarkerConfig {
1065 fn default() -> Self {
1066 Self {
1067 show: false,
1068 shape: MarkerShape::Circle,
1069 size: 6.0,
1070 fill_color: "#1f77b4".to_string(),
1071 border_color: "#1f77b4".to_string(),
1072 }
1073 }
1074}
1075
1076impl Default for SystemMetrics {
1077 fn default() -> Self {
1078 Self {
1079 memory_usage_mb: 0.0,
1080 gpu_memory_mb: None,
1081 cpu_utilization: 0.0,
1082 gpu_utilization: None,
1083 step_duration_ms: 0.0,
1084 samples_per_second: 0.0,
1085 }
1086 }
1087}
1088#[cfg(test)]
1089mod tests {
1090 use super::*;
1091 #[test]
1092 fn test_training_visualizer_creation() {
1093 let config = VisualizationConfig::default();
1094 let visualizer = TrainingVisualizer::<f32>::new(config);
1095 assert!(visualizer.metrics_history.is_empty());
1096 assert!(visualizer.active_plots.is_empty());
1097 }
1098
1099 #[test]
1100 fn test_add_metrics() {
1101 let config = VisualizationConfig::default();
1102 let mut visualizer = TrainingVisualizer::<f32>::new(config);
1103 let metrics = TrainingMetrics {
1104 epoch: 1,
1105 step: 100,
1106 timestamp: "2024-01-01T00:00:00Z".to_string(),
1107 losses: HashMap::from([("train_loss".to_string(), 0.5)]),
1108 accuracies: HashMap::from([("train_acc".to_string(), 0.8)]),
1109 learning_rate: 0.001,
1110 custom_metrics: HashMap::new(),
1111 system_metrics: SystemMetrics::default(),
1112 };
1113 visualizer.add_metrics(metrics);
1114 assert_eq!(visualizer.metrics_history.len(), 1);
1115 }
1116
1117 #[test]
1118 fn test_plot_config_defaults() {
1119 let config = PlotConfig::default();
1120 assert_eq!(config.title, "Training Metrics");
1121 assert_eq!(config.plot_type, PlotType::Line);
1122 assert_eq!(config.update_mode, UpdateMode::Append);
1123 }
1124
1125 #[test]
1126 fn test_axis_scale_variants() {
1127 assert_eq!(AxisScale::Linear, AxisScale::Linear);
1128 assert_eq!(AxisScale::Log, AxisScale::Log);
1129 assert_eq!(AxisScale::Sqrt, AxisScale::Sqrt);
1130 let custom = AxisScale::Custom("symlog".to_string());
1131 match custom {
1132 AxisScale::Custom(name) => assert_eq!(name, "symlog"),
1133 _ => panic!("Expected custom scale"),
1134 }
1135 }
1136
1137 #[test]
1138 fn test_markershapes() {
1139 let shapes = [
1140 MarkerShape::Circle,
1141 MarkerShape::Square,
1142 MarkerShape::Triangle,
1143 MarkerShape::Diamond,
1144 MarkerShape::Cross,
1145 MarkerShape::Plus,
1146 ];
1147 assert_eq!(shapes.len(), 6);
1148 assert_eq!(shapes[0], MarkerShape::Circle);
1149 }
1150
1151 #[test]
1152 fn test_plot_types() {
1153 let types = [
1154 PlotType::Line,
1155 PlotType::Scatter,
1156 PlotType::Bar,
1157 PlotType::Area,
1158 PlotType::Histogram,
1159 PlotType::Box,
1160 PlotType::Heatmap,
1161 ];
1162 assert_eq!(types.len(), 7);
1163 assert_eq!(types[0], PlotType::Line);
1164 }
1165
1166 #[test]
1167 fn test_update_modes() {
1168 let append = UpdateMode::Append;
1169 let replace = UpdateMode::Replace;
1170 let rolling = UpdateMode::Rolling(100);
1171 assert_eq!(append, UpdateMode::Append);
1172 assert_eq!(replace, UpdateMode::Replace);
1173 match rolling {
1174 UpdateMode::Rolling(size) => assert_eq!(size, 100),
1175 _ => panic!("Expected rolling update mode"),
1176 }
1177 }
1178
1179 #[test]
1180 fn test_clear_history() {
1181 let config = VisualizationConfig::default();
1182 let mut visualizer = TrainingVisualizer::<f32>::new(config);
1183 visualizer.clear_history();
1184 assert!(visualizer.metrics_history.is_empty());
1185 }
1186
1187 #[test]
1188 fn test_plot_management() {
1189 let config = VisualizationConfig::default();
1190 let mut visualizer = TrainingVisualizer::<f32>::new(config);
1191 let plot_config = PlotConfig::default();
1192 visualizer.add_plot("test_plot".to_string(), plot_config);
1193 assert!(visualizer.active_plots.contains_key("test_plot"));
1194 let removed = visualizer.remove_plot("test_plot");
1195 assert!(removed.is_some());
1196 assert!(!visualizer.active_plots.contains_key("test_plot"));
1197 }
1198}