Skip to main content

scirs2_transform/reduction/
umap.rs

1//! Uniform Manifold Approximation and Projection (UMAP) for dimensionality reduction
2//!
3//! UMAP is a non-linear dimensionality reduction technique that can be used for
4//! visualization similarly to t-SNE, but also for general non-linear dimension reduction.
5//!
6//! ## Algorithm Overview
7//!
8//! 1. **k-NN graph construction**: Find k nearest neighbors for each point
9//! 2. **Fuzzy simplicial set**: Compute membership strengths using smooth kNN distance
10//! 3. **Spectral initialization**: Initialize embedding using Laplacian eigenvectors
11//! 4. **SGD optimization**: Optimize layout with negative sampling
12//!
13//! ## Features
14//!
15//! - Proper smooth k-NN distance computation with binary search for sigma
16//! - Fuzzy simplicial set union with local connectivity constraint
17//! - Spectral initialization via normalized Laplacian eigenvectors
18//! - SGD layout optimization with edge sampling and negative sampling schedule
19//! - Out-of-sample extension via inverse distance weighting
20
21use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
22use scirs2_core::numeric::{Float, NumCast};
23use scirs2_core::random::Rng;
24use scirs2_core::validation::{check_positive, checkshape};
25use scirs2_linalg::eigh;
26use std::collections::BinaryHeap;
27
28use crate::error::{Result, TransformError};
29
30/// UMAP (Uniform Manifold Approximation and Projection) dimensionality reduction
31///
32/// UMAP constructs a high dimensional graph representation of the data then optimizes
33/// a low dimensional graph to be as structurally similar as possible.
34///
35/// # Example
36///
37/// ```rust,no_run
38/// use scirs2_transform::UMAP;
39/// use scirs2_core::ndarray::Array2;
40///
41/// let data = Array2::<f64>::zeros((50, 10));
42/// let mut umap = UMAP::new(15, 2, 0.1, 1.0, 200);
43/// let embedding = umap.fit_transform(&data).expect("should succeed");
44/// assert_eq!(embedding.shape(), &[50, 2]);
45/// ```
46#[derive(Debug, Clone)]
47pub struct UMAP {
48    /// Number of neighbors to consider for local structure
49    n_neighbors: usize,
50    /// Number of components (dimensions) in the low dimensional space
51    n_components: usize,
52    /// Controls how tightly UMAP packs points together (minimum distance)
53    min_dist: f64,
54    /// Controls the effective scale of local vs global structure
55    spread: f64,
56    /// Learning rate for optimization
57    learning_rate: f64,
58    /// Number of epochs for optimization
59    n_epochs: usize,
60    /// Random seed for reproducibility
61    random_state: Option<u64>,
62    /// Training data for out-of-sample extension
63    training_data: Option<Array2<f64>>,
64    /// Training k-NN graph for out-of-sample extension
65    training_graph: Option<Array2<f64>>,
66    /// Metric to use for distance computation
67    metric: String,
68    /// The low dimensional embedding
69    embedding: Option<Array2<f64>>,
70    /// Negative sampling rate (number of negative samples per positive edge)
71    negative_sample_rate: usize,
72    /// Whether to use spectral initialization
73    spectral_init: bool,
74    /// Parameters for the smooth approximation
75    a: f64,
76    b: f64,
77    /// Local connectivity parameter (must be >= 1)
78    local_connectivity: f64,
79    /// Set operation mix ratio (0.0 = pure intersection, 1.0 = pure union)
80    set_op_mix_ratio: f64,
81}
82
83impl UMAP {
84    /// Creates a new UMAP instance
85    ///
86    /// # Arguments
87    /// * `n_neighbors` - Number of neighbors to consider for local structure (typically 5-50)
88    /// * `n_components` - Number of dimensions in the low dimensional space (typically 2 or 3)
89    /// * `min_dist` - Minimum distance between points in low dimensional space (typically 0.001-0.5)
90    /// * `learning_rate` - Learning rate for SGD optimization (typically 1.0)
91    /// * `n_epochs` - Number of epochs for optimization (typically 200-500)
92    pub fn new(
93        n_neighbors: usize,
94        n_components: usize,
95        min_dist: f64,
96        learning_rate: f64,
97        n_epochs: usize,
98    ) -> Self {
99        let spread = 1.0;
100        let (a, b) = Self::find_ab_params(spread, min_dist);
101
102        UMAP {
103            n_neighbors,
104            n_components,
105            min_dist,
106            spread,
107            learning_rate,
108            n_epochs,
109            random_state: None,
110            metric: "euclidean".to_string(),
111            embedding: None,
112            training_data: None,
113            training_graph: None,
114            negative_sample_rate: 5,
115            spectral_init: true,
116            a,
117            b,
118            local_connectivity: 1.0,
119            set_op_mix_ratio: 1.0,
120        }
121    }
122
123    /// Sets the random state for reproducibility
124    pub fn with_random_state(mut self, seed: u64) -> Self {
125        self.random_state = Some(seed);
126        self
127    }
128
129    /// Sets the distance metric
130    pub fn with_metric(mut self, metric: &str) -> Self {
131        self.metric = metric.to_string();
132        self
133    }
134
135    /// Sets the negative sampling rate
136    pub fn with_negative_sample_rate(mut self, rate: usize) -> Self {
137        self.negative_sample_rate = rate;
138        self
139    }
140
141    /// Enable or disable spectral initialization
142    pub fn with_spectral_init(mut self, use_spectral: bool) -> Self {
143        self.spectral_init = use_spectral;
144        self
145    }
146
147    /// Sets the local connectivity parameter
148    pub fn with_local_connectivity(mut self, local_connectivity: f64) -> Self {
149        self.local_connectivity = local_connectivity.max(1.0);
150        self
151    }
152
153    /// Sets the set operation mix ratio
154    pub fn with_set_op_mix_ratio(mut self, ratio: f64) -> Self {
155        self.set_op_mix_ratio = ratio.clamp(0.0, 1.0);
156        self
157    }
158
159    /// Sets the spread parameter
160    pub fn with_spread(mut self, spread: f64) -> Self {
161        self.spread = spread;
162        let (a, b) = Self::find_ab_params(spread, self.min_dist);
163        self.a = a;
164        self.b = b;
165        self
166    }
167
168    /// Find a and b parameters to approximate the membership function
169    ///
170    /// We want: 1 / (1 + a * d^(2b)) to approximate
171    ///   1.0 if d <= min_dist
172    ///   exp(-(d - min_dist) / spread) if d > min_dist
173    fn find_ab_params(spread: f64, min_dist: f64) -> (f64, f64) {
174        if min_dist <= 0.0 || spread <= 0.0 {
175            return (1.0, 1.0);
176        }
177
178        let mut a = 1.0;
179        let mut b = 1.0;
180
181        // Use curve fitting approach: sample the target curve and fit a, b
182        // Target: phi(d) = 1 if d <= min_dist, else exp(-(d - min_dist) / spread)
183        // Model: psi(d) = 1 / (1 + a * d^(2b))
184        // Minimize sum of (phi(d_i) - psi(d_i))^2
185
186        // Initial guess based on analytical approximation
187        if min_dist < spread {
188            b = min_dist.ln().abs() / (1.0 - min_dist).ln().abs().max(1e-10);
189            b = b.clamp(0.1, 10.0);
190        }
191
192        // Newton's method refinement
193        for _ in 0..100 {
194            let mut residual_a = 0.0;
195            let mut residual_b = 0.0;
196            let mut jacobian_aa = 0.0;
197            let mut jacobian_bb = 0.0;
198
199            let n_samples = 50;
200            for k in 0..n_samples {
201                let d = min_dist + (3.0 * spread) * (k as f64 / n_samples as f64);
202                if d < 1e-10 {
203                    continue;
204                }
205
206                let target = if d <= min_dist {
207                    1.0
208                } else {
209                    (-(d - min_dist) / spread).exp()
210                };
211
212                let d2b = d.powf(2.0 * b);
213                let denom = 1.0 + a * d2b;
214                let model = 1.0 / denom;
215                let diff = model - target;
216
217                // Gradient w.r.t. a
218                let da = -d2b / (denom * denom);
219                // Gradient w.r.t. b
220                let db = -2.0 * a * d2b * d.ln() / (denom * denom);
221
222                residual_a += diff * da;
223                residual_b += diff * db;
224                jacobian_aa += da * da;
225                jacobian_bb += db * db;
226            }
227
228            if jacobian_aa.abs() > 1e-15 {
229                a -= 0.5 * residual_a / jacobian_aa;
230            }
231            if jacobian_bb.abs() > 1e-15 {
232                b -= 0.5 * residual_b / jacobian_bb;
233            }
234
235            a = a.max(0.001);
236            b = b.max(0.001);
237
238            if residual_a.abs() < 1e-8 && residual_b.abs() < 1e-8 {
239                break;
240            }
241        }
242
243        (a, b)
244    }
245
246    /// Compute pairwise distances between all points
247    fn compute_distances<S>(&self, x: &ArrayBase<S, Ix2>) -> Array2<f64>
248    where
249        S: Data,
250        S::Elem: Float + NumCast,
251    {
252        let n_samples = x.shape()[0];
253        let n_features = x.shape()[1];
254        let mut distances = Array2::zeros((n_samples, n_samples));
255
256        for i in 0..n_samples {
257            for j in i + 1..n_samples {
258                let dist = match self.metric.as_str() {
259                    "manhattan" => {
260                        let mut d = 0.0;
261                        for k in 0..n_features {
262                            let diff = NumCast::from(x[[i, k]]).unwrap_or(0.0)
263                                - NumCast::from(x[[j, k]]).unwrap_or(0.0);
264                            d += diff.abs();
265                        }
266                        d
267                    }
268                    "cosine" => {
269                        let mut dot = 0.0;
270                        let mut norm_i = 0.0;
271                        let mut norm_j = 0.0;
272                        for k in 0..n_features {
273                            let vi: f64 = NumCast::from(x[[i, k]]).unwrap_or(0.0);
274                            let vj: f64 = NumCast::from(x[[j, k]]).unwrap_or(0.0);
275                            dot += vi * vj;
276                            norm_i += vi * vi;
277                            norm_j += vj * vj;
278                        }
279                        let denom = (norm_i * norm_j).sqrt();
280                        if denom > 1e-10 {
281                            1.0 - (dot / denom).clamp(-1.0, 1.0)
282                        } else {
283                            1.0
284                        }
285                    }
286                    _ => {
287                        // Default: euclidean
288                        let mut d = 0.0;
289                        for k in 0..n_features {
290                            let diff = NumCast::from(x[[i, k]]).unwrap_or(0.0)
291                                - NumCast::from(x[[j, k]]).unwrap_or(0.0);
292                            d += diff * diff;
293                        }
294                        d.sqrt()
295                    }
296                };
297
298                distances[[i, j]] = dist;
299                distances[[j, i]] = dist;
300            }
301        }
302
303        distances
304    }
305
306    /// Find k nearest neighbors for each point
307    fn find_neighbors(&self, distances: &Array2<f64>) -> (Array2<usize>, Array2<f64>) {
308        let n_samples = distances.shape()[0];
309        let k = self.n_neighbors;
310
311        let mut indices = Array2::zeros((n_samples, k));
312        let mut neighbor_distances = Array2::zeros((n_samples, k));
313
314        for i in 0..n_samples {
315            let mut heap: BinaryHeap<(std::cmp::Reverse<i64>, usize)> = BinaryHeap::new();
316
317            for j in 0..n_samples {
318                if i != j {
319                    let dist_fixed = (distances[[i, j]] * 1e9) as i64;
320                    heap.push((std::cmp::Reverse(dist_fixed), j));
321                }
322            }
323
324            for j in 0..k {
325                if let Some((std::cmp::Reverse(dist_fixed), idx)) = heap.pop() {
326                    indices[[i, j]] = idx;
327                    neighbor_distances[[i, j]] = dist_fixed as f64 / 1e9;
328                }
329            }
330        }
331
332        (indices, neighbor_distances)
333    }
334
335    /// Compute fuzzy simplicial set (high dimensional graph)
336    ///
337    /// For each point, compute the smooth k-NN distance (rho) and
338    /// bandwidth (sigma) to convert distances to membership strengths.
339    /// Then form the fuzzy union: A + A^T - A * A^T
340    fn compute_graph(
341        &self,
342        knn_indices: &Array2<usize>,
343        knn_distances: &Array2<f64>,
344    ) -> Array2<f64> {
345        let n_samples = knn_indices.shape()[0];
346        let k = self.n_neighbors;
347        let mut graph = Array2::zeros((n_samples, n_samples));
348
349        for i in 0..n_samples {
350            // Compute rho: distance to the local_connectivity-th nearest neighbor
351            // With local_connectivity = 1, rho = distance to 1st nearest neighbor
352            let local_idx = (self.local_connectivity as usize)
353                .saturating_sub(1)
354                .min(k - 1);
355            let rho = knn_distances[[i, local_idx]];
356
357            // Binary search for sigma such that sum of memberships = log2(k)
358            let target = (k as f64).ln() / (2.0f64).ln();
359            let mut sigma_lo = 0.0;
360            let mut sigma_hi = f64::INFINITY;
361            let mut sigma = 1.0;
362
363            for _ in 0..64 {
364                let mut membership_sum = 0.0;
365                for j in 0..k {
366                    let d = (knn_distances[[i, j]] - rho).max(0.0);
367                    if sigma > 1e-15 {
368                        membership_sum += (-d / sigma).exp();
369                    }
370                }
371
372                if (membership_sum - target).abs() < 1e-5 {
373                    break;
374                }
375
376                if membership_sum > target {
377                    sigma_hi = sigma;
378                    sigma = (sigma_lo + sigma_hi) / 2.0;
379                } else {
380                    sigma_lo = sigma;
381                    if sigma_hi == f64::INFINITY {
382                        sigma *= 2.0;
383                    } else {
384                        sigma = (sigma_lo + sigma_hi) / 2.0;
385                    }
386                }
387            }
388
389            // Compute membership strengths
390            for j in 0..k {
391                let neighbor_idx = knn_indices[[i, j]];
392                let d = (knn_distances[[i, j]] - rho).max(0.0);
393                let strength = if sigma > 1e-15 {
394                    (-d / sigma).exp()
395                } else if d < 1e-15 {
396                    1.0
397                } else {
398                    0.0
399                };
400                graph[[i, neighbor_idx]] = strength;
401            }
402        }
403
404        // Symmetrize using the fuzzy set union:
405        // union(A, B) = A + B - A * B
406        // With mix ratio: mix * union + (1 - mix) * intersection
407        let graph_t = graph.t().to_owned();
408
409        if (self.set_op_mix_ratio - 1.0).abs() < 1e-10 {
410            // Pure union
411            &graph + &graph_t - &graph * &graph_t
412        } else if self.set_op_mix_ratio.abs() < 1e-10 {
413            // Pure intersection
414            &graph * &graph_t
415        } else {
416            // Mixed
417            let union = &graph + &graph_t - &graph * &graph_t;
418            let intersection = &graph * &graph_t;
419            &intersection * (1.0 - self.set_op_mix_ratio) + &union * self.set_op_mix_ratio
420        }
421    }
422
423    /// Initialize the low dimensional embedding using spectral method
424    fn initialize_embedding(&self, n_samples: usize, graph: &Array2<f64>) -> Result<Array2<f64>> {
425        if self.spectral_init && n_samples > self.n_components + 1 {
426            // Spectral initialization using normalized Laplacian eigenvectors
427            match self.spectral_init_from_graph(n_samples, graph) {
428                Ok(embedding) => return Ok(embedding),
429                Err(_) => {
430                    // Fall back to random initialization
431                }
432            }
433        }
434
435        // Random initialization
436        let mut rng = scirs2_core::random::rng();
437        let mut embedding = Array2::zeros((n_samples, self.n_components));
438        for i in 0..n_samples {
439            for j in 0..self.n_components {
440                embedding[[i, j]] = rng.random_range(0.0..1.0) * 10.0 - 5.0;
441            }
442        }
443
444        Ok(embedding)
445    }
446
447    /// Spectral initialization from the fuzzy simplicial set graph
448    fn spectral_init_from_graph(
449        &self,
450        n_samples: usize,
451        graph: &Array2<f64>,
452    ) -> Result<Array2<f64>> {
453        // Compute degree matrix
454        let mut degree = Array1::zeros(n_samples);
455        for i in 0..n_samples {
456            degree[i] = graph.row(i).sum();
457        }
458
459        // Check for isolated nodes
460        for i in 0..n_samples {
461            if degree[i] < 1e-10 {
462                return Err(TransformError::ComputationError(
463                    "Graph has isolated nodes, cannot use spectral initialization".to_string(),
464                ));
465            }
466        }
467
468        // Compute normalized Laplacian: L = I - D^{-1/2} W D^{-1/2}
469        let mut laplacian = Array2::zeros((n_samples, n_samples));
470        for i in 0..n_samples {
471            for j in 0..n_samples {
472                if i == j {
473                    laplacian[[i, j]] = 1.0;
474                } else {
475                    let norm_weight = graph[[i, j]] / (degree[i] * degree[j]).sqrt();
476                    laplacian[[i, j]] = -norm_weight;
477                }
478            }
479        }
480
481        // Eigendecomposition
482        let (eigenvalues, eigenvectors) =
483            eigh(&laplacian.view(), None).map_err(|e| TransformError::LinalgError(e))?;
484
485        // Sort eigenvalues in ascending order and pick eigenvectors 1..n_components+1
486        let mut indices: Vec<usize> = (0..n_samples).collect();
487        indices.sort_by(|&a, &b| {
488            eigenvalues[a]
489                .partial_cmp(&eigenvalues[b])
490                .unwrap_or(std::cmp::Ordering::Equal)
491        });
492
493        let mut embedding = Array2::zeros((n_samples, self.n_components));
494        for j in 0..self.n_components {
495            let idx = indices[j + 1]; // Skip the first (constant) eigenvector
496            let scale = 10.0; // Scale for spread
497            for i in 0..n_samples {
498                embedding[[i, j]] = eigenvectors[[i, idx]] * scale;
499            }
500        }
501
502        Ok(embedding)
503    }
504
505    /// Optimize the low dimensional embedding using SGD with negative sampling
506    fn optimize_embedding(
507        &self,
508        embedding: &mut Array2<f64>,
509        graph: &Array2<f64>,
510        n_epochs: usize,
511    ) {
512        let n_samples = embedding.shape()[0];
513        let mut rng = scirs2_core::random::rng();
514
515        // Create edge list from graph with weights
516        let mut edges = Vec::new();
517        let mut weights = Vec::new();
518        for i in 0..n_samples {
519            for j in 0..n_samples {
520                if graph[[i, j]] > 0.0 {
521                    edges.push((i, j));
522                    weights.push(graph[[i, j]]);
523                }
524            }
525        }
526
527        let n_edges = edges.len();
528        if n_edges == 0 {
529            return;
530        }
531
532        // Compute epochs per sample based on weights
533        let max_weight = weights.iter().cloned().fold(0.0f64, f64::max);
534        let epochs_per_sample: Vec<f64> = if max_weight > 0.0 {
535            weights
536                .iter()
537                .map(|&w| {
538                    let epoch_ratio = max_weight / w.max(1e-10);
539                    epoch_ratio.min(n_epochs as f64)
540                })
541                .collect()
542        } else {
543            vec![1.0; n_edges]
544        };
545
546        let mut epochs_per_negative_sample: Vec<f64> = epochs_per_sample
547            .iter()
548            .map(|&e| e / self.negative_sample_rate as f64)
549            .collect();
550
551        let mut epoch_of_next_sample: Vec<f64> = epochs_per_sample.clone();
552        let mut epoch_of_next_negative_sample: Vec<f64> = epochs_per_negative_sample.clone();
553
554        // Clipping constant for gradient
555        let clip_val = 4.0;
556
557        // Optimization loop
558        for epoch in 0..n_epochs {
559            let alpha = self.learning_rate * (1.0 - epoch as f64 / n_epochs as f64);
560
561            for edge_idx in 0..n_edges {
562                if epoch_of_next_sample[edge_idx] > epoch as f64 {
563                    continue;
564                }
565
566                let (i, j) = edges[edge_idx];
567
568                // Compute distance in embedding space
569                let mut dist_sq = 0.0;
570                for d in 0..self.n_components {
571                    let diff = embedding[[i, d]] - embedding[[j, d]];
572                    dist_sq += diff * diff;
573                }
574                dist_sq = dist_sq.max(1e-10);
575
576                // Attractive force
577                let grad_coeff = -2.0 * self.a * self.b * dist_sq.powf(self.b - 1.0)
578                    / (1.0 + self.a * dist_sq.powf(self.b));
579
580                for d in 0..self.n_components {
581                    let grad = (grad_coeff * (embedding[[i, d]] - embedding[[j, d]]))
582                        .clamp(-clip_val, clip_val);
583                    embedding[[i, d]] += alpha * grad;
584                    embedding[[j, d]] -= alpha * grad;
585                }
586
587                // Update next sample epoch
588                epoch_of_next_sample[edge_idx] += epochs_per_sample[edge_idx];
589
590                // Negative sampling
591                let n_neg = self.negative_sample_rate;
592                for _ in 0..n_neg {
593                    if epoch_of_next_negative_sample[edge_idx] > epoch as f64 {
594                        break;
595                    }
596
597                    let k = rng.random_range(0..n_samples);
598                    if k == i {
599                        continue;
600                    }
601
602                    let mut neg_dist_sq = 0.0;
603                    for d in 0..self.n_components {
604                        let diff = embedding[[i, d]] - embedding[[k, d]];
605                        neg_dist_sq += diff * diff;
606                    }
607                    neg_dist_sq = neg_dist_sq.max(1e-10);
608
609                    // Repulsive force
610                    let grad_coeff = 2.0 * self.b
611                        / ((0.001 + neg_dist_sq) * (1.0 + self.a * neg_dist_sq.powf(self.b)));
612
613                    for d in 0..self.n_components {
614                        let grad = (grad_coeff * (embedding[[i, d]] - embedding[[k, d]]))
615                            .clamp(-clip_val, clip_val);
616                        embedding[[i, d]] += alpha * grad;
617                    }
618
619                    epoch_of_next_negative_sample[edge_idx] += epochs_per_negative_sample[edge_idx];
620                }
621            }
622        }
623    }
624
625    /// Fits the UMAP model to the input data
626    ///
627    /// # Arguments
628    /// * `x` - The input data, shape (n_samples, n_features)
629    ///
630    /// # Returns
631    /// * `Result<()>` - Ok if successful, Err otherwise
632    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
633    where
634        S: Data,
635        S::Elem: Float + NumCast + Send + Sync,
636    {
637        let (n_samples, n_features) = x.dim();
638
639        check_positive(self.n_neighbors, "n_neighbors")?;
640        check_positive(self.n_components, "n_components")?;
641        check_positive(self.n_epochs, "n_epochs")?;
642        checkshape(x, &[n_samples, n_features], "x")?;
643
644        if n_samples < self.n_neighbors {
645            return Err(TransformError::InvalidInput(format!(
646                "n_neighbors={} must be <= n_samples={}",
647                self.n_neighbors, n_samples
648            )));
649        }
650
651        // Store training data
652        let training_data = Array2::from_shape_fn((n_samples, n_features), |(i, j)| {
653            NumCast::from(x[[i, j]]).unwrap_or(0.0)
654        });
655        self.training_data = Some(training_data);
656
657        // Step 1: Compute pairwise distances
658        let distances = self.compute_distances(x);
659
660        // Step 2: Find k nearest neighbors
661        let (knn_indices, knn_distances) = self.find_neighbors(&distances);
662
663        // Step 3: Compute fuzzy simplicial set
664        let graph = self.compute_graph(&knn_indices, &knn_distances);
665        self.training_graph = Some(graph.clone());
666
667        // Step 4: Initialize low dimensional embedding
668        let mut embedding = self.initialize_embedding(n_samples, &graph)?;
669
670        // Step 5: Optimize the embedding
671        self.optimize_embedding(&mut embedding, &graph, self.n_epochs);
672
673        self.embedding = Some(embedding);
674
675        Ok(())
676    }
677
678    /// Transforms the input data using the fitted UMAP model
679    ///
680    /// # Arguments
681    /// * `x` - The input data, shape (n_samples, n_features)
682    ///
683    /// # Returns
684    /// * `Result<Array2<f64>>` - The transformed data, shape (n_samples, n_components)
685    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
686    where
687        S: Data,
688        S::Elem: Float + NumCast,
689    {
690        if self.embedding.is_none() {
691            return Err(TransformError::NotFitted(
692                "UMAP model has not been fitted".to_string(),
693            ));
694        }
695
696        let training_data = self
697            .training_data
698            .as_ref()
699            .ok_or_else(|| TransformError::NotFitted("Training data not available".to_string()))?;
700
701        let (_, n_features) = x.dim();
702        let (_, n_training_features) = training_data.dim();
703
704        if n_features != n_training_features {
705            return Err(TransformError::InvalidInput(format!(
706                "Input features {n_features} must match training features {n_training_features}"
707            )));
708        }
709
710        // If transforming the same data as training, return stored embedding
711        if self.is_same_data(x, training_data) {
712            return self
713                .embedding
714                .as_ref()
715                .cloned()
716                .ok_or_else(|| TransformError::NotFitted("Embedding not available".to_string()));
717        }
718
719        // Out-of-sample extension
720        self.transform_new_data(x)
721    }
722
723    /// Fits the UMAP model to the input data and returns the embedding
724    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
725    where
726        S: Data,
727        S::Elem: Float + NumCast + Send + Sync,
728    {
729        self.fit(x)?;
730        self.transform(x)
731    }
732
733    /// Returns the low dimensional embedding
734    pub fn embedding(&self) -> Option<&Array2<f64>> {
735        self.embedding.as_ref()
736    }
737
738    /// Returns the fuzzy simplicial set graph
739    pub fn graph(&self) -> Option<&Array2<f64>> {
740        self.training_graph.as_ref()
741    }
742
743    /// Check if the input data is the same as training data
744    fn is_same_data<S>(&self, x: &ArrayBase<S, Ix2>, training_data: &Array2<f64>) -> bool
745    where
746        S: Data,
747        S::Elem: Float + NumCast,
748    {
749        if x.dim() != training_data.dim() {
750            return false;
751        }
752
753        let (n_samples, n_features) = x.dim();
754        for i in 0..n_samples {
755            for j in 0..n_features {
756                let x_val: f64 = NumCast::from(x[[i, j]]).unwrap_or(0.0);
757                if (x_val - training_data[[i, j]]).abs() > 1e-10 {
758                    return false;
759                }
760            }
761        }
762        true
763    }
764
765    /// Transform new data using out-of-sample extension (inverse distance weighting)
766    fn transform_new_data<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
767    where
768        S: Data,
769        S::Elem: Float + NumCast,
770    {
771        let training_data = self
772            .training_data
773            .as_ref()
774            .ok_or_else(|| TransformError::NotFitted("Training data not available".to_string()))?;
775        let training_embedding = self
776            .embedding
777            .as_ref()
778            .ok_or_else(|| TransformError::NotFitted("Embedding not available".to_string()))?;
779
780        let (n_new_samples, _) = x.dim();
781        let (n_training_samples, _) = training_data.dim();
782
783        let mut new_embedding = Array2::zeros((n_new_samples, self.n_components));
784
785        for i in 0..n_new_samples {
786            // Compute distances to all training samples
787            let mut distances: Vec<(f64, usize)> = Vec::with_capacity(n_training_samples);
788            for j in 0..n_training_samples {
789                let mut dist_sq = 0.0;
790                for k in 0..x.ncols() {
791                    let x_val: f64 = NumCast::from(x[[i, k]]).unwrap_or(0.0);
792                    let train_val = training_data[[j, k]];
793                    let diff = x_val - train_val;
794                    dist_sq += diff * diff;
795                }
796                distances.push((dist_sq.sqrt(), j));
797            }
798
799            // Sort and take k nearest neighbors
800            distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
801            let k = self.n_neighbors.min(n_training_samples);
802
803            // Inverse distance weighting
804            let mut total_weight = 0.0;
805            let mut weighted_coords = vec![0.0; self.n_components];
806
807            for &(dist, train_idx) in distances.iter().take(k) {
808                let weight = if dist > 1e-10 {
809                    1.0 / (dist + 1e-10)
810                } else {
811                    1e10
812                };
813                total_weight += weight;
814
815                for dim in 0..self.n_components {
816                    weighted_coords[dim] += weight * training_embedding[[train_idx, dim]];
817                }
818            }
819
820            if total_weight > 0.0 {
821                for dim in 0..self.n_components {
822                    new_embedding[[i, dim]] = weighted_coords[dim] / total_weight;
823                }
824            }
825        }
826
827        Ok(new_embedding)
828    }
829}
830
831#[cfg(test)]
832mod tests {
833    use super::*;
834    use scirs2_core::ndarray::Array;
835
836    #[test]
837    fn test_umap_basic() {
838        let x = Array::from_shape_vec(
839            (10, 3),
840            vec![
841                1.0, 2.0, 3.0, 1.1, 2.1, 3.1, 1.2, 2.2, 3.2, 5.0, 6.0, 7.0, 5.1, 6.1, 7.1, 5.2,
842                6.2, 7.2, 9.0, 10.0, 11.0, 9.1, 10.1, 11.1, 9.2, 10.2, 11.2, 9.3, 10.3, 11.3,
843            ],
844        )
845        .expect("Failed to create test array");
846
847        let mut umap = UMAP::new(3, 2, 0.1, 1.0, 50);
848        let embedding = umap.fit_transform(&x).expect("UMAP fit_transform failed");
849
850        assert_eq!(embedding.shape(), &[10, 2]);
851        for val in embedding.iter() {
852            assert!(val.is_finite());
853        }
854    }
855
856    #[test]
857    fn test_umap_parameters() {
858        let x: Array2<f64> = Array::eye(5);
859
860        let mut umap = UMAP::new(2, 3, 0.5, 0.5, 100)
861            .with_random_state(42)
862            .with_metric("euclidean");
863
864        let embedding = umap.fit_transform(&x).expect("UMAP fit_transform failed");
865        assert_eq!(embedding.shape(), &[5, 3]);
866    }
867
868    #[test]
869    fn test_umap_spectral_init() {
870        let x = Array::from_shape_vec(
871            (8, 2),
872            vec![
873                0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 5.0, 5.0, 5.0, 6.0, 6.0, 5.0, 6.0, 6.0,
874            ],
875        )
876        .expect("Failed to create test array");
877
878        let mut umap = UMAP::new(3, 2, 0.1, 1.0, 50).with_spectral_init(true);
879        let embedding = umap.fit_transform(&x).expect("UMAP fit_transform failed");
880
881        assert_eq!(embedding.shape(), &[8, 2]);
882        for val in embedding.iter() {
883            assert!(val.is_finite());
884        }
885    }
886
887    #[test]
888    fn test_umap_random_init() {
889        let x = Array::from_shape_vec(
890            (8, 2),
891            vec![
892                0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 5.0, 5.0, 5.0, 6.0, 6.0, 5.0, 6.0, 6.0,
893            ],
894        )
895        .expect("Failed to create test array");
896
897        let mut umap = UMAP::new(3, 2, 0.1, 1.0, 50).with_spectral_init(false);
898        let embedding = umap.fit_transform(&x).expect("UMAP fit_transform failed");
899
900        assert_eq!(embedding.shape(), &[8, 2]);
901        for val in embedding.iter() {
902            assert!(val.is_finite());
903        }
904    }
905
906    #[test]
907    fn test_umap_negative_sampling() {
908        let x = Array::from_shape_vec(
909            (8, 2),
910            vec![
911                0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 5.0, 5.0, 5.0, 6.0, 6.0, 5.0, 6.0, 6.0,
912            ],
913        )
914        .expect("Failed to create test array");
915
916        let mut umap = UMAP::new(3, 2, 0.1, 1.0, 50).with_negative_sample_rate(10);
917        let embedding = umap.fit_transform(&x).expect("UMAP fit_transform failed");
918
919        assert_eq!(embedding.shape(), &[8, 2]);
920        for val in embedding.iter() {
921            assert!(val.is_finite());
922        }
923    }
924
925    #[test]
926    fn test_umap_out_of_sample() {
927        let x_train = Array::from_shape_vec(
928            (10, 3),
929            vec![
930                1.0, 2.0, 3.0, 1.1, 2.1, 3.1, 1.2, 2.2, 3.2, 5.0, 6.0, 7.0, 5.1, 6.1, 7.1, 5.2,
931                6.2, 7.2, 9.0, 10.0, 11.0, 9.1, 10.1, 11.1, 9.2, 10.2, 11.2, 9.3, 10.3, 11.3,
932            ],
933        )
934        .expect("Failed to create test array");
935
936        let mut umap = UMAP::new(3, 2, 0.1, 1.0, 50);
937        umap.fit(&x_train).expect("UMAP fit failed");
938
939        let x_test = Array::from_shape_vec((2, 3), vec![1.05, 2.05, 3.05, 9.05, 10.05, 11.05])
940            .expect("Failed to create test array");
941
942        let test_embedding = umap.transform(&x_test).expect("UMAP transform failed");
943        assert_eq!(test_embedding.shape(), &[2, 2]);
944        for val in test_embedding.iter() {
945            assert!(val.is_finite());
946        }
947    }
948
949    #[test]
950    fn test_umap_find_ab_params() {
951        let (a, b) = UMAP::find_ab_params(1.0, 0.1);
952        assert!(a > 0.0);
953        assert!(b > 0.0);
954
955        // The function 1/(1+a*d^(2b)) should be close to 1 at d=0
956        let val_at_zero = 1.0 / (1.0 + a * 0.0f64.powf(2.0 * b));
957        assert!((val_at_zero - 1.0).abs() < 1e-5);
958    }
959}