1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
22use sklears_core::types::Float;
23use std::collections::HashMap;
24use std::path::Path;
25
26#[derive(Debug, Clone)]
28pub struct InteractiveVisualizationConfig {
29 pub width: usize,
31 pub height: usize,
33 pub color_scheme: ColorScheme,
35 pub enable_zoom_pan: bool,
37 pub enable_selection: bool,
39 pub show_tooltips: bool,
41 pub animation_duration: usize,
43 pub real_time_updates: bool,
45 pub update_interval: usize,
47}
48
49impl Default for InteractiveVisualizationConfig {
50 fn default() -> Self {
51 Self {
52 width: 800,
53 height: 600,
54 color_scheme: ColorScheme::Viridis,
55 enable_zoom_pan: true,
56 enable_selection: true,
57 show_tooltips: true,
58 animation_duration: 750,
59 real_time_updates: false,
60 update_interval: 100,
61 }
62 }
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq)]
67pub enum ColorScheme {
68 Viridis,
70 Plasma,
72 Turbo,
74 CoolWarm,
76 Custom,
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
82pub enum PlotType {
83 CanonicalScatter,
85 LoadingHeatmap,
87 CorrelationNetwork,
89 Scatter3D,
91 TimeSeries,
93 Biplot,
95 ParallelCoordinates,
97}
98
99#[derive(Debug, Clone)]
101pub struct InteractivePlot {
102 pub plot_type: PlotType,
104 pub data: PlotData,
106 pub config: InteractiveVisualizationConfig,
108 pub metadata: HashMap<String, String>,
110 pub callbacks: HashMap<String, String>,
112}
113
114#[derive(Debug, Clone)]
116pub struct PlotData {
117 pub x: Array1<f64>,
119 pub y: Array1<f64>,
121 pub z: Option<Array1<f64>>,
123 pub colors: Option<Array1<f64>>,
125 pub sizes: Option<Array1<f64>>,
127 pub labels: Option<Vec<String>>,
129 pub additional_dims: Option<Array2<f64>>,
131}
132
133impl PlotData {
134 pub fn new(x: Array1<f64>, y: Array1<f64>) -> Self {
136 Self {
137 x,
138 y,
139 z: None,
140 colors: None,
141 sizes: None,
142 labels: None,
143 additional_dims: None,
144 }
145 }
146
147 pub fn with_z(mut self, z: Array1<f64>) -> Self {
149 self.z = Some(z);
150 self
151 }
152
153 pub fn with_colors(mut self, colors: Array1<f64>) -> Self {
155 self.colors = Some(colors);
156 self
157 }
158
159 pub fn with_sizes(mut self, sizes: Array1<f64>) -> Self {
161 self.sizes = Some(sizes);
162 self
163 }
164
165 pub fn with_labels(mut self, labels: Vec<String>) -> Self {
167 self.labels = Some(labels);
168 self
169 }
170
171 pub fn with_additional_dims(mut self, dims: Array2<f64>) -> Self {
173 self.additional_dims = Some(dims);
174 self
175 }
176}
177
178#[derive(Debug)]
180pub struct InteractiveVisualizer {
181 config: InteractiveVisualizationConfig,
183 plots: Vec<InteractivePlot>,
185 output_dir: String,
187}
188
189impl InteractiveVisualizer {
190 pub fn new() -> Self {
192 Self {
193 config: InteractiveVisualizationConfig::default(),
194 plots: Vec::new(),
195 output_dir: "visualizations".to_string(),
196 }
197 }
198
199 pub fn with_config(config: InteractiveVisualizationConfig) -> Self {
201 Self {
202 config,
203 plots: Vec::new(),
204 output_dir: "visualizations".to_string(),
205 }
206 }
207
208 pub fn with_output_dir<P: AsRef<Path>>(mut self, path: P) -> Self {
210 self.output_dir = path.as_ref().to_string_lossy().to_string();
211 self
212 }
213
214 pub fn add_plot(&mut self, plot: InteractivePlot) {
216 self.plots.push(plot);
217 }
218
219 pub fn canonical_scatter(
221 &mut self,
222 x_canonical: ArrayView1<f64>,
223 y_canonical: ArrayView1<f64>,
224 labels: Option<Vec<String>>,
225 ) -> Result<(), VisualizationError> {
226 let data = PlotData::new(x_canonical.to_owned(), y_canonical.to_owned()).with_labels(
227 labels.unwrap_or_else(|| {
228 (0..x_canonical.len())
229 .map(|i| format!("Point {}", i))
230 .collect()
231 }),
232 );
233
234 let plot = InteractivePlot {
235 plot_type: PlotType::CanonicalScatter,
236 data,
237 config: self.config.clone(),
238 metadata: HashMap::new(),
239 callbacks: HashMap::new(),
240 };
241
242 self.add_plot(plot);
243 Ok(())
244 }
245
246 pub fn loading_heatmap(
248 &mut self,
249 loadings: ArrayView2<f64>,
250 feature_names: Option<Vec<String>>,
251 component_names: Option<Vec<String>>,
252 ) -> Result<(), VisualizationError> {
253 let (n_features, n_components) = loadings.dim();
255 let mut x_coords = Vec::new();
256 let mut y_coords = Vec::new();
257 let mut colors = Vec::new();
258 let mut labels = Vec::new();
259
260 for i in 0..n_features {
261 for j in 0..n_components {
262 x_coords.push(j as f64);
263 y_coords.push(i as f64);
264 colors.push(loadings[[i, j]]);
265
266 let feature_name = feature_names
267 .as_ref()
268 .map(|names| names[i].clone())
269 .unwrap_or_else(|| format!("Feature {}", i));
270 let component_name = component_names
271 .as_ref()
272 .map(|names| names[j].clone())
273 .unwrap_or_else(|| format!("Component {}", j));
274
275 labels.push(format!(
276 "{} -> {}: {:.4}",
277 feature_name,
278 component_name,
279 loadings[[i, j]]
280 ));
281 }
282 }
283
284 let data = PlotData::new(Array1::from_vec(x_coords), Array1::from_vec(y_coords))
285 .with_colors(Array1::from_vec(colors))
286 .with_labels(labels);
287
288 let plot = InteractivePlot {
289 plot_type: PlotType::LoadingHeatmap,
290 data,
291 config: self.config.clone(),
292 metadata: HashMap::new(),
293 callbacks: HashMap::new(),
294 };
295
296 self.add_plot(plot);
297 Ok(())
298 }
299
300 pub fn correlation_network(
302 &mut self,
303 correlation_matrix: ArrayView2<f64>,
304 variable_names: Option<Vec<String>>,
305 threshold: f64,
306 ) -> Result<(), VisualizationError> {
307 let n_vars = correlation_matrix.nrows();
308
309 let mut x_coords = Vec::new();
311 let mut y_coords = Vec::new();
312 let mut labels = Vec::new();
313
314 for i in 0..n_vars {
315 let angle = 2.0 * std::f64::consts::PI * (i as f64) / (n_vars as f64);
316 x_coords.push(angle.cos());
317 y_coords.push(angle.sin());
318
319 let label = variable_names
320 .as_ref()
321 .map(|names| names[i].clone())
322 .unwrap_or_else(|| format!("Var {}", i));
323 labels.push(label);
324 }
325
326 let data = PlotData::new(Array1::from_vec(x_coords), Array1::from_vec(y_coords))
327 .with_labels(labels);
328
329 let mut plot = InteractivePlot {
330 plot_type: PlotType::CorrelationNetwork,
331 data,
332 config: self.config.clone(),
333 metadata: HashMap::new(),
334 callbacks: HashMap::new(),
335 };
336
337 plot.metadata
339 .insert("threshold".to_string(), threshold.to_string());
340 plot.metadata.insert(
341 "correlation_data".to_string(),
342 format!("{:?}", correlation_matrix.shape()),
343 );
344
345 self.add_plot(plot);
346 Ok(())
347 }
348
349 pub fn scatter_3d(
351 &mut self,
352 x: ArrayView1<f64>,
353 y: ArrayView1<f64>,
354 z: ArrayView1<f64>,
355 colors: Option<ArrayView1<f64>>,
356 labels: Option<Vec<String>>,
357 ) -> Result<(), VisualizationError> {
358 let mut data = PlotData::new(x.to_owned(), y.to_owned()).with_z(z.to_owned());
359
360 if let Some(color_values) = colors {
361 data = data.with_colors(color_values.to_owned());
362 }
363
364 if let Some(point_labels) = labels {
365 data = data.with_labels(point_labels);
366 }
367
368 let plot = InteractivePlot {
369 plot_type: PlotType::Scatter3D,
370 data,
371 config: self.config.clone(),
372 metadata: HashMap::new(),
373 callbacks: HashMap::new(),
374 };
375
376 self.add_plot(plot);
377 Ok(())
378 }
379
380 pub fn generate_html(&self, filename: &str) -> Result<(), VisualizationError> {
382 let html_content = self.generate_html_content()?;
383
384 std::fs::create_dir_all(&self.output_dir)
386 .map_err(|e| VisualizationError::IoError(e.to_string()))?;
387
388 let filepath = format!("{}/{}", self.output_dir, filename);
389 std::fs::write(&filepath, html_content)
390 .map_err(|e| VisualizationError::IoError(e.to_string()))?;
391
392 println!("Interactive visualization saved to: {}", filepath);
393 Ok(())
394 }
395
396 fn generate_html_content(&self) -> Result<String, VisualizationError> {
398 let mut html = String::new();
399
400 html.push_str(
402 r#"
403<!DOCTYPE html>
404<html>
405<head>
406 <title>Interactive Cross-Decomposition Visualization</title>
407 <script src="https://d3js.org/d3.v7.min.js"></script>
408 <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
409 <style>
410 body { font-family: Arial, sans-serif; margin: 20px; }
411 .plot-container { margin: 20px 0; border: 1px solid #ccc; padding: 10px; }
412 .plot-title { font-size: 18px; font-weight: bold; margin-bottom: 10px; }
413 .plot-description { color: #666; margin-bottom: 15px; }
414 </style>
415</head>
416<body>
417 <h1>Interactive Cross-Decomposition Analysis</h1>
418"#,
419 );
420
421 for (i, plot) in self.plots.iter().enumerate() {
423 html.push_str(&format!(r#"
424 <div class="plot-container">
425 <div class="plot-title">{}</div>
426 <div class="plot-description">Interactive visualization with zoom, pan, and hover tooltips</div>
427 <div id="plot-{}" style="width: {}px; height: {}px;"></div>
428 </div>
429"#,
430 self.plot_type_title(plot.plot_type),
431 i,
432 plot.config.width,
433 plot.config.height
434 ));
435 }
436
437 html.push_str(
439 r#"
440 <script>
441 // Color schemes
442 const colorSchemes = {
443 'Viridis': 'Viridis',
444 'Plasma': 'Plasma',
445 'Turbo': 'Turbo',
446 'CoolWarm': 'RdBu'
447 };
448"#,
449 );
450
451 for (i, plot) in self.plots.iter().enumerate() {
453 html.push_str(&self.generate_plot_javascript(i, plot)?);
454 }
455
456 html.push_str(
457 r#"
458 </script>
459</body>
460</html>
461"#,
462 );
463
464 Ok(html)
465 }
466
467 fn generate_plot_javascript(
469 &self,
470 plot_index: usize,
471 plot: &InteractivePlot,
472 ) -> Result<String, VisualizationError> {
473 match plot.plot_type {
474 PlotType::CanonicalScatter => self.generate_scatter_js(plot_index, plot),
475 PlotType::LoadingHeatmap => self.generate_heatmap_js(plot_index, plot),
476 PlotType::CorrelationNetwork => self.generate_network_js(plot_index, plot),
477 PlotType::Scatter3D => self.generate_3d_scatter_js(plot_index, plot),
478 PlotType::TimeSeries => self.generate_timeseries_js(plot_index, plot),
479 PlotType::Biplot => self.generate_biplot_js(plot_index, plot),
480 PlotType::ParallelCoordinates => self.generate_parallel_js(plot_index, plot),
481 }
482 }
483
484 fn generate_scatter_js(
486 &self,
487 plot_index: usize,
488 plot: &InteractivePlot,
489 ) -> Result<String, VisualizationError> {
490 let x_data: Vec<String> = plot.data.x.iter().map(|v| v.to_string()).collect();
491 let y_data: Vec<String> = plot.data.y.iter().map(|v| v.to_string()).collect();
492 let labels = plot
493 .data
494 .labels
495 .as_ref()
496 .map(|l| {
497 l.iter()
498 .map(|s| format!("\"{}\"", s))
499 .collect::<Vec<_>>()
500 .join(",")
501 })
502 .unwrap_or_else(|| "[]".to_string());
503
504 Ok(format!(
505 r#"
506 // Scatter plot for plot-{}
507 const trace{} = {{
508 x: [{}],
509 y: [{}],
510 mode: 'markers',
511 type: 'scatter',
512 text: [{}],
513 hovertemplate: '%{{text}}<br>X: %{{x:.3f}}<br>Y: %{{y:.3f}}<extra></extra>',
514 marker: {{
515 size: 8,
516 color: 'rgba(31, 119, 180, 0.7)',
517 line: {{
518 color: 'rgba(31, 119, 180, 1.0)',
519 width: 1
520 }}
521 }}
522 }};
523
524 const layout{} = {{
525 title: 'Canonical Correlation Scatter Plot',
526 xaxis: {{ title: 'First Canonical Variable' }},
527 yaxis: {{ title: 'Second Canonical Variable' }},
528 hovermode: 'closest',
529 showlegend: false
530 }};
531
532 const config{} = {{
533 displayModeBar: true,
534 modeBarButtonsToAdd: [
535 {{
536 name: 'Select points',
537 icon: Plotly.Icons.selectbox,
538 click: function(gd) {{
539 console.log('Selection tool activated for plot {}');
540 }}
541 }}
542 ]
543 }};
544
545 Plotly.newPlot('plot-{}', [trace{}], layout{}, config{});
546"#,
547 plot_index,
548 plot_index,
549 x_data.join(","),
550 y_data.join(","),
551 labels,
552 plot_index,
553 plot_index,
554 plot_index,
555 plot_index,
556 plot_index,
557 plot_index,
558 plot_index
559 ))
560 }
561
562 fn generate_heatmap_js(
564 &self,
565 plot_index: usize,
566 plot: &InteractivePlot,
567 ) -> Result<String, VisualizationError> {
568 Ok(format!(
571 r#"
572 // Heatmap for plot-{}
573 const trace{} = {{
574 type: 'scatter',
575 mode: 'markers',
576 x: [{}],
577 y: [{}],
578 marker: {{
579 size: 20,
580 color: [{}],
581 colorscale: 'Viridis',
582 showscale: true,
583 colorbar: {{
584 title: 'Loading Value'
585 }}
586 }},
587 text: [{}],
588 hovertemplate: '%{{text}}<extra></extra>'
589 }};
590
591 const layout{} = {{
592 title: 'Component Loading Heatmap',
593 xaxis: {{ title: 'Components' }},
594 yaxis: {{ title: 'Features' }},
595 hovermode: 'closest'
596 }};
597
598 Plotly.newPlot('plot-{}', [trace{}], layout{});
599"#,
600 plot_index,
601 plot_index,
602 plot.data
603 .x
604 .iter()
605 .map(|v| v.to_string())
606 .collect::<Vec<_>>()
607 .join(","),
608 plot.data
609 .y
610 .iter()
611 .map(|v| v.to_string())
612 .collect::<Vec<_>>()
613 .join(","),
614 plot.data
615 .colors
616 .as_ref()
617 .map(|c| c
618 .iter()
619 .map(|v| v.to_string())
620 .collect::<Vec<_>>()
621 .join(","))
622 .unwrap_or_else(|| "[]".to_string()),
623 plot.data
624 .labels
625 .as_ref()
626 .map(|l| l
627 .iter()
628 .map(|s| format!("\"{}\"", s))
629 .collect::<Vec<_>>()
630 .join(","))
631 .unwrap_or_else(|| "[]".to_string()),
632 plot_index,
633 plot_index,
634 plot_index,
635 plot_index
636 ))
637 }
638
639 fn generate_network_js(
641 &self,
642 plot_index: usize,
643 plot: &InteractivePlot,
644 ) -> Result<String, VisualizationError> {
645 Ok(format!(
646 r#"
647 // Network plot for plot-{}
648 const trace{} = {{
649 x: [{}],
650 y: [{}],
651 mode: 'markers+text',
652 type: 'scatter',
653 text: [{}],
654 textposition: 'middle center',
655 marker: {{
656 size: 15,
657 color: 'rgba(255, 127, 14, 0.8)',
658 line: {{
659 color: 'rgba(255, 127, 14, 1.0)',
660 width: 2
661 }}
662 }}
663 }};
664
665 const layout{} = {{
666 title: 'Correlation Network',
667 xaxis: {{ title: '', showgrid: false, zeroline: false, showticklabels: false }},
668 yaxis: {{ title: '', showgrid: false, zeroline: false, showticklabels: false }},
669 hovermode: 'closest',
670 showlegend: false
671 }};
672
673 Plotly.newPlot('plot-{}', [trace{}], layout{});
674"#,
675 plot_index,
676 plot_index,
677 plot.data
678 .x
679 .iter()
680 .map(|v| v.to_string())
681 .collect::<Vec<_>>()
682 .join(","),
683 plot.data
684 .y
685 .iter()
686 .map(|v| v.to_string())
687 .collect::<Vec<_>>()
688 .join(","),
689 plot.data
690 .labels
691 .as_ref()
692 .map(|l| l
693 .iter()
694 .map(|s| format!("\"{}\"", s))
695 .collect::<Vec<_>>()
696 .join(","))
697 .unwrap_or_else(|| "[]".to_string()),
698 plot_index,
699 plot_index,
700 plot_index,
701 plot_index
702 ))
703 }
704
705 fn generate_3d_scatter_js(
707 &self,
708 plot_index: usize,
709 plot: &InteractivePlot,
710 ) -> Result<String, VisualizationError> {
711 let z_data = plot
712 .data
713 .z
714 .as_ref()
715 .map(|z| {
716 z.iter()
717 .map(|v| v.to_string())
718 .collect::<Vec<_>>()
719 .join(",")
720 })
721 .unwrap_or_else(|| "[]".to_string());
722
723 Ok(format!(
724 r#"
725 // 3D scatter plot for plot-{}
726 const trace{} = {{
727 x: [{}],
728 y: [{}],
729 z: [{}],
730 mode: 'markers',
731 type: 'scatter3d',
732 marker: {{
733 size: 5,
734 color: [{}],
735 colorscale: 'Viridis',
736 showscale: true
737 }},
738 text: [{}],
739 hovertemplate: '%{{text}}<br>X: %{{x:.3f}}<br>Y: %{{y:.3f}}<br>Z: %{{z:.3f}}<extra></extra>'
740 }};
741
742 const layout{} = {{
743 title: '3D Multi-View Data Visualization',
744 scene: {{
745 xaxis: {{ title: 'Component 1' }},
746 yaxis: {{ title: 'Component 2' }},
747 zaxis: {{ title: 'Component 3' }}
748 }},
749 hovermode: 'closest'
750 }};
751
752 Plotly.newPlot('plot-{}', [trace{}], layout{});
753"#,
754 plot_index,
755 plot_index,
756 plot.data
757 .x
758 .iter()
759 .map(|v| v.to_string())
760 .collect::<Vec<_>>()
761 .join(","),
762 plot.data
763 .y
764 .iter()
765 .map(|v| v.to_string())
766 .collect::<Vec<_>>()
767 .join(","),
768 z_data,
769 plot.data
770 .colors
771 .as_ref()
772 .map(|c| c
773 .iter()
774 .map(|v| v.to_string())
775 .collect::<Vec<_>>()
776 .join(","))
777 .unwrap_or_else(|| "[]".to_string()),
778 plot.data
779 .labels
780 .as_ref()
781 .map(|l| l
782 .iter()
783 .map(|s| format!("\"{}\"", s))
784 .collect::<Vec<_>>()
785 .join(","))
786 .unwrap_or_else(|| "[]".to_string()),
787 plot_index,
788 plot_index,
789 plot_index,
790 plot_index
791 ))
792 }
793
794 fn generate_timeseries_js(
796 &self,
797 plot_index: usize,
798 _plot: &InteractivePlot,
799 ) -> Result<String, VisualizationError> {
800 Ok(format!("// Time series plot {} - placeholder", plot_index))
801 }
802
803 fn generate_biplot_js(
804 &self,
805 plot_index: usize,
806 _plot: &InteractivePlot,
807 ) -> Result<String, VisualizationError> {
808 Ok(format!("// Biplot {} - placeholder", plot_index))
809 }
810
811 fn generate_parallel_js(
812 &self,
813 plot_index: usize,
814 _plot: &InteractivePlot,
815 ) -> Result<String, VisualizationError> {
816 Ok(format!(
817 "// Parallel coordinates plot {} - placeholder",
818 plot_index
819 ))
820 }
821
822 fn plot_type_title(&self, plot_type: PlotType) -> &'static str {
824 match plot_type {
825 PlotType::CanonicalScatter => "Canonical Correlation Scatter Plot",
826 PlotType::LoadingHeatmap => "Component Loading Heatmap",
827 PlotType::CorrelationNetwork => "Correlation Network Visualization",
828 PlotType::Scatter3D => "3D Multi-View Data Visualization",
829 PlotType::TimeSeries => "Temporal Dynamics Visualization",
830 PlotType::Biplot => "Biplot Visualization",
831 PlotType::ParallelCoordinates => "Parallel Coordinates Plot",
832 }
833 }
834}
835
836impl Default for InteractiveVisualizer {
837 fn default() -> Self {
838 Self::new()
839 }
840}
841
842#[derive(Debug, thiserror::Error)]
844pub enum VisualizationError {
845 #[error("Dimension mismatch: {0}")]
846 DimensionError(String),
847 #[error("Invalid configuration: {0}")]
848 ConfigError(String),
849 #[error("IO error: {0}")]
850 IoError(String),
851 #[error("Rendering error: {0}")]
852 RenderError(String),
853}
854
855#[allow(non_snake_case)]
856#[cfg(test)]
857mod tests {
858 use super::*;
859 use approx::assert_abs_diff_eq;
860
861 #[test]
862 fn test_interactive_visualizer_creation() {
863 let visualizer = InteractiveVisualizer::new();
864 assert_eq!(visualizer.plots.len(), 0);
865 assert_eq!(visualizer.config.width, 800);
866 assert_eq!(visualizer.config.height, 600);
867 }
868
869 #[test]
870 fn test_plot_data_creation() {
871 let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
872 let y = Array1::from_vec(vec![4.0, 5.0, 6.0]);
873
874 let data = PlotData::new(x.clone(), y.clone());
875
876 assert_eq!(data.x, x);
877 assert_eq!(data.y, y);
878 assert!(data.z.is_none());
879 assert!(data.colors.is_none());
880 }
881
882 #[test]
883 fn test_plot_data_with_colors_and_z() {
884 let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
885 let y = Array1::from_vec(vec![4.0, 5.0, 6.0]);
886 let z = Array1::from_vec(vec![7.0, 8.0, 9.0]);
887 let colors = Array1::from_vec(vec![0.1, 0.5, 0.9]);
888
889 let data = PlotData::new(x.clone(), y.clone())
890 .with_z(z.clone())
891 .with_colors(colors.clone());
892
893 assert_eq!(data.x, x);
894 assert_eq!(data.y, y);
895 assert_eq!(data.z.unwrap(), z);
896 assert_eq!(data.colors.unwrap(), colors);
897 }
898
899 #[test]
900 fn test_canonical_scatter_plot() -> Result<(), VisualizationError> {
901 let mut visualizer = InteractiveVisualizer::new();
902
903 let x_canonical = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
904 let y_canonical = Array1::from_vec(vec![2.0, 4.0, 6.0, 8.0]);
905
906 visualizer.canonical_scatter(x_canonical.view(), y_canonical.view(), None)?;
907
908 assert_eq!(visualizer.plots.len(), 1);
909 assert_eq!(visualizer.plots[0].plot_type, PlotType::CanonicalScatter);
910 assert_eq!(visualizer.plots[0].data.x.len(), 4);
911
912 Ok(())
913 }
914
915 #[test]
916 fn test_loading_heatmap() -> Result<(), VisualizationError> {
917 let mut visualizer = InteractiveVisualizer::new();
918
919 let loadings = Array2::from_shape_vec((3, 2), vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6]).unwrap();
920 let feature_names = Some(vec![
921 "Feature1".to_string(),
922 "Feature2".to_string(),
923 "Feature3".to_string(),
924 ]);
925 let component_names = Some(vec!["Comp1".to_string(), "Comp2".to_string()]);
926
927 visualizer.loading_heatmap(loadings.view(), feature_names, component_names)?;
928
929 assert_eq!(visualizer.plots.len(), 1);
930 assert_eq!(visualizer.plots[0].plot_type, PlotType::LoadingHeatmap);
931 assert_eq!(visualizer.plots[0].data.x.len(), 6); Ok(())
934 }
935
936 #[test]
937 fn test_3d_scatter_plot() -> Result<(), VisualizationError> {
938 let mut visualizer = InteractiveVisualizer::new();
939
940 let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
941 let y = Array1::from_vec(vec![4.0, 5.0, 6.0]);
942 let z = Array1::from_vec(vec![7.0, 8.0, 9.0]);
943 let colors = Array1::from_vec(vec![0.1, 0.5, 0.9]);
944
945 visualizer.scatter_3d(x.view(), y.view(), z.view(), Some(colors.view()), None)?;
946
947 assert_eq!(visualizer.plots.len(), 1);
948 assert_eq!(visualizer.plots[0].plot_type, PlotType::Scatter3D);
949 assert!(visualizer.plots[0].data.z.is_some());
950 assert!(visualizer.plots[0].data.colors.is_some());
951
952 Ok(())
953 }
954
955 #[test]
956 fn test_color_scheme_enum() {
957 let scheme = ColorScheme::Viridis;
958 assert_eq!(scheme, ColorScheme::Viridis);
959 assert_ne!(scheme, ColorScheme::Plasma);
960 }
961
962 #[test]
963 fn test_visualization_config_default() {
964 let config = InteractiveVisualizationConfig::default();
965 assert_eq!(config.width, 800);
966 assert_eq!(config.height, 600);
967 assert_eq!(config.color_scheme, ColorScheme::Viridis);
968 assert!(config.enable_zoom_pan);
969 assert!(config.show_tooltips);
970 }
971
972 #[test]
973 fn test_correlation_network() -> Result<(), VisualizationError> {
974 let mut visualizer = InteractiveVisualizer::new();
975
976 let correlation_matrix =
977 Array2::from_shape_vec((3, 3), vec![1.0, 0.5, 0.3, 0.5, 1.0, 0.7, 0.3, 0.7, 1.0])
978 .unwrap();
979 let variable_names = Some(vec![
980 "Var1".to_string(),
981 "Var2".to_string(),
982 "Var3".to_string(),
983 ]);
984
985 visualizer.correlation_network(correlation_matrix.view(), variable_names, 0.5)?;
986
987 assert_eq!(visualizer.plots.len(), 1);
988 assert_eq!(visualizer.plots[0].plot_type, PlotType::CorrelationNetwork);
989 assert_eq!(visualizer.plots[0].data.x.len(), 3);
990 assert!(visualizer.plots[0].metadata.contains_key("threshold"));
991
992 Ok(())
993 }
994
995 #[test]
996 fn test_html_generation() {
997 let mut visualizer = InteractiveVisualizer::new();
998
999 let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1001 let y = Array1::from_vec(vec![2.0, 4.0, 6.0]);
1002 let _ = visualizer.canonical_scatter(x.view(), y.view(), None);
1003
1004 let html_result = visualizer.generate_html_content();
1006 assert!(html_result.is_ok());
1007
1008 let html = html_result.unwrap();
1009 assert!(html.contains("<!DOCTYPE html>"));
1010 assert!(html.contains("Interactive Cross-Decomposition"));
1011 assert!(html.contains("Plotly"));
1012 assert!(html.contains("plot-0"));
1013 }
1014
1015 #[test]
1016 fn test_multiple_plots() -> Result<(), VisualizationError> {
1017 let mut visualizer = InteractiveVisualizer::new();
1018
1019 let x1 = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1021 let y1 = Array1::from_vec(vec![2.0, 4.0, 6.0]);
1022 visualizer.canonical_scatter(x1.view(), y1.view(), None)?;
1023
1024 let x2 = Array1::from_vec(vec![1.0, 2.0]);
1026 let y2 = Array1::from_vec(vec![3.0, 4.0]);
1027 let z2 = Array1::from_vec(vec![5.0, 6.0]);
1028 visualizer.scatter_3d(x2.view(), y2.view(), z2.view(), None, None)?;
1029
1030 assert_eq!(visualizer.plots.len(), 2);
1031 assert_eq!(visualizer.plots[0].plot_type, PlotType::CanonicalScatter);
1032 assert_eq!(visualizer.plots[1].plot_type, PlotType::Scatter3D);
1033
1034 Ok(())
1035 }
1036}