1use crate::error::{ScirsError, ScirsResult};
8use scirs2_core::error_context;
9use scirs2_core::ndarray::{Array1, ArrayView1}; use std::collections::HashMap;
11use std::fs::File;
12use std::io::Write;
13use std::path::Path;
14
15#[derive(Debug, Clone)]
17pub struct OptimizationTrajectory {
18 pub parameters: Vec<Array1<f64>>,
20 pub function_values: Vec<f64>,
22 pub gradient_norms: Vec<f64>,
24 pub step_sizes: Vec<f64>,
26 pub custom_metrics: HashMap<String, Vec<f64>>,
28 pub nit: Vec<usize>,
30 pub times: Vec<f64>,
32}
33
34impl OptimizationTrajectory {
35 pub fn new() -> Self {
37 Self {
38 parameters: Vec::new(),
39 function_values: Vec::new(),
40 gradient_norms: Vec::new(),
41 step_sizes: Vec::new(),
42 custom_metrics: HashMap::new(),
43 nit: Vec::new(),
44 times: Vec::new(),
45 }
46 }
47
48 pub fn add_point(
50 &mut self,
51 iteration: usize,
52 params: &ArrayView1<f64>,
53 function_value: f64,
54 time: f64,
55 ) {
56 self.nit.push(iteration);
57 self.parameters.push(params.to_owned());
58 self.function_values.push(function_value);
59 self.times.push(time);
60 }
61
62 pub fn add_gradient_norm(&mut self, grad_norm: f64) {
64 self.gradient_norms.push(grad_norm);
65 }
66
67 pub fn add_step_size(&mut self, step_size: f64) {
69 self.step_sizes.push(step_size);
70 }
71
72 pub fn add_custom_metric(&mut self, name: &str, value: f64) {
74 self.custom_metrics
75 .entry(name.to_string())
76 .or_default()
77 .push(value);
78 }
79
80 pub fn len(&self) -> usize {
82 self.nit.len()
83 }
84
85 pub fn is_empty(&self) -> bool {
87 self.nit.is_empty()
88 }
89
90 pub fn final_parameters(&self) -> Option<&Array1<f64>> {
92 self.parameters.last()
93 }
94
95 pub fn final_function_value(&self) -> Option<f64> {
97 self.function_values.last().copied()
98 }
99
100 pub fn convergence_rate(&self) -> Option<f64> {
102 if self.function_values.len() < 3 {
103 return None;
104 }
105
106 let n = self.function_values.len();
107 let mut rates = Vec::new();
108
109 for i in 1..(n - 1) {
110 let f_current = self.function_values[i];
111 let f_next = self.function_values[i + 1];
112 let f_prev = self.function_values[i - 1];
113
114 if (f_current - f_next).abs() > 1e-14 && (f_prev - f_current).abs() > 1e-14 {
115 let rate = (f_current - f_next).abs() / (f_prev - f_current).abs();
116 if rate.is_finite() && rate > 0.0 {
117 rates.push(rate);
118 }
119 }
120 }
121
122 if rates.is_empty() {
123 None
124 } else {
125 Some(rates.iter().sum::<f64>() / rates.len() as f64)
126 }
127 }
128}
129
130impl Default for OptimizationTrajectory {
131 fn default() -> Self {
132 Self::new()
133 }
134}
135
136#[derive(Debug, Clone)]
138pub struct VisualizationConfig {
139 pub format: OutputFormat,
141 pub width: u32,
143 pub height: u32,
145 pub title: Option<String>,
147 pub show_grid: bool,
149 pub log_scale_y: bool,
151 pub color_scheme: ColorScheme,
153 pub show_legend: bool,
155 pub custom_style: Option<String>,
157}
158
159impl Default for VisualizationConfig {
160 fn default() -> Self {
161 Self {
162 format: OutputFormat::Svg,
163 width: 800,
164 height: 600,
165 title: None,
166 show_grid: true,
167 log_scale_y: false,
168 color_scheme: ColorScheme::Default,
169 show_legend: true,
170 custom_style: None,
171 }
172 }
173}
174
175#[derive(Debug, Clone, Copy, PartialEq)]
177pub enum OutputFormat {
178 Svg,
179 Png,
180 Html,
181 Data, }
183
184#[derive(Debug, Clone, Copy, PartialEq)]
186pub enum ColorScheme {
187 Default,
188 Viridis,
189 Plasma,
190 Scientific,
191 Monochrome,
192}
193
194fn crc32(data: &[u8]) -> u32 {
200 let mut crc: u32 = 0xFFFF_FFFF;
201 for &byte in data {
202 let idx = ((crc ^ u32::from(byte)) & 0xFF) as usize;
203 let mut entry = idx as u32;
205 for _ in 0..8 {
206 if entry & 1 != 0 {
207 entry = (entry >> 1) ^ 0xEDB8_8320;
208 } else {
209 entry >>= 1;
210 }
211 }
212 crc = (crc >> 8) ^ entry;
213 }
214 !crc
215}
216
217fn adler32(data: &[u8]) -> u32 {
219 let (mut s1, mut s2) = (1u32, 0u32);
220 for &b in data {
221 s1 = (s1 + u32::from(b)) % 65521;
222 s2 = (s2 + s1) % 65521;
223 }
224 (s2 << 16) | s1
225}
226
227#[inline]
229fn be32(v: u32) -> [u8; 4] {
230 v.to_be_bytes()
231}
232
233fn png_chunk(chunk_type: &[u8; 4], data: &[u8]) -> Vec<u8> {
235 let mut chunk = Vec::with_capacity(12 + data.len());
236 chunk.extend_from_slice(&be32(data.len() as u32));
237 chunk.extend_from_slice(chunk_type);
238 chunk.extend_from_slice(data);
239 let crc = crc32(&chunk[4..]);
240 chunk.extend_from_slice(&be32(crc));
241 chunk
242}
243
244fn deflate_stored(scanlines: &[u8]) -> Vec<u8> {
246 const MAX_BLOCK: usize = 65535;
247 let mut out = Vec::new();
248 out.extend_from_slice(&[0x78, 0x01]);
250
251 let adler = adler32(scanlines);
252 let total = scanlines.len();
253 let mut offset = 0;
254 while offset < total {
255 let end = (offset + MAX_BLOCK).min(total);
256 let block = &scanlines[offset..end];
257 let last = if end == total { 1u8 } else { 0u8 };
258 let len = block.len() as u16;
259 let nlen = !len;
260 out.push(last);
261 out.extend_from_slice(&len.to_le_bytes());
262 out.extend_from_slice(&nlen.to_le_bytes());
263 out.extend_from_slice(block);
264 offset = end;
265 }
266 out.extend_from_slice(&adler.to_be_bytes());
267 out
268}
269
270fn write_png(path: &Path, pixels: &[u8], width: usize, height: usize) -> ScirsResult<()> {
272 use crate::error::ScirsError;
273
274 let mut scanlines = Vec::with_capacity(height * (1 + width * 3));
276 for row in 0..height {
277 scanlines.push(0x00); let start = row * width * 3;
279 scanlines.extend_from_slice(&pixels[start..start + width * 3]);
280 }
281
282 let idat_data = deflate_stored(&scanlines);
283
284 let mut ihdr_data = Vec::with_capacity(13);
286 ihdr_data.extend_from_slice(&be32(width as u32));
287 ihdr_data.extend_from_slice(&be32(height as u32));
288 ihdr_data.extend_from_slice(&[8, 2, 0, 0, 0]);
289
290 let mut png_bytes = Vec::new();
291 png_bytes.extend_from_slice(&[137, 80, 78, 71, 13, 10, 26, 10]);
293 png_bytes.extend(png_chunk(b"IHDR", &ihdr_data));
294 png_bytes.extend(png_chunk(b"IDAT", &idat_data));
295 png_bytes.extend(png_chunk(b"IEND", &[]));
296
297 let mut file = File::create(path)
298 .map_err(|e| ScirsError::IoError(error_context!(format!("PNG create: {e}"))))?;
299 file.write_all(&png_bytes)
300 .map_err(|e| ScirsError::IoError(error_context!(format!("PNG write: {e}"))))?;
301 Ok(())
302}
303
304fn png_draw_line(
306 pixels: &mut [u8],
307 width: usize,
308 height: usize,
309 x0: usize,
310 y0: usize,
311 x1: usize,
312 y1: usize,
313 r: u8,
314 g: u8,
315 b: u8,
316) {
317 let (mut x0, mut y0) = (x0 as isize, y0 as isize);
318 let (x1, y1) = (x1 as isize, y1 as isize);
319 let dx = (x1 - x0).abs();
320 let sx: isize = if x0 < x1 { 1 } else { -1 };
321 let dy = -(y1 - y0).abs();
322 let sy: isize = if y0 < y1 { 1 } else { -1 };
323 let mut err = dx + dy;
324 loop {
325 if x0 >= 0 && x0 < width as isize && y0 >= 0 && y0 < height as isize {
326 let idx = (y0 as usize * width + x0 as usize) * 3;
327 pixels[idx] = r;
328 pixels[idx + 1] = g;
329 pixels[idx + 2] = b;
330 }
331 if x0 == x1 && y0 == y1 {
332 break;
333 }
334 let e2 = 2 * err;
335 if e2 >= dy {
336 err += dy;
337 x0 += sx;
338 }
339 if e2 <= dx {
340 err += dx;
341 y0 += sy;
342 }
343 }
344}
345
346pub struct OptimizationVisualizer {
348 config: VisualizationConfig,
349}
350
351impl OptimizationVisualizer {
352 pub fn new() -> Self {
354 Self {
355 config: VisualizationConfig::default(),
356 }
357 }
358
359 pub fn with_config(config: VisualizationConfig) -> Self {
361 Self { config }
362 }
363
364 pub fn plot_convergence(
366 &self,
367 trajectory: &OptimizationTrajectory,
368 output_path: &Path,
369 ) -> ScirsResult<()> {
370 if trajectory.is_empty() {
371 return Err(ScirsError::InvalidInput(error_context!("Empty trajectory")));
372 }
373
374 match self.config.format {
375 OutputFormat::Svg => self.plot_convergence_svg(trajectory, output_path),
376 OutputFormat::Html => self.plot_convergence_html(trajectory, output_path),
377 OutputFormat::Data => self.export_convergence_data(trajectory, output_path),
378 OutputFormat::Png => self.plot_convergence_png(trajectory, output_path),
379 }
380 }
381
382 pub fn plot_parameter_trajectory(
384 &self,
385 trajectory: &OptimizationTrajectory,
386 output_path: &Path,
387 ) -> ScirsResult<()> {
388 if trajectory.is_empty() {
389 return Err(ScirsError::InvalidInput(error_context!("Empty trajectory")));
390 }
391
392 if trajectory.parameters[0].len() != 2 {
393 return Err(ScirsError::InvalidInput(error_context!(
394 "Parameter trajectory visualization only supports 2D problems"
395 )));
396 }
397
398 match self.config.format {
399 OutputFormat::Svg => self.plot_trajectory_svg(trajectory, output_path),
400 OutputFormat::Html => self.plot_trajectory_html(trajectory, output_path),
401 OutputFormat::Data => self.export_trajectory_data(trajectory, output_path),
402 OutputFormat::Png => self.plot_trajectory_png(trajectory, output_path),
403 }
404 }
405
406 pub fn create_optimization_report(
408 &self,
409 trajectory: &OptimizationTrajectory,
410 output_dir: &Path,
411 ) -> ScirsResult<()> {
412 std::fs::create_dir_all(output_dir)?;
413
414 let convergence_path = output_dir.join("convergence.svg");
416 self.plot_convergence(trajectory, &convergence_path)?;
417
418 if !trajectory.parameters.is_empty() && trajectory.parameters[0].len() == 2 {
420 let trajectory_path = output_dir.join("trajectory.svg");
421 self.plot_parameter_trajectory(trajectory, &trajectory_path)?;
422 }
423
424 let summary_path = output_dir.join("summary.html");
426 self.generate_summary_report(trajectory, &summary_path)?;
427
428 let data_path = output_dir.join("data.csv");
430 self.export_convergence_data(trajectory, &data_path)?;
431
432 Ok(())
433 }
434
435 fn generate_summary_report(
437 &self,
438 trajectory: &OptimizationTrajectory,
439 output_path: &Path,
440 ) -> ScirsResult<()> {
441 let mut file = File::create(output_path)?;
442
443 let html_content = format!(
444 r#"<!DOCTYPE html>
445<html>
446<head>
447 <title>Optimization Summary</title>
448 <style>
449 body {{ font-family: Arial, sans-serif; margin: 20px; }}
450 .metric {{ margin: 10px 0; }}
451 .value {{ font-weight: bold; color: #2E86AB; }}
452 table {{ border-collapse: collapse; width: 100%; }}
453 th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
454 th {{ background-color: #f2f2f2; }}
455 </style>
456</head>
457<body>
458 <h1>Optimization Summary Report</h1>
459
460 <h2>Basic Statistics</h2>
461 <div class="metric">Total Iterations: <span class="value">{}</span></div>
462 <div class="metric">Final Function Value: <span class="value">{:.6e}</span></div>
463 <div class="metric">Initial Function Value: <span class="value">{:.6e}</span></div>
464 <div class="metric">Function Improvement: <span class="value">{:.6e}</span></div>
465 <div class="metric">Total Runtime: <span class="value">{:.3}s</span></div>
466 {}
467
468 <h2>Convergence Analysis</h2>
469 <table>
470 <tr><th>Metric</th><th>Value</th></tr>
471 <tr><td>Convergence Rate</td><td>{}</td></tr>
472 <tr><td>Average Iteration Time</td><td>{:.6}s</td></tr>
473 <tr><td>Function Evaluations per Second</td><td>{:.2}</td></tr>
474 </table>
475
476 {}
477</body>
478</html>"#,
479 trajectory.len(),
480 trajectory.final_function_value().unwrap_or(0.0),
481 trajectory.function_values.first().cloned().unwrap_or(0.0),
482 trajectory.function_values.first().cloned().unwrap_or(0.0)
483 - trajectory.final_function_value().unwrap_or(0.0),
484 trajectory.times.last().cloned().unwrap_or(0.0),
485 if !trajectory.gradient_norms.is_empty() {
486 format!("<div class=\"metric\">Final Gradient Norm: <span class=\"value\">{:.6e}</span></div>",
487 trajectory.gradient_norms.last().cloned().unwrap_or(0.0))
488 } else {
489 String::new()
490 },
491 trajectory
492 .convergence_rate()
493 .map(|r| format!("{:.6}", r))
494 .unwrap_or_else(|| "N/A".to_string()),
495 if trajectory.len() > 1 && !trajectory.times.is_empty() {
496 trajectory.times.last().cloned().unwrap_or(0.0) / trajectory.len() as f64
497 } else {
498 0.0
499 },
500 if !trajectory.times.is_empty() && trajectory.times.last().cloned().unwrap_or(0.0) > 0.0
501 {
502 trajectory.len() as f64 / trajectory.times.last().cloned().unwrap_or(1.0)
503 } else {
504 0.0
505 },
506 self.generate_custom_metrics_table(trajectory)
507 );
508
509 file.write_all(html_content.as_bytes())?;
510 Ok(())
511 }
512
513 fn generate_custom_metrics_table(&self, trajectory: &OptimizationTrajectory) -> String {
514 if trajectory.custom_metrics.is_empty() {
515 return String::new();
516 }
517
518 let mut table = String::from("<h2>Custom Metrics</h2>\n<table>\n<tr><th>Metric</th><th>Final Value</th><th>Min</th><th>Max</th><th>Mean</th></tr>\n");
519
520 for (name, values) in &trajectory.custom_metrics {
521 if let Some(final_val) = values.last() {
522 let min_val = values.iter().cloned().fold(f64::INFINITY, f64::min);
523 let max_val = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
524 let mean_val = values.iter().sum::<f64>() / values.len() as f64;
525
526 table.push_str(&format!(
527 "<tr><td>{}</td><td>{:.6e}</td><td>{:.6e}</td><td>{:.6e}</td><td>{:.6e}</td></tr>\n",
528 name, final_val, min_val, max_val, mean_val
529 ));
530 }
531 }
532 table.push_str("</table>\n");
533 table
534 }
535
536 fn plot_convergence_svg(
537 &self,
538 trajectory: &OptimizationTrajectory,
539 output_path: &Path,
540 ) -> ScirsResult<()> {
541 let mut file = File::create(output_path)?;
542
543 let width = self.config.width;
544 let height = self.config.height;
545 let margin = 60;
546 let plot_width = width - 2 * margin;
547 let plot_height = height - 2 * margin;
548
549 let min_y = if self.config.log_scale_y {
550 trajectory
551 .function_values
552 .iter()
553 .filter(|&&v| v > 0.0)
554 .cloned()
555 .fold(f64::INFINITY, f64::min)
556 .ln()
557 } else {
558 trajectory
559 .function_values
560 .iter()
561 .cloned()
562 .fold(f64::INFINITY, f64::min)
563 };
564
565 let max_y = if self.config.log_scale_y {
566 trajectory
567 .function_values
568 .iter()
569 .filter(|&&v| v > 0.0)
570 .cloned()
571 .fold(f64::NEG_INFINITY, f64::max)
572 .ln()
573 } else {
574 trajectory
575 .function_values
576 .iter()
577 .cloned()
578 .fold(f64::NEG_INFINITY, f64::max)
579 };
580
581 let max_x = trajectory.nit.len() as f64;
582
583 let mut svg_content = format!(
584 r#"<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">
585 <defs>
586 <style>
587 .axis {{ stroke: #333; stroke-width: 1; }}
588 .grid {{ stroke: #ccc; stroke-width: 0.5; stroke-dasharray: 2,2; }}
589 .line {{ fill: none; stroke: #2E86AB; stroke-width: 2; }}
590 .text {{ font-family: Arial, sans-serif; font-size: 12px; fill: #333; }}
591 .title {{ font-family: Arial, sans-serif; font-size: 16px; fill: #333; font-weight: bold; }}
592 </style>
593 </defs>
594"#,
595 width, height
596 );
597
598 if self.config.show_grid {
600 for i in 0..=10 {
601 let x = margin as f64 + (i as f64 / 10.0) * plot_width as f64;
602 svg_content.push_str(&format!(
603 r#" <line x1="{}" y1="{}" x2="{}" y2="{}" class="grid" />
604"#,
605 x,
606 margin,
607 x,
608 height - margin
609 ));
610 }
611
612 for i in 0..=10 {
613 let y = margin as f64 + (i as f64 / 10.0) * plot_height as f64;
614 svg_content.push_str(&format!(
615 r#" <line x1="{}" y1="{}" x2="{}" y2="{}" class="grid" />
616"#,
617 margin,
618 y,
619 width - margin,
620 y
621 ));
622 }
623 }
624
625 svg_content.push_str(&format!(
627 r#" <line x1="{}" y1="{}" x2="{}" y2="{}" class="axis" />
628 <line x1="{}" y1="{}" x2="{}" y2="{}" class="axis" />
629"#,
630 margin,
631 height - margin,
632 width - margin,
633 height - margin, margin,
635 margin,
636 margin,
637 height - margin ));
639
640 svg_content.push_str(" <polyline points=\"");
642 for (i, &f_val) in trajectory.function_values.iter().enumerate() {
643 let x = margin as f64 + (i as f64 / max_x) * plot_width as f64;
644 let y_val = if self.config.log_scale_y && f_val > 0.0 {
645 f_val.ln()
646 } else {
647 f_val
648 };
649 let y = height as f64
650 - margin as f64
651 - ((y_val - min_y) / (max_y - min_y)) * plot_height as f64;
652 svg_content.push_str(&format!("{},{} ", x, y));
653 }
654 svg_content.push_str("\" class=\"line\" />\n");
655
656 if let Some(ref title) = self.config.title {
658 svg_content.push_str(&format!(
659 r#" <text x="{}" y="30" text-anchor="middle" class="title">{}</text>
660"#,
661 width / 2,
662 title
663 ));
664 }
665
666 svg_content.push_str(&format!(
668 r#" <text x="{}" y="{}" text-anchor="middle" class="text">Iteration</text>
669 <text x="20" y="{}" text-anchor="middle" class="text" transform="rotate(-90 20 {})">Function Value{}</text>
670"#,
671 width / 2, height - 10,
672 height / 2, height / 2,
673 if self.config.log_scale_y { " (log)" } else { "" }
674 ));
675
676 svg_content.push_str("</svg>");
677
678 file.write_all(svg_content.as_bytes())?;
679 Ok(())
680 }
681
682 fn plot_convergence_html(
683 &self,
684 trajectory: &OptimizationTrajectory,
685 output_path: &Path,
686 ) -> ScirsResult<()> {
687 let mut file = File::create(output_path)?;
688
689 let html_content = format!(
690 r#"<!DOCTYPE html>
691<html>
692<head>
693 <title>Optimization Convergence</title>
694 <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
695</head>
696<body>
697 <div id="convergence-plot" style="width:{}px;height:{}px;"></div>
698 <script>
699 var trace = {{
700 x: [{}],
701 y: [{}],
702 type: 'scatter',
703 mode: 'lines',
704 name: 'Function Value',
705 line: {{ color: '#2E86AB', width: 2 }}
706 }};
707
708 var layout = {{
709 title: '{}',
710 xaxis: {{ title: 'Iteration' }},
711 yaxis: {{
712 title: 'Function Value',
713 type: '{}'
714 }},
715 showlegend: {}
716 }};
717
718 Plotly.newPlot('convergence-plot', [trace], layout);
719 </script>
720</body>
721</html>"#,
722 self.config.width,
723 self.config.height,
724 trajectory
725 .nit
726 .iter()
727 .map(|i| i.to_string())
728 .collect::<Vec<_>>()
729 .join(","),
730 trajectory
731 .function_values
732 .iter()
733 .map(|f| f.to_string())
734 .collect::<Vec<_>>()
735 .join(","),
736 self.config
737 .title
738 .as_deref()
739 .unwrap_or("Optimization Convergence"),
740 if self.config.log_scale_y {
741 "log"
742 } else {
743 "linear"
744 },
745 self.config.show_legend
746 );
747
748 file.write_all(html_content.as_bytes())?;
749 Ok(())
750 }
751
752 fn plot_trajectory_svg(
753 &self,
754 trajectory: &OptimizationTrajectory,
755 output_path: &Path,
756 ) -> ScirsResult<()> {
757 let mut file = File::create(output_path)?;
758
759 let width = self.config.width;
760 let height = self.config.height;
761 let margin = 60;
762 let plot_width = width - 2 * margin;
763 let plot_height = height - 2 * margin;
764
765 let x_coords: Vec<f64> = trajectory.parameters.iter().map(|p| p[0]).collect();
766 let y_coords: Vec<f64> = trajectory.parameters.iter().map(|p| p[1]).collect();
767
768 let min_x = x_coords.iter().cloned().fold(f64::INFINITY, f64::min);
769 let max_x = x_coords.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
770 let min_y = y_coords.iter().cloned().fold(f64::INFINITY, f64::min);
771 let max_y = y_coords.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
772
773 let mut svg_content = format!(
774 r#"<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">
775 <defs>
776 <style>
777 .axis {{ stroke: #333; stroke-width: 1; }}
778 .grid {{ stroke: #ccc; stroke-width: 0.5; stroke-dasharray: 2,2; }}
779 .trajectory {{ fill: none; stroke: #2E86AB; stroke-width: 2; }}
780 .start {{ fill: #4CAF50; stroke: #333; stroke-width: 1; }}
781 .end {{ fill: #F44336; stroke: #333; stroke-width: 1; }}
782 .text {{ font-family: Arial, sans-serif; font-size: 12px; fill: #333; }}
783 .title {{ font-family: Arial, sans-serif; font-size: 16px; fill: #333; font-weight: bold; }}
784 </style>
785 </defs>
786"#,
787 width, height
788 );
789
790 if self.config.show_grid {
792 for i in 0..=10 {
793 let x = margin as f64 + (i as f64 / 10.0) * plot_width as f64;
794 svg_content.push_str(&format!(
795 r#" <line x1="{}" y1="{}" x2="{}" y2="{}" class="grid" />
796"#,
797 x,
798 margin,
799 x,
800 height - margin
801 ));
802 }
803
804 for i in 0..=10 {
805 let y = margin as f64 + (i as f64 / 10.0) * plot_height as f64;
806 svg_content.push_str(&format!(
807 r#" <line x1="{}" y1="{}" x2="{}" y2="{}" class="grid" />
808"#,
809 margin,
810 y,
811 width - margin,
812 y
813 ));
814 }
815 }
816
817 svg_content.push_str(&format!(
819 r#" <line x1="{}" y1="{}" x2="{}" y2="{}" class="axis" />
820 <line x1="{}" y1="{}" x2="{}" y2="{}" class="axis" />
821"#,
822 margin,
823 height - margin,
824 width - margin,
825 height - margin,
826 margin,
827 margin,
828 margin,
829 height - margin
830 ));
831
832 svg_content.push_str(" <polyline points=\"");
834 for (x_val, y_val) in x_coords.iter().zip(y_coords.iter()) {
835 let x = margin as f64 + ((x_val - min_x) / (max_x - min_x)) * plot_width as f64;
836 let y = height as f64
837 - margin as f64
838 - ((y_val - min_y) / (max_y - min_y)) * plot_height as f64;
839 svg_content.push_str(&format!("{},{} ", x, y));
840 }
841 svg_content.push_str("\" class=\"trajectory\" />\n");
842
843 if !x_coords.is_empty() {
845 let start_x =
846 margin as f64 + ((x_coords[0] - min_x) / (max_x - min_x)) * plot_width as f64;
847 let start_y = height as f64
848 - margin as f64
849 - ((y_coords[0] - min_y) / (max_y - min_y)) * plot_height as f64;
850
851 let end_x = margin as f64
852 + ((x_coords.last().expect("Operation failed") - min_x) / (max_x - min_x))
853 * plot_width as f64;
854 let end_y = height as f64
855 - margin as f64
856 - ((y_coords.last().expect("Operation failed") - min_y) / (max_y - min_y))
857 * plot_height as f64;
858
859 svg_content.push_str(&format!(
860 r#" <circle cx="{}" cy="{}" r="5" class="start" />
861 <circle cx="{}" cy="{}" r="5" class="end" />
862"#,
863 start_x, start_y, end_x, end_y
864 ));
865 }
866
867 if let Some(ref title) = self.config.title {
869 svg_content.push_str(&format!(
870 r#" <text x="{}" y="30" text-anchor="middle" class="title">{}</text>
871"#,
872 width / 2,
873 title
874 ));
875 }
876
877 svg_content.push_str("</svg>");
878
879 file.write_all(svg_content.as_bytes())?;
880 Ok(())
881 }
882
883 fn plot_trajectory_html(
884 &self,
885 trajectory: &OptimizationTrajectory,
886 output_path: &Path,
887 ) -> ScirsResult<()> {
888 let mut file = File::create(output_path)?;
889
890 let x_coords: Vec<f64> = trajectory.parameters.iter().map(|p| p[0]).collect();
891 let y_coords: Vec<f64> = trajectory.parameters.iter().map(|p| p[1]).collect();
892
893 let html_content = format!(
894 r#"<!DOCTYPE html>
895<html>
896<head>
897 <title>Parameter Trajectory</title>
898 <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
899</head>
900<body>
901 <div id="trajectory-plot" style="width:{}px;height:{}px;"></div>
902 <script>
903 var trace = {{
904 x: [{}],
905 y: [{}],
906 type: 'scatter',
907 mode: 'lines+markers',
908 name: 'Trajectory',
909 line: {{ color: '#2E86AB', width: 2 }},
910 marker: {{
911 size: [{}],
912 color: [{}],
913 colorscale: 'Viridis',
914 showscale: true
915 }}
916 }};
917
918 var layout = {{
919 title: '{}',
920 xaxis: {{ title: 'Parameter 1' }},
921 yaxis: {{ title: 'Parameter 2' }},
922 showlegend: {}
923 }};
924
925 Plotly.newPlot('trajectory-plot', [trace], layout);
926 </script>
927</body>
928</html>"#,
929 self.config.width,
930 self.config.height,
931 x_coords
932 .iter()
933 .map(|x| x.to_string())
934 .collect::<Vec<_>>()
935 .join(","),
936 y_coords
937 .iter()
938 .map(|y| y.to_string())
939 .collect::<Vec<_>>()
940 .join(","),
941 (0..x_coords.len())
942 .map(|i| if i == 0 {
943 "10"
944 } else if i == x_coords.len() - 1 {
945 "10"
946 } else {
947 "6"
948 })
949 .collect::<Vec<_>>()
950 .join(","),
951 (0..x_coords.len())
952 .map(|i| i.to_string())
953 .collect::<Vec<_>>()
954 .join(","),
955 self.config
956 .title
957 .as_deref()
958 .unwrap_or("Parameter Trajectory"),
959 self.config.show_legend
960 );
961
962 file.write_all(html_content.as_bytes())?;
963 Ok(())
964 }
965
966 fn export_convergence_data(
967 &self,
968 trajectory: &OptimizationTrajectory,
969 output_path: &Path,
970 ) -> ScirsResult<()> {
971 let mut file = File::create(output_path)?;
972
973 let mut header = "iteration,function_value,time".to_string();
975 if !trajectory.gradient_norms.is_empty() {
976 header.push_str(",gradient_norm");
977 }
978 if !trajectory.step_sizes.is_empty() {
979 header.push_str(",step_size");
980 }
981
982 if !trajectory.parameters.is_empty() {
984 for i in 0..trajectory.parameters[0].len() {
985 header.push_str(&format!(",param_{}", i));
986 }
987 }
988
989 for name in trajectory.custom_metrics.keys() {
991 header.push_str(&format!(",{}", name));
992 }
993 header.push('\n');
994
995 file.write_all(header.as_bytes())?;
996
997 for i in 0..trajectory.len() {
999 let mut row = format!(
1000 "{},{},{}",
1001 trajectory.nit[i], trajectory.function_values[i], trajectory.times[i]
1002 );
1003
1004 if i < trajectory.gradient_norms.len() {
1005 row.push_str(&format!(",{}", trajectory.gradient_norms[i]));
1006 } else if !trajectory.gradient_norms.is_empty() {
1007 row.push(',');
1008 }
1009
1010 if i < trajectory.step_sizes.len() {
1011 row.push_str(&format!(",{}", trajectory.step_sizes[i]));
1012 } else if !trajectory.step_sizes.is_empty() {
1013 row.push(',');
1014 }
1015
1016 if i < trajectory.parameters.len() {
1018 for param in trajectory.parameters[i].iter() {
1019 row.push_str(&format!(",{}", param));
1020 }
1021 }
1022
1023 for name in trajectory.custom_metrics.keys() {
1025 if let Some(values) = trajectory.custom_metrics.get(name) {
1026 if i < values.len() {
1027 row.push_str(&format!(",{}", values[i]));
1028 } else {
1029 row.push(',');
1030 }
1031 }
1032 }
1033
1034 row.push('\n');
1035 file.write_all(row.as_bytes())?;
1036 }
1037
1038 Ok(())
1039 }
1040
1041 fn export_trajectory_data(
1042 &self,
1043 trajectory: &OptimizationTrajectory,
1044 output_path: &Path,
1045 ) -> ScirsResult<()> {
1046 self.export_convergence_data(trajectory, output_path)
1047 }
1048
1049 fn plot_convergence_png(
1054 &self,
1055 trajectory: &OptimizationTrajectory,
1056 output_path: &Path,
1057 ) -> ScirsResult<()> {
1058 let width = self.config.width as usize;
1059 let height = self.config.height as usize;
1060 let margin = 60_usize;
1061
1062 let mut pixels = vec![255u8; width * height * 3];
1064
1065 let min_y = trajectory
1066 .function_values
1067 .iter()
1068 .cloned()
1069 .fold(f64::INFINITY, f64::min);
1070 let max_y = trajectory
1071 .function_values
1072 .iter()
1073 .cloned()
1074 .fold(f64::NEG_INFINITY, f64::max);
1075 let y_range = (max_y - min_y).max(f64::EPSILON);
1076 let n = trajectory.function_values.len();
1077
1078 for px in margin..width.saturating_sub(margin) {
1080 let row = height.saturating_sub(margin + 1);
1081 let idx = (row * width + px) * 3;
1082 pixels[idx] = 50;
1083 pixels[idx + 1] = 50;
1084 pixels[idx + 2] = 50;
1085 }
1086 for py in margin..height.saturating_sub(margin) {
1087 let idx = (py * width + margin) * 3;
1088 pixels[idx] = 50;
1089 pixels[idx + 1] = 50;
1090 pixels[idx + 2] = 50;
1091 }
1092
1093 let plot_w = width.saturating_sub(2 * margin);
1095 let plot_h = height.saturating_sub(2 * margin);
1096 if n >= 2 {
1097 for i in 0..n.saturating_sub(1) {
1098 let x0 = margin + i * plot_w / (n - 1);
1099 let x1 = margin + (i + 1) * plot_w / (n - 1);
1100 let v0 = trajectory.function_values[i];
1101 let v1 = trajectory.function_values[i + 1];
1102 let y0 = height
1103 .saturating_sub(margin)
1104 .saturating_sub(((v0 - min_y) / y_range * plot_h as f64) as usize);
1105 let y1 = height
1106 .saturating_sub(margin)
1107 .saturating_sub(((v1 - min_y) / y_range * plot_h as f64) as usize);
1108 png_draw_line(&mut pixels, width, height, x0, y0, x1, y1, 46, 134, 171);
1110 }
1111 }
1112
1113 write_png(output_path, &pixels, width, height)
1114 }
1115
1116 fn plot_trajectory_png(
1118 &self,
1119 trajectory: &OptimizationTrajectory,
1120 output_path: &Path,
1121 ) -> ScirsResult<()> {
1122 let width = self.config.width as usize;
1123 let height = self.config.height as usize;
1124 let margin = 60_usize;
1125
1126 if trajectory.parameters.is_empty() || trajectory.parameters[0].len() != 2 {
1127 return Err(ScirsError::InvalidInput(error_context!(
1128 "Parameter trajectory PNG visualization only supports 2D problems"
1129 )));
1130 }
1131
1132 let xs: Vec<f64> = trajectory.parameters.iter().map(|p| p[0]).collect();
1133 let ys: Vec<f64> = trajectory.parameters.iter().map(|p| p[1]).collect();
1134
1135 let min_x = xs.iter().cloned().fold(f64::INFINITY, f64::min);
1136 let max_x = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1137 let min_y = ys.iter().cloned().fold(f64::INFINITY, f64::min);
1138 let max_y = ys.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1139 let x_range = (max_x - min_x).max(f64::EPSILON);
1140 let y_range = (max_y - min_y).max(f64::EPSILON);
1141
1142 let mut pixels = vec![255u8; width * height * 3];
1143
1144 let plot_w = width.saturating_sub(2 * margin);
1145 let plot_h = height.saturating_sub(2 * margin);
1146 let n = xs.len();
1147 for i in 0..n.saturating_sub(1) {
1148 let px0 = margin + ((xs[i] - min_x) / x_range * plot_w as f64) as usize;
1149 let py0 = height
1150 .saturating_sub(margin)
1151 .saturating_sub(((ys[i] - min_y) / y_range * plot_h as f64) as usize);
1152 let px1 = margin + ((xs[i + 1] - min_x) / x_range * plot_w as f64) as usize;
1153 let py1 = height
1154 .saturating_sub(margin)
1155 .saturating_sub(((ys[i + 1] - min_y) / y_range * plot_h as f64) as usize);
1156 png_draw_line(&mut pixels, width, height, px0, py0, px1, py1, 46, 134, 171);
1157 }
1158
1159 write_png(output_path, &pixels, width, height)
1160 }
1161}
1162
1163impl Default for OptimizationVisualizer {
1164 fn default() -> Self {
1165 Self::new()
1166 }
1167}
1168
1169pub mod tracking {
1171 use super::OptimizationTrajectory;
1172 use scirs2_core::ndarray::ArrayView1;
1173 use std::time::Instant;
1174
1175 pub struct TrajectoryTracker {
1177 trajectory: OptimizationTrajectory,
1178 start_time: Instant,
1179 }
1180
1181 impl TrajectoryTracker {
1182 pub fn new() -> Self {
1184 Self {
1185 trajectory: OptimizationTrajectory::new(),
1186 start_time: Instant::now(),
1187 }
1188 }
1189
1190 pub fn record(&mut self, iteration: usize, params: &ArrayView1<f64>, function_value: f64) {
1192 let elapsed = self.start_time.elapsed().as_secs_f64();
1193 self.trajectory
1194 .add_point(iteration, params, function_value, elapsed);
1195 }
1196
1197 pub fn record_gradient_norm(&mut self, grad_norm: f64) {
1199 self.trajectory.add_gradient_norm(grad_norm);
1200 }
1201
1202 pub fn record_step_size(&mut self, step_size: f64) {
1204 self.trajectory.add_step_size(step_size);
1205 }
1206
1207 pub fn record_custom_metric(&mut self, name: &str, value: f64) {
1209 self.trajectory.add_custom_metric(name, value);
1210 }
1211
1212 pub fn trajectory(&self) -> &OptimizationTrajectory {
1214 &self.trajectory
1215 }
1216
1217 pub fn into_trajectory(self) -> OptimizationTrajectory {
1219 self.trajectory
1220 }
1221 }
1222
1223 impl Default for TrajectoryTracker {
1224 fn default() -> Self {
1225 Self::new()
1226 }
1227 }
1228}
1229
1230#[cfg(test)]
1231mod tests {
1232 use super::*;
1233 use scirs2_core::ndarray::array;
1234
1235 #[test]
1236 fn test_trajectory_creation() {
1237 let mut trajectory = OptimizationTrajectory::new();
1238 assert!(trajectory.is_empty());
1239
1240 let params = array![1.0, 2.0];
1241 trajectory.add_point(0, ¶ms.view(), 5.0, 0.1);
1242
1243 assert_eq!(trajectory.len(), 1);
1244 assert_eq!(trajectory.final_function_value(), Some(5.0));
1245 }
1246
1247 #[test]
1248 fn test_convergence_rate_calculation() {
1249 let mut trajectory = OptimizationTrajectory::new();
1250
1251 let function_values = vec![10.0, 5.0, 2.5, 1.25, 0.625];
1253 for (i, &f_val) in function_values.iter().enumerate() {
1254 let params = array![i as f64, i as f64];
1255 trajectory.add_point(i, ¶ms.view(), f_val, i as f64 * 0.1);
1256 }
1257
1258 let rate = trajectory.convergence_rate();
1259 assert!(rate.is_some());
1260 assert!((rate.expect("Operation failed") - 0.5).abs() < 0.1);
1262 }
1263
1264 #[test]
1265 fn test_visualization_config() {
1266 let config = VisualizationConfig {
1267 format: OutputFormat::Svg,
1268 width: 1000,
1269 height: 800,
1270 title: Some("Test Plot".to_string()),
1271 show_grid: true,
1272 log_scale_y: true,
1273 color_scheme: ColorScheme::Viridis,
1274 show_legend: false,
1275 custom_style: None,
1276 };
1277
1278 let visualizer = OptimizationVisualizer::with_config(config);
1279 assert_eq!(visualizer.config.width, 1000);
1280 assert_eq!(visualizer.config.height, 800);
1281 }
1282
1283 #[test]
1284 fn test_trajectory_tracker() {
1285 let mut tracker = tracking::TrajectoryTracker::new();
1286
1287 let params1 = array![0.0, 0.0];
1288 let params2 = array![1.0, 1.0];
1289
1290 tracker.record(0, ¶ms1.view(), 10.0);
1291 tracker.record_gradient_norm(2.5);
1292 tracker.record_step_size(0.1);
1293
1294 tracker.record(1, ¶ms2.view(), 5.0);
1295 tracker.record_gradient_norm(1.5);
1296 tracker.record_step_size(0.2);
1297
1298 let trajectory = tracker.trajectory();
1299 assert_eq!(trajectory.len(), 2);
1300 assert_eq!(trajectory.gradient_norms.len(), 2);
1301 assert_eq!(trajectory.step_sizes.len(), 2);
1302 assert_eq!(trajectory.final_function_value(), Some(5.0));
1303 }
1304
1305 #[test]
1306 fn test_png_convergence_output_valid_file() {
1307 let mut trajectory = OptimizationTrajectory::new();
1310 let params = array![0.0f64, 0.0];
1311 for i in 0..10 {
1312 let val = 100.0 / (i as f64 + 1.0);
1313 trajectory.add_point(i, ¶ms.view(), val, i as f64 * 0.01);
1314 }
1315
1316 let tmp_dir = std::env::temp_dir();
1317 let out_path = tmp_dir.join("test_convergence.png");
1318
1319 let config = VisualizationConfig {
1320 format: OutputFormat::Png,
1321 width: 100,
1322 height: 80,
1323 title: None,
1324 show_grid: false,
1325 log_scale_y: false,
1326 color_scheme: ColorScheme::Default,
1327 show_legend: false,
1328 custom_style: None,
1329 };
1330 let vis = OptimizationVisualizer::with_config(config);
1331 vis.plot_convergence(&trajectory, &out_path)
1332 .expect("PNG write failed");
1333
1334 let data = std::fs::read(&out_path).expect("PNG read failed");
1336 assert!(data.len() > 8, "PNG too small");
1337 assert_eq!(
1338 &data[0..8],
1339 &[137, 80, 78, 71, 13, 10, 26, 10],
1340 "Invalid PNG signature"
1341 );
1342
1343 let _ = std::fs::remove_file(&out_path);
1345 }
1346
1347 #[test]
1348 fn test_png_trajectory_output_valid_file() {
1349 let mut trajectory = OptimizationTrajectory::new();
1351 for i in 0..5 {
1352 let params = array![i as f64 * 0.1, i as f64 * 0.2];
1353 trajectory.add_point(i, ¶ms.view(), 10.0 - i as f64, i as f64 * 0.01);
1354 }
1355
1356 let tmp_dir = std::env::temp_dir();
1357 let out_path = tmp_dir.join("test_trajectory.png");
1358
1359 let config = VisualizationConfig {
1360 format: OutputFormat::Png,
1361 width: 80,
1362 height: 60,
1363 title: None,
1364 show_grid: false,
1365 log_scale_y: false,
1366 color_scheme: ColorScheme::Default,
1367 show_legend: false,
1368 custom_style: None,
1369 };
1370 let vis = OptimizationVisualizer::with_config(config);
1371 vis.plot_parameter_trajectory(&trajectory, &out_path)
1372 .expect("PNG traj write failed");
1373
1374 let data = std::fs::read(&out_path).expect("PNG traj read failed");
1375 assert!(data.len() > 8, "PNG trajectory too small");
1376 assert_eq!(
1377 &data[0..8],
1378 &[137, 80, 78, 71, 13, 10, 26, 10],
1379 "Invalid PNG signature"
1380 );
1381
1382 let _ = std::fs::remove_file(&out_path);
1383 }
1384}