1use crate::error::{ScirsError, ScirsResult};
8use ndarray::{Array1, ArrayView1}; use scirs2_core::error_context;
10use std::collections::HashMap;
11use std::fs::File;
12use std::io::Write;
13use std::path::Path;
14
15#[derive(Debug, Clone)]
17pub struct OptimizationTrajectory {
18 pub parameters: Vec<Array1<f64>>,
20 pub function_values: Vec<f64>,
22 pub gradient_norms: Vec<f64>,
24 pub step_sizes: Vec<f64>,
26 pub custom_metrics: HashMap<String, Vec<f64>>,
28 pub nit: Vec<usize>,
30 pub times: Vec<f64>,
32}
33
34impl OptimizationTrajectory {
35 pub fn new() -> Self {
37 Self {
38 parameters: Vec::new(),
39 function_values: Vec::new(),
40 gradient_norms: Vec::new(),
41 step_sizes: Vec::new(),
42 custom_metrics: HashMap::new(),
43 nit: Vec::new(),
44 times: Vec::new(),
45 }
46 }
47
48 pub fn add_point(
50 &mut self,
51 iteration: usize,
52 params: &ArrayView1<f64>,
53 function_value: f64,
54 time: f64,
55 ) {
56 self.nit.push(iteration);
57 self.parameters.push(params.to_owned());
58 self.function_values.push(function_value);
59 self.times.push(time);
60 }
61
62 pub fn add_gradient_norm(&mut self, grad_norm: f64) {
64 self.gradient_norms.push(grad_norm);
65 }
66
67 pub fn add_step_size(&mut self, step_size: f64) {
69 self.step_sizes.push(step_size);
70 }
71
72 pub fn add_custom_metric(&mut self, name: &str, value: f64) {
74 self.custom_metrics
75 .entry(name.to_string())
76 .or_insert_with(Vec::new)
77 .push(value);
78 }
79
80 pub fn len(&self) -> usize {
82 self.nit.len()
83 }
84
85 pub fn is_empty(&self) -> bool {
87 self.nit.is_empty()
88 }
89
90 pub fn final_parameters(&self) -> Option<&Array1<f64>> {
92 self.parameters.last()
93 }
94
95 pub fn final_function_value(&self) -> Option<f64> {
97 self.function_values.last().copied()
98 }
99
100 pub fn convergence_rate(&self) -> Option<f64> {
102 if self.function_values.len() < 3 {
103 return None;
104 }
105
106 let n = self.function_values.len();
107 let mut rates = Vec::new();
108
109 for i in 1..(n - 1) {
110 let f_current = self.function_values[i];
111 let f_next = self.function_values[i + 1];
112 let f_prev = self.function_values[i - 1];
113
114 if (f_current - f_next).abs() > 1e-14 && (f_prev - f_current).abs() > 1e-14 {
115 let rate = (f_current - f_next).abs() / (f_prev - f_current).abs();
116 if rate.is_finite() && rate > 0.0 {
117 rates.push(rate);
118 }
119 }
120 }
121
122 if rates.is_empty() {
123 None
124 } else {
125 Some(rates.iter().sum::<f64>() / rates.len() as f64)
126 }
127 }
128}
129
130impl Default for OptimizationTrajectory {
131 fn default() -> Self {
132 Self::new()
133 }
134}
135
136#[derive(Debug, Clone)]
138pub struct VisualizationConfig {
139 pub format: OutputFormat,
141 pub width: u32,
143 pub height: u32,
145 pub title: Option<String>,
147 pub show_grid: bool,
149 pub log_scale_y: bool,
151 pub color_scheme: ColorScheme,
153 pub show_legend: bool,
155 pub custom_style: Option<String>,
157}
158
159impl Default for VisualizationConfig {
160 fn default() -> Self {
161 Self {
162 format: OutputFormat::Svg,
163 width: 800,
164 height: 600,
165 title: None,
166 show_grid: true,
167 log_scale_y: false,
168 color_scheme: ColorScheme::Default,
169 show_legend: true,
170 custom_style: None,
171 }
172 }
173}
174
175#[derive(Debug, Clone, Copy, PartialEq)]
177pub enum OutputFormat {
178 Svg,
179 Png,
180 Html,
181 Data, }
183
184#[derive(Debug, Clone, Copy, PartialEq)]
186pub enum ColorScheme {
187 Default,
188 Viridis,
189 Plasma,
190 Scientific,
191 Monochrome,
192}
193
194pub struct OptimizationVisualizer {
196 config: VisualizationConfig,
197}
198
199impl OptimizationVisualizer {
200 pub fn new() -> Self {
202 Self {
203 config: VisualizationConfig::default(),
204 }
205 }
206
207 pub fn with_config(config: VisualizationConfig) -> Self {
209 Self { config }
210 }
211
212 pub fn plot_convergence(
214 &self,
215 trajectory: &OptimizationTrajectory,
216 output_path: &Path,
217 ) -> ScirsResult<()> {
218 if trajectory.is_empty() {
219 return Err(ScirsError::InvalidInput(error_context!("Empty trajectory")));
220 }
221
222 match self.config.format {
223 OutputFormat::Svg => self.plot_convergence_svg(trajectory, output_path),
224 OutputFormat::Html => self.plot_convergence_html(trajectory, output_path),
225 OutputFormat::Data => self.export_convergence_data(trajectory, output_path),
226 _ => Err(ScirsError::NotImplementedError(error_context!(
227 "PNG output not yet implemented"
228 ))),
229 }
230 }
231
232 pub fn plot_parameter_trajectory(
234 &self,
235 trajectory: &OptimizationTrajectory,
236 output_path: &Path,
237 ) -> ScirsResult<()> {
238 if trajectory.is_empty() {
239 return Err(ScirsError::InvalidInput(error_context!("Empty trajectory")));
240 }
241
242 if trajectory.parameters[0].len() != 2 {
243 return Err(ScirsError::InvalidInput(error_context!(
244 "Parameter trajectory visualization only supports 2D problems"
245 )));
246 }
247
248 match self.config.format {
249 OutputFormat::Svg => self.plot_trajectory_svg(trajectory, output_path),
250 OutputFormat::Html => self.plot_trajectory_html(trajectory, output_path),
251 OutputFormat::Data => self.export_trajectory_data(trajectory, output_path),
252 _ => Err(ScirsError::NotImplementedError(error_context!(
253 "PNG output not yet implemented"
254 ))),
255 }
256 }
257
258 pub fn create_optimization_report(
260 &self,
261 trajectory: &OptimizationTrajectory,
262 output_dir: &Path,
263 ) -> ScirsResult<()> {
264 std::fs::create_dir_all(output_dir)?;
265
266 let convergence_path = output_dir.join("convergence.svg");
268 self.plot_convergence(trajectory, &convergence_path)?;
269
270 if !trajectory.parameters.is_empty() && trajectory.parameters[0].len() == 2 {
272 let trajectory_path = output_dir.join("trajectory.svg");
273 self.plot_parameter_trajectory(trajectory, &trajectory_path)?;
274 }
275
276 let summary_path = output_dir.join("summary.html");
278 self.generate_summary_report(trajectory, &summary_path)?;
279
280 let data_path = output_dir.join("data.csv");
282 self.export_convergence_data(trajectory, &data_path)?;
283
284 Ok(())
285 }
286
287 fn generate_summary_report(
289 &self,
290 trajectory: &OptimizationTrajectory,
291 output_path: &Path,
292 ) -> ScirsResult<()> {
293 let mut file = File::create(output_path)?;
294
295 let html_content = format!(
296 r#"<!DOCTYPE html>
297<html>
298<head>
299 <title>Optimization Summary</title>
300 <style>
301 body {{ font-family: Arial, sans-serif; margin: 20px; }}
302 .metric {{ margin: 10px 0; }}
303 .value {{ font-weight: bold; color: #2E86AB; }}
304 table {{ border-collapse: collapse; width: 100%; }}
305 th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
306 th {{ background-color: #f2f2f2; }}
307 </style>
308</head>
309<body>
310 <h1>Optimization Summary Report</h1>
311
312 <h2>Basic Statistics</h2>
313 <div class="metric">Total Iterations: <span class="value">{}</span></div>
314 <div class="metric">Final Function Value: <span class="value">{:.6e}</span></div>
315 <div class="metric">Initial Function Value: <span class="value">{:.6e}</span></div>
316 <div class="metric">Function Improvement: <span class="value">{:.6e}</span></div>
317 <div class="metric">Total Runtime: <span class="value">{:.3}s</span></div>
318 {}
319
320 <h2>Convergence Analysis</h2>
321 <table>
322 <tr><th>Metric</th><th>Value</th></tr>
323 <tr><td>Convergence Rate</td><td>{}</td></tr>
324 <tr><td>Average Iteration Time</td><td>{:.6}s</td></tr>
325 <tr><td>Function Evaluations per Second</td><td>{:.2}</td></tr>
326 </table>
327
328 {}
329</body>
330</html>"#,
331 trajectory.len(),
332 trajectory.final_function_value().unwrap_or(0.0),
333 trajectory.function_values.first().cloned().unwrap_or(0.0),
334 trajectory.function_values.first().cloned().unwrap_or(0.0)
335 - trajectory.final_function_value().unwrap_or(0.0),
336 trajectory.times.last().cloned().unwrap_or(0.0),
337 if !trajectory.gradient_norms.is_empty() {
338 format!("<div class=\"metric\">Final Gradient Norm: <span class=\"value\">{:.6e}</span></div>",
339 trajectory.gradient_norms.last().cloned().unwrap_or(0.0))
340 } else {
341 String::new()
342 },
343 trajectory
344 .convergence_rate()
345 .map(|r| format!("{:.6}", r))
346 .unwrap_or_else(|| "N/A".to_string()),
347 if trajectory.len() > 1 && !trajectory.times.is_empty() {
348 trajectory.times.last().cloned().unwrap_or(0.0) / trajectory.len() as f64
349 } else {
350 0.0
351 },
352 if !trajectory.times.is_empty() && trajectory.times.last().cloned().unwrap_or(0.0) > 0.0
353 {
354 trajectory.len() as f64 / trajectory.times.last().cloned().unwrap_or(1.0)
355 } else {
356 0.0
357 },
358 self.generate_custom_metrics_table(trajectory)
359 );
360
361 file.write_all(html_content.as_bytes())?;
362 Ok(())
363 }
364
365 fn generate_custom_metrics_table(&self, trajectory: &OptimizationTrajectory) -> String {
366 if trajectory.custom_metrics.is_empty() {
367 return String::new();
368 }
369
370 let mut table = String::from("<h2>Custom Metrics</h2>\n<table>\n<tr><th>Metric</th><th>Final Value</th><th>Min</th><th>Max</th><th>Mean</th></tr>\n");
371
372 for (name, values) in &trajectory.custom_metrics {
373 if let Some(final_val) = values.last() {
374 let min_val = values.iter().cloned().fold(f64::INFINITY, f64::min);
375 let max_val = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
376 let mean_val = values.iter().sum::<f64>() / values.len() as f64;
377
378 table.push_str(&format!(
379 "<tr><td>{}</td><td>{:.6e}</td><td>{:.6e}</td><td>{:.6e}</td><td>{:.6e}</td></tr>\n",
380 name, final_val, min_val, max_val, mean_val
381 ));
382 }
383 }
384 table.push_str("</table>\n");
385 table
386 }
387
388 fn plot_convergence_svg(
389 &self,
390 trajectory: &OptimizationTrajectory,
391 output_path: &Path,
392 ) -> ScirsResult<()> {
393 let mut file = File::create(output_path)?;
394
395 let width = self.config.width;
396 let height = self.config.height;
397 let margin = 60;
398 let plot_width = width - 2 * margin;
399 let plot_height = height - 2 * margin;
400
401 let min_y = if self.config.log_scale_y {
402 trajectory
403 .function_values
404 .iter()
405 .filter(|&&v| v > 0.0)
406 .cloned()
407 .fold(f64::INFINITY, f64::min)
408 .ln()
409 } else {
410 trajectory
411 .function_values
412 .iter()
413 .cloned()
414 .fold(f64::INFINITY, f64::min)
415 };
416
417 let max_y = if self.config.log_scale_y {
418 trajectory
419 .function_values
420 .iter()
421 .filter(|&&v| v > 0.0)
422 .cloned()
423 .fold(f64::NEG_INFINITY, f64::max)
424 .ln()
425 } else {
426 trajectory
427 .function_values
428 .iter()
429 .cloned()
430 .fold(f64::NEG_INFINITY, f64::max)
431 };
432
433 let max_x = trajectory.nit.len() as f64;
434
435 let mut svg_content = format!(
436 r#"<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">
437 <defs>
438 <style>
439 .axis {{ stroke: #333; stroke-width: 1; }}
440 .grid {{ stroke: #ccc; stroke-width: 0.5; stroke-dasharray: 2,2; }}
441 .line {{ fill: none; stroke: #2E86AB; stroke-width: 2; }}
442 .text {{ font-family: Arial, sans-serif; font-size: 12px; fill: #333; }}
443 .title {{ font-family: Arial, sans-serif; font-size: 16px; fill: #333; font-weight: bold; }}
444 </style>
445 </defs>
446"#,
447 width, height
448 );
449
450 if self.config.show_grid {
452 for i in 0..=10 {
453 let x = margin as f64 + (i as f64 / 10.0) * plot_width as f64;
454 svg_content.push_str(&format!(
455 r#" <line x1="{}" y1="{}" x2="{}" y2="{}" class="grid" />
456"#,
457 x,
458 margin,
459 x,
460 height - margin
461 ));
462 }
463
464 for i in 0..=10 {
465 let y = margin as f64 + (i as f64 / 10.0) * plot_height as f64;
466 svg_content.push_str(&format!(
467 r#" <line x1="{}" y1="{}" x2="{}" y2="{}" class="grid" />
468"#,
469 margin,
470 y,
471 width - margin,
472 y
473 ));
474 }
475 }
476
477 svg_content.push_str(&format!(
479 r#" <line x1="{}" y1="{}" x2="{}" y2="{}" class="axis" />
480 <line x1="{}" y1="{}" x2="{}" y2="{}" class="axis" />
481"#,
482 margin,
483 height - margin,
484 width - margin,
485 height - margin, margin,
487 margin,
488 margin,
489 height - margin ));
491
492 svg_content.push_str(" <polyline points=\"");
494 for (i, &f_val) in trajectory.function_values.iter().enumerate() {
495 let x = margin as f64 + (i as f64 / max_x) * plot_width as f64;
496 let y_val = if self.config.log_scale_y && f_val > 0.0 {
497 f_val.ln()
498 } else {
499 f_val
500 };
501 let y = height as f64
502 - margin as f64
503 - ((y_val - min_y) / (max_y - min_y)) * plot_height as f64;
504 svg_content.push_str(&format!("{},{} ", x, y));
505 }
506 svg_content.push_str("\" class=\"line\" />\n");
507
508 if let Some(ref title) = self.config.title {
510 svg_content.push_str(&format!(
511 r#" <text x="{}" y="30" text-anchor="middle" class="title">{}</text>
512"#,
513 width / 2,
514 title
515 ));
516 }
517
518 svg_content.push_str(&format!(
520 r#" <text x="{}" y="{}" text-anchor="middle" class="text">Iteration</text>
521 <text x="20" y="{}" text-anchor="middle" class="text" transform="rotate(-90 20 {})">Function Value{}</text>
522"#,
523 width / 2, height - 10,
524 height / 2, height / 2,
525 if self.config.log_scale_y { " (log)" } else { "" }
526 ));
527
528 svg_content.push_str("</svg>");
529
530 file.write_all(svg_content.as_bytes())?;
531 Ok(())
532 }
533
534 fn plot_convergence_html(
535 &self,
536 trajectory: &OptimizationTrajectory,
537 output_path: &Path,
538 ) -> ScirsResult<()> {
539 let mut file = File::create(output_path)?;
540
541 let html_content = format!(
542 r#"<!DOCTYPE html>
543<html>
544<head>
545 <title>Optimization Convergence</title>
546 <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
547</head>
548<body>
549 <div id="convergence-plot" style="width:{}px;height:{}px;"></div>
550 <script>
551 var trace = {{
552 x: [{}],
553 y: [{}],
554 type: 'scatter',
555 mode: 'lines',
556 name: 'Function Value',
557 line: {{ color: '#2E86AB', width: 2 }}
558 }};
559
560 var layout = {{
561 title: '{}',
562 xaxis: {{ title: 'Iteration' }},
563 yaxis: {{
564 title: 'Function Value',
565 type: '{}'
566 }},
567 showlegend: {}
568 }};
569
570 Plotly.newPlot('convergence-plot', [trace], layout);
571 </script>
572</body>
573</html>"#,
574 self.config.width,
575 self.config.height,
576 trajectory
577 .nit
578 .iter()
579 .map(|i| i.to_string())
580 .collect::<Vec<_>>()
581 .join(","),
582 trajectory
583 .function_values
584 .iter()
585 .map(|f| f.to_string())
586 .collect::<Vec<_>>()
587 .join(","),
588 self.config
589 .title
590 .as_deref()
591 .unwrap_or("Optimization Convergence"),
592 if self.config.log_scale_y {
593 "log"
594 } else {
595 "linear"
596 },
597 self.config.show_legend
598 );
599
600 file.write_all(html_content.as_bytes())?;
601 Ok(())
602 }
603
604 fn plot_trajectory_svg(
605 &self,
606 trajectory: &OptimizationTrajectory,
607 output_path: &Path,
608 ) -> ScirsResult<()> {
609 let mut file = File::create(output_path)?;
610
611 let width = self.config.width;
612 let height = self.config.height;
613 let margin = 60;
614 let plot_width = width - 2 * margin;
615 let plot_height = height - 2 * margin;
616
617 let x_coords: Vec<f64> = trajectory.parameters.iter().map(|p| p[0]).collect();
618 let y_coords: Vec<f64> = trajectory.parameters.iter().map(|p| p[1]).collect();
619
620 let min_x = x_coords.iter().cloned().fold(f64::INFINITY, f64::min);
621 let max_x = x_coords.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
622 let min_y = y_coords.iter().cloned().fold(f64::INFINITY, f64::min);
623 let max_y = y_coords.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
624
625 let mut svg_content = format!(
626 r#"<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">
627 <defs>
628 <style>
629 .axis {{ stroke: #333; stroke-width: 1; }}
630 .grid {{ stroke: #ccc; stroke-width: 0.5; stroke-dasharray: 2,2; }}
631 .trajectory {{ fill: none; stroke: #2E86AB; stroke-width: 2; }}
632 .start {{ fill: #4CAF50; stroke: #333; stroke-width: 1; }}
633 .end {{ fill: #F44336; stroke: #333; stroke-width: 1; }}
634 .text {{ font-family: Arial, sans-serif; font-size: 12px; fill: #333; }}
635 .title {{ font-family: Arial, sans-serif; font-size: 16px; fill: #333; font-weight: bold; }}
636 </style>
637 </defs>
638"#,
639 width, height
640 );
641
642 if self.config.show_grid {
644 for i in 0..=10 {
645 let x = margin as f64 + (i as f64 / 10.0) * plot_width as f64;
646 svg_content.push_str(&format!(
647 r#" <line x1="{}" y1="{}" x2="{}" y2="{}" class="grid" />
648"#,
649 x,
650 margin,
651 x,
652 height - margin
653 ));
654 }
655
656 for i in 0..=10 {
657 let y = margin as f64 + (i as f64 / 10.0) * plot_height as f64;
658 svg_content.push_str(&format!(
659 r#" <line x1="{}" y1="{}" x2="{}" y2="{}" class="grid" />
660"#,
661 margin,
662 y,
663 width - margin,
664 y
665 ));
666 }
667 }
668
669 svg_content.push_str(&format!(
671 r#" <line x1="{}" y1="{}" x2="{}" y2="{}" class="axis" />
672 <line x1="{}" y1="{}" x2="{}" y2="{}" class="axis" />
673"#,
674 margin,
675 height - margin,
676 width - margin,
677 height - margin,
678 margin,
679 margin,
680 margin,
681 height - margin
682 ));
683
684 svg_content.push_str(" <polyline points=\"");
686 for (x_val, y_val) in x_coords.iter().zip(y_coords.iter()) {
687 let x = margin as f64 + ((x_val - min_x) / (max_x - min_x)) * plot_width as f64;
688 let y = height as f64
689 - margin as f64
690 - ((y_val - min_y) / (max_y - min_y)) * plot_height as f64;
691 svg_content.push_str(&format!("{},{} ", x, y));
692 }
693 svg_content.push_str("\" class=\"trajectory\" />\n");
694
695 if !x_coords.is_empty() {
697 let start_x =
698 margin as f64 + ((x_coords[0] - min_x) / (max_x - min_x)) * plot_width as f64;
699 let start_y = height as f64
700 - margin as f64
701 - ((y_coords[0] - min_y) / (max_y - min_y)) * plot_height as f64;
702
703 let end_x = margin as f64
704 + ((x_coords.last().unwrap() - min_x) / (max_x - min_x)) * plot_width as f64;
705 let end_y = height as f64
706 - margin as f64
707 - ((y_coords.last().unwrap() - min_y) / (max_y - min_y)) * plot_height as f64;
708
709 svg_content.push_str(&format!(
710 r#" <circle cx="{}" cy="{}" r="5" class="start" />
711 <circle cx="{}" cy="{}" r="5" class="end" />
712"#,
713 start_x, start_y, end_x, end_y
714 ));
715 }
716
717 if let Some(ref title) = self.config.title {
719 svg_content.push_str(&format!(
720 r#" <text x="{}" y="30" text-anchor="middle" class="title">{}</text>
721"#,
722 width / 2,
723 title
724 ));
725 }
726
727 svg_content.push_str("</svg>");
728
729 file.write_all(svg_content.as_bytes())?;
730 Ok(())
731 }
732
733 fn plot_trajectory_html(
734 &self,
735 trajectory: &OptimizationTrajectory,
736 output_path: &Path,
737 ) -> ScirsResult<()> {
738 let mut file = File::create(output_path)?;
739
740 let x_coords: Vec<f64> = trajectory.parameters.iter().map(|p| p[0]).collect();
741 let y_coords: Vec<f64> = trajectory.parameters.iter().map(|p| p[1]).collect();
742
743 let html_content = format!(
744 r#"<!DOCTYPE html>
745<html>
746<head>
747 <title>Parameter Trajectory</title>
748 <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
749</head>
750<body>
751 <div id="trajectory-plot" style="width:{}px;height:{}px;"></div>
752 <script>
753 var trace = {{
754 x: [{}],
755 y: [{}],
756 type: 'scatter',
757 mode: 'lines+markers',
758 name: 'Trajectory',
759 line: {{ color: '#2E86AB', width: 2 }},
760 marker: {{
761 size: [{}],
762 color: [{}],
763 colorscale: 'Viridis',
764 showscale: true
765 }}
766 }};
767
768 var layout = {{
769 title: '{}',
770 xaxis: {{ title: 'Parameter 1' }},
771 yaxis: {{ title: 'Parameter 2' }},
772 showlegend: {}
773 }};
774
775 Plotly.newPlot('trajectory-plot', [trace], layout);
776 </script>
777</body>
778</html>"#,
779 self.config.width,
780 self.config.height,
781 x_coords
782 .iter()
783 .map(|x| x.to_string())
784 .collect::<Vec<_>>()
785 .join(","),
786 y_coords
787 .iter()
788 .map(|y| y.to_string())
789 .collect::<Vec<_>>()
790 .join(","),
791 (0..x_coords.len())
792 .map(|i| if i == 0 {
793 "10"
794 } else if i == x_coords.len() - 1 {
795 "10"
796 } else {
797 "6"
798 })
799 .collect::<Vec<_>>()
800 .join(","),
801 (0..x_coords.len())
802 .map(|i| i.to_string())
803 .collect::<Vec<_>>()
804 .join(","),
805 self.config
806 .title
807 .as_deref()
808 .unwrap_or("Parameter Trajectory"),
809 self.config.show_legend
810 );
811
812 file.write_all(html_content.as_bytes())?;
813 Ok(())
814 }
815
816 fn export_convergence_data(
817 &self,
818 trajectory: &OptimizationTrajectory,
819 output_path: &Path,
820 ) -> ScirsResult<()> {
821 let mut file = File::create(output_path)?;
822
823 let mut header = "iteration,function_value,time".to_string();
825 if !trajectory.gradient_norms.is_empty() {
826 header.push_str(",gradient_norm");
827 }
828 if !trajectory.step_sizes.is_empty() {
829 header.push_str(",step_size");
830 }
831
832 if !trajectory.parameters.is_empty() {
834 for i in 0..trajectory.parameters[0].len() {
835 header.push_str(&format!(",param_{}", i));
836 }
837 }
838
839 for name in trajectory.custom_metrics.keys() {
841 header.push_str(&format!(",{}", name));
842 }
843 header.push('\n');
844
845 file.write_all(header.as_bytes())?;
846
847 for i in 0..trajectory.len() {
849 let mut row = format!(
850 "{},{},{}",
851 trajectory.nit[i], trajectory.function_values[i], trajectory.times[i]
852 );
853
854 if i < trajectory.gradient_norms.len() {
855 row.push_str(&format!(",{}", trajectory.gradient_norms[i]));
856 } else if !trajectory.gradient_norms.is_empty() {
857 row.push_str(",");
858 }
859
860 if i < trajectory.step_sizes.len() {
861 row.push_str(&format!(",{}", trajectory.step_sizes[i]));
862 } else if !trajectory.step_sizes.is_empty() {
863 row.push_str(",");
864 }
865
866 if i < trajectory.parameters.len() {
868 for param in trajectory.parameters[i].iter() {
869 row.push_str(&format!(",{}", param));
870 }
871 }
872
873 for name in trajectory.custom_metrics.keys() {
875 if let Some(values) = trajectory.custom_metrics.get(name) {
876 if i < values.len() {
877 row.push_str(&format!(",{}", values[i]));
878 } else {
879 row.push_str(",");
880 }
881 }
882 }
883
884 row.push('\n');
885 file.write_all(row.as_bytes())?;
886 }
887
888 Ok(())
889 }
890
891 fn export_trajectory_data(
892 &self,
893 trajectory: &OptimizationTrajectory,
894 output_path: &Path,
895 ) -> ScirsResult<()> {
896 self.export_convergence_data(trajectory, output_path)
897 }
898}
899
900impl Default for OptimizationVisualizer {
901 fn default() -> Self {
902 Self::new()
903 }
904}
905
906pub mod tracking {
908 use super::OptimizationTrajectory;
909 use ndarray::ArrayView1;
910 use std::time::Instant;
911
912 pub struct TrajectoryTracker {
914 trajectory: OptimizationTrajectory,
915 start_time: Instant,
916 }
917
918 impl TrajectoryTracker {
919 pub fn new() -> Self {
921 Self {
922 trajectory: OptimizationTrajectory::new(),
923 start_time: Instant::now(),
924 }
925 }
926
927 pub fn record(&mut self, iteration: usize, params: &ArrayView1<f64>, function_value: f64) {
929 let elapsed = self.start_time.elapsed().as_secs_f64();
930 self.trajectory
931 .add_point(iteration, params, function_value, elapsed);
932 }
933
934 pub fn record_gradient_norm(&mut self, grad_norm: f64) {
936 self.trajectory.add_gradient_norm(grad_norm);
937 }
938
939 pub fn record_step_size(&mut self, step_size: f64) {
941 self.trajectory.add_step_size(step_size);
942 }
943
944 pub fn record_custom_metric(&mut self, name: &str, value: f64) {
946 self.trajectory.add_custom_metric(name, value);
947 }
948
949 pub fn trajectory(&self) -> &OptimizationTrajectory {
951 &self.trajectory
952 }
953
954 pub fn into_trajectory(self) -> OptimizationTrajectory {
956 self.trajectory
957 }
958 }
959
960 impl Default for TrajectoryTracker {
961 fn default() -> Self {
962 Self::new()
963 }
964 }
965}
966
967#[cfg(test)]
968mod tests {
969 use super::*;
970 use ndarray::array;
971
972 #[test]
973 fn test_trajectory_creation() {
974 let mut trajectory = OptimizationTrajectory::new();
975 assert!(trajectory.is_empty());
976
977 let params = array![1.0, 2.0];
978 trajectory.add_point(0, ¶ms.view(), 5.0, 0.1);
979
980 assert_eq!(trajectory.len(), 1);
981 assert_eq!(trajectory.final_function_value(), Some(5.0));
982 }
983
984 #[test]
985 fn test_convergence_rate_calculation() {
986 let mut trajectory = OptimizationTrajectory::new();
987
988 let function_values = vec![10.0, 5.0, 2.5, 1.25, 0.625];
990 for (i, &f_val) in function_values.iter().enumerate() {
991 let params = array![i as f64, i as f64];
992 trajectory.add_point(i, ¶ms.view(), f_val, i as f64 * 0.1);
993 }
994
995 let rate = trajectory.convergence_rate();
996 assert!(rate.is_some());
997 assert!((rate.unwrap() - 0.5).abs() < 0.1);
999 }
1000
1001 #[test]
1002 fn test_visualization_config() {
1003 let config = VisualizationConfig {
1004 format: OutputFormat::Svg,
1005 width: 1000,
1006 height: 800,
1007 title: Some("Test Plot".to_string()),
1008 show_grid: true,
1009 log_scale_y: true,
1010 color_scheme: ColorScheme::Viridis,
1011 show_legend: false,
1012 custom_style: None,
1013 };
1014
1015 let visualizer = OptimizationVisualizer::with_config(config);
1016 assert_eq!(visualizer.config.width, 1000);
1017 assert_eq!(visualizer.config.height, 800);
1018 }
1019
1020 #[test]
1021 fn test_trajectory_tracker() {
1022 let mut tracker = tracking::TrajectoryTracker::new();
1023
1024 let params1 = array![0.0, 0.0];
1025 let params2 = array![1.0, 1.0];
1026
1027 tracker.record(0, ¶ms1.view(), 10.0);
1028 tracker.record_gradient_norm(2.5);
1029 tracker.record_step_size(0.1);
1030
1031 tracker.record(1, ¶ms2.view(), 5.0);
1032 tracker.record_gradient_norm(1.5);
1033 tracker.record_step_size(0.2);
1034
1035 let trajectory = tracker.trajectory();
1036 assert_eq!(trajectory.len(), 2);
1037 assert_eq!(trajectory.gradient_norms.len(), 2);
1038 assert_eq!(trajectory.step_sizes.len(), 2);
1039 assert_eq!(trajectory.final_function_value(), Some(5.0));
1040 }
1041}