1use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
49use scirs2_core::numeric::{Float, FromPrimitive};
50use std::collections::HashMap;
51use std::fmt::Debug;
52
53use serde::{Deserialize, Serialize};
54
55use crate::error::{ClusteringError, Result};
56
57pub mod animation;
59pub mod export;
60pub mod interactive;
61
62pub use animation::{
64 AnimationFrame, ConvergenceInfo, IterativeAnimationConfig, IterativeAnimationRecorder,
65 StreamingConfig, StreamingFrame, StreamingStats, StreamingVisualizer,
66};
67pub use export::{
68 export_animation_to_file, export_scatter_2d_to_file, export_scatter_2d_to_html,
69 export_scatter_2d_to_json, export_scatter_3d_to_file, export_scatter_3d_to_html,
70 export_scatter_3d_to_json, save_visualization_to_file, ExportConfig, ExportFormat,
71};
72pub use interactive::{
73 BoundingBox3D, CameraState, ClusterStats, InteractiveConfig, InteractiveState,
74 InteractiveVisualizer, KeyCode, MouseButton, ViewMode,
75};
76
77#[derive(Debug, Clone)]
79pub struct VisualizationConfig {
80 pub color_scheme: ColorScheme,
82 pub point_size: f32,
84 pub point_opacity: f32,
86 pub show_centroids: bool,
88 pub show_boundaries: bool,
90 pub boundary_type: BoundaryType,
92 pub interactive: bool,
94 pub animation: Option<AnimationConfig>,
96 pub dimensionality_reduction: DimensionalityReduction,
98}
99
100#[derive(Debug, Clone, Copy, PartialEq, Eq)]
102pub enum ColorScheme {
103 Default,
105 ColorblindFriendly,
107 HighContrast,
109 Pastel,
111 Viridis,
113 Plasma,
115 Custom,
117}
118
119#[derive(Debug, Clone, Copy, PartialEq, Eq)]
121pub enum BoundaryType {
122 ConvexHull,
124 Ellipse,
126 AlphaShape,
128 None,
130}
131
132#[derive(Debug, Clone, Copy, PartialEq, Eq)]
134pub enum DimensionalityReduction {
135 PCA,
137 TSNE,
139 UMAP,
141 MDS,
143 First2D,
145 First3D,
147 None,
149}
150
151#[derive(Debug, Clone)]
153pub struct AnimationConfig {
154 pub duration_ms: u32,
156 pub frames: u32,
158 pub easing: EasingFunction,
160 pub loop_animation: bool,
162}
163
164#[derive(Debug, Clone, Copy, PartialEq, Eq)]
166pub enum EasingFunction {
167 Linear,
168 EaseIn,
169 EaseOut,
170 EaseInOut,
171 Bounce,
172 Elastic,
173}
174
175impl Default for VisualizationConfig {
176 fn default() -> Self {
177 Self {
178 color_scheme: ColorScheme::Default,
179 point_size: 5.0,
180 point_opacity: 0.8,
181 show_centroids: true,
182 show_boundaries: false,
183 boundary_type: BoundaryType::ConvexHull,
184 interactive: true,
185 animation: None,
186 dimensionality_reduction: DimensionalityReduction::PCA,
187 }
188 }
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct ScatterPlot2D {
194 pub points: Array2<f64>,
196 pub labels: Array1<i32>,
198 pub centroids: Option<Array2<f64>>,
200 pub colors: Vec<String>,
202 pub sizes: Vec<f32>,
204 pub point_labels: Option<Vec<String>>,
206 pub bounds: (f64, f64, f64, f64),
208 pub legend: Vec<LegendEntry>,
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct ScatterPlot3D {
215 pub points: Array2<f64>,
217 pub labels: Array1<i32>,
219 pub centroids: Option<Array2<f64>>,
221 pub colors: Vec<String>,
223 pub sizes: Vec<f32>,
225 pub point_labels: Option<Vec<String>>,
227 pub bounds: (f64, f64, f64, f64, f64, f64),
229 pub legend: Vec<LegendEntry>,
231}
232
233#[derive(Debug, Clone, Serialize, Deserialize)]
235pub struct LegendEntry {
236 pub cluster_id: i32,
238 pub color: String,
240 pub label: String,
242 pub count: usize,
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct ClusterBoundary {
249 pub cluster_id: i32,
251 pub boundary_points: Array2<f64>,
253 pub boundary_type: String,
255 pub color: String,
257}
258
259#[allow(dead_code)]
272pub fn create_scatter_plot_2d<F: Float + FromPrimitive + Debug>(
273 data: ArrayView2<F>,
274 labels: &Array1<i32>,
275 centroids: Option<&Array2<F>>,
276 config: &VisualizationConfig,
277) -> Result<ScatterPlot2D> {
278 let n_samples = data.nrows();
279 let n_features = data.ncols();
280
281 if labels.len() != n_samples {
282 return Err(ClusteringError::InvalidInput(
283 "Number of labels must match number of samples".to_string(),
284 ));
285 }
286
287 let plotdata =
289 if n_features == 2 && config.dimensionality_reduction == DimensionalityReduction::None {
290 data.mapv(|x| x.to_f64().unwrap_or(0.0))
291 } else {
292 apply_dimensionality_reduction_2d(data, config.dimensionality_reduction)?
293 };
294
295 let plot_centroids = if let Some(cents) = centroids {
297 if cents.ncols() == 2 && config.dimensionality_reduction == DimensionalityReduction::None {
298 Some(cents.mapv(|x| x.to_f64().unwrap_or(0.0)))
299 } else {
300 Some(apply_dimensionality_reduction_2d(
301 cents.view(),
302 config.dimensionality_reduction,
303 )?)
304 }
305 } else {
306 None
307 };
308
309 let unique_labels: Vec<i32> = {
311 let mut labels_vec: Vec<i32> = labels.iter().cloned().collect();
312 labels_vec.sort_unstable();
313 labels_vec.dedup();
314 labels_vec
315 };
316
317 let cluster_colors = generate_cluster_colors(&unique_labels, config.color_scheme);
318 let point_colors = labels
319 .iter()
320 .map(|&label| {
321 cluster_colors
322 .get(&label)
323 .cloned()
324 .unwrap_or_else(|| "#000000".to_string())
325 })
326 .collect();
327
328 let sizes = vec![config.point_size; n_samples];
330
331 let bounds = calculate_2d_bounds(&plotdata);
333
334 let legend = create_legend(&unique_labels, &cluster_colors, labels);
336
337 Ok(ScatterPlot2D {
338 points: plotdata,
339 labels: labels.clone(),
340 centroids: plot_centroids,
341 colors: point_colors,
342 sizes,
343 point_labels: None,
344 bounds,
345 legend,
346 })
347}
348
349#[allow(dead_code)]
362pub fn create_scatter_plot_3d<F: Float + FromPrimitive + Debug>(
363 data: ArrayView2<F>,
364 labels: &Array1<i32>,
365 centroids: Option<&Array2<F>>,
366 config: &VisualizationConfig,
367) -> Result<ScatterPlot3D> {
368 let n_samples = data.nrows();
369 let n_features = data.ncols();
370
371 if labels.len() != n_samples {
372 return Err(ClusteringError::InvalidInput(
373 "Number of labels must match number of samples".to_string(),
374 ));
375 }
376
377 let plotdata =
379 if n_features == 3 && config.dimensionality_reduction == DimensionalityReduction::None {
380 data.mapv(|x| x.to_f64().unwrap_or(0.0))
381 } else {
382 apply_dimensionality_reduction_3d(data, config.dimensionality_reduction)?
383 };
384
385 let plot_centroids = if let Some(cents) = centroids {
387 if cents.ncols() == 3 && config.dimensionality_reduction == DimensionalityReduction::None {
388 Some(cents.mapv(|x| x.to_f64().unwrap_or(0.0)))
389 } else {
390 Some(apply_dimensionality_reduction_3d(
391 cents.view(),
392 config.dimensionality_reduction,
393 )?)
394 }
395 } else {
396 None
397 };
398
399 let unique_labels: Vec<i32> = {
401 let mut labels_vec: Vec<i32> = labels.iter().cloned().collect();
402 labels_vec.sort_unstable();
403 labels_vec.dedup();
404 labels_vec
405 };
406
407 let cluster_colors = generate_cluster_colors(&unique_labels, config.color_scheme);
408 let point_colors = labels
409 .iter()
410 .map(|&label| {
411 cluster_colors
412 .get(&label)
413 .cloned()
414 .unwrap_or_else(|| "#000000".to_string())
415 })
416 .collect();
417
418 let sizes = vec![config.point_size; n_samples];
420
421 let bounds = calculate_3d_bounds(&plotdata);
423
424 let legend = create_legend(&unique_labels, &cluster_colors, labels);
426
427 Ok(ScatterPlot3D {
428 points: plotdata,
429 labels: labels.clone(),
430 centroids: plot_centroids,
431 colors: point_colors,
432 sizes,
433 point_labels: None,
434 bounds,
435 legend,
436 })
437}
438
439#[allow(dead_code)]
441fn apply_dimensionality_reduction_2d<F: Float + FromPrimitive + Debug>(
442 data: ArrayView2<F>,
443 method: DimensionalityReduction,
444) -> Result<Array2<f64>> {
445 let data_f64 = data.mapv(|x| x.to_f64().unwrap_or(0.0));
446
447 match method {
448 DimensionalityReduction::PCA => apply_pca_2d(&data_f64),
449 DimensionalityReduction::First2D => {
450 if data_f64.ncols() >= 2 {
451 Ok(data_f64.slice(s![.., 0..2]).to_owned())
452 } else {
453 Err(ClusteringError::InvalidInput(
454 "Data must have at least 2 dimensions for First2D".to_string(),
455 ))
456 }
457 }
458 DimensionalityReduction::TSNE => apply_tsne_2d(&data_f64),
459 DimensionalityReduction::UMAP => apply_umap_2d(&data_f64),
460 DimensionalityReduction::MDS => apply_mds_2d(&data_f64),
461 DimensionalityReduction::None => {
462 if data_f64.ncols() == 2 {
463 Ok(data_f64)
464 } else {
465 Err(ClusteringError::InvalidInput(
466 "Data must be 2D when no dimensionality reduction is specified".to_string(),
467 ))
468 }
469 }
470 _ => apply_pca_2d(&data_f64), }
472}
473
474#[allow(dead_code)]
476fn apply_dimensionality_reduction_3d<F: Float + FromPrimitive + Debug>(
477 data: ArrayView2<F>,
478 method: DimensionalityReduction,
479) -> Result<Array2<f64>> {
480 let data_f64 = data.mapv(|x| x.to_f64().unwrap_or(0.0));
481
482 match method {
483 DimensionalityReduction::PCA => apply_pca_3d(&data_f64),
484 DimensionalityReduction::First3D => {
485 if data_f64.ncols() >= 3 {
486 Ok(data_f64.slice(s![.., 0..3]).to_owned())
487 } else {
488 Err(ClusteringError::InvalidInput(
489 "Data must have at least 3 dimensions for First3D".to_string(),
490 ))
491 }
492 }
493 DimensionalityReduction::None => {
494 if data_f64.ncols() == 3 {
495 Ok(data_f64)
496 } else {
497 Err(ClusteringError::InvalidInput(
498 "Data must be 3D when no dimensionality reduction is specified".to_string(),
499 ))
500 }
501 }
502 _ => apply_pca_3d(&data_f64), }
504}
505
506#[allow(dead_code)]
508fn apply_pca_2d(data: &Array2<f64>) -> Result<Array2<f64>> {
509 let n_samples = data.nrows();
510 let n_features = data.ncols();
511
512 if n_features < 2 {
513 return Err(ClusteringError::InvalidInput(
514 "Need at least 2 features for PCA".to_string(),
515 ));
516 }
517
518 let mean = data.mean_axis(Axis(0)).unwrap();
520 let centered = data - &mean;
521
522 let cov = centered.t().dot(¢ered) / (n_samples - 1) as f64;
524
525 let n_features = centered.ncols();
528 let eigenvectors_ = Array2::eye(n_features)
529 .slice(s![.., 0..2.min(n_features)])
530 .to_owned();
531
532 let projected = centered.dot(&eigenvectors_);
534
535 Ok(projected)
536}
537
538#[allow(dead_code)]
540fn apply_pca_3d(data: &Array2<f64>) -> Result<Array2<f64>> {
541 let n_samples = data.nrows();
542 let n_features = data.ncols();
543
544 if n_features < 3 {
545 return Err(ClusteringError::InvalidInput(
546 "Need at least 3 features for 3D PCA".to_string(),
547 ));
548 }
549
550 let mean = data.mean_axis(Axis(0)).unwrap();
552 let centered = data - &mean;
553
554 let cov = centered.t().dot(¢ered) / (n_samples - 1) as f64;
556
557 let n_features = centered.ncols();
560 let eigenvectors_ = Array2::eye(n_features)
561 .slice(s![.., 0..3.min(n_features)])
562 .to_owned();
563
564 let projected = centered.dot(&eigenvectors_);
566
567 Ok(projected)
568}
569
570#[allow(dead_code)]
573fn apply_tsne_2d(data: &Array2<f64>) -> Result<Array2<f64>> {
574 apply_pca_2d(data)
576}
577
578#[allow(dead_code)]
579fn apply_umap_2d(data: &Array2<f64>) -> Result<Array2<f64>> {
580 apply_pca_2d(data)
582}
583
584#[allow(dead_code)]
585fn apply_mds_2d(data: &Array2<f64>) -> Result<Array2<f64>> {
586 apply_pca_2d(data)
588}
589
590#[allow(dead_code)]
592fn compute_top_eigenvectors(
593 matrix: &Array2<f64>,
594 num_components: usize,
595) -> Result<(Array2<f64>, Array1<f64>)> {
596 let n = matrix.nrows();
597 let k = num_components.min(n);
598
599 let mut eigenvectors = Array2::zeros((n, k));
600 let mut eigenvalues = Array1::zeros(k);
601
602 for i in 0..k {
604 let mut v = Array1::from_elem(n, 1.0 / (n as f64).sqrt());
605
606 for j in 0..i {
608 let prev_eigenvector = eigenvectors.column(j);
609 let dot_product = v.dot(&prev_eigenvector);
610 v = &v - &(&prev_eigenvector * dot_product);
611 }
612
613 for _ in 0..100 {
615 let new_v = matrix.dot(&v);
616 let norm = (new_v.dot(&new_v)).sqrt();
617 if norm > 1e-10 {
618 v = new_v / norm;
619 }
620 }
621
622 eigenvalues[i] = v.dot(&matrix.dot(&v));
623 for j in 0..n {
624 eigenvectors[[j, i]] = v[j];
625 }
626 }
627
628 Ok((eigenvectors, eigenvalues))
629}
630
631#[allow(dead_code)]
633fn generate_cluster_colors(labels: &[i32], scheme: ColorScheme) -> HashMap<i32, String> {
634 let mut colors = HashMap::new();
635
636 let color_palette = match scheme {
637 ColorScheme::Default => vec![
638 "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f",
639 "#bcbd22", "#17becf",
640 ],
641 ColorScheme::ColorblindFriendly => vec![
642 "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#ffff33", "#a65628", "#f781bf", "#999999",
643 ],
644 ColorScheme::HighContrast => vec![
645 "#000000", "#ffffff", "#ff0000", "#00ff00", "#0000ff", "#ffff00", "#ff00ff", "#00ffff",
646 ],
647 ColorScheme::Pastel => vec![
648 "#aec7e8", "#ffbb78", "#98df8a", "#ff9896", "#c5b0d5", "#c49c94", "#f7b6d3", "#c7c7c7",
649 "#dbdb8d", "#9edae5",
650 ],
651 ColorScheme::Viridis => vec![
652 "#440154", "#482777", "#3f4a8a", "#31678e", "#26838f", "#1f9d8a", "#6cce5a", "#b6de2b",
653 "#fee825",
654 ],
655 ColorScheme::Plasma => vec![
656 "#0c0887", "#5302a3", "#8b0aa5", "#b83289", "#db5c68", "#f48849", "#febd2a", "#f0f921",
657 ],
658 ColorScheme::Custom => vec!["#333333"], };
660
661 for (i, &label) in labels.iter().enumerate() {
662 colors.entry(label).or_insert_with(|| {
663 let color_index = i % color_palette.len();
664 color_palette[color_index].to_string()
665 });
666 }
667
668 colors
669}
670
671#[allow(dead_code)]
673fn calculate_2d_bounds(data: &Array2<f64>) -> (f64, f64, f64, f64) {
674 if data.is_empty() {
675 return (0.0, 1.0, 0.0, 1.0);
676 }
677
678 let x_min = data.column(0).iter().fold(f64::INFINITY, |a, &b| a.min(b));
679 let x_max = data
680 .column(0)
681 .iter()
682 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
683 let y_min = data.column(1).iter().fold(f64::INFINITY, |a, &b| a.min(b));
684 let y_max = data
685 .column(1)
686 .iter()
687 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
688
689 let x_range = x_max - x_min;
691 let y_range = y_max - y_min;
692 let padding = 0.05;
693
694 (
695 x_min - x_range * padding,
696 x_max + x_range * padding,
697 y_min - y_range * padding,
698 y_max + y_range * padding,
699 )
700}
701
702#[allow(dead_code)]
704fn calculate_3d_bounds(data: &Array2<f64>) -> (f64, f64, f64, f64, f64, f64) {
705 if data.is_empty() {
706 return (0.0, 1.0, 0.0, 1.0, 0.0, 1.0);
707 }
708
709 let x_min = data.column(0).iter().fold(f64::INFINITY, |a, &b| a.min(b));
710 let x_max = data
711 .column(0)
712 .iter()
713 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
714 let y_min = data.column(1).iter().fold(f64::INFINITY, |a, &b| a.min(b));
715 let y_max = data
716 .column(1)
717 .iter()
718 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
719 let z_min = data.column(2).iter().fold(f64::INFINITY, |a, &b| a.min(b));
720 let z_max = data
721 .column(2)
722 .iter()
723 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
724
725 let x_range = x_max - x_min;
727 let y_range = y_max - y_min;
728 let z_range = z_max - z_min;
729 let padding = 0.05;
730
731 (
732 x_min - x_range * padding,
733 x_max + x_range * padding,
734 y_min - y_range * padding,
735 y_max + y_range * padding,
736 z_min - z_range * padding,
737 z_max + z_range * padding,
738 )
739}
740
741#[allow(dead_code)]
743fn create_legend(
744 labels: &[i32],
745 colors: &HashMap<i32, String>,
746 data_labels: &Array1<i32>,
747) -> Vec<LegendEntry> {
748 let mut legend = Vec::new();
749
750 for &label in labels {
751 let count = data_labels.iter().filter(|&&l| l == label).count();
752 let color = colors
753 .get(&label)
754 .cloned()
755 .unwrap_or_else(|| "#000000".to_string());
756
757 legend.push(LegendEntry {
758 cluster_id: label,
759 color,
760 label: format!("Cluster {}", label),
761 count,
762 });
763 }
764
765 legend.sort_by_key(|entry| entry.cluster_id);
767
768 legend
769}
770
771#[cfg(test)]
772mod tests {
773 use super::*;
774 use scirs2_core::ndarray::Array2;
775
776 #[test]
777 fn test_create_scatter_plot_2d() {
778 let data =
779 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
780 let labels = Array1::from_vec(vec![0, 0, 1, 1]);
781 let config = VisualizationConfig::default();
782
783 let plot = create_scatter_plot_2d(data.view(), &labels, None, &config).unwrap();
784
785 assert_eq!(plot.points.nrows(), 4);
786 assert_eq!(plot.points.ncols(), 2);
787 assert_eq!(plot.labels.len(), 4);
788 assert_eq!(plot.colors.len(), 4);
789 assert_eq!(plot.legend.len(), 2);
790 }
791
792 #[test]
793 fn test_create_scatter_plot_3d() {
794 let data = Array2::from_shape_vec(
795 (4, 3),
796 vec![
797 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
798 ],
799 )
800 .unwrap();
801 let labels = Array1::from_vec(vec![0, 0, 1, 1]);
802 let config = VisualizationConfig::default();
803
804 let plot = create_scatter_plot_3d(data.view(), &labels, None, &config).unwrap();
805
806 assert_eq!(plot.points.nrows(), 4);
807 assert_eq!(plot.points.ncols(), 3);
808 assert_eq!(plot.labels.len(), 4);
809 }
810
811 #[test]
812 fn test_dimensionality_reduction() {
813 let data = Array2::from_shape_vec((10, 5), (0..50).map(|x| x as f64).collect()).unwrap();
814
815 let result_2d =
816 apply_dimensionality_reduction_2d(data.view(), DimensionalityReduction::PCA).unwrap();
817 assert_eq!(result_2d.ncols(), 2);
818
819 let result_3d =
820 apply_dimensionality_reduction_3d(data.view(), DimensionalityReduction::PCA).unwrap();
821 assert_eq!(result_3d.ncols(), 3);
822 }
823
824 #[test]
825 fn test_color_generation() {
826 let labels = vec![0, 1, 2];
827 let colors = generate_cluster_colors(&labels, ColorScheme::Default);
828
829 assert_eq!(colors.len(), 3);
830 assert!(colors.contains_key(&0));
831 assert!(colors.contains_key(&1));
832 assert!(colors.contains_key(&2));
833 }
834}