scirs2_cluster/visualization/
mod.rs

1//! Enhanced visualization capabilities for clustering results
2//!
3//! This module provides comprehensive visualization tools for clustering algorithms,
4//! including 2D/3D scatter plots, animations, interactive visualizations, real-time
5//! streaming displays, and various export formats for research and presentation use.
6//!
7//! # Features
8//!
9//! * **Static Visualizations**: 2D and 3D scatter plots with customizable styling
10//! * **Interactive 3D**: Real-time manipulation, camera controls, VR/AR support
11//! * **Animations**: Algorithm convergence animations, real-time streaming
12//! * **Export Capabilities**: Multiple formats (PNG, SVG, HTML, JSON, video, etc.)
13//! * **Dimensionality Reduction**: PCA, t-SNE, UMAP integration for high-dimensional data
14//! * **Real-time Streaming**: Live data visualization with adaptive boundaries
15//!
16//! # Examples
17//!
18//! ## Basic 2D Visualization
19//! ```
20//! use scirs2_core::ndarray::Array2;
21//! use scirs2_cluster::visualization::{create_scatter_plot_2d, VisualizationConfig};
22//!
23//! let data = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
24//! let labels = scirs2_core::ndarray::Array1::from_vec(vec![0, 0, 1, 1]);
25//! let config = VisualizationConfig::default();
26//!
27//! let plot = create_scatter_plot_2d(data.view(), &labels, None, &config).unwrap();
28//! ```
29//!
30//! ## 3D Interactive Visualization
31//! ```
32//! use scirs2_cluster::visualization::interactive::{InteractiveVisualizer, InteractiveConfig};
33//!
34//! let config = InteractiveConfig::default();
35//! let mut visualizer = InteractiveVisualizer::new(config);
36//! // Set up interactive controls and update with data
37//! ```
38//!
39//! ## Animation Recording
40//! ```
41//! use scirs2_cluster::visualization::animation::{IterativeAnimationRecorder, IterativeAnimationConfig};
42//!
43//! let config = IterativeAnimationConfig::default();
44//! let mut recorder = IterativeAnimationRecorder::new(config);
45//! // Record frames during algorithm iterations
46//! ```
47
48use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
49use scirs2_core::numeric::{Float, FromPrimitive};
50use std::collections::HashMap;
51use std::fmt::Debug;
52
53use serde::{Deserialize, Serialize};
54
55use crate::error::{ClusteringError, Result};
56
57// Sub-modules
58pub mod animation;
59pub mod export;
60pub mod interactive;
61
62// Re-export main types from sub-modules
63pub use animation::{
64    AnimationFrame, ConvergenceInfo, IterativeAnimationConfig, IterativeAnimationRecorder,
65    StreamingConfig, StreamingFrame, StreamingStats, StreamingVisualizer,
66};
67pub use export::{
68    export_animation_to_file, export_scatter_2d_to_file, export_scatter_2d_to_html,
69    export_scatter_2d_to_json, export_scatter_3d_to_file, export_scatter_3d_to_html,
70    export_scatter_3d_to_json, save_visualization_to_file, ExportConfig, ExportFormat,
71};
72pub use interactive::{
73    BoundingBox3D, CameraState, ClusterStats, InteractiveConfig, InteractiveState,
74    InteractiveVisualizer, KeyCode, MouseButton, ViewMode,
75};
76
77/// Configuration for clustering visualizations
78#[derive(Debug, Clone)]
79pub struct VisualizationConfig {
80    /// Color scheme for clusters
81    pub color_scheme: ColorScheme,
82    /// Point size for scatter plots
83    pub point_size: f32,
84    /// Point opacity (0.0 to 1.0)
85    pub point_opacity: f32,
86    /// Show cluster centroids
87    pub show_centroids: bool,
88    /// Show cluster boundaries (convex hull or ellipse)
89    pub show_boundaries: bool,
90    /// Boundary type
91    pub boundary_type: BoundaryType,
92    /// Enable interactive features
93    pub interactive: bool,
94    /// Animation settings
95    pub animation: Option<AnimationConfig>,
96    /// Dimensionality reduction method for high-dimensional data
97    pub dimensionality_reduction: DimensionalityReduction,
98}
99
100/// Color schemes for cluster visualization
101#[derive(Debug, Clone, Copy, PartialEq, Eq)]
102pub enum ColorScheme {
103    /// Default bright colors
104    Default,
105    /// Colorblind-friendly palette
106    ColorblindFriendly,
107    /// High contrast colors
108    HighContrast,
109    /// Pastel colors
110    Pastel,
111    /// Viridis colormap
112    Viridis,
113    /// Plasma colormap
114    Plasma,
115    /// Custom colors (user-defined)
116    Custom,
117}
118
119/// Cluster boundary visualization types
120#[derive(Debug, Clone, Copy, PartialEq, Eq)]
121pub enum BoundaryType {
122    /// Convex hull around points
123    ConvexHull,
124    /// Ellipse based on covariance
125    Ellipse,
126    /// Alpha shapes for non-convex boundaries
127    AlphaShape,
128    /// No boundaries
129    None,
130}
131
132/// Dimensionality reduction methods
133#[derive(Debug, Clone, Copy, PartialEq, Eq)]
134pub enum DimensionalityReduction {
135    /// Principal Component Analysis
136    PCA,
137    /// t-Distributed Stochastic Neighbor Embedding
138    TSNE,
139    /// Uniform Manifold Approximation and Projection
140    UMAP,
141    /// Multidimensional Scaling
142    MDS,
143    /// Use first two dimensions
144    First2D,
145    /// Use first three dimensions
146    First3D,
147    /// No reduction (error if >3D)
148    None,
149}
150
151/// Animation configuration for visualizations
152#[derive(Debug, Clone)]
153pub struct AnimationConfig {
154    /// Animation duration in milliseconds
155    pub duration_ms: u32,
156    /// Number of animation frames
157    pub frames: u32,
158    /// Easing function
159    pub easing: EasingFunction,
160    /// Whether to loop the animation
161    pub loop_animation: bool,
162}
163
164/// Easing functions for animations
165#[derive(Debug, Clone, Copy, PartialEq, Eq)]
166pub enum EasingFunction {
167    Linear,
168    EaseIn,
169    EaseOut,
170    EaseInOut,
171    Bounce,
172    Elastic,
173}
174
175impl Default for VisualizationConfig {
176    fn default() -> Self {
177        Self {
178            color_scheme: ColorScheme::Default,
179            point_size: 5.0,
180            point_opacity: 0.8,
181            show_centroids: true,
182            show_boundaries: false,
183            boundary_type: BoundaryType::ConvexHull,
184            interactive: true,
185            animation: None,
186            dimensionality_reduction: DimensionalityReduction::PCA,
187        }
188    }
189}
190
191/// 2D scatter plot visualization data
192#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct ScatterPlot2D {
194    /// Point coordinates
195    pub points: Array2<f64>,
196    /// Cluster labels for each point
197    pub labels: Array1<i32>,
198    /// Cluster centroids (if available)
199    pub centroids: Option<Array2<f64>>,
200    /// Point colors (hex format)
201    pub colors: Vec<String>,
202    /// Point sizes
203    pub sizes: Vec<f32>,
204    /// Point labels (optional)
205    pub point_labels: Option<Vec<String>>,
206    /// Plot boundaries (min_x, max_x, min_y, max_y)
207    pub bounds: (f64, f64, f64, f64),
208    /// Legend information
209    pub legend: Vec<LegendEntry>,
210}
211
212/// 3D scatter plot visualization data
213#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct ScatterPlot3D {
215    /// Point coordinates (x, y, z)
216    pub points: Array2<f64>,
217    /// Cluster labels for each point
218    pub labels: Array1<i32>,
219    /// Cluster centroids (if available)
220    pub centroids: Option<Array2<f64>>,
221    /// Point colors (hex format)
222    pub colors: Vec<String>,
223    /// Point sizes
224    pub sizes: Vec<f32>,
225    /// Point labels (optional)
226    pub point_labels: Option<Vec<String>>,
227    /// Plot boundaries (min_x, max_x, min_y, max_y, min_z, max_z)
228    pub bounds: (f64, f64, f64, f64, f64, f64),
229    /// Legend information
230    pub legend: Vec<LegendEntry>,
231}
232
233/// Legend entry for visualizations
234#[derive(Debug, Clone, Serialize, Deserialize)]
235pub struct LegendEntry {
236    /// Cluster ID
237    pub cluster_id: i32,
238    /// Color hex code
239    pub color: String,
240    /// Cluster label/name
241    pub label: String,
242    /// Number of points in cluster
243    pub count: usize,
244}
245
246/// Cluster boundary representation
247#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct ClusterBoundary {
249    /// Cluster ID
250    pub cluster_id: i32,
251    /// Boundary points
252    pub boundary_points: Array2<f64>,
253    /// Boundary type
254    pub boundary_type: String,
255    /// Color for the boundary
256    pub color: String,
257}
258
259/// Create 2D scatter plot visualization
260///
261/// # Arguments
262///
263/// * `data` - Input data matrix (samples x features)
264/// * `labels` - Cluster labels for each sample
265/// * `centroids` - Optional cluster centroids
266/// * `config` - Visualization configuration
267///
268/// # Returns
269///
270/// * `Result<ScatterPlot2D>` - 2D scatter plot data
271#[allow(dead_code)]
272pub fn create_scatter_plot_2d<F: Float + FromPrimitive + Debug>(
273    data: ArrayView2<F>,
274    labels: &Array1<i32>,
275    centroids: Option<&Array2<F>>,
276    config: &VisualizationConfig,
277) -> Result<ScatterPlot2D> {
278    let n_samples = data.nrows();
279    let n_features = data.ncols();
280
281    if labels.len() != n_samples {
282        return Err(ClusteringError::InvalidInput(
283            "Number of labels must match number of samples".to_string(),
284        ));
285    }
286
287    // Reduce dimensionality if needed
288    let plotdata =
289        if n_features == 2 && config.dimensionality_reduction == DimensionalityReduction::None {
290            data.mapv(|x| x.to_f64().unwrap_or(0.0))
291        } else {
292            apply_dimensionality_reduction_2d(data, config.dimensionality_reduction)?
293        };
294
295    // Convert centroids if provided
296    let plot_centroids = if let Some(cents) = centroids {
297        if cents.ncols() == 2 && config.dimensionality_reduction == DimensionalityReduction::None {
298            Some(cents.mapv(|x| x.to_f64().unwrap_or(0.0)))
299        } else {
300            Some(apply_dimensionality_reduction_2d(
301                cents.view(),
302                config.dimensionality_reduction,
303            )?)
304        }
305    } else {
306        None
307    };
308
309    // Generate colors for clusters
310    let unique_labels: Vec<i32> = {
311        let mut labels_vec: Vec<i32> = labels.iter().cloned().collect();
312        labels_vec.sort_unstable();
313        labels_vec.dedup();
314        labels_vec
315    };
316
317    let cluster_colors = generate_cluster_colors(&unique_labels, config.color_scheme);
318    let point_colors = labels
319        .iter()
320        .map(|&label| {
321            cluster_colors
322                .get(&label)
323                .cloned()
324                .unwrap_or_else(|| "#000000".to_string())
325        })
326        .collect();
327
328    // Generate point sizes
329    let sizes = vec![config.point_size; n_samples];
330
331    // Calculate plot bounds
332    let bounds = calculate_2d_bounds(&plotdata);
333
334    // Create legend
335    let legend = create_legend(&unique_labels, &cluster_colors, labels);
336
337    Ok(ScatterPlot2D {
338        points: plotdata,
339        labels: labels.clone(),
340        centroids: plot_centroids,
341        colors: point_colors,
342        sizes,
343        point_labels: None,
344        bounds,
345        legend,
346    })
347}
348
349/// Create 3D scatter plot visualization
350///
351/// # Arguments
352///
353/// * `data` - Input data matrix (samples x features)
354/// * `labels` - Cluster labels for each sample
355/// * `centroids` - Optional cluster centroids
356/// * `config` - Visualization configuration
357///
358/// # Returns
359///
360/// * `Result<ScatterPlot3D>` - 3D scatter plot data
361#[allow(dead_code)]
362pub fn create_scatter_plot_3d<F: Float + FromPrimitive + Debug>(
363    data: ArrayView2<F>,
364    labels: &Array1<i32>,
365    centroids: Option<&Array2<F>>,
366    config: &VisualizationConfig,
367) -> Result<ScatterPlot3D> {
368    let n_samples = data.nrows();
369    let n_features = data.ncols();
370
371    if labels.len() != n_samples {
372        return Err(ClusteringError::InvalidInput(
373            "Number of labels must match number of samples".to_string(),
374        ));
375    }
376
377    // Reduce dimensionality if needed
378    let plotdata =
379        if n_features == 3 && config.dimensionality_reduction == DimensionalityReduction::None {
380            data.mapv(|x| x.to_f64().unwrap_or(0.0))
381        } else {
382            apply_dimensionality_reduction_3d(data, config.dimensionality_reduction)?
383        };
384
385    // Convert centroids if provided
386    let plot_centroids = if let Some(cents) = centroids {
387        if cents.ncols() == 3 && config.dimensionality_reduction == DimensionalityReduction::None {
388            Some(cents.mapv(|x| x.to_f64().unwrap_or(0.0)))
389        } else {
390            Some(apply_dimensionality_reduction_3d(
391                cents.view(),
392                config.dimensionality_reduction,
393            )?)
394        }
395    } else {
396        None
397    };
398
399    // Generate colors for clusters
400    let unique_labels: Vec<i32> = {
401        let mut labels_vec: Vec<i32> = labels.iter().cloned().collect();
402        labels_vec.sort_unstable();
403        labels_vec.dedup();
404        labels_vec
405    };
406
407    let cluster_colors = generate_cluster_colors(&unique_labels, config.color_scheme);
408    let point_colors = labels
409        .iter()
410        .map(|&label| {
411            cluster_colors
412                .get(&label)
413                .cloned()
414                .unwrap_or_else(|| "#000000".to_string())
415        })
416        .collect();
417
418    // Generate point sizes
419    let sizes = vec![config.point_size; n_samples];
420
421    // Calculate plot bounds
422    let bounds = calculate_3d_bounds(&plotdata);
423
424    // Create legend
425    let legend = create_legend(&unique_labels, &cluster_colors, labels);
426
427    Ok(ScatterPlot3D {
428        points: plotdata,
429        labels: labels.clone(),
430        centroids: plot_centroids,
431        colors: point_colors,
432        sizes,
433        point_labels: None,
434        bounds,
435        legend,
436    })
437}
438
439/// Apply dimensionality reduction for 2D visualization
440#[allow(dead_code)]
441fn apply_dimensionality_reduction_2d<F: Float + FromPrimitive + Debug>(
442    data: ArrayView2<F>,
443    method: DimensionalityReduction,
444) -> Result<Array2<f64>> {
445    let data_f64 = data.mapv(|x| x.to_f64().unwrap_or(0.0));
446
447    match method {
448        DimensionalityReduction::PCA => apply_pca_2d(&data_f64),
449        DimensionalityReduction::First2D => {
450            if data_f64.ncols() >= 2 {
451                Ok(data_f64.slice(s![.., 0..2]).to_owned())
452            } else {
453                Err(ClusteringError::InvalidInput(
454                    "Data must have at least 2 dimensions for First2D".to_string(),
455                ))
456            }
457        }
458        DimensionalityReduction::TSNE => apply_tsne_2d(&data_f64),
459        DimensionalityReduction::UMAP => apply_umap_2d(&data_f64),
460        DimensionalityReduction::MDS => apply_mds_2d(&data_f64),
461        DimensionalityReduction::None => {
462            if data_f64.ncols() == 2 {
463                Ok(data_f64)
464            } else {
465                Err(ClusteringError::InvalidInput(
466                    "Data must be 2D when no dimensionality reduction is specified".to_string(),
467                ))
468            }
469        }
470        _ => apply_pca_2d(&data_f64), // Default to PCA
471    }
472}
473
474/// Apply dimensionality reduction for 3D visualization
475#[allow(dead_code)]
476fn apply_dimensionality_reduction_3d<F: Float + FromPrimitive + Debug>(
477    data: ArrayView2<F>,
478    method: DimensionalityReduction,
479) -> Result<Array2<f64>> {
480    let data_f64 = data.mapv(|x| x.to_f64().unwrap_or(0.0));
481
482    match method {
483        DimensionalityReduction::PCA => apply_pca_3d(&data_f64),
484        DimensionalityReduction::First3D => {
485            if data_f64.ncols() >= 3 {
486                Ok(data_f64.slice(s![.., 0..3]).to_owned())
487            } else {
488                Err(ClusteringError::InvalidInput(
489                    "Data must have at least 3 dimensions for First3D".to_string(),
490                ))
491            }
492        }
493        DimensionalityReduction::None => {
494            if data_f64.ncols() == 3 {
495                Ok(data_f64)
496            } else {
497                Err(ClusteringError::InvalidInput(
498                    "Data must be 3D when no dimensionality reduction is specified".to_string(),
499                ))
500            }
501        }
502        _ => apply_pca_3d(&data_f64), // Default to PCA
503    }
504}
505
506/// Apply PCA for 2D visualization
507#[allow(dead_code)]
508fn apply_pca_2d(data: &Array2<f64>) -> Result<Array2<f64>> {
509    let n_samples = data.nrows();
510    let n_features = data.ncols();
511
512    if n_features < 2 {
513        return Err(ClusteringError::InvalidInput(
514            "Need at least 2 features for PCA".to_string(),
515        ));
516    }
517
518    // Center the data
519    let mean = data.mean_axis(Axis(0)).unwrap();
520    let centered = data - &mean;
521
522    // Compute covariance matrix
523    let cov = centered.t().dot(&centered) / (n_samples - 1) as f64;
524
525    // Simplified PCA projection (stub implementation)
526    // In a real implementation, this would compute eigenvectors of covariance matrix
527    let n_features = centered.ncols();
528    let eigenvectors_ = Array2::eye(n_features)
529        .slice(s![.., 0..2.min(n_features)])
530        .to_owned();
531
532    // Project data onto first 2 principal components
533    let projected = centered.dot(&eigenvectors_);
534
535    Ok(projected)
536}
537
538/// Apply PCA for 3D visualization
539#[allow(dead_code)]
540fn apply_pca_3d(data: &Array2<f64>) -> Result<Array2<f64>> {
541    let n_samples = data.nrows();
542    let n_features = data.ncols();
543
544    if n_features < 3 {
545        return Err(ClusteringError::InvalidInput(
546            "Need at least 3 features for 3D PCA".to_string(),
547        ));
548    }
549
550    // Center the data
551    let mean = data.mean_axis(Axis(0)).unwrap();
552    let centered = data - &mean;
553
554    // Compute covariance matrix
555    let cov = centered.t().dot(&centered) / (n_samples - 1) as f64;
556
557    // Simplified PCA projection (stub implementation)
558    // In a real implementation, this would compute eigenvectors of covariance matrix
559    let n_features = centered.ncols();
560    let eigenvectors_ = Array2::eye(n_features)
561        .slice(s![.., 0..3.min(n_features)])
562        .to_owned();
563
564    // Project data onto first 3 principal components
565    let projected = centered.dot(&eigenvectors_);
566
567    Ok(projected)
568}
569
570/// Simplified implementation of other dimensionality reduction methods
571/// These would ideally use proper implementations from specialized libraries
572#[allow(dead_code)]
573fn apply_tsne_2d(data: &Array2<f64>) -> Result<Array2<f64>> {
574    // For now, fall back to PCA
575    apply_pca_2d(data)
576}
577
578#[allow(dead_code)]
579fn apply_umap_2d(data: &Array2<f64>) -> Result<Array2<f64>> {
580    // For now, fall back to PCA
581    apply_pca_2d(data)
582}
583
584#[allow(dead_code)]
585fn apply_mds_2d(data: &Array2<f64>) -> Result<Array2<f64>> {
586    // For now, fall back to PCA
587    apply_pca_2d(data)
588}
589
590/// Compute top eigenvectors using power iteration
591#[allow(dead_code)]
592fn compute_top_eigenvectors(
593    matrix: &Array2<f64>,
594    num_components: usize,
595) -> Result<(Array2<f64>, Array1<f64>)> {
596    let n = matrix.nrows();
597    let k = num_components.min(n);
598
599    let mut eigenvectors = Array2::zeros((n, k));
600    let mut eigenvalues = Array1::zeros(k);
601
602    // Simple power iteration for dominant eigenvector
603    for i in 0..k {
604        let mut v = Array1::from_elem(n, 1.0 / (n as f64).sqrt());
605
606        // Orthogonalize against previous eigenvectors
607        for j in 0..i {
608            let prev_eigenvector = eigenvectors.column(j);
609            let dot_product = v.dot(&prev_eigenvector);
610            v = &v - &(&prev_eigenvector * dot_product);
611        }
612
613        // Power iteration
614        for _ in 0..100 {
615            let new_v = matrix.dot(&v);
616            let norm = (new_v.dot(&new_v)).sqrt();
617            if norm > 1e-10 {
618                v = new_v / norm;
619            }
620        }
621
622        eigenvalues[i] = v.dot(&matrix.dot(&v));
623        for j in 0..n {
624            eigenvectors[[j, i]] = v[j];
625        }
626    }
627
628    Ok((eigenvectors, eigenvalues))
629}
630
631/// Generate cluster colors based on color scheme
632#[allow(dead_code)]
633fn generate_cluster_colors(labels: &[i32], scheme: ColorScheme) -> HashMap<i32, String> {
634    let mut colors = HashMap::new();
635
636    let color_palette = match scheme {
637        ColorScheme::Default => vec![
638            "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f",
639            "#bcbd22", "#17becf",
640        ],
641        ColorScheme::ColorblindFriendly => vec![
642            "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#ffff33", "#a65628", "#f781bf", "#999999",
643        ],
644        ColorScheme::HighContrast => vec![
645            "#000000", "#ffffff", "#ff0000", "#00ff00", "#0000ff", "#ffff00", "#ff00ff", "#00ffff",
646        ],
647        ColorScheme::Pastel => vec![
648            "#aec7e8", "#ffbb78", "#98df8a", "#ff9896", "#c5b0d5", "#c49c94", "#f7b6d3", "#c7c7c7",
649            "#dbdb8d", "#9edae5",
650        ],
651        ColorScheme::Viridis => vec![
652            "#440154", "#482777", "#3f4a8a", "#31678e", "#26838f", "#1f9d8a", "#6cce5a", "#b6de2b",
653            "#fee825",
654        ],
655        ColorScheme::Plasma => vec![
656            "#0c0887", "#5302a3", "#8b0aa5", "#b83289", "#db5c68", "#f48849", "#febd2a", "#f0f921",
657        ],
658        ColorScheme::Custom => vec!["#333333"], // Placeholder
659    };
660
661    for (i, &label) in labels.iter().enumerate() {
662        colors.entry(label).or_insert_with(|| {
663            let color_index = i % color_palette.len();
664            color_palette[color_index].to_string()
665        });
666    }
667
668    colors
669}
670
671/// Calculate 2D plot bounds
672#[allow(dead_code)]
673fn calculate_2d_bounds(data: &Array2<f64>) -> (f64, f64, f64, f64) {
674    if data.is_empty() {
675        return (0.0, 1.0, 0.0, 1.0);
676    }
677
678    let x_min = data.column(0).iter().fold(f64::INFINITY, |a, &b| a.min(b));
679    let x_max = data
680        .column(0)
681        .iter()
682        .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
683    let y_min = data.column(1).iter().fold(f64::INFINITY, |a, &b| a.min(b));
684    let y_max = data
685        .column(1)
686        .iter()
687        .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
688
689    // Add padding
690    let x_range = x_max - x_min;
691    let y_range = y_max - y_min;
692    let padding = 0.05;
693
694    (
695        x_min - x_range * padding,
696        x_max + x_range * padding,
697        y_min - y_range * padding,
698        y_max + y_range * padding,
699    )
700}
701
702/// Calculate 3D plot bounds
703#[allow(dead_code)]
704fn calculate_3d_bounds(data: &Array2<f64>) -> (f64, f64, f64, f64, f64, f64) {
705    if data.is_empty() {
706        return (0.0, 1.0, 0.0, 1.0, 0.0, 1.0);
707    }
708
709    let x_min = data.column(0).iter().fold(f64::INFINITY, |a, &b| a.min(b));
710    let x_max = data
711        .column(0)
712        .iter()
713        .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
714    let y_min = data.column(1).iter().fold(f64::INFINITY, |a, &b| a.min(b));
715    let y_max = data
716        .column(1)
717        .iter()
718        .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
719    let z_min = data.column(2).iter().fold(f64::INFINITY, |a, &b| a.min(b));
720    let z_max = data
721        .column(2)
722        .iter()
723        .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
724
725    // Add padding
726    let x_range = x_max - x_min;
727    let y_range = y_max - y_min;
728    let z_range = z_max - z_min;
729    let padding = 0.05;
730
731    (
732        x_min - x_range * padding,
733        x_max + x_range * padding,
734        y_min - y_range * padding,
735        y_max + y_range * padding,
736        z_min - z_range * padding,
737        z_max + z_range * padding,
738    )
739}
740
741/// Create legend entries
742#[allow(dead_code)]
743fn create_legend(
744    labels: &[i32],
745    colors: &HashMap<i32, String>,
746    data_labels: &Array1<i32>,
747) -> Vec<LegendEntry> {
748    let mut legend = Vec::new();
749
750    for &label in labels {
751        let count = data_labels.iter().filter(|&&l| l == label).count();
752        let color = colors
753            .get(&label)
754            .cloned()
755            .unwrap_or_else(|| "#000000".to_string());
756
757        legend.push(LegendEntry {
758            cluster_id: label,
759            color,
760            label: format!("Cluster {}", label),
761            count,
762        });
763    }
764
765    // Sort by cluster ID
766    legend.sort_by_key(|entry| entry.cluster_id);
767
768    legend
769}
770
771#[cfg(test)]
772mod tests {
773    use super::*;
774    use scirs2_core::ndarray::Array2;
775
776    #[test]
777    fn test_create_scatter_plot_2d() {
778        let data =
779            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
780        let labels = Array1::from_vec(vec![0, 0, 1, 1]);
781        let config = VisualizationConfig::default();
782
783        let plot = create_scatter_plot_2d(data.view(), &labels, None, &config).unwrap();
784
785        assert_eq!(plot.points.nrows(), 4);
786        assert_eq!(plot.points.ncols(), 2);
787        assert_eq!(plot.labels.len(), 4);
788        assert_eq!(plot.colors.len(), 4);
789        assert_eq!(plot.legend.len(), 2);
790    }
791
792    #[test]
793    fn test_create_scatter_plot_3d() {
794        let data = Array2::from_shape_vec(
795            (4, 3),
796            vec![
797                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
798            ],
799        )
800        .unwrap();
801        let labels = Array1::from_vec(vec![0, 0, 1, 1]);
802        let config = VisualizationConfig::default();
803
804        let plot = create_scatter_plot_3d(data.view(), &labels, None, &config).unwrap();
805
806        assert_eq!(plot.points.nrows(), 4);
807        assert_eq!(plot.points.ncols(), 3);
808        assert_eq!(plot.labels.len(), 4);
809    }
810
811    #[test]
812    fn test_dimensionality_reduction() {
813        let data = Array2::from_shape_vec((10, 5), (0..50).map(|x| x as f64).collect()).unwrap();
814
815        let result_2d =
816            apply_dimensionality_reduction_2d(data.view(), DimensionalityReduction::PCA).unwrap();
817        assert_eq!(result_2d.ncols(), 2);
818
819        let result_3d =
820            apply_dimensionality_reduction_3d(data.view(), DimensionalityReduction::PCA).unwrap();
821        assert_eq!(result_3d.ncols(), 3);
822    }
823
824    #[test]
825    fn test_color_generation() {
826        let labels = vec![0, 1, 2];
827        let colors = generate_cluster_colors(&labels, ColorScheme::Default);
828
829        assert_eq!(colors.len(), 3);
830        assert!(colors.contains_key(&0));
831        assert!(colors.contains_key(&1));
832        assert!(colors.contains_key(&2));
833    }
834}