1use crate::error::{ScirsError, ScirsResult};
8use scirs2_core::error_context;
9use scirs2_core::ndarray::{Array1, ArrayView1}; use std::collections::HashMap;
11use std::fs::File;
12use std::io::Write;
13use std::path::Path;
14
15#[derive(Debug, Clone)]
17pub struct OptimizationTrajectory {
18 pub parameters: Vec<Array1<f64>>,
20 pub function_values: Vec<f64>,
22 pub gradient_norms: Vec<f64>,
24 pub step_sizes: Vec<f64>,
26 pub custom_metrics: HashMap<String, Vec<f64>>,
28 pub nit: Vec<usize>,
30 pub times: Vec<f64>,
32}
33
34impl OptimizationTrajectory {
35 pub fn new() -> Self {
37 Self {
38 parameters: Vec::new(),
39 function_values: Vec::new(),
40 gradient_norms: Vec::new(),
41 step_sizes: Vec::new(),
42 custom_metrics: HashMap::new(),
43 nit: Vec::new(),
44 times: Vec::new(),
45 }
46 }
47
48 pub fn add_point(
50 &mut self,
51 iteration: usize,
52 params: &ArrayView1<f64>,
53 function_value: f64,
54 time: f64,
55 ) {
56 self.nit.push(iteration);
57 self.parameters.push(params.to_owned());
58 self.function_values.push(function_value);
59 self.times.push(time);
60 }
61
62 pub fn add_gradient_norm(&mut self, grad_norm: f64) {
64 self.gradient_norms.push(grad_norm);
65 }
66
67 pub fn add_step_size(&mut self, step_size: f64) {
69 self.step_sizes.push(step_size);
70 }
71
72 pub fn add_custom_metric(&mut self, name: &str, value: f64) {
74 self.custom_metrics
75 .entry(name.to_string())
76 .or_default()
77 .push(value);
78 }
79
80 pub fn len(&self) -> usize {
82 self.nit.len()
83 }
84
85 pub fn is_empty(&self) -> bool {
87 self.nit.is_empty()
88 }
89
90 pub fn final_parameters(&self) -> Option<&Array1<f64>> {
92 self.parameters.last()
93 }
94
95 pub fn final_function_value(&self) -> Option<f64> {
97 self.function_values.last().copied()
98 }
99
100 pub fn convergence_rate(&self) -> Option<f64> {
102 if self.function_values.len() < 3 {
103 return None;
104 }
105
106 let n = self.function_values.len();
107 let mut rates = Vec::new();
108
109 for i in 1..(n - 1) {
110 let f_current = self.function_values[i];
111 let f_next = self.function_values[i + 1];
112 let f_prev = self.function_values[i - 1];
113
114 if (f_current - f_next).abs() > 1e-14 && (f_prev - f_current).abs() > 1e-14 {
115 let rate = (f_current - f_next).abs() / (f_prev - f_current).abs();
116 if rate.is_finite() && rate > 0.0 {
117 rates.push(rate);
118 }
119 }
120 }
121
122 if rates.is_empty() {
123 None
124 } else {
125 Some(rates.iter().sum::<f64>() / rates.len() as f64)
126 }
127 }
128}
129
130impl Default for OptimizationTrajectory {
131 fn default() -> Self {
132 Self::new()
133 }
134}
135
136#[derive(Debug, Clone)]
138pub struct VisualizationConfig {
139 pub format: OutputFormat,
141 pub width: u32,
143 pub height: u32,
145 pub title: Option<String>,
147 pub show_grid: bool,
149 pub log_scale_y: bool,
151 pub color_scheme: ColorScheme,
153 pub show_legend: bool,
155 pub custom_style: Option<String>,
157}
158
159impl Default for VisualizationConfig {
160 fn default() -> Self {
161 Self {
162 format: OutputFormat::Svg,
163 width: 800,
164 height: 600,
165 title: None,
166 show_grid: true,
167 log_scale_y: false,
168 color_scheme: ColorScheme::Default,
169 show_legend: true,
170 custom_style: None,
171 }
172 }
173}
174
175#[derive(Debug, Clone, Copy, PartialEq)]
177pub enum OutputFormat {
178 Svg,
179 Png,
180 Html,
181 Data, }
183
184#[derive(Debug, Clone, Copy, PartialEq)]
186pub enum ColorScheme {
187 Default,
188 Viridis,
189 Plasma,
190 Scientific,
191 Monochrome,
192}
193
194pub struct OptimizationVisualizer {
196 config: VisualizationConfig,
197}
198
199impl OptimizationVisualizer {
200 pub fn new() -> Self {
202 Self {
203 config: VisualizationConfig::default(),
204 }
205 }
206
207 pub fn with_config(config: VisualizationConfig) -> Self {
209 Self { config }
210 }
211
212 pub fn plot_convergence(
214 &self,
215 trajectory: &OptimizationTrajectory,
216 output_path: &Path,
217 ) -> ScirsResult<()> {
218 if trajectory.is_empty() {
219 return Err(ScirsError::InvalidInput(error_context!("Empty trajectory")));
220 }
221
222 match self.config.format {
223 OutputFormat::Svg => self.plot_convergence_svg(trajectory, output_path),
224 OutputFormat::Html => self.plot_convergence_html(trajectory, output_path),
225 OutputFormat::Data => self.export_convergence_data(trajectory, output_path),
226 _ => Err(ScirsError::NotImplementedError(error_context!(
227 "PNG output not yet implemented"
228 ))),
229 }
230 }
231
232 pub fn plot_parameter_trajectory(
234 &self,
235 trajectory: &OptimizationTrajectory,
236 output_path: &Path,
237 ) -> ScirsResult<()> {
238 if trajectory.is_empty() {
239 return Err(ScirsError::InvalidInput(error_context!("Empty trajectory")));
240 }
241
242 if trajectory.parameters[0].len() != 2 {
243 return Err(ScirsError::InvalidInput(error_context!(
244 "Parameter trajectory visualization only supports 2D problems"
245 )));
246 }
247
248 match self.config.format {
249 OutputFormat::Svg => self.plot_trajectory_svg(trajectory, output_path),
250 OutputFormat::Html => self.plot_trajectory_html(trajectory, output_path),
251 OutputFormat::Data => self.export_trajectory_data(trajectory, output_path),
252 _ => Err(ScirsError::NotImplementedError(error_context!(
253 "PNG output not yet implemented"
254 ))),
255 }
256 }
257
258 pub fn create_optimization_report(
260 &self,
261 trajectory: &OptimizationTrajectory,
262 output_dir: &Path,
263 ) -> ScirsResult<()> {
264 std::fs::create_dir_all(output_dir)?;
265
266 let convergence_path = output_dir.join("convergence.svg");
268 self.plot_convergence(trajectory, &convergence_path)?;
269
270 if !trajectory.parameters.is_empty() && trajectory.parameters[0].len() == 2 {
272 let trajectory_path = output_dir.join("trajectory.svg");
273 self.plot_parameter_trajectory(trajectory, &trajectory_path)?;
274 }
275
276 let summary_path = output_dir.join("summary.html");
278 self.generate_summary_report(trajectory, &summary_path)?;
279
280 let data_path = output_dir.join("data.csv");
282 self.export_convergence_data(trajectory, &data_path)?;
283
284 Ok(())
285 }
286
287 fn generate_summary_report(
289 &self,
290 trajectory: &OptimizationTrajectory,
291 output_path: &Path,
292 ) -> ScirsResult<()> {
293 let mut file = File::create(output_path)?;
294
295 let html_content = format!(
296 r#"<!DOCTYPE html>
297<html>
298<head>
299 <title>Optimization Summary</title>
300 <style>
301 body {{ font-family: Arial, sans-serif; margin: 20px; }}
302 .metric {{ margin: 10px 0; }}
303 .value {{ font-weight: bold; color: #2E86AB; }}
304 table {{ border-collapse: collapse; width: 100%; }}
305 th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
306 th {{ background-color: #f2f2f2; }}
307 </style>
308</head>
309<body>
310 <h1>Optimization Summary Report</h1>
311
312 <h2>Basic Statistics</h2>
313 <div class="metric">Total Iterations: <span class="value">{}</span></div>
314 <div class="metric">Final Function Value: <span class="value">{:.6e}</span></div>
315 <div class="metric">Initial Function Value: <span class="value">{:.6e}</span></div>
316 <div class="metric">Function Improvement: <span class="value">{:.6e}</span></div>
317 <div class="metric">Total Runtime: <span class="value">{:.3}s</span></div>
318 {}
319
320 <h2>Convergence Analysis</h2>
321 <table>
322 <tr><th>Metric</th><th>Value</th></tr>
323 <tr><td>Convergence Rate</td><td>{}</td></tr>
324 <tr><td>Average Iteration Time</td><td>{:.6}s</td></tr>
325 <tr><td>Function Evaluations per Second</td><td>{:.2}</td></tr>
326 </table>
327
328 {}
329</body>
330</html>"#,
331 trajectory.len(),
332 trajectory.final_function_value().unwrap_or(0.0),
333 trajectory.function_values.first().cloned().unwrap_or(0.0),
334 trajectory.function_values.first().cloned().unwrap_or(0.0)
335 - trajectory.final_function_value().unwrap_or(0.0),
336 trajectory.times.last().cloned().unwrap_or(0.0),
337 if !trajectory.gradient_norms.is_empty() {
338 format!("<div class=\"metric\">Final Gradient Norm: <span class=\"value\">{:.6e}</span></div>",
339 trajectory.gradient_norms.last().cloned().unwrap_or(0.0))
340 } else {
341 String::new()
342 },
343 trajectory
344 .convergence_rate()
345 .map(|r| format!("{:.6}", r))
346 .unwrap_or_else(|| "N/A".to_string()),
347 if trajectory.len() > 1 && !trajectory.times.is_empty() {
348 trajectory.times.last().cloned().unwrap_or(0.0) / trajectory.len() as f64
349 } else {
350 0.0
351 },
352 if !trajectory.times.is_empty() && trajectory.times.last().cloned().unwrap_or(0.0) > 0.0
353 {
354 trajectory.len() as f64 / trajectory.times.last().cloned().unwrap_or(1.0)
355 } else {
356 0.0
357 },
358 self.generate_custom_metrics_table(trajectory)
359 );
360
361 file.write_all(html_content.as_bytes())?;
362 Ok(())
363 }
364
365 fn generate_custom_metrics_table(&self, trajectory: &OptimizationTrajectory) -> String {
366 if trajectory.custom_metrics.is_empty() {
367 return String::new();
368 }
369
370 let mut table = String::from("<h2>Custom Metrics</h2>\n<table>\n<tr><th>Metric</th><th>Final Value</th><th>Min</th><th>Max</th><th>Mean</th></tr>\n");
371
372 for (name, values) in &trajectory.custom_metrics {
373 if let Some(final_val) = values.last() {
374 let min_val = values.iter().cloned().fold(f64::INFINITY, f64::min);
375 let max_val = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
376 let mean_val = values.iter().sum::<f64>() / values.len() as f64;
377
378 table.push_str(&format!(
379 "<tr><td>{}</td><td>{:.6e}</td><td>{:.6e}</td><td>{:.6e}</td><td>{:.6e}</td></tr>\n",
380 name, final_val, min_val, max_val, mean_val
381 ));
382 }
383 }
384 table.push_str("</table>\n");
385 table
386 }
387
388 fn plot_convergence_svg(
389 &self,
390 trajectory: &OptimizationTrajectory,
391 output_path: &Path,
392 ) -> ScirsResult<()> {
393 let mut file = File::create(output_path)?;
394
395 let width = self.config.width;
396 let height = self.config.height;
397 let margin = 60;
398 let plot_width = width - 2 * margin;
399 let plot_height = height - 2 * margin;
400
401 let min_y = if self.config.log_scale_y {
402 trajectory
403 .function_values
404 .iter()
405 .filter(|&&v| v > 0.0)
406 .cloned()
407 .fold(f64::INFINITY, f64::min)
408 .ln()
409 } else {
410 trajectory
411 .function_values
412 .iter()
413 .cloned()
414 .fold(f64::INFINITY, f64::min)
415 };
416
417 let max_y = if self.config.log_scale_y {
418 trajectory
419 .function_values
420 .iter()
421 .filter(|&&v| v > 0.0)
422 .cloned()
423 .fold(f64::NEG_INFINITY, f64::max)
424 .ln()
425 } else {
426 trajectory
427 .function_values
428 .iter()
429 .cloned()
430 .fold(f64::NEG_INFINITY, f64::max)
431 };
432
433 let max_x = trajectory.nit.len() as f64;
434
435 let mut svg_content = format!(
436 r#"<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">
437 <defs>
438 <style>
439 .axis {{ stroke: #333; stroke-width: 1; }}
440 .grid {{ stroke: #ccc; stroke-width: 0.5; stroke-dasharray: 2,2; }}
441 .line {{ fill: none; stroke: #2E86AB; stroke-width: 2; }}
442 .text {{ font-family: Arial, sans-serif; font-size: 12px; fill: #333; }}
443 .title {{ font-family: Arial, sans-serif; font-size: 16px; fill: #333; font-weight: bold; }}
444 </style>
445 </defs>
446"#,
447 width, height
448 );
449
450 if self.config.show_grid {
452 for i in 0..=10 {
453 let x = margin as f64 + (i as f64 / 10.0) * plot_width as f64;
454 svg_content.push_str(&format!(
455 r#" <line x1="{}" y1="{}" x2="{}" y2="{}" class="grid" />
456"#,
457 x,
458 margin,
459 x,
460 height - margin
461 ));
462 }
463
464 for i in 0..=10 {
465 let y = margin as f64 + (i as f64 / 10.0) * plot_height as f64;
466 svg_content.push_str(&format!(
467 r#" <line x1="{}" y1="{}" x2="{}" y2="{}" class="grid" />
468"#,
469 margin,
470 y,
471 width - margin,
472 y
473 ));
474 }
475 }
476
477 svg_content.push_str(&format!(
479 r#" <line x1="{}" y1="{}" x2="{}" y2="{}" class="axis" />
480 <line x1="{}" y1="{}" x2="{}" y2="{}" class="axis" />
481"#,
482 margin,
483 height - margin,
484 width - margin,
485 height - margin, margin,
487 margin,
488 margin,
489 height - margin ));
491
492 svg_content.push_str(" <polyline points=\"");
494 for (i, &f_val) in trajectory.function_values.iter().enumerate() {
495 let x = margin as f64 + (i as f64 / max_x) * plot_width as f64;
496 let y_val = if self.config.log_scale_y && f_val > 0.0 {
497 f_val.ln()
498 } else {
499 f_val
500 };
501 let y = height as f64
502 - margin as f64
503 - ((y_val - min_y) / (max_y - min_y)) * plot_height as f64;
504 svg_content.push_str(&format!("{},{} ", x, y));
505 }
506 svg_content.push_str("\" class=\"line\" />\n");
507
508 if let Some(ref title) = self.config.title {
510 svg_content.push_str(&format!(
511 r#" <text x="{}" y="30" text-anchor="middle" class="title">{}</text>
512"#,
513 width / 2,
514 title
515 ));
516 }
517
518 svg_content.push_str(&format!(
520 r#" <text x="{}" y="{}" text-anchor="middle" class="text">Iteration</text>
521 <text x="20" y="{}" text-anchor="middle" class="text" transform="rotate(-90 20 {})">Function Value{}</text>
522"#,
523 width / 2, height - 10,
524 height / 2, height / 2,
525 if self.config.log_scale_y { " (log)" } else { "" }
526 ));
527
528 svg_content.push_str("</svg>");
529
530 file.write_all(svg_content.as_bytes())?;
531 Ok(())
532 }
533
534 fn plot_convergence_html(
535 &self,
536 trajectory: &OptimizationTrajectory,
537 output_path: &Path,
538 ) -> ScirsResult<()> {
539 let mut file = File::create(output_path)?;
540
541 let html_content = format!(
542 r#"<!DOCTYPE html>
543<html>
544<head>
545 <title>Optimization Convergence</title>
546 <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
547</head>
548<body>
549 <div id="convergence-plot" style="width:{}px;height:{}px;"></div>
550 <script>
551 var trace = {{
552 x: [{}],
553 y: [{}],
554 type: 'scatter',
555 mode: 'lines',
556 name: 'Function Value',
557 line: {{ color: '#2E86AB', width: 2 }}
558 }};
559
560 var layout = {{
561 title: '{}',
562 xaxis: {{ title: 'Iteration' }},
563 yaxis: {{
564 title: 'Function Value',
565 type: '{}'
566 }},
567 showlegend: {}
568 }};
569
570 Plotly.newPlot('convergence-plot', [trace], layout);
571 </script>
572</body>
573</html>"#,
574 self.config.width,
575 self.config.height,
576 trajectory
577 .nit
578 .iter()
579 .map(|i| i.to_string())
580 .collect::<Vec<_>>()
581 .join(","),
582 trajectory
583 .function_values
584 .iter()
585 .map(|f| f.to_string())
586 .collect::<Vec<_>>()
587 .join(","),
588 self.config
589 .title
590 .as_deref()
591 .unwrap_or("Optimization Convergence"),
592 if self.config.log_scale_y {
593 "log"
594 } else {
595 "linear"
596 },
597 self.config.show_legend
598 );
599
600 file.write_all(html_content.as_bytes())?;
601 Ok(())
602 }
603
604 fn plot_trajectory_svg(
605 &self,
606 trajectory: &OptimizationTrajectory,
607 output_path: &Path,
608 ) -> ScirsResult<()> {
609 let mut file = File::create(output_path)?;
610
611 let width = self.config.width;
612 let height = self.config.height;
613 let margin = 60;
614 let plot_width = width - 2 * margin;
615 let plot_height = height - 2 * margin;
616
617 let x_coords: Vec<f64> = trajectory.parameters.iter().map(|p| p[0]).collect();
618 let y_coords: Vec<f64> = trajectory.parameters.iter().map(|p| p[1]).collect();
619
620 let min_x = x_coords.iter().cloned().fold(f64::INFINITY, f64::min);
621 let max_x = x_coords.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
622 let min_y = y_coords.iter().cloned().fold(f64::INFINITY, f64::min);
623 let max_y = y_coords.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
624
625 let mut svg_content = format!(
626 r#"<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">
627 <defs>
628 <style>
629 .axis {{ stroke: #333; stroke-width: 1; }}
630 .grid {{ stroke: #ccc; stroke-width: 0.5; stroke-dasharray: 2,2; }}
631 .trajectory {{ fill: none; stroke: #2E86AB; stroke-width: 2; }}
632 .start {{ fill: #4CAF50; stroke: #333; stroke-width: 1; }}
633 .end {{ fill: #F44336; stroke: #333; stroke-width: 1; }}
634 .text {{ font-family: Arial, sans-serif; font-size: 12px; fill: #333; }}
635 .title {{ font-family: Arial, sans-serif; font-size: 16px; fill: #333; font-weight: bold; }}
636 </style>
637 </defs>
638"#,
639 width, height
640 );
641
642 if self.config.show_grid {
644 for i in 0..=10 {
645 let x = margin as f64 + (i as f64 / 10.0) * plot_width as f64;
646 svg_content.push_str(&format!(
647 r#" <line x1="{}" y1="{}" x2="{}" y2="{}" class="grid" />
648"#,
649 x,
650 margin,
651 x,
652 height - margin
653 ));
654 }
655
656 for i in 0..=10 {
657 let y = margin as f64 + (i as f64 / 10.0) * plot_height as f64;
658 svg_content.push_str(&format!(
659 r#" <line x1="{}" y1="{}" x2="{}" y2="{}" class="grid" />
660"#,
661 margin,
662 y,
663 width - margin,
664 y
665 ));
666 }
667 }
668
669 svg_content.push_str(&format!(
671 r#" <line x1="{}" y1="{}" x2="{}" y2="{}" class="axis" />
672 <line x1="{}" y1="{}" x2="{}" y2="{}" class="axis" />
673"#,
674 margin,
675 height - margin,
676 width - margin,
677 height - margin,
678 margin,
679 margin,
680 margin,
681 height - margin
682 ));
683
684 svg_content.push_str(" <polyline points=\"");
686 for (x_val, y_val) in x_coords.iter().zip(y_coords.iter()) {
687 let x = margin as f64 + ((x_val - min_x) / (max_x - min_x)) * plot_width as f64;
688 let y = height as f64
689 - margin as f64
690 - ((y_val - min_y) / (max_y - min_y)) * plot_height as f64;
691 svg_content.push_str(&format!("{},{} ", x, y));
692 }
693 svg_content.push_str("\" class=\"trajectory\" />\n");
694
695 if !x_coords.is_empty() {
697 let start_x =
698 margin as f64 + ((x_coords[0] - min_x) / (max_x - min_x)) * plot_width as f64;
699 let start_y = height as f64
700 - margin as f64
701 - ((y_coords[0] - min_y) / (max_y - min_y)) * plot_height as f64;
702
703 let end_x = margin as f64
704 + ((x_coords.last().expect("Operation failed") - min_x) / (max_x - min_x))
705 * plot_width as f64;
706 let end_y = height as f64
707 - margin as f64
708 - ((y_coords.last().expect("Operation failed") - min_y) / (max_y - min_y))
709 * plot_height as f64;
710
711 svg_content.push_str(&format!(
712 r#" <circle cx="{}" cy="{}" r="5" class="start" />
713 <circle cx="{}" cy="{}" r="5" class="end" />
714"#,
715 start_x, start_y, end_x, end_y
716 ));
717 }
718
719 if let Some(ref title) = self.config.title {
721 svg_content.push_str(&format!(
722 r#" <text x="{}" y="30" text-anchor="middle" class="title">{}</text>
723"#,
724 width / 2,
725 title
726 ));
727 }
728
729 svg_content.push_str("</svg>");
730
731 file.write_all(svg_content.as_bytes())?;
732 Ok(())
733 }
734
735 fn plot_trajectory_html(
736 &self,
737 trajectory: &OptimizationTrajectory,
738 output_path: &Path,
739 ) -> ScirsResult<()> {
740 let mut file = File::create(output_path)?;
741
742 let x_coords: Vec<f64> = trajectory.parameters.iter().map(|p| p[0]).collect();
743 let y_coords: Vec<f64> = trajectory.parameters.iter().map(|p| p[1]).collect();
744
745 let html_content = format!(
746 r#"<!DOCTYPE html>
747<html>
748<head>
749 <title>Parameter Trajectory</title>
750 <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
751</head>
752<body>
753 <div id="trajectory-plot" style="width:{}px;height:{}px;"></div>
754 <script>
755 var trace = {{
756 x: [{}],
757 y: [{}],
758 type: 'scatter',
759 mode: 'lines+markers',
760 name: 'Trajectory',
761 line: {{ color: '#2E86AB', width: 2 }},
762 marker: {{
763 size: [{}],
764 color: [{}],
765 colorscale: 'Viridis',
766 showscale: true
767 }}
768 }};
769
770 var layout = {{
771 title: '{}',
772 xaxis: {{ title: 'Parameter 1' }},
773 yaxis: {{ title: 'Parameter 2' }},
774 showlegend: {}
775 }};
776
777 Plotly.newPlot('trajectory-plot', [trace], layout);
778 </script>
779</body>
780</html>"#,
781 self.config.width,
782 self.config.height,
783 x_coords
784 .iter()
785 .map(|x| x.to_string())
786 .collect::<Vec<_>>()
787 .join(","),
788 y_coords
789 .iter()
790 .map(|y| y.to_string())
791 .collect::<Vec<_>>()
792 .join(","),
793 (0..x_coords.len())
794 .map(|i| if i == 0 {
795 "10"
796 } else if i == x_coords.len() - 1 {
797 "10"
798 } else {
799 "6"
800 })
801 .collect::<Vec<_>>()
802 .join(","),
803 (0..x_coords.len())
804 .map(|i| i.to_string())
805 .collect::<Vec<_>>()
806 .join(","),
807 self.config
808 .title
809 .as_deref()
810 .unwrap_or("Parameter Trajectory"),
811 self.config.show_legend
812 );
813
814 file.write_all(html_content.as_bytes())?;
815 Ok(())
816 }
817
818 fn export_convergence_data(
819 &self,
820 trajectory: &OptimizationTrajectory,
821 output_path: &Path,
822 ) -> ScirsResult<()> {
823 let mut file = File::create(output_path)?;
824
825 let mut header = "iteration,function_value,time".to_string();
827 if !trajectory.gradient_norms.is_empty() {
828 header.push_str(",gradient_norm");
829 }
830 if !trajectory.step_sizes.is_empty() {
831 header.push_str(",step_size");
832 }
833
834 if !trajectory.parameters.is_empty() {
836 for i in 0..trajectory.parameters[0].len() {
837 header.push_str(&format!(",param_{}", i));
838 }
839 }
840
841 for name in trajectory.custom_metrics.keys() {
843 header.push_str(&format!(",{}", name));
844 }
845 header.push('\n');
846
847 file.write_all(header.as_bytes())?;
848
849 for i in 0..trajectory.len() {
851 let mut row = format!(
852 "{},{},{}",
853 trajectory.nit[i], trajectory.function_values[i], trajectory.times[i]
854 );
855
856 if i < trajectory.gradient_norms.len() {
857 row.push_str(&format!(",{}", trajectory.gradient_norms[i]));
858 } else if !trajectory.gradient_norms.is_empty() {
859 row.push(',');
860 }
861
862 if i < trajectory.step_sizes.len() {
863 row.push_str(&format!(",{}", trajectory.step_sizes[i]));
864 } else if !trajectory.step_sizes.is_empty() {
865 row.push(',');
866 }
867
868 if i < trajectory.parameters.len() {
870 for param in trajectory.parameters[i].iter() {
871 row.push_str(&format!(",{}", param));
872 }
873 }
874
875 for name in trajectory.custom_metrics.keys() {
877 if let Some(values) = trajectory.custom_metrics.get(name) {
878 if i < values.len() {
879 row.push_str(&format!(",{}", values[i]));
880 } else {
881 row.push(',');
882 }
883 }
884 }
885
886 row.push('\n');
887 file.write_all(row.as_bytes())?;
888 }
889
890 Ok(())
891 }
892
893 fn export_trajectory_data(
894 &self,
895 trajectory: &OptimizationTrajectory,
896 output_path: &Path,
897 ) -> ScirsResult<()> {
898 self.export_convergence_data(trajectory, output_path)
899 }
900}
901
902impl Default for OptimizationVisualizer {
903 fn default() -> Self {
904 Self::new()
905 }
906}
907
908pub mod tracking {
910 use super::OptimizationTrajectory;
911 use scirs2_core::ndarray::ArrayView1;
912 use std::time::Instant;
913
914 pub struct TrajectoryTracker {
916 trajectory: OptimizationTrajectory,
917 start_time: Instant,
918 }
919
920 impl TrajectoryTracker {
921 pub fn new() -> Self {
923 Self {
924 trajectory: OptimizationTrajectory::new(),
925 start_time: Instant::now(),
926 }
927 }
928
929 pub fn record(&mut self, iteration: usize, params: &ArrayView1<f64>, function_value: f64) {
931 let elapsed = self.start_time.elapsed().as_secs_f64();
932 self.trajectory
933 .add_point(iteration, params, function_value, elapsed);
934 }
935
936 pub fn record_gradient_norm(&mut self, grad_norm: f64) {
938 self.trajectory.add_gradient_norm(grad_norm);
939 }
940
941 pub fn record_step_size(&mut self, step_size: f64) {
943 self.trajectory.add_step_size(step_size);
944 }
945
946 pub fn record_custom_metric(&mut self, name: &str, value: f64) {
948 self.trajectory.add_custom_metric(name, value);
949 }
950
951 pub fn trajectory(&self) -> &OptimizationTrajectory {
953 &self.trajectory
954 }
955
956 pub fn into_trajectory(self) -> OptimizationTrajectory {
958 self.trajectory
959 }
960 }
961
962 impl Default for TrajectoryTracker {
963 fn default() -> Self {
964 Self::new()
965 }
966 }
967}
968
969#[cfg(test)]
970mod tests {
971 use super::*;
972 use scirs2_core::ndarray::array;
973
974 #[test]
975 fn test_trajectory_creation() {
976 let mut trajectory = OptimizationTrajectory::new();
977 assert!(trajectory.is_empty());
978
979 let params = array![1.0, 2.0];
980 trajectory.add_point(0, ¶ms.view(), 5.0, 0.1);
981
982 assert_eq!(trajectory.len(), 1);
983 assert_eq!(trajectory.final_function_value(), Some(5.0));
984 }
985
986 #[test]
987 fn test_convergence_rate_calculation() {
988 let mut trajectory = OptimizationTrajectory::new();
989
990 let function_values = vec![10.0, 5.0, 2.5, 1.25, 0.625];
992 for (i, &f_val) in function_values.iter().enumerate() {
993 let params = array![i as f64, i as f64];
994 trajectory.add_point(i, ¶ms.view(), f_val, i as f64 * 0.1);
995 }
996
997 let rate = trajectory.convergence_rate();
998 assert!(rate.is_some());
999 assert!((rate.expect("Operation failed") - 0.5).abs() < 0.1);
1001 }
1002
1003 #[test]
1004 fn test_visualization_config() {
1005 let config = VisualizationConfig {
1006 format: OutputFormat::Svg,
1007 width: 1000,
1008 height: 800,
1009 title: Some("Test Plot".to_string()),
1010 show_grid: true,
1011 log_scale_y: true,
1012 color_scheme: ColorScheme::Viridis,
1013 show_legend: false,
1014 custom_style: None,
1015 };
1016
1017 let visualizer = OptimizationVisualizer::with_config(config);
1018 assert_eq!(visualizer.config.width, 1000);
1019 assert_eq!(visualizer.config.height, 800);
1020 }
1021
1022 #[test]
1023 fn test_trajectory_tracker() {
1024 let mut tracker = tracking::TrajectoryTracker::new();
1025
1026 let params1 = array![0.0, 0.0];
1027 let params2 = array![1.0, 1.0];
1028
1029 tracker.record(0, ¶ms1.view(), 10.0);
1030 tracker.record_gradient_norm(2.5);
1031 tracker.record_step_size(0.1);
1032
1033 tracker.record(1, ¶ms2.view(), 5.0);
1034 tracker.record_gradient_norm(1.5);
1035 tracker.record_step_size(0.2);
1036
1037 let trajectory = tracker.trajectory();
1038 assert_eq!(trajectory.len(), 2);
1039 assert_eq!(trajectory.gradient_norms.len(), 2);
1040 assert_eq!(trajectory.step_sizes.len(), 2);
1041 assert_eq!(trajectory.final_function_value(), Some(5.0));
1042 }
1043}