Skip to main content

scirs2_optimize/
visualization.rs

1//! Visualization tools for optimization trajectories and analysis
2//!
3//! This module provides comprehensive visualization capabilities for optimization
4//! processes, including trajectory plotting, convergence analysis, and parameter
5//! surface visualization.
6
7use crate::error::{ScirsError, ScirsResult};
8use scirs2_core::error_context;
9use scirs2_core::ndarray::{Array1, ArrayView1}; // Unused import: Array2, ArrayView2
10use std::collections::HashMap;
11use std::fs::File;
12use std::io::Write;
13use std::path::Path;
14
15/// Trajectory data collected during optimization
16#[derive(Debug, Clone)]
17pub struct OptimizationTrajectory {
18    /// Parameter values at each iteration
19    pub parameters: Vec<Array1<f64>>,
20    /// Function values at each iteration
21    pub function_values: Vec<f64>,
22    /// Gradient norms at each iteration (if available)
23    pub gradient_norms: Vec<f64>,
24    /// Step sizes at each iteration (if available)
25    pub step_sizes: Vec<f64>,
26    /// Custom metrics at each iteration
27    pub custom_metrics: HashMap<String, Vec<f64>>,
28    /// Iteration numbers
29    pub nit: Vec<usize>,
30    /// Wall clock times (in seconds from start)
31    pub times: Vec<f64>,
32}
33
34impl OptimizationTrajectory {
35    /// Create a new empty trajectory
36    pub fn new() -> Self {
37        Self {
38            parameters: Vec::new(),
39            function_values: Vec::new(),
40            gradient_norms: Vec::new(),
41            step_sizes: Vec::new(),
42            custom_metrics: HashMap::new(),
43            nit: Vec::new(),
44            times: Vec::new(),
45        }
46    }
47
48    /// Add a new point to the trajectory
49    pub fn add_point(
50        &mut self,
51        iteration: usize,
52        params: &ArrayView1<f64>,
53        function_value: f64,
54        time: f64,
55    ) {
56        self.nit.push(iteration);
57        self.parameters.push(params.to_owned());
58        self.function_values.push(function_value);
59        self.times.push(time);
60    }
61
62    /// Add gradient norm information
63    pub fn add_gradient_norm(&mut self, grad_norm: f64) {
64        self.gradient_norms.push(grad_norm);
65    }
66
67    /// Add step size information
68    pub fn add_step_size(&mut self, step_size: f64) {
69        self.step_sizes.push(step_size);
70    }
71
72    /// Add custom metric
73    pub fn add_custom_metric(&mut self, name: &str, value: f64) {
74        self.custom_metrics
75            .entry(name.to_string())
76            .or_default()
77            .push(value);
78    }
79
80    /// Get the number of recorded points
81    pub fn len(&self) -> usize {
82        self.nit.len()
83    }
84
85    /// Check if trajectory is empty
86    pub fn is_empty(&self) -> bool {
87        self.nit.is_empty()
88    }
89
90    /// Get the final parameter values
91    pub fn final_parameters(&self) -> Option<&Array1<f64>> {
92        self.parameters.last()
93    }
94
95    /// Get the final function value
96    pub fn final_function_value(&self) -> Option<f64> {
97        self.function_values.last().copied()
98    }
99
100    /// Calculate convergence rate (linear convergence coefficient)
101    pub fn convergence_rate(&self) -> Option<f64> {
102        if self.function_values.len() < 3 {
103            return None;
104        }
105
106        let n = self.function_values.len();
107        let mut rates = Vec::new();
108
109        for i in 1..(n - 1) {
110            let f_current = self.function_values[i];
111            let f_next = self.function_values[i + 1];
112            let f_prev = self.function_values[i - 1];
113
114            if (f_current - f_next).abs() > 1e-14 && (f_prev - f_current).abs() > 1e-14 {
115                let rate = (f_current - f_next).abs() / (f_prev - f_current).abs();
116                if rate.is_finite() && rate > 0.0 {
117                    rates.push(rate);
118                }
119            }
120        }
121
122        if rates.is_empty() {
123            None
124        } else {
125            Some(rates.iter().sum::<f64>() / rates.len() as f64)
126        }
127    }
128}
129
130impl Default for OptimizationTrajectory {
131    fn default() -> Self {
132        Self::new()
133    }
134}
135
136/// Configuration for trajectory visualization
137#[derive(Debug, Clone)]
138pub struct VisualizationConfig {
139    /// Output format (svg, png, html)
140    pub format: OutputFormat,
141    /// Width of the plot in pixels
142    pub width: u32,
143    /// Height of the plot in pixels
144    pub height: u32,
145    /// Title for the plot
146    pub title: Option<String>,
147    /// Whether to show grid
148    pub show_grid: bool,
149    /// Whether to use logarithmic scale for y-axis
150    pub log_scale_y: bool,
151    /// Color scheme
152    pub color_scheme: ColorScheme,
153    /// Whether to show legend
154    pub show_legend: bool,
155    /// Custom styling
156    pub custom_style: Option<String>,
157}
158
159impl Default for VisualizationConfig {
160    fn default() -> Self {
161        Self {
162            format: OutputFormat::Svg,
163            width: 800,
164            height: 600,
165            title: None,
166            show_grid: true,
167            log_scale_y: false,
168            color_scheme: ColorScheme::Default,
169            show_legend: true,
170            custom_style: None,
171        }
172    }
173}
174
175/// Supported output formats
176#[derive(Debug, Clone, Copy, PartialEq)]
177pub enum OutputFormat {
178    Svg,
179    Png,
180    Html,
181    Data, // Raw data output
182}
183
184/// Color schemes for visualization
185#[derive(Debug, Clone, Copy, PartialEq)]
186pub enum ColorScheme {
187    Default,
188    Viridis,
189    Plasma,
190    Scientific,
191    Monochrome,
192}
193
194// ─────────────────────────────────────────────────────────────────────────────
195// Pure-Rust minimal PNG encoder (stored/uncompressed DEFLATE blocks, 24-bit RGB)
196// ─────────────────────────────────────────────────────────────────────────────
197
198/// CRC-32 table (ISO 3309 polynomial 0xEDB88320).
199fn crc32(data: &[u8]) -> u32 {
200    let mut crc: u32 = 0xFFFF_FFFF;
201    for &byte in data {
202        let idx = ((crc ^ u32::from(byte)) & 0xFF) as usize;
203        // Compute CRC table entry on the fly
204        let mut entry = idx as u32;
205        for _ in 0..8 {
206            if entry & 1 != 0 {
207                entry = (entry >> 1) ^ 0xEDB8_8320;
208            } else {
209                entry >>= 1;
210            }
211        }
212        crc = (crc >> 8) ^ entry;
213    }
214    !crc
215}
216
217/// Adler-32 checksum.
218fn adler32(data: &[u8]) -> u32 {
219    let (mut s1, mut s2) = (1u32, 0u32);
220    for &b in data {
221        s1 = (s1 + u32::from(b)) % 65521;
222        s2 = (s2 + s1) % 65521;
223    }
224    (s2 << 16) | s1
225}
226
227/// Write a 4-byte big-endian u32.
228#[inline]
229fn be32(v: u32) -> [u8; 4] {
230    v.to_be_bytes()
231}
232
233/// Build a PNG chunk: length(4) + type(4) + data + crc(4).
234fn png_chunk(chunk_type: &[u8; 4], data: &[u8]) -> Vec<u8> {
235    let mut chunk = Vec::with_capacity(12 + data.len());
236    chunk.extend_from_slice(&be32(data.len() as u32));
237    chunk.extend_from_slice(chunk_type);
238    chunk.extend_from_slice(data);
239    let crc = crc32(&chunk[4..]);
240    chunk.extend_from_slice(&be32(crc));
241    chunk
242}
243
244/// Encode raw scanlines (with filter byte 0x00 per row) using DEFLATE stored blocks.
245fn deflate_stored(scanlines: &[u8]) -> Vec<u8> {
246    const MAX_BLOCK: usize = 65535;
247    let mut out = Vec::new();
248    // zlib header: CMF=0x78, FLG=0x01 (no dict, check bits make it divisible by 31)
249    out.extend_from_slice(&[0x78, 0x01]);
250
251    let adler = adler32(scanlines);
252    let total = scanlines.len();
253    let mut offset = 0;
254    while offset < total {
255        let end = (offset + MAX_BLOCK).min(total);
256        let block = &scanlines[offset..end];
257        let last = if end == total { 1u8 } else { 0u8 };
258        let len = block.len() as u16;
259        let nlen = !len;
260        out.push(last);
261        out.extend_from_slice(&len.to_le_bytes());
262        out.extend_from_slice(&nlen.to_le_bytes());
263        out.extend_from_slice(block);
264        offset = end;
265    }
266    out.extend_from_slice(&adler.to_be_bytes());
267    out
268}
269
270/// Write a minimal 24-bit RGB PNG to a file path.
271fn write_png(path: &Path, pixels: &[u8], width: usize, height: usize) -> ScirsResult<()> {
272    use crate::error::ScirsError;
273
274    // Build scanlines: each row is 0x00 (filter=None) + RGB bytes
275    let mut scanlines = Vec::with_capacity(height * (1 + width * 3));
276    for row in 0..height {
277        scanlines.push(0x00); // filter byte
278        let start = row * width * 3;
279        scanlines.extend_from_slice(&pixels[start..start + width * 3]);
280    }
281
282    let idat_data = deflate_stored(&scanlines);
283
284    // IHDR: width(4), height(4), bit_depth(1), color_type(2=RGB), compression(0), filter(0), interlace(0)
285    let mut ihdr_data = Vec::with_capacity(13);
286    ihdr_data.extend_from_slice(&be32(width as u32));
287    ihdr_data.extend_from_slice(&be32(height as u32));
288    ihdr_data.extend_from_slice(&[8, 2, 0, 0, 0]);
289
290    let mut png_bytes = Vec::new();
291    // PNG signature
292    png_bytes.extend_from_slice(&[137, 80, 78, 71, 13, 10, 26, 10]);
293    png_bytes.extend(png_chunk(b"IHDR", &ihdr_data));
294    png_bytes.extend(png_chunk(b"IDAT", &idat_data));
295    png_bytes.extend(png_chunk(b"IEND", &[]));
296
297    let mut file = File::create(path)
298        .map_err(|e| ScirsError::IoError(error_context!(format!("PNG create: {e}"))))?;
299    file.write_all(&png_bytes)
300        .map_err(|e| ScirsError::IoError(error_context!(format!("PNG write: {e}"))))?;
301    Ok(())
302}
303
304/// Bresenham line drawing on an RGB pixel buffer.
305fn png_draw_line(
306    pixels: &mut [u8],
307    width: usize,
308    height: usize,
309    x0: usize,
310    y0: usize,
311    x1: usize,
312    y1: usize,
313    r: u8,
314    g: u8,
315    b: u8,
316) {
317    let (mut x0, mut y0) = (x0 as isize, y0 as isize);
318    let (x1, y1) = (x1 as isize, y1 as isize);
319    let dx = (x1 - x0).abs();
320    let sx: isize = if x0 < x1 { 1 } else { -1 };
321    let dy = -(y1 - y0).abs();
322    let sy: isize = if y0 < y1 { 1 } else { -1 };
323    let mut err = dx + dy;
324    loop {
325        if x0 >= 0 && x0 < width as isize && y0 >= 0 && y0 < height as isize {
326            let idx = (y0 as usize * width + x0 as usize) * 3;
327            pixels[idx] = r;
328            pixels[idx + 1] = g;
329            pixels[idx + 2] = b;
330        }
331        if x0 == x1 && y0 == y1 {
332            break;
333        }
334        let e2 = 2 * err;
335        if e2 >= dy {
336            err += dy;
337            x0 += sx;
338        }
339        if e2 <= dx {
340            err += dx;
341            y0 += sy;
342        }
343    }
344}
345
346/// Main visualization interface
347pub struct OptimizationVisualizer {
348    config: VisualizationConfig,
349}
350
351impl OptimizationVisualizer {
352    /// Create a new visualizer with default configuration
353    pub fn new() -> Self {
354        Self {
355            config: VisualizationConfig::default(),
356        }
357    }
358
359    /// Create a new visualizer with custom configuration
360    pub fn with_config(config: VisualizationConfig) -> Self {
361        Self { config }
362    }
363
364    /// Plot convergence curve (function value vs iteration)
365    pub fn plot_convergence(
366        &self,
367        trajectory: &OptimizationTrajectory,
368        output_path: &Path,
369    ) -> ScirsResult<()> {
370        if trajectory.is_empty() {
371            return Err(ScirsError::InvalidInput(error_context!("Empty trajectory")));
372        }
373
374        match self.config.format {
375            OutputFormat::Svg => self.plot_convergence_svg(trajectory, output_path),
376            OutputFormat::Html => self.plot_convergence_html(trajectory, output_path),
377            OutputFormat::Data => self.export_convergence_data(trajectory, output_path),
378            OutputFormat::Png => self.plot_convergence_png(trajectory, output_path),
379        }
380    }
381
382    /// Plot parameter trajectory (for 2D problems)
383    pub fn plot_parameter_trajectory(
384        &self,
385        trajectory: &OptimizationTrajectory,
386        output_path: &Path,
387    ) -> ScirsResult<()> {
388        if trajectory.is_empty() {
389            return Err(ScirsError::InvalidInput(error_context!("Empty trajectory")));
390        }
391
392        if trajectory.parameters[0].len() != 2 {
393            return Err(ScirsError::InvalidInput(error_context!(
394                "Parameter trajectory visualization only supports 2D problems"
395            )));
396        }
397
398        match self.config.format {
399            OutputFormat::Svg => self.plot_trajectory_svg(trajectory, output_path),
400            OutputFormat::Html => self.plot_trajectory_html(trajectory, output_path),
401            OutputFormat::Data => self.export_trajectory_data(trajectory, output_path),
402            OutputFormat::Png => self.plot_trajectory_png(trajectory, output_path),
403        }
404    }
405
406    /// Create a comprehensive optimization report
407    pub fn create_optimization_report(
408        &self,
409        trajectory: &OptimizationTrajectory,
410        output_dir: &Path,
411    ) -> ScirsResult<()> {
412        std::fs::create_dir_all(output_dir)?;
413
414        // Generate convergence plot
415        let convergence_path = output_dir.join("convergence.svg");
416        self.plot_convergence(trajectory, &convergence_path)?;
417
418        // Generate parameter trajectory if 2D
419        if !trajectory.parameters.is_empty() && trajectory.parameters[0].len() == 2 {
420            let trajectory_path = output_dir.join("trajectory.svg");
421            self.plot_parameter_trajectory(trajectory, &trajectory_path)?;
422        }
423
424        // Generate summary statistics
425        let summary_path = output_dir.join("summary.html");
426        self.generate_summary_report(trajectory, &summary_path)?;
427
428        // Export raw data
429        let data_path = output_dir.join("data.csv");
430        self.export_convergence_data(trajectory, &data_path)?;
431
432        Ok(())
433    }
434
435    /// Generate summary statistics report
436    fn generate_summary_report(
437        &self,
438        trajectory: &OptimizationTrajectory,
439        output_path: &Path,
440    ) -> ScirsResult<()> {
441        let mut file = File::create(output_path)?;
442
443        let html_content = format!(
444            r#"<!DOCTYPE html>
445<html>
446<head>
447    <title>Optimization Summary</title>
448    <style>
449        body {{ font-family: Arial, sans-serif; margin: 20px; }}
450        .metric {{ margin: 10px 0; }}
451        .value {{ font-weight: bold; color: #2E86AB; }}
452        table {{ border-collapse: collapse; width: 100%; }}
453        th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
454        th {{ background-color: #f2f2f2; }}
455    </style>
456</head>
457<body>
458    <h1>Optimization Summary Report</h1>
459    
460    <h2>Basic Statistics</h2>
461    <div class="metric">Total Iterations: <span class="value">{}</span></div>
462    <div class="metric">Final Function Value: <span class="value">{:.6e}</span></div>
463    <div class="metric">Initial Function Value: <span class="value">{:.6e}</span></div>
464    <div class="metric">Function Improvement: <span class="value">{:.6e}</span></div>
465    <div class="metric">Total Runtime: <span class="value">{:.3}s</span></div>
466    {}
467    
468    <h2>Convergence Analysis</h2>
469    <table>
470        <tr><th>Metric</th><th>Value</th></tr>
471        <tr><td>Convergence Rate</td><td>{}</td></tr>
472        <tr><td>Average Iteration Time</td><td>{:.6}s</td></tr>
473        <tr><td>Function Evaluations per Second</td><td>{:.2}</td></tr>
474    </table>
475    
476    {}
477</body>
478</html>"#,
479            trajectory.len(),
480            trajectory.final_function_value().unwrap_or(0.0),
481            trajectory.function_values.first().cloned().unwrap_or(0.0),
482            trajectory.function_values.first().cloned().unwrap_or(0.0)
483                - trajectory.final_function_value().unwrap_or(0.0),
484            trajectory.times.last().cloned().unwrap_or(0.0),
485            if !trajectory.gradient_norms.is_empty() {
486                format!("<div class=\"metric\">Final Gradient Norm: <span class=\"value\">{:.6e}</span></div>",
487                       trajectory.gradient_norms.last().cloned().unwrap_or(0.0))
488            } else {
489                String::new()
490            },
491            trajectory
492                .convergence_rate()
493                .map(|r| format!("{:.6}", r))
494                .unwrap_or_else(|| "N/A".to_string()),
495            if trajectory.len() > 1 && !trajectory.times.is_empty() {
496                trajectory.times.last().cloned().unwrap_or(0.0) / trajectory.len() as f64
497            } else {
498                0.0
499            },
500            if !trajectory.times.is_empty() && trajectory.times.last().cloned().unwrap_or(0.0) > 0.0
501            {
502                trajectory.len() as f64 / trajectory.times.last().cloned().unwrap_or(1.0)
503            } else {
504                0.0
505            },
506            self.generate_custom_metrics_table(trajectory)
507        );
508
509        file.write_all(html_content.as_bytes())?;
510        Ok(())
511    }
512
513    fn generate_custom_metrics_table(&self, trajectory: &OptimizationTrajectory) -> String {
514        if trajectory.custom_metrics.is_empty() {
515            return String::new();
516        }
517
518        let mut table = String::from("<h2>Custom Metrics</h2>\n<table>\n<tr><th>Metric</th><th>Final Value</th><th>Min</th><th>Max</th><th>Mean</th></tr>\n");
519
520        for (name, values) in &trajectory.custom_metrics {
521            if let Some(final_val) = values.last() {
522                let min_val = values.iter().cloned().fold(f64::INFINITY, f64::min);
523                let max_val = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
524                let mean_val = values.iter().sum::<f64>() / values.len() as f64;
525
526                table.push_str(&format!(
527                    "<tr><td>{}</td><td>{:.6e}</td><td>{:.6e}</td><td>{:.6e}</td><td>{:.6e}</td></tr>\n",
528                    name, final_val, min_val, max_val, mean_val
529                ));
530            }
531        }
532        table.push_str("</table>\n");
533        table
534    }
535
536    fn plot_convergence_svg(
537        &self,
538        trajectory: &OptimizationTrajectory,
539        output_path: &Path,
540    ) -> ScirsResult<()> {
541        let mut file = File::create(output_path)?;
542
543        let width = self.config.width;
544        let height = self.config.height;
545        let margin = 60;
546        let plot_width = width - 2 * margin;
547        let plot_height = height - 2 * margin;
548
549        let min_y = if self.config.log_scale_y {
550            trajectory
551                .function_values
552                .iter()
553                .filter(|&&v| v > 0.0)
554                .cloned()
555                .fold(f64::INFINITY, f64::min)
556                .ln()
557        } else {
558            trajectory
559                .function_values
560                .iter()
561                .cloned()
562                .fold(f64::INFINITY, f64::min)
563        };
564
565        let max_y = if self.config.log_scale_y {
566            trajectory
567                .function_values
568                .iter()
569                .filter(|&&v| v > 0.0)
570                .cloned()
571                .fold(f64::NEG_INFINITY, f64::max)
572                .ln()
573        } else {
574            trajectory
575                .function_values
576                .iter()
577                .cloned()
578                .fold(f64::NEG_INFINITY, f64::max)
579        };
580
581        let max_x = trajectory.nit.len() as f64;
582
583        let mut svg_content = format!(
584            r#"<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">
585    <defs>
586        <style>
587            .axis {{ stroke: #333; stroke-width: 1; }}
588            .grid {{ stroke: #ccc; stroke-width: 0.5; stroke-dasharray: 2,2; }}
589            .line {{ fill: none; stroke: #2E86AB; stroke-width: 2; }}
590            .text {{ font-family: Arial, sans-serif; font-size: 12px; fill: #333; }}
591            .title {{ font-family: Arial, sans-serif; font-size: 16px; fill: #333; font-weight: bold; }}
592        </style>
593    </defs>
594"#,
595            width, height
596        );
597
598        // Grid lines
599        if self.config.show_grid {
600            for i in 0..=10 {
601                let x = margin as f64 + (i as f64 / 10.0) * plot_width as f64;
602                svg_content.push_str(&format!(
603                    r#"    <line x1="{}" y1="{}" x2="{}" y2="{}" class="grid" />
604"#,
605                    x,
606                    margin,
607                    x,
608                    height - margin
609                ));
610            }
611
612            for i in 0..=10 {
613                let y = margin as f64 + (i as f64 / 10.0) * plot_height as f64;
614                svg_content.push_str(&format!(
615                    r#"    <line x1="{}" y1="{}" x2="{}" y2="{}" class="grid" />
616"#,
617                    margin,
618                    y,
619                    width - margin,
620                    y
621                ));
622            }
623        }
624
625        // Axes
626        svg_content.push_str(&format!(
627            r#"    <line x1="{}" y1="{}" x2="{}" y2="{}" class="axis" />
628    <line x1="{}" y1="{}" x2="{}" y2="{}" class="axis" />
629"#,
630            margin,
631            height - margin,
632            width - margin,
633            height - margin, // x-axis
634            margin,
635            margin,
636            margin,
637            height - margin // y-axis
638        ));
639
640        // Plot line
641        svg_content.push_str("    <polyline points=\"");
642        for (i, &f_val) in trajectory.function_values.iter().enumerate() {
643            let x = margin as f64 + (i as f64 / max_x) * plot_width as f64;
644            let y_val = if self.config.log_scale_y && f_val > 0.0 {
645                f_val.ln()
646            } else {
647                f_val
648            };
649            let y = height as f64
650                - margin as f64
651                - ((y_val - min_y) / (max_y - min_y)) * plot_height as f64;
652            svg_content.push_str(&format!("{},{} ", x, y));
653        }
654        svg_content.push_str("\" class=\"line\" />\n");
655
656        // Title
657        if let Some(ref title) = self.config.title {
658            svg_content.push_str(&format!(
659                r#"    <text x="{}" y="30" text-anchor="middle" class="title">{}</text>
660"#,
661                width / 2,
662                title
663            ));
664        }
665
666        // Labels
667        svg_content.push_str(&format!(
668            r#"    <text x="{}" y="{}" text-anchor="middle" class="text">Iteration</text>
669    <text x="20" y="{}" text-anchor="middle" class="text" transform="rotate(-90 20 {})">Function Value{}</text>
670"#,
671            width / 2, height - 10,
672            height / 2, height / 2,
673            if self.config.log_scale_y { " (log)" } else { "" }
674        ));
675
676        svg_content.push_str("</svg>");
677
678        file.write_all(svg_content.as_bytes())?;
679        Ok(())
680    }
681
682    fn plot_convergence_html(
683        &self,
684        trajectory: &OptimizationTrajectory,
685        output_path: &Path,
686    ) -> ScirsResult<()> {
687        let mut file = File::create(output_path)?;
688
689        let html_content = format!(
690            r#"<!DOCTYPE html>
691<html>
692<head>
693    <title>Optimization Convergence</title>
694    <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
695</head>
696<body>
697    <div id="convergence-plot" style="width:{}px;height:{}px;"></div>
698    <script>
699        var trace = {{
700            x: [{}],
701            y: [{}],
702            type: 'scatter',
703            mode: 'lines',
704            name: 'Function Value',
705            line: {{ color: '#2E86AB', width: 2 }}
706        }};
707        
708        var layout = {{
709            title: '{}',
710            xaxis: {{ title: 'Iteration' }},
711            yaxis: {{ 
712                title: 'Function Value',
713                type: '{}'
714            }},
715            showlegend: {}
716        }};
717        
718        Plotly.newPlot('convergence-plot', [trace], layout);
719    </script>
720</body>
721</html>"#,
722            self.config.width,
723            self.config.height,
724            trajectory
725                .nit
726                .iter()
727                .map(|i| i.to_string())
728                .collect::<Vec<_>>()
729                .join(","),
730            trajectory
731                .function_values
732                .iter()
733                .map(|f| f.to_string())
734                .collect::<Vec<_>>()
735                .join(","),
736            self.config
737                .title
738                .as_deref()
739                .unwrap_or("Optimization Convergence"),
740            if self.config.log_scale_y {
741                "log"
742            } else {
743                "linear"
744            },
745            self.config.show_legend
746        );
747
748        file.write_all(html_content.as_bytes())?;
749        Ok(())
750    }
751
752    fn plot_trajectory_svg(
753        &self,
754        trajectory: &OptimizationTrajectory,
755        output_path: &Path,
756    ) -> ScirsResult<()> {
757        let mut file = File::create(output_path)?;
758
759        let width = self.config.width;
760        let height = self.config.height;
761        let margin = 60;
762        let plot_width = width - 2 * margin;
763        let plot_height = height - 2 * margin;
764
765        let x_coords: Vec<f64> = trajectory.parameters.iter().map(|p| p[0]).collect();
766        let y_coords: Vec<f64> = trajectory.parameters.iter().map(|p| p[1]).collect();
767
768        let min_x = x_coords.iter().cloned().fold(f64::INFINITY, f64::min);
769        let max_x = x_coords.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
770        let min_y = y_coords.iter().cloned().fold(f64::INFINITY, f64::min);
771        let max_y = y_coords.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
772
773        let mut svg_content = format!(
774            r#"<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">
775    <defs>
776        <style>
777            .axis {{ stroke: #333; stroke-width: 1; }}
778            .grid {{ stroke: #ccc; stroke-width: 0.5; stroke-dasharray: 2,2; }}
779            .trajectory {{ fill: none; stroke: #2E86AB; stroke-width: 2; }}
780            .start {{ fill: #4CAF50; stroke: #333; stroke-width: 1; }}
781            .end {{ fill: #F44336; stroke: #333; stroke-width: 1; }}
782            .text {{ font-family: Arial, sans-serif; font-size: 12px; fill: #333; }}
783            .title {{ font-family: Arial, sans-serif; font-size: 16px; fill: #333; font-weight: bold; }}
784        </style>
785    </defs>
786"#,
787            width, height
788        );
789
790        // Grid
791        if self.config.show_grid {
792            for i in 0..=10 {
793                let x = margin as f64 + (i as f64 / 10.0) * plot_width as f64;
794                svg_content.push_str(&format!(
795                    r#"    <line x1="{}" y1="{}" x2="{}" y2="{}" class="grid" />
796"#,
797                    x,
798                    margin,
799                    x,
800                    height - margin
801                ));
802            }
803
804            for i in 0..=10 {
805                let y = margin as f64 + (i as f64 / 10.0) * plot_height as f64;
806                svg_content.push_str(&format!(
807                    r#"    <line x1="{}" y1="{}" x2="{}" y2="{}" class="grid" />
808"#,
809                    margin,
810                    y,
811                    width - margin,
812                    y
813                ));
814            }
815        }
816
817        // Axes
818        svg_content.push_str(&format!(
819            r#"    <line x1="{}" y1="{}" x2="{}" y2="{}" class="axis" />
820    <line x1="{}" y1="{}" x2="{}" y2="{}" class="axis" />
821"#,
822            margin,
823            height - margin,
824            width - margin,
825            height - margin,
826            margin,
827            margin,
828            margin,
829            height - margin
830        ));
831
832        // Trajectory
833        svg_content.push_str("    <polyline points=\"");
834        for (x_val, y_val) in x_coords.iter().zip(y_coords.iter()) {
835            let x = margin as f64 + ((x_val - min_x) / (max_x - min_x)) * plot_width as f64;
836            let y = height as f64
837                - margin as f64
838                - ((y_val - min_y) / (max_y - min_y)) * plot_height as f64;
839            svg_content.push_str(&format!("{},{} ", x, y));
840        }
841        svg_content.push_str("\" class=\"trajectory\" />\n");
842
843        // Start and end points
844        if !x_coords.is_empty() {
845            let start_x =
846                margin as f64 + ((x_coords[0] - min_x) / (max_x - min_x)) * plot_width as f64;
847            let start_y = height as f64
848                - margin as f64
849                - ((y_coords[0] - min_y) / (max_y - min_y)) * plot_height as f64;
850
851            let end_x = margin as f64
852                + ((x_coords.last().expect("Operation failed") - min_x) / (max_x - min_x))
853                    * plot_width as f64;
854            let end_y = height as f64
855                - margin as f64
856                - ((y_coords.last().expect("Operation failed") - min_y) / (max_y - min_y))
857                    * plot_height as f64;
858
859            svg_content.push_str(&format!(
860                r#"    <circle cx="{}" cy="{}" r="5" class="start" />
861    <circle cx="{}" cy="{}" r="5" class="end" />
862"#,
863                start_x, start_y, end_x, end_y
864            ));
865        }
866
867        // Title
868        if let Some(ref title) = self.config.title {
869            svg_content.push_str(&format!(
870                r#"    <text x="{}" y="30" text-anchor="middle" class="title">{}</text>
871"#,
872                width / 2,
873                title
874            ));
875        }
876
877        svg_content.push_str("</svg>");
878
879        file.write_all(svg_content.as_bytes())?;
880        Ok(())
881    }
882
883    fn plot_trajectory_html(
884        &self,
885        trajectory: &OptimizationTrajectory,
886        output_path: &Path,
887    ) -> ScirsResult<()> {
888        let mut file = File::create(output_path)?;
889
890        let x_coords: Vec<f64> = trajectory.parameters.iter().map(|p| p[0]).collect();
891        let y_coords: Vec<f64> = trajectory.parameters.iter().map(|p| p[1]).collect();
892
893        let html_content = format!(
894            r#"<!DOCTYPE html>
895<html>
896<head>
897    <title>Parameter Trajectory</title>
898    <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
899</head>
900<body>
901    <div id="trajectory-plot" style="width:{}px;height:{}px;"></div>
902    <script>
903        var trace = {{
904            x: [{}],
905            y: [{}],
906            type: 'scatter',
907            mode: 'lines+markers',
908            name: 'Trajectory',
909            line: {{ color: '#2E86AB', width: 2 }},
910            marker: {{ 
911                size: [{}],
912                color: [{}],
913                colorscale: 'Viridis',
914                showscale: true
915            }}
916        }};
917        
918        var layout = {{
919            title: '{}',
920            xaxis: {{ title: 'Parameter 1' }},
921            yaxis: {{ title: 'Parameter 2' }},
922            showlegend: {}
923        }};
924        
925        Plotly.newPlot('trajectory-plot', [trace], layout);
926    </script>
927</body>
928</html>"#,
929            self.config.width,
930            self.config.height,
931            x_coords
932                .iter()
933                .map(|x| x.to_string())
934                .collect::<Vec<_>>()
935                .join(","),
936            y_coords
937                .iter()
938                .map(|y| y.to_string())
939                .collect::<Vec<_>>()
940                .join(","),
941            (0..x_coords.len())
942                .map(|i| if i == 0 {
943                    "10"
944                } else if i == x_coords.len() - 1 {
945                    "10"
946                } else {
947                    "6"
948                })
949                .collect::<Vec<_>>()
950                .join(","),
951            (0..x_coords.len())
952                .map(|i| i.to_string())
953                .collect::<Vec<_>>()
954                .join(","),
955            self.config
956                .title
957                .as_deref()
958                .unwrap_or("Parameter Trajectory"),
959            self.config.show_legend
960        );
961
962        file.write_all(html_content.as_bytes())?;
963        Ok(())
964    }
965
966    fn export_convergence_data(
967        &self,
968        trajectory: &OptimizationTrajectory,
969        output_path: &Path,
970    ) -> ScirsResult<()> {
971        let mut file = File::create(output_path)?;
972
973        // CSV header
974        let mut header = "iteration,function_value,time".to_string();
975        if !trajectory.gradient_norms.is_empty() {
976            header.push_str(",gradient_norm");
977        }
978        if !trajectory.step_sizes.is_empty() {
979            header.push_str(",step_size");
980        }
981
982        // Add parameter columns
983        if !trajectory.parameters.is_empty() {
984            for i in 0..trajectory.parameters[0].len() {
985                header.push_str(&format!(",param_{}", i));
986            }
987        }
988
989        // Add custom metrics
990        for name in trajectory.custom_metrics.keys() {
991            header.push_str(&format!(",{}", name));
992        }
993        header.push('\n');
994
995        file.write_all(header.as_bytes())?;
996
997        // Data rows
998        for i in 0..trajectory.len() {
999            let mut row = format!(
1000                "{},{},{}",
1001                trajectory.nit[i], trajectory.function_values[i], trajectory.times[i]
1002            );
1003
1004            if i < trajectory.gradient_norms.len() {
1005                row.push_str(&format!(",{}", trajectory.gradient_norms[i]));
1006            } else if !trajectory.gradient_norms.is_empty() {
1007                row.push(',');
1008            }
1009
1010            if i < trajectory.step_sizes.len() {
1011                row.push_str(&format!(",{}", trajectory.step_sizes[i]));
1012            } else if !trajectory.step_sizes.is_empty() {
1013                row.push(',');
1014            }
1015
1016            // Parameters
1017            if i < trajectory.parameters.len() {
1018                for param in trajectory.parameters[i].iter() {
1019                    row.push_str(&format!(",{}", param));
1020                }
1021            }
1022
1023            // Custom metrics
1024            for name in trajectory.custom_metrics.keys() {
1025                if let Some(values) = trajectory.custom_metrics.get(name) {
1026                    if i < values.len() {
1027                        row.push_str(&format!(",{}", values[i]));
1028                    } else {
1029                        row.push(',');
1030                    }
1031                }
1032            }
1033
1034            row.push('\n');
1035            file.write_all(row.as_bytes())?;
1036        }
1037
1038        Ok(())
1039    }
1040
1041    fn export_trajectory_data(
1042        &self,
1043        trajectory: &OptimizationTrajectory,
1044        output_path: &Path,
1045    ) -> ScirsResult<()> {
1046        self.export_convergence_data(trajectory, output_path)
1047    }
1048
1049    /// Render the convergence trajectory as a PNG file.
1050    ///
1051    /// Writes a minimal PNG (24-bit RGB, no compression / stored DEFLATE blocks)
1052    /// containing a line plot of function values vs iteration count.
1053    fn plot_convergence_png(
1054        &self,
1055        trajectory: &OptimizationTrajectory,
1056        output_path: &Path,
1057    ) -> ScirsResult<()> {
1058        let width = self.config.width as usize;
1059        let height = self.config.height as usize;
1060        let margin = 60_usize;
1061
1062        // Build RGB pixel buffer (white background)
1063        let mut pixels = vec![255u8; width * height * 3];
1064
1065        let min_y = trajectory
1066            .function_values
1067            .iter()
1068            .cloned()
1069            .fold(f64::INFINITY, f64::min);
1070        let max_y = trajectory
1071            .function_values
1072            .iter()
1073            .cloned()
1074            .fold(f64::NEG_INFINITY, f64::max);
1075        let y_range = (max_y - min_y).max(f64::EPSILON);
1076        let n = trajectory.function_values.len();
1077
1078        // Draw axes (dark gray)
1079        for px in margin..width.saturating_sub(margin) {
1080            let row = height.saturating_sub(margin + 1);
1081            let idx = (row * width + px) * 3;
1082            pixels[idx] = 50;
1083            pixels[idx + 1] = 50;
1084            pixels[idx + 2] = 50;
1085        }
1086        for py in margin..height.saturating_sub(margin) {
1087            let idx = (py * width + margin) * 3;
1088            pixels[idx] = 50;
1089            pixels[idx + 1] = 50;
1090            pixels[idx + 2] = 50;
1091        }
1092
1093        // Draw line (blue: #2E86AB = R46, G134, B171)
1094        let plot_w = width.saturating_sub(2 * margin);
1095        let plot_h = height.saturating_sub(2 * margin);
1096        if n >= 2 {
1097            for i in 0..n.saturating_sub(1) {
1098                let x0 = margin + i * plot_w / (n - 1);
1099                let x1 = margin + (i + 1) * plot_w / (n - 1);
1100                let v0 = trajectory.function_values[i];
1101                let v1 = trajectory.function_values[i + 1];
1102                let y0 = height
1103                    .saturating_sub(margin)
1104                    .saturating_sub(((v0 - min_y) / y_range * plot_h as f64) as usize);
1105                let y1 = height
1106                    .saturating_sub(margin)
1107                    .saturating_sub(((v1 - min_y) / y_range * plot_h as f64) as usize);
1108                // Bresenham line segment
1109                png_draw_line(&mut pixels, width, height, x0, y0, x1, y1, 46, 134, 171);
1110            }
1111        }
1112
1113        write_png(output_path, &pixels, width, height)
1114    }
1115
1116    /// Render the parameter trajectory as a PNG file (2D only).
1117    fn plot_trajectory_png(
1118        &self,
1119        trajectory: &OptimizationTrajectory,
1120        output_path: &Path,
1121    ) -> ScirsResult<()> {
1122        let width = self.config.width as usize;
1123        let height = self.config.height as usize;
1124        let margin = 60_usize;
1125
1126        if trajectory.parameters.is_empty() || trajectory.parameters[0].len() != 2 {
1127            return Err(ScirsError::InvalidInput(error_context!(
1128                "Parameter trajectory PNG visualization only supports 2D problems"
1129            )));
1130        }
1131
1132        let xs: Vec<f64> = trajectory.parameters.iter().map(|p| p[0]).collect();
1133        let ys: Vec<f64> = trajectory.parameters.iter().map(|p| p[1]).collect();
1134
1135        let min_x = xs.iter().cloned().fold(f64::INFINITY, f64::min);
1136        let max_x = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1137        let min_y = ys.iter().cloned().fold(f64::INFINITY, f64::min);
1138        let max_y = ys.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1139        let x_range = (max_x - min_x).max(f64::EPSILON);
1140        let y_range = (max_y - min_y).max(f64::EPSILON);
1141
1142        let mut pixels = vec![255u8; width * height * 3];
1143
1144        let plot_w = width.saturating_sub(2 * margin);
1145        let plot_h = height.saturating_sub(2 * margin);
1146        let n = xs.len();
1147        for i in 0..n.saturating_sub(1) {
1148            let px0 = margin + ((xs[i] - min_x) / x_range * plot_w as f64) as usize;
1149            let py0 = height
1150                .saturating_sub(margin)
1151                .saturating_sub(((ys[i] - min_y) / y_range * plot_h as f64) as usize);
1152            let px1 = margin + ((xs[i + 1] - min_x) / x_range * plot_w as f64) as usize;
1153            let py1 = height
1154                .saturating_sub(margin)
1155                .saturating_sub(((ys[i + 1] - min_y) / y_range * plot_h as f64) as usize);
1156            png_draw_line(&mut pixels, width, height, px0, py0, px1, py1, 46, 134, 171);
1157        }
1158
1159        write_png(output_path, &pixels, width, height)
1160    }
1161}
1162
1163impl Default for OptimizationVisualizer {
1164    fn default() -> Self {
1165        Self::new()
1166    }
1167}
1168
1169/// Utility functions for creating trajectory trackers
1170pub mod tracking {
1171    use super::OptimizationTrajectory;
1172    use scirs2_core::ndarray::ArrayView1;
1173    use std::time::Instant;
1174
1175    /// A callback-based trajectory tracker for use with optimization algorithms
1176    pub struct TrajectoryTracker {
1177        trajectory: OptimizationTrajectory,
1178        start_time: Instant,
1179    }
1180
1181    impl TrajectoryTracker {
1182        /// Create a new trajectory tracker
1183        pub fn new() -> Self {
1184            Self {
1185                trajectory: OptimizationTrajectory::new(),
1186                start_time: Instant::now(),
1187            }
1188        }
1189
1190        /// Record a new point in the optimization trajectory
1191        pub fn record(&mut self, iteration: usize, params: &ArrayView1<f64>, function_value: f64) {
1192            let elapsed = self.start_time.elapsed().as_secs_f64();
1193            self.trajectory
1194                .add_point(iteration, params, function_value, elapsed);
1195        }
1196
1197        /// Record gradient norm
1198        pub fn record_gradient_norm(&mut self, grad_norm: f64) {
1199            self.trajectory.add_gradient_norm(grad_norm);
1200        }
1201
1202        /// Record step size
1203        pub fn record_step_size(&mut self, step_size: f64) {
1204            self.trajectory.add_step_size(step_size);
1205        }
1206
1207        /// Record custom metric
1208        pub fn record_custom_metric(&mut self, name: &str, value: f64) {
1209            self.trajectory.add_custom_metric(name, value);
1210        }
1211
1212        /// Get the recorded trajectory
1213        pub fn trajectory(&self) -> &OptimizationTrajectory {
1214            &self.trajectory
1215        }
1216
1217        /// Consume the tracker and return the trajectory
1218        pub fn into_trajectory(self) -> OptimizationTrajectory {
1219            self.trajectory
1220        }
1221    }
1222
1223    impl Default for TrajectoryTracker {
1224        fn default() -> Self {
1225            Self::new()
1226        }
1227    }
1228}
1229
1230#[cfg(test)]
1231mod tests {
1232    use super::*;
1233    use scirs2_core::ndarray::array;
1234
1235    #[test]
1236    fn test_trajectory_creation() {
1237        let mut trajectory = OptimizationTrajectory::new();
1238        assert!(trajectory.is_empty());
1239
1240        let params = array![1.0, 2.0];
1241        trajectory.add_point(0, &params.view(), 5.0, 0.1);
1242
1243        assert_eq!(trajectory.len(), 1);
1244        assert_eq!(trajectory.final_function_value(), Some(5.0));
1245    }
1246
1247    #[test]
1248    fn test_convergence_rate_calculation() {
1249        let mut trajectory = OptimizationTrajectory::new();
1250
1251        // Add points with known convergence pattern
1252        let function_values = vec![10.0, 5.0, 2.5, 1.25, 0.625];
1253        for (i, &f_val) in function_values.iter().enumerate() {
1254            let params = array![i as f64, i as f64];
1255            trajectory.add_point(i, &params.view(), f_val, i as f64 * 0.1);
1256        }
1257
1258        let rate = trajectory.convergence_rate();
1259        assert!(rate.is_some());
1260        // Should be approximately 0.5 for this geometric sequence
1261        assert!((rate.expect("Operation failed") - 0.5).abs() < 0.1);
1262    }
1263
1264    #[test]
1265    fn test_visualization_config() {
1266        let config = VisualizationConfig {
1267            format: OutputFormat::Svg,
1268            width: 1000,
1269            height: 800,
1270            title: Some("Test Plot".to_string()),
1271            show_grid: true,
1272            log_scale_y: true,
1273            color_scheme: ColorScheme::Viridis,
1274            show_legend: false,
1275            custom_style: None,
1276        };
1277
1278        let visualizer = OptimizationVisualizer::with_config(config);
1279        assert_eq!(visualizer.config.width, 1000);
1280        assert_eq!(visualizer.config.height, 800);
1281    }
1282
1283    #[test]
1284    fn test_trajectory_tracker() {
1285        let mut tracker = tracking::TrajectoryTracker::new();
1286
1287        let params1 = array![0.0, 0.0];
1288        let params2 = array![1.0, 1.0];
1289
1290        tracker.record(0, &params1.view(), 10.0);
1291        tracker.record_gradient_norm(2.5);
1292        tracker.record_step_size(0.1);
1293
1294        tracker.record(1, &params2.view(), 5.0);
1295        tracker.record_gradient_norm(1.5);
1296        tracker.record_step_size(0.2);
1297
1298        let trajectory = tracker.trajectory();
1299        assert_eq!(trajectory.len(), 2);
1300        assert_eq!(trajectory.gradient_norms.len(), 2);
1301        assert_eq!(trajectory.step_sizes.len(), 2);
1302        assert_eq!(trajectory.final_function_value(), Some(5.0));
1303    }
1304
1305    #[test]
1306    fn test_png_convergence_output_valid_file() {
1307        // Build a simple convergence trajectory and render it as PNG.
1308        // Verify the output is a valid PNG file (starts with PNG signature).
1309        let mut trajectory = OptimizationTrajectory::new();
1310        let params = array![0.0f64, 0.0];
1311        for i in 0..10 {
1312            let val = 100.0 / (i as f64 + 1.0);
1313            trajectory.add_point(i, &params.view(), val, i as f64 * 0.01);
1314        }
1315
1316        let tmp_dir = std::env::temp_dir();
1317        let out_path = tmp_dir.join("test_convergence.png");
1318
1319        let config = VisualizationConfig {
1320            format: OutputFormat::Png,
1321            width: 100,
1322            height: 80,
1323            title: None,
1324            show_grid: false,
1325            log_scale_y: false,
1326            color_scheme: ColorScheme::Default,
1327            show_legend: false,
1328            custom_style: None,
1329        };
1330        let vis = OptimizationVisualizer::with_config(config);
1331        vis.plot_convergence(&trajectory, &out_path)
1332            .expect("PNG write failed");
1333
1334        // Read and validate PNG signature
1335        let data = std::fs::read(&out_path).expect("PNG read failed");
1336        assert!(data.len() > 8, "PNG too small");
1337        assert_eq!(
1338            &data[0..8],
1339            &[137, 80, 78, 71, 13, 10, 26, 10],
1340            "Invalid PNG signature"
1341        );
1342
1343        // Clean up
1344        let _ = std::fs::remove_file(&out_path);
1345    }
1346
1347    #[test]
1348    fn test_png_trajectory_output_valid_file() {
1349        // Build a 2D trajectory and render it as PNG.
1350        let mut trajectory = OptimizationTrajectory::new();
1351        for i in 0..5 {
1352            let params = array![i as f64 * 0.1, i as f64 * 0.2];
1353            trajectory.add_point(i, &params.view(), 10.0 - i as f64, i as f64 * 0.01);
1354        }
1355
1356        let tmp_dir = std::env::temp_dir();
1357        let out_path = tmp_dir.join("test_trajectory.png");
1358
1359        let config = VisualizationConfig {
1360            format: OutputFormat::Png,
1361            width: 80,
1362            height: 60,
1363            title: None,
1364            show_grid: false,
1365            log_scale_y: false,
1366            color_scheme: ColorScheme::Default,
1367            show_legend: false,
1368            custom_style: None,
1369        };
1370        let vis = OptimizationVisualizer::with_config(config);
1371        vis.plot_parameter_trajectory(&trajectory, &out_path)
1372            .expect("PNG traj write failed");
1373
1374        let data = std::fs::read(&out_path).expect("PNG traj read failed");
1375        assert!(data.len() > 8, "PNG trajectory too small");
1376        assert_eq!(
1377            &data[0..8],
1378            &[137, 80, 78, 71, 13, 10, 26, 10],
1379            "Invalid PNG signature"
1380        );
1381
1382        let _ = std::fs::remove_file(&out_path);
1383    }
1384}