1use super::config::{DownsamplingStrategy, VisualizationConfig};
7use crate::error::{NeuralError, Result};
8
9use num_traits::Float;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::fmt::Debug;
13use std::fs;
14use std::path::PathBuf;
15
16#[allow(dead_code)]
18pub struct TrainingVisualizer<F: Float + Debug> {
19 metrics_history: Vec<TrainingMetrics<F>>,
21 config: VisualizationConfig,
23 active_plots: HashMap<String, PlotConfig>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct TrainingMetrics<F: Float + Debug> {
30 pub epoch: usize,
32 pub step: usize,
34 pub timestamp: String,
36 pub losses: HashMap<String, F>,
38 pub accuracies: HashMap<String, F>,
40 pub learning_rate: F,
42 pub custom_metrics: HashMap<String, F>,
44 pub system_metrics: SystemMetrics,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct SystemMetrics {
51 pub memory_usage_mb: f64,
53 pub gpu_memory_mb: Option<f64>,
55 pub cpu_utilization: f64,
57 pub gpu_utilization: Option<f64>,
59 pub step_duration_ms: f64,
61 pub samples_per_second: f64,
63}
64
65#[derive(Debug, Clone, Serialize)]
67pub struct PlotConfig {
68 pub title: String,
70 pub x_axis: AxisConfig,
72 pub y_axis: AxisConfig,
74 pub series: Vec<SeriesConfig>,
76 pub plot_type: PlotType,
78 pub update_mode: UpdateMode,
80}
81
82#[derive(Debug, Clone, Serialize)]
84pub struct AxisConfig {
85 pub label: String,
87 pub scale: AxisScale,
89 pub range: Option<(f64, f64)>,
91 pub show_grid: bool,
93 pub ticks: TickConfig,
95}
96
97#[derive(Debug, Clone, PartialEq, Serialize)]
99pub enum AxisScale {
100 Linear,
102 Log,
104 Sqrt,
106 Custom(String),
108}
109
110#[derive(Debug, Clone, Serialize)]
112pub struct TickConfig {
113 pub interval: Option<f64>,
115 pub format: TickFormat,
117 pub show_labels: bool,
119 pub rotation: f32,
121}
122
123#[derive(Debug, Clone, PartialEq, Serialize)]
125pub enum TickFormat {
126 Auto,
128 Fixed(u32),
130 Scientific,
132 Percentage,
134 Custom(String),
136}
137
138#[derive(Debug, Clone, Serialize)]
140pub struct SeriesConfig {
141 pub name: String,
143 pub data_source: String,
145 pub style: LineStyleConfig,
147 pub markers: MarkerConfig,
149 pub color: String,
151 pub opacity: f32,
153}
154
155#[derive(Debug, Clone, Serialize)]
157pub struct LineStyleConfig {
158 pub style: LineStyle,
160 pub width: f32,
162 pub smoothing: bool,
164 pub smoothing_window: usize,
166}
167
168#[derive(Debug, Clone, PartialEq, Serialize)]
170pub enum LineStyle {
171 Solid,
173 Dashed,
175 Dotted,
177 DashDot,
179}
180
181#[derive(Debug, Clone, Serialize)]
183pub struct MarkerConfig {
184 pub show: bool,
186 pub shape: MarkerShape,
188 pub size: f32,
190 pub fill_color: String,
192 pub border_color: String,
194}
195
196#[derive(Debug, Clone, PartialEq, Serialize)]
198pub enum MarkerShape {
199 Circle,
201 Square,
203 Triangle,
205 Diamond,
207 Cross,
209 Plus,
211}
212
213#[derive(Debug, Clone, PartialEq, Serialize)]
215pub enum PlotType {
216 Line,
218 Scatter,
220 Bar,
222 Area,
224 Histogram,
226 Box,
228 Heatmap,
230}
231
232#[derive(Debug, Clone, PartialEq, Serialize)]
234pub enum UpdateMode {
235 Append,
237 Replace,
239 Rolling(usize),
241}
242
243impl<F: Float + Debug + 'static + num_traits::FromPrimitive + Send + Sync> TrainingVisualizer<F> {
246 pub fn new(config: VisualizationConfig) -> Self {
248 Self {
249 metrics_history: Vec::new(),
250 config,
251 active_plots: HashMap::new(),
252 }
253 }
254
255 pub fn add_metrics(&mut self, metrics: TrainingMetrics<F>) {
257 self.metrics_history.push(metrics);
258
259 if self.metrics_history.len() > self.config.performance.max_points_per_plot
261 && self.config.performance.enable_downsampling
262 {
263 self.downsample_metrics();
264 }
265 }
266
267 pub fn visualize_training_curves(&self) -> Result<Vec<PathBuf>> {
269 let mut output_files = Vec::new();
270
271 if let Some(loss_plot) = self.create_loss_plot()? {
273 let loss_path = self.config.output_dir.join("training_loss.html");
274 fs::write(&loss_path, loss_plot)
275 .map_err(|e| NeuralError::IOError(format!("Failed to write loss plot: {}", e)))?;
276 output_files.push(loss_path);
277 }
278
279 if let Some(accuracy_plot) = self.create_accuracy_plot()? {
281 let accuracy_path = self.config.output_dir.join("training_accuracy.html");
282 fs::write(&accuracy_path, accuracy_plot).map_err(|e| {
283 NeuralError::IOError(format!("Failed to write accuracy plot: {}", e))
284 })?;
285 output_files.push(accuracy_path);
286 }
287
288 if let Some(lr_plot) = self.create_learning_rate_plot()? {
290 let lr_path = self.config.output_dir.join("learning_rate.html");
291 fs::write(&lr_path, lr_plot).map_err(|e| {
292 NeuralError::IOError(format!("Failed to write learning rate plot: {}", e))
293 })?;
294 output_files.push(lr_path);
295 }
296
297 if let Some(system_plot) = self.create_system_metrics_plot()? {
299 let system_path = self.config.output_dir.join("system_metrics.html");
300 fs::write(&system_path, system_plot).map_err(|e| {
301 NeuralError::IOError(format!("Failed to write system metrics plot: {}", e))
302 })?;
303 output_files.push(system_path);
304 }
305
306 Ok(output_files)
307 }
308
309 pub fn get_metrics_history(&self) -> &[TrainingMetrics<F>] {
311 &self.metrics_history
312 }
313
314 pub fn clear_history(&mut self) {
316 self.metrics_history.clear();
317 }
318
319 pub fn add_plot(&mut self, name: String, config: PlotConfig) {
321 self.active_plots.insert(name, config);
322 }
323
324 pub fn remove_plot(&mut self, name: &str) -> Option<PlotConfig> {
326 self.active_plots.remove(name)
327 }
328
329 pub fn update_config(&mut self, config: VisualizationConfig) {
331 self.config = config;
332 }
333
334 fn downsample_metrics(&mut self) {
335 match self.config.performance.downsampling_strategy {
337 DownsamplingStrategy::Uniform => {
338 let step = self.metrics_history.len() / self.config.performance.max_points_per_plot;
340 if step > 1 {
341 let mut downsampled = Vec::new();
342 for (i, metric) in self.metrics_history.iter().enumerate() {
343 if i % step == 0 {
344 downsampled.push(metric.clone());
345 }
346 }
347 self.metrics_history = downsampled;
348 }
349 }
350 _ => {
351 if self.metrics_history.len() > self.config.performance.max_points_per_plot {
353 let start =
354 self.metrics_history.len() - self.config.performance.max_points_per_plot;
355 self.metrics_history.drain(0..start);
356 }
357 }
358 }
359 }
360
361 fn create_loss_plot(&self) -> Result<Option<String>> {
362 if self.metrics_history.is_empty() {
363 return Ok(None);
364 }
365
366 let plot_html = r#"
369<!DOCTYPE html>
370<html>
371<head>
372 <title>Training Loss</title>
373 <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
374</head>
375<body>
376 <div id="lossPlot" style="width:100%;height:500px;"></div>
377 <script>
378 // TODO: Implement actual loss curve plotting
379 var trace = {
380 x: [1, 2, 3, 4],
381 y: [0.8, 0.6, 0.4, 0.3],
382 type: 'scatter',
383 name: 'Training Loss'
384 };
385
386 var layout = {
387 title: 'Training Loss Over Time',
388 xaxis: { title: 'Epoch' },
389 yaxis: { title: 'Loss' }
390 };
391
392 Plotly.newPlot('lossPlot', [trace], layout);
393 </script>
394</body>
395</html>"#;
396
397 Ok(Some(plot_html.to_string()))
398 }
399
400 fn create_accuracy_plot(&self) -> Result<Option<String>> {
401 if self.metrics_history.is_empty() {
402 return Ok(None);
403 }
404
405 Ok(None)
407 }
408
409 fn create_learning_rate_plot(&self) -> Result<Option<String>> {
410 if self.metrics_history.is_empty() {
411 return Ok(None);
412 }
413
414 Ok(None)
416 }
417
418 fn create_system_metrics_plot(&self) -> Result<Option<String>> {
419 if self.metrics_history.is_empty() {
420 return Ok(None);
421 }
422
423 Ok(None)
425 }
426}
427
428impl Default for PlotConfig {
431 fn default() -> Self {
432 Self {
433 title: "Training Metrics".to_string(),
434 x_axis: AxisConfig::default(),
435 y_axis: AxisConfig::default(),
436 series: Vec::new(),
437 plot_type: PlotType::Line,
438 update_mode: UpdateMode::Append,
439 }
440 }
441}
442
443impl Default for AxisConfig {
444 fn default() -> Self {
445 Self {
446 label: "".to_string(),
447 scale: AxisScale::Linear,
448 range: None,
449 show_grid: true,
450 ticks: TickConfig::default(),
451 }
452 }
453}
454
455impl Default for TickConfig {
456 fn default() -> Self {
457 Self {
458 interval: None,
459 format: TickFormat::Auto,
460 show_labels: true,
461 rotation: 0.0,
462 }
463 }
464}
465
466impl Default for SeriesConfig {
467 fn default() -> Self {
468 Self {
469 name: "Series".to_string(),
470 data_source: "".to_string(),
471 style: LineStyleConfig::default(),
472 markers: MarkerConfig::default(),
473 color: "#1f77b4".to_string(), opacity: 1.0,
475 }
476 }
477}
478
479impl Default for LineStyleConfig {
480 fn default() -> Self {
481 Self {
482 style: LineStyle::Solid,
483 width: 2.0,
484 smoothing: false,
485 smoothing_window: 5,
486 }
487 }
488}
489
490impl Default for MarkerConfig {
491 fn default() -> Self {
492 Self {
493 show: false,
494 shape: MarkerShape::Circle,
495 size: 6.0,
496 fill_color: "#1f77b4".to_string(),
497 border_color: "#1f77b4".to_string(),
498 }
499 }
500}
501
502impl Default for SystemMetrics {
503 fn default() -> Self {
504 Self {
505 memory_usage_mb: 0.0,
506 gpu_memory_mb: None,
507 cpu_utilization: 0.0,
508 gpu_utilization: None,
509 step_duration_ms: 0.0,
510 samples_per_second: 0.0,
511 }
512 }
513}
514
515#[cfg(test)]
516mod tests {
517 use super::*;
518
519 #[test]
520 fn test_training_visualizer_creation() {
521 let config = VisualizationConfig::default();
522 let visualizer = TrainingVisualizer::<f32>::new(config);
523
524 assert!(visualizer.metrics_history.is_empty());
525 assert!(visualizer.active_plots.is_empty());
526 }
527
528 #[test]
529 fn test_add_metrics() {
530 let config = VisualizationConfig::default();
531 let mut visualizer = TrainingVisualizer::<f32>::new(config);
532
533 let metrics = TrainingMetrics {
534 epoch: 1,
535 step: 100,
536 timestamp: "2024-01-01T00:00:00Z".to_string(),
537 losses: HashMap::from([("train_loss".to_string(), 0.5)]),
538 accuracies: HashMap::from([("train_acc".to_string(), 0.8)]),
539 learning_rate: 0.001,
540 custom_metrics: HashMap::new(),
541 system_metrics: SystemMetrics::default(),
542 };
543
544 visualizer.add_metrics(metrics);
545 assert_eq!(visualizer.metrics_history.len(), 1);
546 }
547
548 #[test]
549 fn test_plot_config_defaults() {
550 let config = PlotConfig::default();
551 assert_eq!(config.title, "Training Metrics");
552 assert_eq!(config.plot_type, PlotType::Line);
553 assert_eq!(config.update_mode, UpdateMode::Append);
554 }
555
556 #[test]
557 fn test_axis_scale_variants() {
558 assert_eq!(AxisScale::Linear, AxisScale::Linear);
559 assert_eq!(AxisScale::Log, AxisScale::Log);
560 assert_eq!(AxisScale::Sqrt, AxisScale::Sqrt);
561
562 let custom = AxisScale::Custom("symlog".to_string());
563 match custom {
564 AxisScale::Custom(name) => assert_eq!(name, "symlog"),
565 _ => panic!("Expected custom scale"),
566 }
567 }
568
569 #[test]
570 fn test_marker_shapes() {
571 let shapes = [
572 MarkerShape::Circle,
573 MarkerShape::Square,
574 MarkerShape::Triangle,
575 MarkerShape::Diamond,
576 MarkerShape::Cross,
577 MarkerShape::Plus,
578 ];
579
580 assert_eq!(shapes.len(), 6);
581 assert_eq!(shapes[0], MarkerShape::Circle);
582 }
583
584 #[test]
585 fn test_plot_types() {
586 let types = [
587 PlotType::Line,
588 PlotType::Scatter,
589 PlotType::Bar,
590 PlotType::Area,
591 PlotType::Histogram,
592 PlotType::Box,
593 PlotType::Heatmap,
594 ];
595
596 assert_eq!(types.len(), 7);
597 assert_eq!(types[0], PlotType::Line);
598 }
599
600 #[test]
601 fn test_update_modes() {
602 let append = UpdateMode::Append;
603 let replace = UpdateMode::Replace;
604 let rolling = UpdateMode::Rolling(100);
605
606 assert_eq!(append, UpdateMode::Append);
607 assert_eq!(replace, UpdateMode::Replace);
608
609 match rolling {
610 UpdateMode::Rolling(size) => assert_eq!(size, 100),
611 _ => panic!("Expected rolling update mode"),
612 }
613 }
614
615 #[test]
616 fn test_clear_history() {
617 let config = VisualizationConfig::default();
618 let mut visualizer = TrainingVisualizer::<f32>::new(config);
619
620 let metrics = TrainingMetrics {
621 epoch: 1,
622 step: 100,
623 timestamp: "2024-01-01T00:00:00Z".to_string(),
624 losses: HashMap::from([("train_loss".to_string(), 0.5)]),
625 accuracies: HashMap::from([("train_acc".to_string(), 0.8)]),
626 learning_rate: 0.001,
627 custom_metrics: HashMap::new(),
628 system_metrics: SystemMetrics::default(),
629 };
630
631 visualizer.add_metrics(metrics);
632 assert_eq!(visualizer.metrics_history.len(), 1);
633
634 visualizer.clear_history();
635 assert!(visualizer.metrics_history.is_empty());
636 }
637
638 #[test]
639 fn test_plot_management() {
640 let config = VisualizationConfig::default();
641 let mut visualizer = TrainingVisualizer::<f32>::new(config);
642
643 let plot_config = PlotConfig::default();
644 visualizer.add_plot("test_plot".to_string(), plot_config);
645
646 assert!(visualizer.active_plots.contains_key("test_plot"));
647
648 let removed = visualizer.remove_plot("test_plot");
649 assert!(removed.is_some());
650 assert!(!visualizer.active_plots.contains_key("test_plot"));
651 }
652}