scirs2_transform/reduction/
umap.rs

1//! Uniform Manifold Approximation and Projection (UMAP) for dimensionality reduction
2//!
3//! UMAP is a non-linear dimensionality reduction technique that can be used for
4//! visualization similarly to t-SNE, but also for general non-linear dimension reduction.
5
6use scirs2_core::ndarray::{Array2, ArrayBase, Data, Ix2};
7use scirs2_core::numeric::{Float, NumCast};
8use scirs2_core::random::Rng;
9use scirs2_core::validation::{check_positive, checkshape};
10use std::collections::BinaryHeap;
11
12use crate::error::{Result, TransformError};
13
14/// UMAP (Uniform Manifold Approximation and Projection) dimensionality reduction
15///
16/// UMAP constructs a high dimensional graph representation of the data then optimizes
17/// a low dimensional graph to be as structurally similar as possible.
18#[derive(Debug, Clone)]
19pub struct UMAP {
20    /// Number of neighbors to consider for local structure
21    n_neighbors: usize,
22    /// Number of components (dimensions) in the low dimensional space
23    n_components: usize,
24    /// Controls how UMAP balances local versus global structure
25    #[allow(dead_code)]
26    mindist: f64,
27    /// Controls how tightly UMAP is allowed to pack points together
28    #[allow(dead_code)]
29    spread: f64,
30    /// Learning rate for optimization
31    learning_rate: f64,
32    /// Number of epochs for optimization
33    n_epochs: usize,
34    /// Random seed for reproducibility
35    random_state: Option<u64>,
36    /// Training data for out-of-sample extension
37    training_data: Option<Array2<f64>>,
38    /// Training k-NN graph for out-of-sample extension
39    training_graph: Option<Array2<f64>>,
40    /// Metric to use for distance computation
41    metric: String,
42    /// The low dimensional embedding
43    embedding: Option<Array2<f64>>,
44    /// Parameters for the optimization
45    a: f64,
46    b: f64,
47}
48
49impl UMAP {
50    /// Creates a new UMAP instance
51    ///
52    /// # Arguments
53    /// * `n_neighbors` - Number of neighbors to consider for local structure (default: 15)
54    /// * `n_components` - Number of dimensions in the low dimensional space (default: 2)
55    /// * `mindist` - Minimum distance between points in low dimensional space (default: 0.1)
56    /// * `learning_rate` - Learning rate for optimization (default: 1.0)
57    /// * `n_epochs` - Number of epochs for optimization (default: 200)
58    pub fn new(
59        n_neighbors: usize,
60        n_components: usize,
61        mindist: f64,
62        learning_rate: f64,
63        n_epochs: usize,
64    ) -> Self {
65        // Compute a and b parameters based on mindist and spread
66        let spread = 1.0;
67        let (a, b) = Self::find_ab_params(spread, mindist);
68
69        UMAP {
70            n_neighbors,
71            n_components,
72            mindist,
73            spread,
74            learning_rate,
75            n_epochs,
76            random_state: None,
77            metric: "euclidean".to_string(),
78            embedding: None,
79            training_data: None,
80            training_graph: None,
81            a,
82            b,
83        }
84    }
85
86    /// Sets the random state for reproducibility
87    pub fn with_random_state(mut self, seed: u64) -> Self {
88        self.random_state = Some(seed);
89        self
90    }
91
92    /// Sets the distance metric
93    pub fn with_metric(mut self, metric: &str) -> Self {
94        self.metric = metric.to_string();
95        self
96    }
97
98    /// Find a and b parameters to approximate the fuzzy set membership function
99    fn find_ab_params(_spread: f64, mindist: f64) -> (f64, f64) {
100        // Binary search to find good values of a and b
101        let mut a = 1.0;
102        let mut b = 1.0;
103
104        // Initial guess based on mindist and _spread
105        if mindist > 0.0 {
106            b = mindist.ln() / (1.0 - mindist).ln();
107        }
108
109        // Refine using Newton's method
110        for _ in 0..64 {
111            let val = 1.0 / (1.0 + a * mindist.powf(2.0 * b));
112            let grad_a = -mindist.powf(2.0 * b) / (1.0 + a * mindist.powf(2.0 * b)).powi(2);
113            let grad_b = -2.0 * a * mindist.powf(2.0 * b) * mindist.ln()
114                / (1.0 + a * mindist.powf(2.0 * b)).powi(2);
115
116            if (val - 0.5).abs() < 1e-5 {
117                break;
118            }
119
120            a -= (val - 0.5) / grad_a;
121            b -= (val - 0.5) / grad_b;
122
123            a = a.max(0.001);
124            b = b.max(0.001);
125        }
126
127        (a, b)
128    }
129
130    /// Compute pairwise distances between all points
131    fn compute_distances<S>(&self, x: &ArrayBase<S, Ix2>) -> Array2<f64>
132    where
133        S: Data,
134        S::Elem: Float + NumCast,
135    {
136        let nsamples = x.shape()[0];
137        let mut distances = Array2::zeros((nsamples, nsamples));
138
139        // Compute pairwise Euclidean distances
140        for i in 0..nsamples {
141            for j in i + 1..nsamples {
142                let mut dist = 0.0;
143                for k in 0..x.shape()[1] {
144                    let diff = NumCast::from(x[[i, k]]).unwrap_or(0.0)
145                        - NumCast::from(x[[j, k]]).unwrap_or(0.0);
146                    dist += diff * diff;
147                }
148                dist = dist.sqrt();
149                distances[[i, j]] = dist;
150                distances[[j, i]] = dist;
151            }
152        }
153
154        distances
155    }
156
157    /// Find k nearest neighbors for each point
158    fn find_neighbors(&self, distances: &Array2<f64>) -> (Array2<usize>, Array2<f64>) {
159        let nsamples = distances.shape()[0];
160        let k = self.n_neighbors;
161
162        let mut indices = Array2::zeros((nsamples, k));
163        let mut neighbor_distances = Array2::zeros((nsamples, k));
164
165        for i in 0..nsamples {
166            // Use a min heap to find k smallest distances
167            let mut heap: BinaryHeap<(std::cmp::Reverse<i64>, usize)> = BinaryHeap::new();
168
169            for j in 0..nsamples {
170                if i != j {
171                    // Convert to fixed point for comparison
172                    let dist_fixed = (distances[[i, j]] * 1e9) as i64;
173                    heap.push((std::cmp::Reverse(dist_fixed), j));
174                }
175            }
176
177            // Extract k nearest neighbors
178            for j in 0..k {
179                if let Some((std::cmp::Reverse(dist_fixed), idx)) = heap.pop() {
180                    indices[[i, j]] = idx;
181                    neighbor_distances[[i, j]] = dist_fixed as f64 / 1e9;
182                }
183            }
184        }
185
186        (indices, neighbor_distances)
187    }
188
189    /// Compute fuzzy simplicial set (high dimensional graph)
190    fn compute_graph(
191        &self,
192        knn_indices: &Array2<usize>,
193        knn_distances: &Array2<f64>,
194    ) -> Array2<f64> {
195        let nsamples = knn_indices.shape()[0];
196        let mut graph = Array2::zeros((nsamples, nsamples));
197
198        // For each point, compute membership strengths to its neighbors
199        for i in 0..nsamples {
200            // Find rho (distance to nearest neighbor)
201            let rho = knn_distances[[i, 0]];
202
203            // Binary search for sigma
204            let mut sigma = 1.0;
205            let target = self.n_neighbors as f64;
206
207            for _ in 0..64 {
208                let mut sum = 0.0;
209                for j in 1..self.n_neighbors {
210                    let d = (knn_distances[[i, j]] - rho).max(0.0);
211                    sum += (-d / sigma).exp();
212                }
213
214                if (sum - target).abs() < 1e-5 {
215                    break;
216                }
217
218                if sum > target {
219                    sigma *= 2.0;
220                } else {
221                    sigma /= 2.0;
222                }
223            }
224
225            // Compute membership strengths
226            for j in 0..self.n_neighbors {
227                let neighbor_idx = knn_indices[[i, j]];
228                let d = (knn_distances[[i, j]] - rho).max(0.0);
229                let strength = (-d / sigma).exp();
230                graph[[i, neighbor_idx]] = strength;
231            }
232        }
233
234        // Symmetrize the graph
235        let graph_transpose = graph.t().to_owned();
236        &graph + &graph_transpose - &graph * &graph_transpose
237    }
238
239    /// Initialize the low dimensional embedding
240    fn initialize_embedding(&self, nsamples: usize) -> Array2<f64> {
241        let mut rng = scirs2_core::random::rng();
242
243        // Initialize with small random values
244        let mut embedding = Array2::zeros((nsamples, self.n_components));
245        for i in 0..nsamples {
246            for j in 0..self.n_components {
247                embedding[[i, j]] = rng.gen_range(0.0..1.0) * 10.0 - 5.0;
248            }
249        }
250
251        embedding
252    }
253
254    /// Optimize the low dimensional embedding
255    fn optimize_embedding(
256        &self,
257        embedding: &mut Array2<f64>,
258        graph: &Array2<f64>,
259        n_epochs: usize,
260    ) {
261        let nsamples = embedding.shape()[0];
262        let mut rng = scirs2_core::random::rng();
263
264        // Create edge list from graph
265        let mut edges = Vec::new();
266        let mut weights = Vec::new();
267        for i in 0..nsamples {
268            for j in 0..nsamples {
269                if graph[[i, j]] > 0.0 {
270                    edges.push((i, j));
271                    weights.push(graph[[i, j]]);
272                }
273            }
274        }
275
276        let n_edges = edges.len();
277
278        // Optimization loop
279        for epoch in 0..n_epochs {
280            // Adjust learning rate
281            let alpha = self.learning_rate * (1.0 - epoch as f64 / n_epochs as f64);
282
283            // Sample edges for this epoch
284            for _ in 0..n_edges {
285                // Sample an edge
286                let edge_idx = rng.gen_range(0..n_edges);
287                let (i, j) = edges[edge_idx];
288
289                // Compute distance in embedding space
290                let mut dist_sq = 0.0;
291                for d in 0..self.n_components {
292                    let diff = embedding[[i, d]] - embedding[[j, d]];
293                    dist_sq += diff * diff;
294                }
295                let dist = dist_sq.sqrt();
296
297                // Attractive force
298                if dist > 0.0 {
299                    let attraction = -2.0 * self.a * self.b * dist.powf(2.0 * self.b - 2.0)
300                        / (1.0 + self.a * dist.powf(2.0 * self.b));
301
302                    for d in 0..self.n_components {
303                        let grad = attraction * (embedding[[i, d]] - embedding[[j, d]]) / dist;
304                        embedding[[i, d]] += alpha * grad * weights[edge_idx];
305                        embedding[[j, d]] -= alpha * grad * weights[edge_idx];
306                    }
307                }
308
309                // Repulsive force - sample a negative edge
310                let k = rng.gen_range(0..nsamples);
311                if k != i && k != j {
312                    let mut neg_dist_sq = 0.0;
313                    for d in 0..self.n_components {
314                        let diff = embedding[[i, d]] - embedding[[k, d]];
315                        neg_dist_sq += diff * diff;
316                    }
317                    let neg_dist = neg_dist_sq.sqrt();
318
319                    if neg_dist > 0.0 {
320                        let repulsion = 2.0 * self.b
321                            / (1.0 + self.a * neg_dist.powf(2.0 * self.b))
322                            / (1.0 + neg_dist * neg_dist);
323
324                        for d in 0..self.n_components {
325                            let grad =
326                                repulsion * (embedding[[i, d]] - embedding[[k, d]]) / neg_dist;
327                            embedding[[i, d]] += alpha * grad;
328                            embedding[[k, d]] -= alpha * grad;
329                        }
330                    }
331                }
332            }
333        }
334    }
335
336    /// Fits the UMAP model to the input data
337    ///
338    /// # Arguments
339    /// * `x` - The input data, shape (nsamples, n_features)
340    ///
341    /// # Returns
342    /// * `Result<()>` - Ok if successful, Err otherwise
343    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
344    where
345        S: Data,
346        S::Elem: Float + NumCast + Send + Sync,
347    {
348        let (nsamples, n_features) = x.dim();
349
350        // Validate inputs
351        check_positive(self.n_neighbors, "n_neighbors")?;
352        check_positive(self.n_components, "n_components")?;
353        check_positive(self.n_epochs, "n_epochs")?;
354        checkshape(x, &[nsamples, n_features], "x")?;
355
356        if nsamples < self.n_neighbors {
357            return Err(TransformError::InvalidInput(format!(
358                "n_neighbors={} must be <= nsamples={}",
359                self.n_neighbors, nsamples
360            )));
361        }
362
363        // Store training data for out-of-sample extension
364        let training_data = Array2::from_shape_fn((nsamples, n_features), |(i, j)| {
365            NumCast::from(x[[i, j]]).unwrap_or(0.0)
366        });
367        self.training_data = Some(training_data);
368
369        // Step 1: Compute pairwise distances
370        let distances = self.compute_distances(x);
371
372        // Step 2: Find k nearest neighbors
373        let (knn_indices, knn_distances) = self.find_neighbors(&distances);
374
375        // Step 3: Compute fuzzy simplicial set
376        let graph = self.compute_graph(&knn_indices, &knn_distances);
377        self.training_graph = Some(graph.clone());
378
379        // Step 4: Initialize low dimensional embedding
380        let mut embedding = self.initialize_embedding(nsamples);
381
382        // Step 5: Optimize the embedding
383        self.optimize_embedding(&mut embedding, &graph, self.n_epochs);
384
385        self.embedding = Some(embedding);
386
387        Ok(())
388    }
389
390    /// Transforms the input data using the fitted UMAP model
391    ///
392    /// # Arguments
393    /// * `x` - The input data, shape (nsamples, n_features)
394    ///
395    /// # Returns
396    /// * `Result<Array2<f64>>` - The transformed data, shape (nsamples, n_components)
397    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
398    where
399        S: Data,
400        S::Elem: Float + NumCast,
401    {
402        if self.embedding.is_none() {
403            return Err(TransformError::NotFitted(
404                "UMAP model has not been fitted".to_string(),
405            ));
406        }
407
408        let training_data = self
409            .training_data
410            .as_ref()
411            .ok_or_else(|| TransformError::NotFitted("Training data not available".to_string()))?;
412
413        let (_n_new_samples, n_features) = x.dim();
414        let (_, n_training_features) = training_data.dim();
415
416        if n_features != n_training_features {
417            return Err(TransformError::InvalidInput(format!(
418                "Input features {n_features} must match training features {n_training_features}"
419            )));
420        }
421
422        // If transforming the same data as training, return stored embedding
423        if self.is_same_data(x, training_data) {
424            return Ok(self.embedding.as_ref().unwrap().clone());
425        }
426
427        // Implement out-of-sample extension using weighted average of nearest neighbors
428        self.transform_new_data(x)
429    }
430
431    /// Fits the UMAP model to the input data and returns the embedding
432    ///
433    /// # Arguments
434    /// * `x` - The input data, shape (nsamples, n_features)
435    ///
436    /// # Returns
437    /// * `Result<Array2<f64>>` - The transformed data, shape (nsamples, n_components)
438    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
439    where
440        S: Data,
441        S::Elem: Float + NumCast + Send + Sync,
442    {
443        self.fit(x)?;
444        self.transform(x)
445    }
446
447    /// Returns the low dimensional embedding
448    pub fn embedding(&self) -> Option<&Array2<f64>> {
449        self.embedding.as_ref()
450    }
451
452    /// Check if the input data is the same as training data
453    fn is_same_data<S>(&self, x: &ArrayBase<S, Ix2>, trainingdata: &Array2<f64>) -> bool
454    where
455        S: Data,
456        S::Elem: Float + NumCast,
457    {
458        if x.dim() != trainingdata.dim() {
459            return false;
460        }
461
462        let (nsamples, n_features) = x.dim();
463        for i in 0..nsamples {
464            for j in 0..n_features {
465                let x_val = NumCast::from(x[[i, j]]).unwrap_or(0.0);
466                if (x_val - trainingdata[[i, j]]).abs() > 1e-10 {
467                    return false;
468                }
469            }
470        }
471        true
472    }
473
474    /// Transform new data using out-of-sample extension
475    fn transform_new_data<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
476    where
477        S: Data,
478        S::Elem: Float + NumCast,
479    {
480        let training_data = self.training_data.as_ref().unwrap();
481        let training_embedding = self.embedding.as_ref().unwrap();
482
483        let (n_new_samples_, _) = x.dim();
484        let (n_training_samples_, _) = training_data.dim();
485
486        // For each new sample, find k nearest neighbors in training data
487        let mut new_embedding = Array2::zeros((n_new_samples_, self.n_components));
488
489        for i in 0..n_new_samples_ {
490            // Compute distances to all training samples
491            let mut distances = Vec::new();
492            for j in 0..n_training_samples_ {
493                let mut dist_sq = 0.0;
494                for k in 0..x.ncols() {
495                    let x_val = NumCast::from(x[[i, k]]).unwrap_or(0.0);
496                    let train_val = training_data[[j, k]];
497                    let diff = x_val - train_val;
498                    dist_sq += diff * diff;
499                }
500                distances.push((dist_sq.sqrt(), j));
501            }
502
503            // Sort and take k nearest neighbors
504            distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
505            let k = self.n_neighbors.min(n_training_samples_);
506
507            // Compute weights based on distances (inverse distance weighting)
508            let mut total_weight = 0.0;
509            let mut weighted_coords = vec![0.0; self.n_components];
510
511            for (dist, train_idx) in distances.iter().take(k) {
512                let weight = if *dist > 1e-10 {
513                    1.0 / (*dist + 1e-10)
514                } else {
515                    1e10
516                };
517                total_weight += weight;
518
519                for dim in 0..self.n_components {
520                    weighted_coords[dim] += weight * training_embedding[[*train_idx, dim]];
521                }
522            }
523
524            // Normalize weights and set coordinates
525            if total_weight > 0.0 {
526                for dim in 0..self.n_components {
527                    new_embedding[[i, dim]] = weighted_coords[dim] / total_weight;
528                }
529            }
530        }
531
532        Ok(new_embedding)
533    }
534}
535
536#[cfg(test)]
537mod tests {
538    use super::*;
539    use scirs2_core::ndarray::Array;
540
541    #[test]
542    fn test_umap_basic() {
543        // Create a simple dataset
544        let x = Array::from_shape_vec(
545            (10, 3),
546            vec![
547                1.0, 2.0, 3.0, 1.1, 2.1, 3.1, 1.2, 2.2, 3.2, 5.0, 6.0, 7.0, 5.1, 6.1, 7.1, 5.2,
548                6.2, 7.2, 9.0, 10.0, 11.0, 9.1, 10.1, 11.1, 9.2, 10.2, 11.2, 9.3, 10.3, 11.3,
549            ],
550        )
551        .unwrap();
552
553        // Initialize and fit UMAP
554        let mut umap = UMAP::new(3, 2, 0.1, 1.0, 50);
555        let embedding = umap.fit_transform(&x).unwrap();
556
557        // Check that the shape is correct
558        assert_eq!(embedding.shape(), &[10, 2]);
559
560        // Check that embedding values are finite
561        for val in embedding.iter() {
562            assert!(val.is_finite());
563        }
564    }
565
566    #[test]
567    fn test_umap_parameters() {
568        let x: Array2<f64> = Array::eye(5);
569
570        // Test with different parameters
571        let mut umap = UMAP::new(2, 3, 0.5, 0.5, 100)
572            .with_random_state(42)
573            .with_metric("euclidean");
574
575        let embedding = umap.fit_transform(&x).unwrap();
576        assert_eq!(embedding.shape(), &[5, 3]);
577    }
578}