sklears_cross_decomposition/
interactive_visualization.rs

1//! Interactive Visualization for Cross-Decomposition Methods
2//!
3//! This module provides interactive visualization capabilities for cross-decomposition
4//! algorithms including CCA, PLS, and tensor methods. It supports various plot types,
5//! real-time updates, and web-based interactive dashboards.
6//!
7//! ## Supported Visualizations
8//! - Interactive canonical correlation plots
9//! - Real-time component analysis
10//! - 3D multi-view data visualization
11//! - Network visualization for correlation structures
12//! - Temporal dynamics visualization
13//! - Component loading heatmaps with interactivity
14//!
15//! ## Output Formats
16//! - HTML with JavaScript interactivity
17//! - SVG with embedded interactions
18//! - JSON data for custom visualizations
19//! - Real-time streaming updates via WebSocket
20
21use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
22use sklears_core::types::Float;
23use std::collections::HashMap;
24use std::path::Path;
25
26/// Interactive visualization configuration
27#[derive(Debug, Clone)]
28pub struct InteractiveVisualizationConfig {
29    /// Plot width in pixels
30    pub width: usize,
31    /// Plot height in pixels
32    pub height: usize,
33    /// Color scheme for the visualization
34    pub color_scheme: ColorScheme,
35    /// Whether to enable zoom and pan interactions
36    pub enable_zoom_pan: bool,
37    /// Whether to enable point selection
38    pub enable_selection: bool,
39    /// Whether to show tooltips on hover
40    pub show_tooltips: bool,
41    /// Animation duration in milliseconds
42    pub animation_duration: usize,
43    /// Whether to enable real-time updates
44    pub real_time_updates: bool,
45    /// Update interval in milliseconds (for real-time)
46    pub update_interval: usize,
47}
48
49impl Default for InteractiveVisualizationConfig {
50    fn default() -> Self {
51        Self {
52            width: 800,
53            height: 600,
54            color_scheme: ColorScheme::Viridis,
55            enable_zoom_pan: true,
56            enable_selection: true,
57            show_tooltips: true,
58            animation_duration: 750,
59            real_time_updates: false,
60            update_interval: 100,
61        }
62    }
63}
64
65/// Color schemes for visualizations
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
67pub enum ColorScheme {
68    /// Viridis color scheme (perceptually uniform)
69    Viridis,
70    /// Plasma color scheme (high contrast)
71    Plasma,
72    /// Turbo color scheme (rainbow alternative)
73    Turbo,
74    /// Cool warm color scheme (diverging)
75    CoolWarm,
76    /// Custom color scheme
77    Custom,
78}
79
80/// Plot types for interactive visualization
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
82pub enum PlotType {
83    /// Scatter plot with canonical coordinates
84    CanonicalScatter,
85    /// Component loading heatmap
86    LoadingHeatmap,
87    /// Correlation network graph
88    CorrelationNetwork,
89    /// 3D scatter plot for multi-view data
90    Scatter3D,
91    /// Time series plot for temporal dynamics
92    TimeSeries,
93    /// Biplot (scores and loadings together)
94    Biplot,
95    /// Parallel coordinates plot
96    ParallelCoordinates,
97}
98
99/// Interactive plot data container
100#[derive(Debug, Clone)]
101pub struct InteractivePlot {
102    /// Plot type
103    pub plot_type: PlotType,
104    /// Data points for plotting
105    pub data: PlotData,
106    /// Plot configuration
107    pub config: InteractiveVisualizationConfig,
108    /// Metadata for tooltips and interactions
109    pub metadata: HashMap<String, String>,
110    /// Custom JavaScript callbacks
111    pub callbacks: HashMap<String, String>,
112}
113
114/// Plot data structure
115#[derive(Debug, Clone)]
116pub struct PlotData {
117    /// X coordinates
118    pub x: Array1<f64>,
119    /// Y coordinates
120    pub y: Array1<f64>,
121    /// Z coordinates (for 3D plots)
122    pub z: Option<Array1<f64>>,
123    /// Point colors (indices into color scheme)
124    pub colors: Option<Array1<f64>>,
125    /// Point sizes
126    pub sizes: Option<Array1<f64>>,
127    /// Point labels for tooltips
128    pub labels: Option<Vec<String>>,
129    /// Additional data dimensions for parallel coordinates
130    pub additional_dims: Option<Array2<f64>>,
131}
132
133impl PlotData {
134    /// Create new plot data with x and y coordinates
135    pub fn new(x: Array1<f64>, y: Array1<f64>) -> Self {
136        Self {
137            x,
138            y,
139            z: None,
140            colors: None,
141            sizes: None,
142            labels: None,
143            additional_dims: None,
144        }
145    }
146
147    /// Add Z coordinates for 3D plotting
148    pub fn with_z(mut self, z: Array1<f64>) -> Self {
149        self.z = Some(z);
150        self
151    }
152
153    /// Add colors for points
154    pub fn with_colors(mut self, colors: Array1<f64>) -> Self {
155        self.colors = Some(colors);
156        self
157    }
158
159    /// Add sizes for points
160    pub fn with_sizes(mut self, sizes: Array1<f64>) -> Self {
161        self.sizes = Some(sizes);
162        self
163    }
164
165    /// Add labels for tooltips
166    pub fn with_labels(mut self, labels: Vec<String>) -> Self {
167        self.labels = Some(labels);
168        self
169    }
170
171    /// Add additional dimensions for parallel coordinates
172    pub fn with_additional_dims(mut self, dims: Array2<f64>) -> Self {
173        self.additional_dims = Some(dims);
174        self
175    }
176}
177
178/// Interactive visualization engine
179#[derive(Debug)]
180pub struct InteractiveVisualizer {
181    /// Configuration
182    config: InteractiveVisualizationConfig,
183    /// Current plots
184    plots: Vec<InteractivePlot>,
185    /// Output directory for generated files
186    output_dir: String,
187}
188
189impl InteractiveVisualizer {
190    /// Create a new interactive visualizer
191    pub fn new() -> Self {
192        Self {
193            config: InteractiveVisualizationConfig::default(),
194            plots: Vec::new(),
195            output_dir: "visualizations".to_string(),
196        }
197    }
198
199    /// Create visualizer with custom configuration
200    pub fn with_config(config: InteractiveVisualizationConfig) -> Self {
201        Self {
202            config,
203            plots: Vec::new(),
204            output_dir: "visualizations".to_string(),
205        }
206    }
207
208    /// Set output directory
209    pub fn with_output_dir<P: AsRef<Path>>(mut self, path: P) -> Self {
210        self.output_dir = path.as_ref().to_string_lossy().to_string();
211        self
212    }
213
214    /// Add a new interactive plot
215    pub fn add_plot(&mut self, plot: InteractivePlot) {
216        self.plots.push(plot);
217    }
218
219    /// Create canonical correlation scatter plot
220    pub fn canonical_scatter(
221        &mut self,
222        x_canonical: ArrayView1<f64>,
223        y_canonical: ArrayView1<f64>,
224        labels: Option<Vec<String>>,
225    ) -> Result<(), VisualizationError> {
226        let data = PlotData::new(x_canonical.to_owned(), y_canonical.to_owned()).with_labels(
227            labels.unwrap_or_else(|| {
228                (0..x_canonical.len())
229                    .map(|i| format!("Point {}", i))
230                    .collect()
231            }),
232        );
233
234        let plot = InteractivePlot {
235            plot_type: PlotType::CanonicalScatter,
236            data,
237            config: self.config.clone(),
238            metadata: HashMap::new(),
239            callbacks: HashMap::new(),
240        };
241
242        self.add_plot(plot);
243        Ok(())
244    }
245
246    /// Create component loading heatmap
247    pub fn loading_heatmap(
248        &mut self,
249        loadings: ArrayView2<f64>,
250        feature_names: Option<Vec<String>>,
251        component_names: Option<Vec<String>>,
252    ) -> Result<(), VisualizationError> {
253        // Convert 2D loadings to plot coordinates for heatmap
254        let (n_features, n_components) = loadings.dim();
255        let mut x_coords = Vec::new();
256        let mut y_coords = Vec::new();
257        let mut colors = Vec::new();
258        let mut labels = Vec::new();
259
260        for i in 0..n_features {
261            for j in 0..n_components {
262                x_coords.push(j as f64);
263                y_coords.push(i as f64);
264                colors.push(loadings[[i, j]]);
265
266                let feature_name = feature_names
267                    .as_ref()
268                    .map(|names| names[i].clone())
269                    .unwrap_or_else(|| format!("Feature {}", i));
270                let component_name = component_names
271                    .as_ref()
272                    .map(|names| names[j].clone())
273                    .unwrap_or_else(|| format!("Component {}", j));
274
275                labels.push(format!(
276                    "{} -> {}: {:.4}",
277                    feature_name,
278                    component_name,
279                    loadings[[i, j]]
280                ));
281            }
282        }
283
284        let data = PlotData::new(Array1::from_vec(x_coords), Array1::from_vec(y_coords))
285            .with_colors(Array1::from_vec(colors))
286            .with_labels(labels);
287
288        let plot = InteractivePlot {
289            plot_type: PlotType::LoadingHeatmap,
290            data,
291            config: self.config.clone(),
292            metadata: HashMap::new(),
293            callbacks: HashMap::new(),
294        };
295
296        self.add_plot(plot);
297        Ok(())
298    }
299
300    /// Create correlation network visualization
301    pub fn correlation_network(
302        &mut self,
303        correlation_matrix: ArrayView2<f64>,
304        variable_names: Option<Vec<String>>,
305        threshold: f64,
306    ) -> Result<(), VisualizationError> {
307        let n_vars = correlation_matrix.nrows();
308
309        // Create network layout (simple circular for now)
310        let mut x_coords = Vec::new();
311        let mut y_coords = Vec::new();
312        let mut labels = Vec::new();
313
314        for i in 0..n_vars {
315            let angle = 2.0 * std::f64::consts::PI * (i as f64) / (n_vars as f64);
316            x_coords.push(angle.cos());
317            y_coords.push(angle.sin());
318
319            let label = variable_names
320                .as_ref()
321                .map(|names| names[i].clone())
322                .unwrap_or_else(|| format!("Var {}", i));
323            labels.push(label);
324        }
325
326        let data = PlotData::new(Array1::from_vec(x_coords), Array1::from_vec(y_coords))
327            .with_labels(labels);
328
329        let mut plot = InteractivePlot {
330            plot_type: PlotType::CorrelationNetwork,
331            data,
332            config: self.config.clone(),
333            metadata: HashMap::new(),
334            callbacks: HashMap::new(),
335        };
336
337        // Add network metadata
338        plot.metadata
339            .insert("threshold".to_string(), threshold.to_string());
340        plot.metadata.insert(
341            "correlation_data".to_string(),
342            format!("{:?}", correlation_matrix.shape()),
343        );
344
345        self.add_plot(plot);
346        Ok(())
347    }
348
349    /// Create 3D scatter plot for multi-view data
350    pub fn scatter_3d(
351        &mut self,
352        x: ArrayView1<f64>,
353        y: ArrayView1<f64>,
354        z: ArrayView1<f64>,
355        colors: Option<ArrayView1<f64>>,
356        labels: Option<Vec<String>>,
357    ) -> Result<(), VisualizationError> {
358        let mut data = PlotData::new(x.to_owned(), y.to_owned()).with_z(z.to_owned());
359
360        if let Some(color_values) = colors {
361            data = data.with_colors(color_values.to_owned());
362        }
363
364        if let Some(point_labels) = labels {
365            data = data.with_labels(point_labels);
366        }
367
368        let plot = InteractivePlot {
369            plot_type: PlotType::Scatter3D,
370            data,
371            config: self.config.clone(),
372            metadata: HashMap::new(),
373            callbacks: HashMap::new(),
374        };
375
376        self.add_plot(plot);
377        Ok(())
378    }
379
380    /// Generate HTML output with interactive plots
381    pub fn generate_html(&self, filename: &str) -> Result<(), VisualizationError> {
382        let html_content = self.generate_html_content()?;
383
384        // Create output directory if it doesn't exist
385        std::fs::create_dir_all(&self.output_dir)
386            .map_err(|e| VisualizationError::IoError(e.to_string()))?;
387
388        let filepath = format!("{}/{}", self.output_dir, filename);
389        std::fs::write(&filepath, html_content)
390            .map_err(|e| VisualizationError::IoError(e.to_string()))?;
391
392        println!("Interactive visualization saved to: {}", filepath);
393        Ok(())
394    }
395
396    /// Generate the HTML content for interactive plots
397    fn generate_html_content(&self) -> Result<String, VisualizationError> {
398        let mut html = String::new();
399
400        // HTML header with D3.js and other dependencies
401        html.push_str(
402            r#"
403<!DOCTYPE html>
404<html>
405<head>
406    <title>Interactive Cross-Decomposition Visualization</title>
407    <script src="https://d3js.org/d3.v7.min.js"></script>
408    <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
409    <style>
410        body { font-family: Arial, sans-serif; margin: 20px; }
411        .plot-container { margin: 20px 0; border: 1px solid #ccc; padding: 10px; }
412        .plot-title { font-size: 18px; font-weight: bold; margin-bottom: 10px; }
413        .plot-description { color: #666; margin-bottom: 15px; }
414    </style>
415</head>
416<body>
417    <h1>Interactive Cross-Decomposition Analysis</h1>
418"#,
419        );
420
421        // Generate plots
422        for (i, plot) in self.plots.iter().enumerate() {
423            html.push_str(&format!(r#"
424    <div class="plot-container">
425        <div class="plot-title">{}</div>
426        <div class="plot-description">Interactive visualization with zoom, pan, and hover tooltips</div>
427        <div id="plot-{}" style="width: {}px; height: {}px;"></div>
428    </div>
429"#,
430                self.plot_type_title(plot.plot_type),
431                i,
432                plot.config.width,
433                plot.config.height
434            ));
435        }
436
437        // JavaScript for interactive plots
438        html.push_str(
439            r#"
440    <script>
441        // Color schemes
442        const colorSchemes = {
443            'Viridis': 'Viridis',
444            'Plasma': 'Plasma',
445            'Turbo': 'Turbo',
446            'CoolWarm': 'RdBu'
447        };
448"#,
449        );
450
451        // Generate JavaScript for each plot
452        for (i, plot) in self.plots.iter().enumerate() {
453            html.push_str(&self.generate_plot_javascript(i, plot)?);
454        }
455
456        html.push_str(
457            r#"
458    </script>
459</body>
460</html>
461"#,
462        );
463
464        Ok(html)
465    }
466
467    /// Generate JavaScript for a specific plot
468    fn generate_plot_javascript(
469        &self,
470        plot_index: usize,
471        plot: &InteractivePlot,
472    ) -> Result<String, VisualizationError> {
473        match plot.plot_type {
474            PlotType::CanonicalScatter => self.generate_scatter_js(plot_index, plot),
475            PlotType::LoadingHeatmap => self.generate_heatmap_js(plot_index, plot),
476            PlotType::CorrelationNetwork => self.generate_network_js(plot_index, plot),
477            PlotType::Scatter3D => self.generate_3d_scatter_js(plot_index, plot),
478            PlotType::TimeSeries => self.generate_timeseries_js(plot_index, plot),
479            PlotType::Biplot => self.generate_biplot_js(plot_index, plot),
480            PlotType::ParallelCoordinates => self.generate_parallel_js(plot_index, plot),
481        }
482    }
483
484    /// Generate JavaScript for scatter plot
485    fn generate_scatter_js(
486        &self,
487        plot_index: usize,
488        plot: &InteractivePlot,
489    ) -> Result<String, VisualizationError> {
490        let x_data: Vec<String> = plot.data.x.iter().map(|v| v.to_string()).collect();
491        let y_data: Vec<String> = plot.data.y.iter().map(|v| v.to_string()).collect();
492        let labels = plot
493            .data
494            .labels
495            .as_ref()
496            .map(|l| {
497                l.iter()
498                    .map(|s| format!("\"{}\"", s))
499                    .collect::<Vec<_>>()
500                    .join(",")
501            })
502            .unwrap_or_else(|| "[]".to_string());
503
504        Ok(format!(
505            r#"
506        // Scatter plot for plot-{}
507        const trace{} = {{
508            x: [{}],
509            y: [{}],
510            mode: 'markers',
511            type: 'scatter',
512            text: [{}],
513            hovertemplate: '%{{text}}<br>X: %{{x:.3f}}<br>Y: %{{y:.3f}}<extra></extra>',
514            marker: {{
515                size: 8,
516                color: 'rgba(31, 119, 180, 0.7)',
517                line: {{
518                    color: 'rgba(31, 119, 180, 1.0)',
519                    width: 1
520                }}
521            }}
522        }};
523
524        const layout{} = {{
525            title: 'Canonical Correlation Scatter Plot',
526            xaxis: {{ title: 'First Canonical Variable' }},
527            yaxis: {{ title: 'Second Canonical Variable' }},
528            hovermode: 'closest',
529            showlegend: false
530        }};
531
532        const config{} = {{
533            displayModeBar: true,
534            modeBarButtonsToAdd: [
535                {{
536                    name: 'Select points',
537                    icon: Plotly.Icons.selectbox,
538                    click: function(gd) {{
539                        console.log('Selection tool activated for plot {}');
540                    }}
541                }}
542            ]
543        }};
544
545        Plotly.newPlot('plot-{}', [trace{}], layout{}, config{});
546"#,
547            plot_index,
548            plot_index,
549            x_data.join(","),
550            y_data.join(","),
551            labels,
552            plot_index,
553            plot_index,
554            plot_index,
555            plot_index,
556            plot_index,
557            plot_index,
558            plot_index
559        ))
560    }
561
562    /// Generate JavaScript for heatmap
563    fn generate_heatmap_js(
564        &self,
565        plot_index: usize,
566        plot: &InteractivePlot,
567    ) -> Result<String, VisualizationError> {
568        // This is a simplified heatmap implementation
569        // In practice, you'd reconstruct the 2D matrix from the plot data
570        Ok(format!(
571            r#"
572        // Heatmap for plot-{}
573        const trace{} = {{
574            type: 'scatter',
575            mode: 'markers',
576            x: [{}],
577            y: [{}],
578            marker: {{
579                size: 20,
580                color: [{}],
581                colorscale: 'Viridis',
582                showscale: true,
583                colorbar: {{
584                    title: 'Loading Value'
585                }}
586            }},
587            text: [{}],
588            hovertemplate: '%{{text}}<extra></extra>'
589        }};
590
591        const layout{} = {{
592            title: 'Component Loading Heatmap',
593            xaxis: {{ title: 'Components' }},
594            yaxis: {{ title: 'Features' }},
595            hovermode: 'closest'
596        }};
597
598        Plotly.newPlot('plot-{}', [trace{}], layout{});
599"#,
600            plot_index,
601            plot_index,
602            plot.data
603                .x
604                .iter()
605                .map(|v| v.to_string())
606                .collect::<Vec<_>>()
607                .join(","),
608            plot.data
609                .y
610                .iter()
611                .map(|v| v.to_string())
612                .collect::<Vec<_>>()
613                .join(","),
614            plot.data
615                .colors
616                .as_ref()
617                .map(|c| c
618                    .iter()
619                    .map(|v| v.to_string())
620                    .collect::<Vec<_>>()
621                    .join(","))
622                .unwrap_or_else(|| "[]".to_string()),
623            plot.data
624                .labels
625                .as_ref()
626                .map(|l| l
627                    .iter()
628                    .map(|s| format!("\"{}\"", s))
629                    .collect::<Vec<_>>()
630                    .join(","))
631                .unwrap_or_else(|| "[]".to_string()),
632            plot_index,
633            plot_index,
634            plot_index,
635            plot_index
636        ))
637    }
638
639    /// Generate JavaScript for network visualization
640    fn generate_network_js(
641        &self,
642        plot_index: usize,
643        plot: &InteractivePlot,
644    ) -> Result<String, VisualizationError> {
645        Ok(format!(
646            r#"
647        // Network plot for plot-{}
648        const trace{} = {{
649            x: [{}],
650            y: [{}],
651            mode: 'markers+text',
652            type: 'scatter',
653            text: [{}],
654            textposition: 'middle center',
655            marker: {{
656                size: 15,
657                color: 'rgba(255, 127, 14, 0.8)',
658                line: {{
659                    color: 'rgba(255, 127, 14, 1.0)',
660                    width: 2
661                }}
662            }}
663        }};
664
665        const layout{} = {{
666            title: 'Correlation Network',
667            xaxis: {{ title: '', showgrid: false, zeroline: false, showticklabels: false }},
668            yaxis: {{ title: '', showgrid: false, zeroline: false, showticklabels: false }},
669            hovermode: 'closest',
670            showlegend: false
671        }};
672
673        Plotly.newPlot('plot-{}', [trace{}], layout{});
674"#,
675            plot_index,
676            plot_index,
677            plot.data
678                .x
679                .iter()
680                .map(|v| v.to_string())
681                .collect::<Vec<_>>()
682                .join(","),
683            plot.data
684                .y
685                .iter()
686                .map(|v| v.to_string())
687                .collect::<Vec<_>>()
688                .join(","),
689            plot.data
690                .labels
691                .as_ref()
692                .map(|l| l
693                    .iter()
694                    .map(|s| format!("\"{}\"", s))
695                    .collect::<Vec<_>>()
696                    .join(","))
697                .unwrap_or_else(|| "[]".to_string()),
698            plot_index,
699            plot_index,
700            plot_index,
701            plot_index
702        ))
703    }
704
705    /// Generate JavaScript for 3D scatter plot
706    fn generate_3d_scatter_js(
707        &self,
708        plot_index: usize,
709        plot: &InteractivePlot,
710    ) -> Result<String, VisualizationError> {
711        let z_data = plot
712            .data
713            .z
714            .as_ref()
715            .map(|z| {
716                z.iter()
717                    .map(|v| v.to_string())
718                    .collect::<Vec<_>>()
719                    .join(",")
720            })
721            .unwrap_or_else(|| "[]".to_string());
722
723        Ok(format!(
724            r#"
725        // 3D scatter plot for plot-{}
726        const trace{} = {{
727            x: [{}],
728            y: [{}],
729            z: [{}],
730            mode: 'markers',
731            type: 'scatter3d',
732            marker: {{
733                size: 5,
734                color: [{}],
735                colorscale: 'Viridis',
736                showscale: true
737            }},
738            text: [{}],
739            hovertemplate: '%{{text}}<br>X: %{{x:.3f}}<br>Y: %{{y:.3f}}<br>Z: %{{z:.3f}}<extra></extra>'
740        }};
741
742        const layout{} = {{
743            title: '3D Multi-View Data Visualization',
744            scene: {{
745                xaxis: {{ title: 'Component 1' }},
746                yaxis: {{ title: 'Component 2' }},
747                zaxis: {{ title: 'Component 3' }}
748            }},
749            hovermode: 'closest'
750        }};
751
752        Plotly.newPlot('plot-{}', [trace{}], layout{});
753"#,
754            plot_index,
755            plot_index,
756            plot.data
757                .x
758                .iter()
759                .map(|v| v.to_string())
760                .collect::<Vec<_>>()
761                .join(","),
762            plot.data
763                .y
764                .iter()
765                .map(|v| v.to_string())
766                .collect::<Vec<_>>()
767                .join(","),
768            z_data,
769            plot.data
770                .colors
771                .as_ref()
772                .map(|c| c
773                    .iter()
774                    .map(|v| v.to_string())
775                    .collect::<Vec<_>>()
776                    .join(","))
777                .unwrap_or_else(|| "[]".to_string()),
778            plot.data
779                .labels
780                .as_ref()
781                .map(|l| l
782                    .iter()
783                    .map(|s| format!("\"{}\"", s))
784                    .collect::<Vec<_>>()
785                    .join(","))
786                .unwrap_or_else(|| "[]".to_string()),
787            plot_index,
788            plot_index,
789            plot_index,
790            plot_index
791        ))
792    }
793
794    /// Generate placeholder JavaScript for other plot types
795    fn generate_timeseries_js(
796        &self,
797        plot_index: usize,
798        _plot: &InteractivePlot,
799    ) -> Result<String, VisualizationError> {
800        Ok(format!("// Time series plot {} - placeholder", plot_index))
801    }
802
803    fn generate_biplot_js(
804        &self,
805        plot_index: usize,
806        _plot: &InteractivePlot,
807    ) -> Result<String, VisualizationError> {
808        Ok(format!("// Biplot {} - placeholder", plot_index))
809    }
810
811    fn generate_parallel_js(
812        &self,
813        plot_index: usize,
814        _plot: &InteractivePlot,
815    ) -> Result<String, VisualizationError> {
816        Ok(format!(
817            "// Parallel coordinates plot {} - placeholder",
818            plot_index
819        ))
820    }
821
822    /// Get title for plot type
823    fn plot_type_title(&self, plot_type: PlotType) -> &'static str {
824        match plot_type {
825            PlotType::CanonicalScatter => "Canonical Correlation Scatter Plot",
826            PlotType::LoadingHeatmap => "Component Loading Heatmap",
827            PlotType::CorrelationNetwork => "Correlation Network Visualization",
828            PlotType::Scatter3D => "3D Multi-View Data Visualization",
829            PlotType::TimeSeries => "Temporal Dynamics Visualization",
830            PlotType::Biplot => "Biplot Visualization",
831            PlotType::ParallelCoordinates => "Parallel Coordinates Plot",
832        }
833    }
834}
835
836impl Default for InteractiveVisualizer {
837    fn default() -> Self {
838        Self::new()
839    }
840}
841
842/// Visualization errors
843#[derive(Debug, thiserror::Error)]
844pub enum VisualizationError {
845    #[error("Dimension mismatch: {0}")]
846    DimensionError(String),
847    #[error("Invalid configuration: {0}")]
848    ConfigError(String),
849    #[error("IO error: {0}")]
850    IoError(String),
851    #[error("Rendering error: {0}")]
852    RenderError(String),
853}
854
855#[allow(non_snake_case)]
856#[cfg(test)]
857mod tests {
858    use super::*;
859    use approx::assert_abs_diff_eq;
860
861    #[test]
862    fn test_interactive_visualizer_creation() {
863        let visualizer = InteractiveVisualizer::new();
864        assert_eq!(visualizer.plots.len(), 0);
865        assert_eq!(visualizer.config.width, 800);
866        assert_eq!(visualizer.config.height, 600);
867    }
868
869    #[test]
870    fn test_plot_data_creation() {
871        let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
872        let y = Array1::from_vec(vec![4.0, 5.0, 6.0]);
873
874        let data = PlotData::new(x.clone(), y.clone());
875
876        assert_eq!(data.x, x);
877        assert_eq!(data.y, y);
878        assert!(data.z.is_none());
879        assert!(data.colors.is_none());
880    }
881
882    #[test]
883    fn test_plot_data_with_colors_and_z() {
884        let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
885        let y = Array1::from_vec(vec![4.0, 5.0, 6.0]);
886        let z = Array1::from_vec(vec![7.0, 8.0, 9.0]);
887        let colors = Array1::from_vec(vec![0.1, 0.5, 0.9]);
888
889        let data = PlotData::new(x.clone(), y.clone())
890            .with_z(z.clone())
891            .with_colors(colors.clone());
892
893        assert_eq!(data.x, x);
894        assert_eq!(data.y, y);
895        assert_eq!(data.z.unwrap(), z);
896        assert_eq!(data.colors.unwrap(), colors);
897    }
898
899    #[test]
900    fn test_canonical_scatter_plot() -> Result<(), VisualizationError> {
901        let mut visualizer = InteractiveVisualizer::new();
902
903        let x_canonical = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
904        let y_canonical = Array1::from_vec(vec![2.0, 4.0, 6.0, 8.0]);
905
906        visualizer.canonical_scatter(x_canonical.view(), y_canonical.view(), None)?;
907
908        assert_eq!(visualizer.plots.len(), 1);
909        assert_eq!(visualizer.plots[0].plot_type, PlotType::CanonicalScatter);
910        assert_eq!(visualizer.plots[0].data.x.len(), 4);
911
912        Ok(())
913    }
914
915    #[test]
916    fn test_loading_heatmap() -> Result<(), VisualizationError> {
917        let mut visualizer = InteractiveVisualizer::new();
918
919        let loadings = Array2::from_shape_vec((3, 2), vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6]).unwrap();
920        let feature_names = Some(vec![
921            "Feature1".to_string(),
922            "Feature2".to_string(),
923            "Feature3".to_string(),
924        ]);
925        let component_names = Some(vec!["Comp1".to_string(), "Comp2".to_string()]);
926
927        visualizer.loading_heatmap(loadings.view(), feature_names, component_names)?;
928
929        assert_eq!(visualizer.plots.len(), 1);
930        assert_eq!(visualizer.plots[0].plot_type, PlotType::LoadingHeatmap);
931        assert_eq!(visualizer.plots[0].data.x.len(), 6); // 3 features * 2 components
932
933        Ok(())
934    }
935
936    #[test]
937    fn test_3d_scatter_plot() -> Result<(), VisualizationError> {
938        let mut visualizer = InteractiveVisualizer::new();
939
940        let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
941        let y = Array1::from_vec(vec![4.0, 5.0, 6.0]);
942        let z = Array1::from_vec(vec![7.0, 8.0, 9.0]);
943        let colors = Array1::from_vec(vec![0.1, 0.5, 0.9]);
944
945        visualizer.scatter_3d(x.view(), y.view(), z.view(), Some(colors.view()), None)?;
946
947        assert_eq!(visualizer.plots.len(), 1);
948        assert_eq!(visualizer.plots[0].plot_type, PlotType::Scatter3D);
949        assert!(visualizer.plots[0].data.z.is_some());
950        assert!(visualizer.plots[0].data.colors.is_some());
951
952        Ok(())
953    }
954
955    #[test]
956    fn test_color_scheme_enum() {
957        let scheme = ColorScheme::Viridis;
958        assert_eq!(scheme, ColorScheme::Viridis);
959        assert_ne!(scheme, ColorScheme::Plasma);
960    }
961
962    #[test]
963    fn test_visualization_config_default() {
964        let config = InteractiveVisualizationConfig::default();
965        assert_eq!(config.width, 800);
966        assert_eq!(config.height, 600);
967        assert_eq!(config.color_scheme, ColorScheme::Viridis);
968        assert!(config.enable_zoom_pan);
969        assert!(config.show_tooltips);
970    }
971
972    #[test]
973    fn test_correlation_network() -> Result<(), VisualizationError> {
974        let mut visualizer = InteractiveVisualizer::new();
975
976        let correlation_matrix =
977            Array2::from_shape_vec((3, 3), vec![1.0, 0.5, 0.3, 0.5, 1.0, 0.7, 0.3, 0.7, 1.0])
978                .unwrap();
979        let variable_names = Some(vec![
980            "Var1".to_string(),
981            "Var2".to_string(),
982            "Var3".to_string(),
983        ]);
984
985        visualizer.correlation_network(correlation_matrix.view(), variable_names, 0.5)?;
986
987        assert_eq!(visualizer.plots.len(), 1);
988        assert_eq!(visualizer.plots[0].plot_type, PlotType::CorrelationNetwork);
989        assert_eq!(visualizer.plots[0].data.x.len(), 3);
990        assert!(visualizer.plots[0].metadata.contains_key("threshold"));
991
992        Ok(())
993    }
994
995    #[test]
996    fn test_html_generation() {
997        let mut visualizer = InteractiveVisualizer::new();
998
999        // Add a simple scatter plot
1000        let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1001        let y = Array1::from_vec(vec![2.0, 4.0, 6.0]);
1002        let _ = visualizer.canonical_scatter(x.view(), y.view(), None);
1003
1004        // Test HTML content generation (not file writing)
1005        let html_result = visualizer.generate_html_content();
1006        assert!(html_result.is_ok());
1007
1008        let html = html_result.unwrap();
1009        assert!(html.contains("<!DOCTYPE html>"));
1010        assert!(html.contains("Interactive Cross-Decomposition"));
1011        assert!(html.contains("Plotly"));
1012        assert!(html.contains("plot-0"));
1013    }
1014
1015    #[test]
1016    fn test_multiple_plots() -> Result<(), VisualizationError> {
1017        let mut visualizer = InteractiveVisualizer::new();
1018
1019        // Add scatter plot
1020        let x1 = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1021        let y1 = Array1::from_vec(vec![2.0, 4.0, 6.0]);
1022        visualizer.canonical_scatter(x1.view(), y1.view(), None)?;
1023
1024        // Add 3D plot
1025        let x2 = Array1::from_vec(vec![1.0, 2.0]);
1026        let y2 = Array1::from_vec(vec![3.0, 4.0]);
1027        let z2 = Array1::from_vec(vec![5.0, 6.0]);
1028        visualizer.scatter_3d(x2.view(), y2.view(), z2.view(), None, None)?;
1029
1030        assert_eq!(visualizer.plots.len(), 2);
1031        assert_eq!(visualizer.plots[0].plot_type, PlotType::CanonicalScatter);
1032        assert_eq!(visualizer.plots[1].plot_type, PlotType::Scatter3D);
1033
1034        Ok(())
1035    }
1036}