1use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
49use scirs2_core::numeric::{Float, FromPrimitive};
50use std::collections::HashMap;
51use std::fmt::Debug;
52
53use serde::{Deserialize, Serialize};
54
55use crate::error::{ClusteringError, Result};
56
57pub mod animation;
59pub mod export;
60pub mod interactive;
61
62pub use animation::{
64    AnimationFrame, ConvergenceInfo, IterativeAnimationConfig, IterativeAnimationRecorder,
65    StreamingConfig, StreamingFrame, StreamingStats, StreamingVisualizer,
66};
67pub use export::{
68    export_animation_to_file, export_scatter_2d_to_file, export_scatter_2d_to_html,
69    export_scatter_2d_to_json, export_scatter_3d_to_file, export_scatter_3d_to_html,
70    export_scatter_3d_to_json, save_visualization_to_file, ExportConfig, ExportFormat,
71};
72pub use interactive::{
73    BoundingBox3D, CameraState, ClusterStats, InteractiveConfig, InteractiveState,
74    InteractiveVisualizer, KeyCode, MouseButton, ViewMode,
75};
76
77#[derive(Debug, Clone)]
79pub struct VisualizationConfig {
80    pub color_scheme: ColorScheme,
82    pub point_size: f32,
84    pub point_opacity: f32,
86    pub show_centroids: bool,
88    pub show_boundaries: bool,
90    pub boundary_type: BoundaryType,
92    pub interactive: bool,
94    pub animation: Option<AnimationConfig>,
96    pub dimensionality_reduction: DimensionalityReduction,
98}
99
100#[derive(Debug, Clone, Copy, PartialEq, Eq)]
102pub enum ColorScheme {
103    Default,
105    ColorblindFriendly,
107    HighContrast,
109    Pastel,
111    Viridis,
113    Plasma,
115    Custom,
117}
118
119#[derive(Debug, Clone, Copy, PartialEq, Eq)]
121pub enum BoundaryType {
122    ConvexHull,
124    Ellipse,
126    AlphaShape,
128    None,
130}
131
132#[derive(Debug, Clone, Copy, PartialEq, Eq)]
134pub enum DimensionalityReduction {
135    PCA,
137    TSNE,
139    UMAP,
141    MDS,
143    First2D,
145    First3D,
147    None,
149}
150
151#[derive(Debug, Clone)]
153pub struct AnimationConfig {
154    pub duration_ms: u32,
156    pub frames: u32,
158    pub easing: EasingFunction,
160    pub loop_animation: bool,
162}
163
164#[derive(Debug, Clone, Copy, PartialEq, Eq)]
166pub enum EasingFunction {
167    Linear,
168    EaseIn,
169    EaseOut,
170    EaseInOut,
171    Bounce,
172    Elastic,
173}
174
175impl Default for VisualizationConfig {
176    fn default() -> Self {
177        Self {
178            color_scheme: ColorScheme::Default,
179            point_size: 5.0,
180            point_opacity: 0.8,
181            show_centroids: true,
182            show_boundaries: false,
183            boundary_type: BoundaryType::ConvexHull,
184            interactive: true,
185            animation: None,
186            dimensionality_reduction: DimensionalityReduction::PCA,
187        }
188    }
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct ScatterPlot2D {
194    pub points: Array2<f64>,
196    pub labels: Array1<i32>,
198    pub centroids: Option<Array2<f64>>,
200    pub colors: Vec<String>,
202    pub sizes: Vec<f32>,
204    pub point_labels: Option<Vec<String>>,
206    pub bounds: (f64, f64, f64, f64),
208    pub legend: Vec<LegendEntry>,
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct ScatterPlot3D {
215    pub points: Array2<f64>,
217    pub labels: Array1<i32>,
219    pub centroids: Option<Array2<f64>>,
221    pub colors: Vec<String>,
223    pub sizes: Vec<f32>,
225    pub point_labels: Option<Vec<String>>,
227    pub bounds: (f64, f64, f64, f64, f64, f64),
229    pub legend: Vec<LegendEntry>,
231}
232
233#[derive(Debug, Clone, Serialize, Deserialize)]
235pub struct LegendEntry {
236    pub cluster_id: i32,
238    pub color: String,
240    pub label: String,
242    pub count: usize,
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct ClusterBoundary {
249    pub cluster_id: i32,
251    pub boundary_points: Array2<f64>,
253    pub boundary_type: String,
255    pub color: String,
257}
258
259#[allow(dead_code)]
272pub fn create_scatter_plot_2d<F: Float + FromPrimitive + Debug>(
273    data: ArrayView2<F>,
274    labels: &Array1<i32>,
275    centroids: Option<&Array2<F>>,
276    config: &VisualizationConfig,
277) -> Result<ScatterPlot2D> {
278    let n_samples = data.nrows();
279    let n_features = data.ncols();
280
281    if labels.len() != n_samples {
282        return Err(ClusteringError::InvalidInput(
283            "Number of labels must match number of samples".to_string(),
284        ));
285    }
286
287    let plotdata =
289        if n_features == 2 && config.dimensionality_reduction == DimensionalityReduction::None {
290            data.mapv(|x| x.to_f64().unwrap_or(0.0))
291        } else {
292            apply_dimensionality_reduction_2d(data, config.dimensionality_reduction)?
293        };
294
295    let plot_centroids = if let Some(cents) = centroids {
297        if cents.ncols() == 2 && config.dimensionality_reduction == DimensionalityReduction::None {
298            Some(cents.mapv(|x| x.to_f64().unwrap_or(0.0)))
299        } else {
300            Some(apply_dimensionality_reduction_2d(
301                cents.view(),
302                config.dimensionality_reduction,
303            )?)
304        }
305    } else {
306        None
307    };
308
309    let unique_labels: Vec<i32> = {
311        let mut labels_vec: Vec<i32> = labels.iter().cloned().collect();
312        labels_vec.sort_unstable();
313        labels_vec.dedup();
314        labels_vec
315    };
316
317    let cluster_colors = generate_cluster_colors(&unique_labels, config.color_scheme);
318    let point_colors = labels
319        .iter()
320        .map(|&label| {
321            cluster_colors
322                .get(&label)
323                .cloned()
324                .unwrap_or_else(|| "#000000".to_string())
325        })
326        .collect();
327
328    let sizes = vec![config.point_size; n_samples];
330
331    let bounds = calculate_2d_bounds(&plotdata);
333
334    let legend = create_legend(&unique_labels, &cluster_colors, labels);
336
337    Ok(ScatterPlot2D {
338        points: plotdata,
339        labels: labels.clone(),
340        centroids: plot_centroids,
341        colors: point_colors,
342        sizes,
343        point_labels: None,
344        bounds,
345        legend,
346    })
347}
348
349#[allow(dead_code)]
362pub fn create_scatter_plot_3d<F: Float + FromPrimitive + Debug>(
363    data: ArrayView2<F>,
364    labels: &Array1<i32>,
365    centroids: Option<&Array2<F>>,
366    config: &VisualizationConfig,
367) -> Result<ScatterPlot3D> {
368    let n_samples = data.nrows();
369    let n_features = data.ncols();
370
371    if labels.len() != n_samples {
372        return Err(ClusteringError::InvalidInput(
373            "Number of labels must match number of samples".to_string(),
374        ));
375    }
376
377    let plotdata =
379        if n_features == 3 && config.dimensionality_reduction == DimensionalityReduction::None {
380            data.mapv(|x| x.to_f64().unwrap_or(0.0))
381        } else {
382            apply_dimensionality_reduction_3d(data, config.dimensionality_reduction)?
383        };
384
385    let plot_centroids = if let Some(cents) = centroids {
387        if cents.ncols() == 3 && config.dimensionality_reduction == DimensionalityReduction::None {
388            Some(cents.mapv(|x| x.to_f64().unwrap_or(0.0)))
389        } else {
390            Some(apply_dimensionality_reduction_3d(
391                cents.view(),
392                config.dimensionality_reduction,
393            )?)
394        }
395    } else {
396        None
397    };
398
399    let unique_labels: Vec<i32> = {
401        let mut labels_vec: Vec<i32> = labels.iter().cloned().collect();
402        labels_vec.sort_unstable();
403        labels_vec.dedup();
404        labels_vec
405    };
406
407    let cluster_colors = generate_cluster_colors(&unique_labels, config.color_scheme);
408    let point_colors = labels
409        .iter()
410        .map(|&label| {
411            cluster_colors
412                .get(&label)
413                .cloned()
414                .unwrap_or_else(|| "#000000".to_string())
415        })
416        .collect();
417
418    let sizes = vec![config.point_size; n_samples];
420
421    let bounds = calculate_3d_bounds(&plotdata);
423
424    let legend = create_legend(&unique_labels, &cluster_colors, labels);
426
427    Ok(ScatterPlot3D {
428        points: plotdata,
429        labels: labels.clone(),
430        centroids: plot_centroids,
431        colors: point_colors,
432        sizes,
433        point_labels: None,
434        bounds,
435        legend,
436    })
437}
438
439#[allow(dead_code)]
441fn apply_dimensionality_reduction_2d<F: Float + FromPrimitive + Debug>(
442    data: ArrayView2<F>,
443    method: DimensionalityReduction,
444) -> Result<Array2<f64>> {
445    let data_f64 = data.mapv(|x| x.to_f64().unwrap_or(0.0));
446
447    match method {
448        DimensionalityReduction::PCA => apply_pca_2d(&data_f64),
449        DimensionalityReduction::First2D => {
450            if data_f64.ncols() >= 2 {
451                Ok(data_f64.slice(s![.., 0..2]).to_owned())
452            } else {
453                Err(ClusteringError::InvalidInput(
454                    "Data must have at least 2 dimensions for First2D".to_string(),
455                ))
456            }
457        }
458        DimensionalityReduction::TSNE => apply_tsne_2d(&data_f64),
459        DimensionalityReduction::UMAP => apply_umap_2d(&data_f64),
460        DimensionalityReduction::MDS => apply_mds_2d(&data_f64),
461        DimensionalityReduction::None => {
462            if data_f64.ncols() == 2 {
463                Ok(data_f64)
464            } else {
465                Err(ClusteringError::InvalidInput(
466                    "Data must be 2D when no dimensionality reduction is specified".to_string(),
467                ))
468            }
469        }
470        _ => apply_pca_2d(&data_f64), }
472}
473
474#[allow(dead_code)]
476fn apply_dimensionality_reduction_3d<F: Float + FromPrimitive + Debug>(
477    data: ArrayView2<F>,
478    method: DimensionalityReduction,
479) -> Result<Array2<f64>> {
480    let data_f64 = data.mapv(|x| x.to_f64().unwrap_or(0.0));
481
482    match method {
483        DimensionalityReduction::PCA => apply_pca_3d(&data_f64),
484        DimensionalityReduction::First3D => {
485            if data_f64.ncols() >= 3 {
486                Ok(data_f64.slice(s![.., 0..3]).to_owned())
487            } else {
488                Err(ClusteringError::InvalidInput(
489                    "Data must have at least 3 dimensions for First3D".to_string(),
490                ))
491            }
492        }
493        DimensionalityReduction::None => {
494            if data_f64.ncols() == 3 {
495                Ok(data_f64)
496            } else {
497                Err(ClusteringError::InvalidInput(
498                    "Data must be 3D when no dimensionality reduction is specified".to_string(),
499                ))
500            }
501        }
502        _ => apply_pca_3d(&data_f64), }
504}
505
506#[allow(dead_code)]
508fn apply_pca_2d(data: &Array2<f64>) -> Result<Array2<f64>> {
509    let n_samples = data.nrows();
510    let n_features = data.ncols();
511
512    if n_features < 2 {
513        return Err(ClusteringError::InvalidInput(
514            "Need at least 2 features for PCA".to_string(),
515        ));
516    }
517
518    let mean = data.mean_axis(Axis(0)).unwrap();
520    let centered = data - &mean;
521
522    let cov = centered.t().dot(¢ered) / (n_samples - 1) as f64;
524
525    let n_features = centered.ncols();
528    let eigenvectors_ = Array2::eye(n_features)
529        .slice(s![.., 0..2.min(n_features)])
530        .to_owned();
531
532    let projected = centered.dot(&eigenvectors_);
534
535    Ok(projected)
536}
537
538#[allow(dead_code)]
540fn apply_pca_3d(data: &Array2<f64>) -> Result<Array2<f64>> {
541    let n_samples = data.nrows();
542    let n_features = data.ncols();
543
544    if n_features < 3 {
545        return Err(ClusteringError::InvalidInput(
546            "Need at least 3 features for 3D PCA".to_string(),
547        ));
548    }
549
550    let mean = data.mean_axis(Axis(0)).unwrap();
552    let centered = data - &mean;
553
554    let cov = centered.t().dot(¢ered) / (n_samples - 1) as f64;
556
557    let n_features = centered.ncols();
560    let eigenvectors_ = Array2::eye(n_features)
561        .slice(s![.., 0..3.min(n_features)])
562        .to_owned();
563
564    let projected = centered.dot(&eigenvectors_);
566
567    Ok(projected)
568}
569
570#[allow(dead_code)]
573fn apply_tsne_2d(data: &Array2<f64>) -> Result<Array2<f64>> {
574    apply_pca_2d(data)
576}
577
578#[allow(dead_code)]
579fn apply_umap_2d(data: &Array2<f64>) -> Result<Array2<f64>> {
580    apply_pca_2d(data)
582}
583
584#[allow(dead_code)]
585fn apply_mds_2d(data: &Array2<f64>) -> Result<Array2<f64>> {
586    apply_pca_2d(data)
588}
589
590#[allow(dead_code)]
592fn compute_top_eigenvectors(
593    matrix: &Array2<f64>,
594    num_components: usize,
595) -> Result<(Array2<f64>, Array1<f64>)> {
596    let n = matrix.nrows();
597    let k = num_components.min(n);
598
599    let mut eigenvectors = Array2::zeros((n, k));
600    let mut eigenvalues = Array1::zeros(k);
601
602    for i in 0..k {
604        let mut v = Array1::from_elem(n, 1.0 / (n as f64).sqrt());
605
606        for j in 0..i {
608            let prev_eigenvector = eigenvectors.column(j);
609            let dot_product = v.dot(&prev_eigenvector);
610            v = &v - &(&prev_eigenvector * dot_product);
611        }
612
613        for _ in 0..100 {
615            let new_v = matrix.dot(&v);
616            let norm = (new_v.dot(&new_v)).sqrt();
617            if norm > 1e-10 {
618                v = new_v / norm;
619            }
620        }
621
622        eigenvalues[i] = v.dot(&matrix.dot(&v));
623        for j in 0..n {
624            eigenvectors[[j, i]] = v[j];
625        }
626    }
627
628    Ok((eigenvectors, eigenvalues))
629}
630
631#[allow(dead_code)]
633fn generate_cluster_colors(labels: &[i32], scheme: ColorScheme) -> HashMap<i32, String> {
634    let mut colors = HashMap::new();
635
636    let color_palette = match scheme {
637        ColorScheme::Default => vec![
638            "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f",
639            "#bcbd22", "#17becf",
640        ],
641        ColorScheme::ColorblindFriendly => vec![
642            "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#ffff33", "#a65628", "#f781bf", "#999999",
643        ],
644        ColorScheme::HighContrast => vec![
645            "#000000", "#ffffff", "#ff0000", "#00ff00", "#0000ff", "#ffff00", "#ff00ff", "#00ffff",
646        ],
647        ColorScheme::Pastel => vec![
648            "#aec7e8", "#ffbb78", "#98df8a", "#ff9896", "#c5b0d5", "#c49c94", "#f7b6d3", "#c7c7c7",
649            "#dbdb8d", "#9edae5",
650        ],
651        ColorScheme::Viridis => vec![
652            "#440154", "#482777", "#3f4a8a", "#31678e", "#26838f", "#1f9d8a", "#6cce5a", "#b6de2b",
653            "#fee825",
654        ],
655        ColorScheme::Plasma => vec![
656            "#0c0887", "#5302a3", "#8b0aa5", "#b83289", "#db5c68", "#f48849", "#febd2a", "#f0f921",
657        ],
658        ColorScheme::Custom => vec!["#333333"], };
660
661    for (i, &label) in labels.iter().enumerate() {
662        colors.entry(label).or_insert_with(|| {
663            let color_index = i % color_palette.len();
664            color_palette[color_index].to_string()
665        });
666    }
667
668    colors
669}
670
671#[allow(dead_code)]
673fn calculate_2d_bounds(data: &Array2<f64>) -> (f64, f64, f64, f64) {
674    if data.is_empty() {
675        return (0.0, 1.0, 0.0, 1.0);
676    }
677
678    let x_min = data.column(0).iter().fold(f64::INFINITY, |a, &b| a.min(b));
679    let x_max = data
680        .column(0)
681        .iter()
682        .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
683    let y_min = data.column(1).iter().fold(f64::INFINITY, |a, &b| a.min(b));
684    let y_max = data
685        .column(1)
686        .iter()
687        .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
688
689    let x_range = x_max - x_min;
691    let y_range = y_max - y_min;
692    let padding = 0.05;
693
694    (
695        x_min - x_range * padding,
696        x_max + x_range * padding,
697        y_min - y_range * padding,
698        y_max + y_range * padding,
699    )
700}
701
702#[allow(dead_code)]
704fn calculate_3d_bounds(data: &Array2<f64>) -> (f64, f64, f64, f64, f64, f64) {
705    if data.is_empty() {
706        return (0.0, 1.0, 0.0, 1.0, 0.0, 1.0);
707    }
708
709    let x_min = data.column(0).iter().fold(f64::INFINITY, |a, &b| a.min(b));
710    let x_max = data
711        .column(0)
712        .iter()
713        .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
714    let y_min = data.column(1).iter().fold(f64::INFINITY, |a, &b| a.min(b));
715    let y_max = data
716        .column(1)
717        .iter()
718        .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
719    let z_min = data.column(2).iter().fold(f64::INFINITY, |a, &b| a.min(b));
720    let z_max = data
721        .column(2)
722        .iter()
723        .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
724
725    let x_range = x_max - x_min;
727    let y_range = y_max - y_min;
728    let z_range = z_max - z_min;
729    let padding = 0.05;
730
731    (
732        x_min - x_range * padding,
733        x_max + x_range * padding,
734        y_min - y_range * padding,
735        y_max + y_range * padding,
736        z_min - z_range * padding,
737        z_max + z_range * padding,
738    )
739}
740
741#[allow(dead_code)]
743fn create_legend(
744    labels: &[i32],
745    colors: &HashMap<i32, String>,
746    data_labels: &Array1<i32>,
747) -> Vec<LegendEntry> {
748    let mut legend = Vec::new();
749
750    for &label in labels {
751        let count = data_labels.iter().filter(|&&l| l == label).count();
752        let color = colors
753            .get(&label)
754            .cloned()
755            .unwrap_or_else(|| "#000000".to_string());
756
757        legend.push(LegendEntry {
758            cluster_id: label,
759            color,
760            label: format!("Cluster {}", label),
761            count,
762        });
763    }
764
765    legend.sort_by_key(|entry| entry.cluster_id);
767
768    legend
769}
770
771#[cfg(test)]
772mod tests {
773    use super::*;
774    use scirs2_core::ndarray::Array2;
775
776    #[test]
777    fn test_create_scatter_plot_2d() {
778        let data =
779            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
780        let labels = Array1::from_vec(vec![0, 0, 1, 1]);
781        let config = VisualizationConfig::default();
782
783        let plot = create_scatter_plot_2d(data.view(), &labels, None, &config).unwrap();
784
785        assert_eq!(plot.points.nrows(), 4);
786        assert_eq!(plot.points.ncols(), 2);
787        assert_eq!(plot.labels.len(), 4);
788        assert_eq!(plot.colors.len(), 4);
789        assert_eq!(plot.legend.len(), 2);
790    }
791
792    #[test]
793    fn test_create_scatter_plot_3d() {
794        let data = Array2::from_shape_vec(
795            (4, 3),
796            vec![
797                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
798            ],
799        )
800        .unwrap();
801        let labels = Array1::from_vec(vec![0, 0, 1, 1]);
802        let config = VisualizationConfig::default();
803
804        let plot = create_scatter_plot_3d(data.view(), &labels, None, &config).unwrap();
805
806        assert_eq!(plot.points.nrows(), 4);
807        assert_eq!(plot.points.ncols(), 3);
808        assert_eq!(plot.labels.len(), 4);
809    }
810
811    #[test]
812    fn test_dimensionality_reduction() {
813        let data = Array2::from_shape_vec((10, 5), (0..50).map(|x| x as f64).collect()).unwrap();
814
815        let result_2d =
816            apply_dimensionality_reduction_2d(data.view(), DimensionalityReduction::PCA).unwrap();
817        assert_eq!(result_2d.ncols(), 2);
818
819        let result_3d =
820            apply_dimensionality_reduction_3d(data.view(), DimensionalityReduction::PCA).unwrap();
821        assert_eq!(result_3d.ncols(), 3);
822    }
823
824    #[test]
825    fn test_color_generation() {
826        let labels = vec![0, 1, 2];
827        let colors = generate_cluster_colors(&labels, ColorScheme::Default);
828
829        assert_eq!(colors.len(), 3);
830        assert!(colors.contains_key(&0));
831        assert!(colors.contains_key(&1));
832        assert!(colors.contains_key(&2));
833    }
834}