scirs2_optimize/
visualization.rs

1//! Visualization tools for optimization trajectories and analysis
2//!
3//! This module provides comprehensive visualization capabilities for optimization
4//! processes, including trajectory plotting, convergence analysis, and parameter
5//! surface visualization.
6
7use crate::error::{ScirsError, ScirsResult};
8use ndarray::{Array1, ArrayView1}; // Unused import: Array2, ArrayView2
9use scirs2_core::error_context;
10use std::collections::HashMap;
11use std::fs::File;
12use std::io::Write;
13use std::path::Path;
14
15/// Trajectory data collected during optimization
16#[derive(Debug, Clone)]
17pub struct OptimizationTrajectory {
18    /// Parameter values at each iteration
19    pub parameters: Vec<Array1<f64>>,
20    /// Function values at each iteration
21    pub function_values: Vec<f64>,
22    /// Gradient norms at each iteration (if available)
23    pub gradient_norms: Vec<f64>,
24    /// Step sizes at each iteration (if available)
25    pub step_sizes: Vec<f64>,
26    /// Custom metrics at each iteration
27    pub custom_metrics: HashMap<String, Vec<f64>>,
28    /// Iteration numbers
29    pub nit: Vec<usize>,
30    /// Wall clock times (in seconds from start)
31    pub times: Vec<f64>,
32}
33
34impl OptimizationTrajectory {
35    /// Create a new empty trajectory
36    pub fn new() -> Self {
37        Self {
38            parameters: Vec::new(),
39            function_values: Vec::new(),
40            gradient_norms: Vec::new(),
41            step_sizes: Vec::new(),
42            custom_metrics: HashMap::new(),
43            nit: Vec::new(),
44            times: Vec::new(),
45        }
46    }
47
48    /// Add a new point to the trajectory
49    pub fn add_point(
50        &mut self,
51        iteration: usize,
52        params: &ArrayView1<f64>,
53        function_value: f64,
54        time: f64,
55    ) {
56        self.nit.push(iteration);
57        self.parameters.push(params.to_owned());
58        self.function_values.push(function_value);
59        self.times.push(time);
60    }
61
62    /// Add gradient norm information
63    pub fn add_gradient_norm(&mut self, grad_norm: f64) {
64        self.gradient_norms.push(grad_norm);
65    }
66
67    /// Add step size information
68    pub fn add_step_size(&mut self, step_size: f64) {
69        self.step_sizes.push(step_size);
70    }
71
72    /// Add custom metric
73    pub fn add_custom_metric(&mut self, name: &str, value: f64) {
74        self.custom_metrics
75            .entry(name.to_string())
76            .or_insert_with(Vec::new)
77            .push(value);
78    }
79
80    /// Get the number of recorded points
81    pub fn len(&self) -> usize {
82        self.nit.len()
83    }
84
85    /// Check if trajectory is empty
86    pub fn is_empty(&self) -> bool {
87        self.nit.is_empty()
88    }
89
90    /// Get the final parameter values
91    pub fn final_parameters(&self) -> Option<&Array1<f64>> {
92        self.parameters.last()
93    }
94
95    /// Get the final function value
96    pub fn final_function_value(&self) -> Option<f64> {
97        self.function_values.last().copied()
98    }
99
100    /// Calculate convergence rate (linear convergence coefficient)
101    pub fn convergence_rate(&self) -> Option<f64> {
102        if self.function_values.len() < 3 {
103            return None;
104        }
105
106        let n = self.function_values.len();
107        let mut rates = Vec::new();
108
109        for i in 1..(n - 1) {
110            let f_current = self.function_values[i];
111            let f_next = self.function_values[i + 1];
112            let f_prev = self.function_values[i - 1];
113
114            if (f_current - f_next).abs() > 1e-14 && (f_prev - f_current).abs() > 1e-14 {
115                let rate = (f_current - f_next).abs() / (f_prev - f_current).abs();
116                if rate.is_finite() && rate > 0.0 {
117                    rates.push(rate);
118                }
119            }
120        }
121
122        if rates.is_empty() {
123            None
124        } else {
125            Some(rates.iter().sum::<f64>() / rates.len() as f64)
126        }
127    }
128}
129
130impl Default for OptimizationTrajectory {
131    fn default() -> Self {
132        Self::new()
133    }
134}
135
136/// Configuration for trajectory visualization
137#[derive(Debug, Clone)]
138pub struct VisualizationConfig {
139    /// Output format (svg, png, html)
140    pub format: OutputFormat,
141    /// Width of the plot in pixels
142    pub width: u32,
143    /// Height of the plot in pixels
144    pub height: u32,
145    /// Title for the plot
146    pub title: Option<String>,
147    /// Whether to show grid
148    pub show_grid: bool,
149    /// Whether to use logarithmic scale for y-axis
150    pub log_scale_y: bool,
151    /// Color scheme
152    pub color_scheme: ColorScheme,
153    /// Whether to show legend
154    pub show_legend: bool,
155    /// Custom styling
156    pub custom_style: Option<String>,
157}
158
159impl Default for VisualizationConfig {
160    fn default() -> Self {
161        Self {
162            format: OutputFormat::Svg,
163            width: 800,
164            height: 600,
165            title: None,
166            show_grid: true,
167            log_scale_y: false,
168            color_scheme: ColorScheme::Default,
169            show_legend: true,
170            custom_style: None,
171        }
172    }
173}
174
175/// Supported output formats
176#[derive(Debug, Clone, Copy, PartialEq)]
177pub enum OutputFormat {
178    Svg,
179    Png,
180    Html,
181    Data, // Raw data output
182}
183
184/// Color schemes for visualization
185#[derive(Debug, Clone, Copy, PartialEq)]
186pub enum ColorScheme {
187    Default,
188    Viridis,
189    Plasma,
190    Scientific,
191    Monochrome,
192}
193
194/// Main visualization interface
195pub struct OptimizationVisualizer {
196    config: VisualizationConfig,
197}
198
199impl OptimizationVisualizer {
200    /// Create a new visualizer with default configuration
201    pub fn new() -> Self {
202        Self {
203            config: VisualizationConfig::default(),
204        }
205    }
206
207    /// Create a new visualizer with custom configuration
208    pub fn with_config(config: VisualizationConfig) -> Self {
209        Self { config }
210    }
211
212    /// Plot convergence curve (function value vs iteration)
213    pub fn plot_convergence(
214        &self,
215        trajectory: &OptimizationTrajectory,
216        output_path: &Path,
217    ) -> ScirsResult<()> {
218        if trajectory.is_empty() {
219            return Err(ScirsError::InvalidInput(error_context!("Empty trajectory")));
220        }
221
222        match self.config.format {
223            OutputFormat::Svg => self.plot_convergence_svg(trajectory, output_path),
224            OutputFormat::Html => self.plot_convergence_html(trajectory, output_path),
225            OutputFormat::Data => self.export_convergence_data(trajectory, output_path),
226            _ => Err(ScirsError::NotImplementedError(error_context!(
227                "PNG output not yet implemented"
228            ))),
229        }
230    }
231
232    /// Plot parameter trajectory (for 2D problems)
233    pub fn plot_parameter_trajectory(
234        &self,
235        trajectory: &OptimizationTrajectory,
236        output_path: &Path,
237    ) -> ScirsResult<()> {
238        if trajectory.is_empty() {
239            return Err(ScirsError::InvalidInput(error_context!("Empty trajectory")));
240        }
241
242        if trajectory.parameters[0].len() != 2 {
243            return Err(ScirsError::InvalidInput(error_context!(
244                "Parameter trajectory visualization only supports 2D problems"
245            )));
246        }
247
248        match self.config.format {
249            OutputFormat::Svg => self.plot_trajectory_svg(trajectory, output_path),
250            OutputFormat::Html => self.plot_trajectory_html(trajectory, output_path),
251            OutputFormat::Data => self.export_trajectory_data(trajectory, output_path),
252            _ => Err(ScirsError::NotImplementedError(error_context!(
253                "PNG output not yet implemented"
254            ))),
255        }
256    }
257
258    /// Create a comprehensive optimization report
259    pub fn create_optimization_report(
260        &self,
261        trajectory: &OptimizationTrajectory,
262        output_dir: &Path,
263    ) -> ScirsResult<()> {
264        std::fs::create_dir_all(output_dir)?;
265
266        // Generate convergence plot
267        let convergence_path = output_dir.join("convergence.svg");
268        self.plot_convergence(trajectory, &convergence_path)?;
269
270        // Generate parameter trajectory if 2D
271        if !trajectory.parameters.is_empty() && trajectory.parameters[0].len() == 2 {
272            let trajectory_path = output_dir.join("trajectory.svg");
273            self.plot_parameter_trajectory(trajectory, &trajectory_path)?;
274        }
275
276        // Generate summary statistics
277        let summary_path = output_dir.join("summary.html");
278        self.generate_summary_report(trajectory, &summary_path)?;
279
280        // Export raw data
281        let data_path = output_dir.join("data.csv");
282        self.export_convergence_data(trajectory, &data_path)?;
283
284        Ok(())
285    }
286
287    /// Generate summary statistics report
288    fn generate_summary_report(
289        &self,
290        trajectory: &OptimizationTrajectory,
291        output_path: &Path,
292    ) -> ScirsResult<()> {
293        let mut file = File::create(output_path)?;
294
295        let html_content = format!(
296            r#"<!DOCTYPE html>
297<html>
298<head>
299    <title>Optimization Summary</title>
300    <style>
301        body {{ font-family: Arial, sans-serif; margin: 20px; }}
302        .metric {{ margin: 10px 0; }}
303        .value {{ font-weight: bold; color: #2E86AB; }}
304        table {{ border-collapse: collapse; width: 100%; }}
305        th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
306        th {{ background-color: #f2f2f2; }}
307    </style>
308</head>
309<body>
310    <h1>Optimization Summary Report</h1>
311    
312    <h2>Basic Statistics</h2>
313    <div class="metric">Total Iterations: <span class="value">{}</span></div>
314    <div class="metric">Final Function Value: <span class="value">{:.6e}</span></div>
315    <div class="metric">Initial Function Value: <span class="value">{:.6e}</span></div>
316    <div class="metric">Function Improvement: <span class="value">{:.6e}</span></div>
317    <div class="metric">Total Runtime: <span class="value">{:.3}s</span></div>
318    {}
319    
320    <h2>Convergence Analysis</h2>
321    <table>
322        <tr><th>Metric</th><th>Value</th></tr>
323        <tr><td>Convergence Rate</td><td>{}</td></tr>
324        <tr><td>Average Iteration Time</td><td>{:.6}s</td></tr>
325        <tr><td>Function Evaluations per Second</td><td>{:.2}</td></tr>
326    </table>
327    
328    {}
329</body>
330</html>"#,
331            trajectory.len(),
332            trajectory.final_function_value().unwrap_or(0.0),
333            trajectory.function_values.first().cloned().unwrap_or(0.0),
334            trajectory.function_values.first().cloned().unwrap_or(0.0)
335                - trajectory.final_function_value().unwrap_or(0.0),
336            trajectory.times.last().cloned().unwrap_or(0.0),
337            if !trajectory.gradient_norms.is_empty() {
338                format!("<div class=\"metric\">Final Gradient Norm: <span class=\"value\">{:.6e}</span></div>",
339                       trajectory.gradient_norms.last().cloned().unwrap_or(0.0))
340            } else {
341                String::new()
342            },
343            trajectory
344                .convergence_rate()
345                .map(|r| format!("{:.6}", r))
346                .unwrap_or_else(|| "N/A".to_string()),
347            if trajectory.len() > 1 && !trajectory.times.is_empty() {
348                trajectory.times.last().cloned().unwrap_or(0.0) / trajectory.len() as f64
349            } else {
350                0.0
351            },
352            if !trajectory.times.is_empty() && trajectory.times.last().cloned().unwrap_or(0.0) > 0.0
353            {
354                trajectory.len() as f64 / trajectory.times.last().cloned().unwrap_or(1.0)
355            } else {
356                0.0
357            },
358            self.generate_custom_metrics_table(trajectory)
359        );
360
361        file.write_all(html_content.as_bytes())?;
362        Ok(())
363    }
364
365    fn generate_custom_metrics_table(&self, trajectory: &OptimizationTrajectory) -> String {
366        if trajectory.custom_metrics.is_empty() {
367            return String::new();
368        }
369
370        let mut table = String::from("<h2>Custom Metrics</h2>\n<table>\n<tr><th>Metric</th><th>Final Value</th><th>Min</th><th>Max</th><th>Mean</th></tr>\n");
371
372        for (name, values) in &trajectory.custom_metrics {
373            if let Some(final_val) = values.last() {
374                let min_val = values.iter().cloned().fold(f64::INFINITY, f64::min);
375                let max_val = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
376                let mean_val = values.iter().sum::<f64>() / values.len() as f64;
377
378                table.push_str(&format!(
379                    "<tr><td>{}</td><td>{:.6e}</td><td>{:.6e}</td><td>{:.6e}</td><td>{:.6e}</td></tr>\n",
380                    name, final_val, min_val, max_val, mean_val
381                ));
382            }
383        }
384        table.push_str("</table>\n");
385        table
386    }
387
388    fn plot_convergence_svg(
389        &self,
390        trajectory: &OptimizationTrajectory,
391        output_path: &Path,
392    ) -> ScirsResult<()> {
393        let mut file = File::create(output_path)?;
394
395        let width = self.config.width;
396        let height = self.config.height;
397        let margin = 60;
398        let plot_width = width - 2 * margin;
399        let plot_height = height - 2 * margin;
400
401        let min_y = if self.config.log_scale_y {
402            trajectory
403                .function_values
404                .iter()
405                .filter(|&&v| v > 0.0)
406                .cloned()
407                .fold(f64::INFINITY, f64::min)
408                .ln()
409        } else {
410            trajectory
411                .function_values
412                .iter()
413                .cloned()
414                .fold(f64::INFINITY, f64::min)
415        };
416
417        let max_y = if self.config.log_scale_y {
418            trajectory
419                .function_values
420                .iter()
421                .filter(|&&v| v > 0.0)
422                .cloned()
423                .fold(f64::NEG_INFINITY, f64::max)
424                .ln()
425        } else {
426            trajectory
427                .function_values
428                .iter()
429                .cloned()
430                .fold(f64::NEG_INFINITY, f64::max)
431        };
432
433        let max_x = trajectory.nit.len() as f64;
434
435        let mut svg_content = format!(
436            r#"<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">
437    <defs>
438        <style>
439            .axis {{ stroke: #333; stroke-width: 1; }}
440            .grid {{ stroke: #ccc; stroke-width: 0.5; stroke-dasharray: 2,2; }}
441            .line {{ fill: none; stroke: #2E86AB; stroke-width: 2; }}
442            .text {{ font-family: Arial, sans-serif; font-size: 12px; fill: #333; }}
443            .title {{ font-family: Arial, sans-serif; font-size: 16px; fill: #333; font-weight: bold; }}
444        </style>
445    </defs>
446"#,
447            width, height
448        );
449
450        // Grid lines
451        if self.config.show_grid {
452            for i in 0..=10 {
453                let x = margin as f64 + (i as f64 / 10.0) * plot_width as f64;
454                svg_content.push_str(&format!(
455                    r#"    <line x1="{}" y1="{}" x2="{}" y2="{}" class="grid" />
456"#,
457                    x,
458                    margin,
459                    x,
460                    height - margin
461                ));
462            }
463
464            for i in 0..=10 {
465                let y = margin as f64 + (i as f64 / 10.0) * plot_height as f64;
466                svg_content.push_str(&format!(
467                    r#"    <line x1="{}" y1="{}" x2="{}" y2="{}" class="grid" />
468"#,
469                    margin,
470                    y,
471                    width - margin,
472                    y
473                ));
474            }
475        }
476
477        // Axes
478        svg_content.push_str(&format!(
479            r#"    <line x1="{}" y1="{}" x2="{}" y2="{}" class="axis" />
480    <line x1="{}" y1="{}" x2="{}" y2="{}" class="axis" />
481"#,
482            margin,
483            height - margin,
484            width - margin,
485            height - margin, // x-axis
486            margin,
487            margin,
488            margin,
489            height - margin // y-axis
490        ));
491
492        // Plot line
493        svg_content.push_str("    <polyline points=\"");
494        for (i, &f_val) in trajectory.function_values.iter().enumerate() {
495            let x = margin as f64 + (i as f64 / max_x) * plot_width as f64;
496            let y_val = if self.config.log_scale_y && f_val > 0.0 {
497                f_val.ln()
498            } else {
499                f_val
500            };
501            let y = height as f64
502                - margin as f64
503                - ((y_val - min_y) / (max_y - min_y)) * plot_height as f64;
504            svg_content.push_str(&format!("{},{} ", x, y));
505        }
506        svg_content.push_str("\" class=\"line\" />\n");
507
508        // Title
509        if let Some(ref title) = self.config.title {
510            svg_content.push_str(&format!(
511                r#"    <text x="{}" y="30" text-anchor="middle" class="title">{}</text>
512"#,
513                width / 2,
514                title
515            ));
516        }
517
518        // Labels
519        svg_content.push_str(&format!(
520            r#"    <text x="{}" y="{}" text-anchor="middle" class="text">Iteration</text>
521    <text x="20" y="{}" text-anchor="middle" class="text" transform="rotate(-90 20 {})">Function Value{}</text>
522"#,
523            width / 2, height - 10,
524            height / 2, height / 2,
525            if self.config.log_scale_y { " (log)" } else { "" }
526        ));
527
528        svg_content.push_str("</svg>");
529
530        file.write_all(svg_content.as_bytes())?;
531        Ok(())
532    }
533
534    fn plot_convergence_html(
535        &self,
536        trajectory: &OptimizationTrajectory,
537        output_path: &Path,
538    ) -> ScirsResult<()> {
539        let mut file = File::create(output_path)?;
540
541        let html_content = format!(
542            r#"<!DOCTYPE html>
543<html>
544<head>
545    <title>Optimization Convergence</title>
546    <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
547</head>
548<body>
549    <div id="convergence-plot" style="width:{}px;height:{}px;"></div>
550    <script>
551        var trace = {{
552            x: [{}],
553            y: [{}],
554            type: 'scatter',
555            mode: 'lines',
556            name: 'Function Value',
557            line: {{ color: '#2E86AB', width: 2 }}
558        }};
559        
560        var layout = {{
561            title: '{}',
562            xaxis: {{ title: 'Iteration' }},
563            yaxis: {{ 
564                title: 'Function Value',
565                type: '{}'
566            }},
567            showlegend: {}
568        }};
569        
570        Plotly.newPlot('convergence-plot', [trace], layout);
571    </script>
572</body>
573</html>"#,
574            self.config.width,
575            self.config.height,
576            trajectory
577                .nit
578                .iter()
579                .map(|i| i.to_string())
580                .collect::<Vec<_>>()
581                .join(","),
582            trajectory
583                .function_values
584                .iter()
585                .map(|f| f.to_string())
586                .collect::<Vec<_>>()
587                .join(","),
588            self.config
589                .title
590                .as_deref()
591                .unwrap_or("Optimization Convergence"),
592            if self.config.log_scale_y {
593                "log"
594            } else {
595                "linear"
596            },
597            self.config.show_legend
598        );
599
600        file.write_all(html_content.as_bytes())?;
601        Ok(())
602    }
603
604    fn plot_trajectory_svg(
605        &self,
606        trajectory: &OptimizationTrajectory,
607        output_path: &Path,
608    ) -> ScirsResult<()> {
609        let mut file = File::create(output_path)?;
610
611        let width = self.config.width;
612        let height = self.config.height;
613        let margin = 60;
614        let plot_width = width - 2 * margin;
615        let plot_height = height - 2 * margin;
616
617        let x_coords: Vec<f64> = trajectory.parameters.iter().map(|p| p[0]).collect();
618        let y_coords: Vec<f64> = trajectory.parameters.iter().map(|p| p[1]).collect();
619
620        let min_x = x_coords.iter().cloned().fold(f64::INFINITY, f64::min);
621        let max_x = x_coords.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
622        let min_y = y_coords.iter().cloned().fold(f64::INFINITY, f64::min);
623        let max_y = y_coords.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
624
625        let mut svg_content = format!(
626            r#"<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">
627    <defs>
628        <style>
629            .axis {{ stroke: #333; stroke-width: 1; }}
630            .grid {{ stroke: #ccc; stroke-width: 0.5; stroke-dasharray: 2,2; }}
631            .trajectory {{ fill: none; stroke: #2E86AB; stroke-width: 2; }}
632            .start {{ fill: #4CAF50; stroke: #333; stroke-width: 1; }}
633            .end {{ fill: #F44336; stroke: #333; stroke-width: 1; }}
634            .text {{ font-family: Arial, sans-serif; font-size: 12px; fill: #333; }}
635            .title {{ font-family: Arial, sans-serif; font-size: 16px; fill: #333; font-weight: bold; }}
636        </style>
637    </defs>
638"#,
639            width, height
640        );
641
642        // Grid
643        if self.config.show_grid {
644            for i in 0..=10 {
645                let x = margin as f64 + (i as f64 / 10.0) * plot_width as f64;
646                svg_content.push_str(&format!(
647                    r#"    <line x1="{}" y1="{}" x2="{}" y2="{}" class="grid" />
648"#,
649                    x,
650                    margin,
651                    x,
652                    height - margin
653                ));
654            }
655
656            for i in 0..=10 {
657                let y = margin as f64 + (i as f64 / 10.0) * plot_height as f64;
658                svg_content.push_str(&format!(
659                    r#"    <line x1="{}" y1="{}" x2="{}" y2="{}" class="grid" />
660"#,
661                    margin,
662                    y,
663                    width - margin,
664                    y
665                ));
666            }
667        }
668
669        // Axes
670        svg_content.push_str(&format!(
671            r#"    <line x1="{}" y1="{}" x2="{}" y2="{}" class="axis" />
672    <line x1="{}" y1="{}" x2="{}" y2="{}" class="axis" />
673"#,
674            margin,
675            height - margin,
676            width - margin,
677            height - margin,
678            margin,
679            margin,
680            margin,
681            height - margin
682        ));
683
684        // Trajectory
685        svg_content.push_str("    <polyline points=\"");
686        for (x_val, y_val) in x_coords.iter().zip(y_coords.iter()) {
687            let x = margin as f64 + ((x_val - min_x) / (max_x - min_x)) * plot_width as f64;
688            let y = height as f64
689                - margin as f64
690                - ((y_val - min_y) / (max_y - min_y)) * plot_height as f64;
691            svg_content.push_str(&format!("{},{} ", x, y));
692        }
693        svg_content.push_str("\" class=\"trajectory\" />\n");
694
695        // Start and end points
696        if !x_coords.is_empty() {
697            let start_x =
698                margin as f64 + ((x_coords[0] - min_x) / (max_x - min_x)) * plot_width as f64;
699            let start_y = height as f64
700                - margin as f64
701                - ((y_coords[0] - min_y) / (max_y - min_y)) * plot_height as f64;
702
703            let end_x = margin as f64
704                + ((x_coords.last().unwrap() - min_x) / (max_x - min_x)) * plot_width as f64;
705            let end_y = height as f64
706                - margin as f64
707                - ((y_coords.last().unwrap() - min_y) / (max_y - min_y)) * plot_height as f64;
708
709            svg_content.push_str(&format!(
710                r#"    <circle cx="{}" cy="{}" r="5" class="start" />
711    <circle cx="{}" cy="{}" r="5" class="end" />
712"#,
713                start_x, start_y, end_x, end_y
714            ));
715        }
716
717        // Title
718        if let Some(ref title) = self.config.title {
719            svg_content.push_str(&format!(
720                r#"    <text x="{}" y="30" text-anchor="middle" class="title">{}</text>
721"#,
722                width / 2,
723                title
724            ));
725        }
726
727        svg_content.push_str("</svg>");
728
729        file.write_all(svg_content.as_bytes())?;
730        Ok(())
731    }
732
733    fn plot_trajectory_html(
734        &self,
735        trajectory: &OptimizationTrajectory,
736        output_path: &Path,
737    ) -> ScirsResult<()> {
738        let mut file = File::create(output_path)?;
739
740        let x_coords: Vec<f64> = trajectory.parameters.iter().map(|p| p[0]).collect();
741        let y_coords: Vec<f64> = trajectory.parameters.iter().map(|p| p[1]).collect();
742
743        let html_content = format!(
744            r#"<!DOCTYPE html>
745<html>
746<head>
747    <title>Parameter Trajectory</title>
748    <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
749</head>
750<body>
751    <div id="trajectory-plot" style="width:{}px;height:{}px;"></div>
752    <script>
753        var trace = {{
754            x: [{}],
755            y: [{}],
756            type: 'scatter',
757            mode: 'lines+markers',
758            name: 'Trajectory',
759            line: {{ color: '#2E86AB', width: 2 }},
760            marker: {{ 
761                size: [{}],
762                color: [{}],
763                colorscale: 'Viridis',
764                showscale: true
765            }}
766        }};
767        
768        var layout = {{
769            title: '{}',
770            xaxis: {{ title: 'Parameter 1' }},
771            yaxis: {{ title: 'Parameter 2' }},
772            showlegend: {}
773        }};
774        
775        Plotly.newPlot('trajectory-plot', [trace], layout);
776    </script>
777</body>
778</html>"#,
779            self.config.width,
780            self.config.height,
781            x_coords
782                .iter()
783                .map(|x| x.to_string())
784                .collect::<Vec<_>>()
785                .join(","),
786            y_coords
787                .iter()
788                .map(|y| y.to_string())
789                .collect::<Vec<_>>()
790                .join(","),
791            (0..x_coords.len())
792                .map(|i| if i == 0 {
793                    "10"
794                } else if i == x_coords.len() - 1 {
795                    "10"
796                } else {
797                    "6"
798                })
799                .collect::<Vec<_>>()
800                .join(","),
801            (0..x_coords.len())
802                .map(|i| i.to_string())
803                .collect::<Vec<_>>()
804                .join(","),
805            self.config
806                .title
807                .as_deref()
808                .unwrap_or("Parameter Trajectory"),
809            self.config.show_legend
810        );
811
812        file.write_all(html_content.as_bytes())?;
813        Ok(())
814    }
815
816    fn export_convergence_data(
817        &self,
818        trajectory: &OptimizationTrajectory,
819        output_path: &Path,
820    ) -> ScirsResult<()> {
821        let mut file = File::create(output_path)?;
822
823        // CSV header
824        let mut header = "iteration,function_value,time".to_string();
825        if !trajectory.gradient_norms.is_empty() {
826            header.push_str(",gradient_norm");
827        }
828        if !trajectory.step_sizes.is_empty() {
829            header.push_str(",step_size");
830        }
831
832        // Add parameter columns
833        if !trajectory.parameters.is_empty() {
834            for i in 0..trajectory.parameters[0].len() {
835                header.push_str(&format!(",param_{}", i));
836            }
837        }
838
839        // Add custom metrics
840        for name in trajectory.custom_metrics.keys() {
841            header.push_str(&format!(",{}", name));
842        }
843        header.push('\n');
844
845        file.write_all(header.as_bytes())?;
846
847        // Data rows
848        for i in 0..trajectory.len() {
849            let mut row = format!(
850                "{},{},{}",
851                trajectory.nit[i], trajectory.function_values[i], trajectory.times[i]
852            );
853
854            if i < trajectory.gradient_norms.len() {
855                row.push_str(&format!(",{}", trajectory.gradient_norms[i]));
856            } else if !trajectory.gradient_norms.is_empty() {
857                row.push_str(",");
858            }
859
860            if i < trajectory.step_sizes.len() {
861                row.push_str(&format!(",{}", trajectory.step_sizes[i]));
862            } else if !trajectory.step_sizes.is_empty() {
863                row.push_str(",");
864            }
865
866            // Parameters
867            if i < trajectory.parameters.len() {
868                for param in trajectory.parameters[i].iter() {
869                    row.push_str(&format!(",{}", param));
870                }
871            }
872
873            // Custom metrics
874            for name in trajectory.custom_metrics.keys() {
875                if let Some(values) = trajectory.custom_metrics.get(name) {
876                    if i < values.len() {
877                        row.push_str(&format!(",{}", values[i]));
878                    } else {
879                        row.push_str(",");
880                    }
881                }
882            }
883
884            row.push('\n');
885            file.write_all(row.as_bytes())?;
886        }
887
888        Ok(())
889    }
890
891    fn export_trajectory_data(
892        &self,
893        trajectory: &OptimizationTrajectory,
894        output_path: &Path,
895    ) -> ScirsResult<()> {
896        self.export_convergence_data(trajectory, output_path)
897    }
898}
899
900impl Default for OptimizationVisualizer {
901    fn default() -> Self {
902        Self::new()
903    }
904}
905
906/// Utility functions for creating trajectory trackers
907pub mod tracking {
908    use super::OptimizationTrajectory;
909    use ndarray::ArrayView1;
910    use std::time::Instant;
911
912    /// A callback-based trajectory tracker for use with optimization algorithms
913    pub struct TrajectoryTracker {
914        trajectory: OptimizationTrajectory,
915        start_time: Instant,
916    }
917
918    impl TrajectoryTracker {
919        /// Create a new trajectory tracker
920        pub fn new() -> Self {
921            Self {
922                trajectory: OptimizationTrajectory::new(),
923                start_time: Instant::now(),
924            }
925        }
926
927        /// Record a new point in the optimization trajectory
928        pub fn record(&mut self, iteration: usize, params: &ArrayView1<f64>, function_value: f64) {
929            let elapsed = self.start_time.elapsed().as_secs_f64();
930            self.trajectory
931                .add_point(iteration, params, function_value, elapsed);
932        }
933
934        /// Record gradient norm
935        pub fn record_gradient_norm(&mut self, grad_norm: f64) {
936            self.trajectory.add_gradient_norm(grad_norm);
937        }
938
939        /// Record step size
940        pub fn record_step_size(&mut self, step_size: f64) {
941            self.trajectory.add_step_size(step_size);
942        }
943
944        /// Record custom metric
945        pub fn record_custom_metric(&mut self, name: &str, value: f64) {
946            self.trajectory.add_custom_metric(name, value);
947        }
948
949        /// Get the recorded trajectory
950        pub fn trajectory(&self) -> &OptimizationTrajectory {
951            &self.trajectory
952        }
953
954        /// Consume the tracker and return the trajectory
955        pub fn into_trajectory(self) -> OptimizationTrajectory {
956            self.trajectory
957        }
958    }
959
960    impl Default for TrajectoryTracker {
961        fn default() -> Self {
962            Self::new()
963        }
964    }
965}
966
967#[cfg(test)]
968mod tests {
969    use super::*;
970    use ndarray::array;
971
972    #[test]
973    fn test_trajectory_creation() {
974        let mut trajectory = OptimizationTrajectory::new();
975        assert!(trajectory.is_empty());
976
977        let params = array![1.0, 2.0];
978        trajectory.add_point(0, &params.view(), 5.0, 0.1);
979
980        assert_eq!(trajectory.len(), 1);
981        assert_eq!(trajectory.final_function_value(), Some(5.0));
982    }
983
984    #[test]
985    fn test_convergence_rate_calculation() {
986        let mut trajectory = OptimizationTrajectory::new();
987
988        // Add points with known convergence pattern
989        let function_values = vec![10.0, 5.0, 2.5, 1.25, 0.625];
990        for (i, &f_val) in function_values.iter().enumerate() {
991            let params = array![i as f64, i as f64];
992            trajectory.add_point(i, &params.view(), f_val, i as f64 * 0.1);
993        }
994
995        let rate = trajectory.convergence_rate();
996        assert!(rate.is_some());
997        // Should be approximately 0.5 for this geometric sequence
998        assert!((rate.unwrap() - 0.5).abs() < 0.1);
999    }
1000
1001    #[test]
1002    fn test_visualization_config() {
1003        let config = VisualizationConfig {
1004            format: OutputFormat::Svg,
1005            width: 1000,
1006            height: 800,
1007            title: Some("Test Plot".to_string()),
1008            show_grid: true,
1009            log_scale_y: true,
1010            color_scheme: ColorScheme::Viridis,
1011            show_legend: false,
1012            custom_style: None,
1013        };
1014
1015        let visualizer = OptimizationVisualizer::with_config(config);
1016        assert_eq!(visualizer.config.width, 1000);
1017        assert_eq!(visualizer.config.height, 800);
1018    }
1019
1020    #[test]
1021    fn test_trajectory_tracker() {
1022        let mut tracker = tracking::TrajectoryTracker::new();
1023
1024        let params1 = array![0.0, 0.0];
1025        let params2 = array![1.0, 1.0];
1026
1027        tracker.record(0, &params1.view(), 10.0);
1028        tracker.record_gradient_norm(2.5);
1029        tracker.record_step_size(0.1);
1030
1031        tracker.record(1, &params2.view(), 5.0);
1032        tracker.record_gradient_norm(1.5);
1033        tracker.record_step_size(0.2);
1034
1035        let trajectory = tracker.trajectory();
1036        assert_eq!(trajectory.len(), 2);
1037        assert_eq!(trajectory.gradient_norms.len(), 2);
1038        assert_eq!(trajectory.step_sizes.len(), 2);
1039        assert_eq!(trajectory.final_function_value(), Some(5.0));
1040    }
1041}