umap_rs/
embedding.rs

1use crate::config::UmapConfig;
2use crate::distances::EuclideanMetric;
3use crate::manifold::LearnedManifold;
4use crate::metric::Metric;
5use crate::optimizer::Optimizer;
6use crate::umap::find_ab_params::find_ab_params;
7use crate::umap::fuzzy_simplicial_set::FuzzySimplicialSet;
8use crate::umap::raise_disconnected_warning::raise_disconnected_warning;
9use dashmap::DashSet;
10use ndarray::Array2;
11use ndarray::ArrayView2;
12use rayon::iter::IntoParallelIterator;
13use rayon::iter::ParallelIterator;
14use serde::Deserialize;
15use serde::Serialize;
16use std::time::Instant;
17use tracing::info;
18
19/// UMAP dimensionality reduction algorithm.
20///
21/// This struct holds the configuration and metrics for UMAP. It can be reused
22/// to learn manifolds from multiple datasets with the same parameters.
23///
24/// # Example
25///
26/// ```ignore
27/// use umap::{Umap, UmapConfig};
28/// use ndarray::Array2;
29///
30/// let config = UmapConfig::default();
31/// let umap = Umap::new(config);
32///
33/// // Learn the manifold structure
34/// let manifold = umap.learn_manifold(
35///     data.view(),
36///     knn_indices.view(),
37///     knn_dists.view(),
38/// );
39///
40/// // Create an optimizer and run training
41/// let mut opt = Optimizer::new(
42///     manifold,
43///     init,
44///     500, // total epochs
45///     config.optimization.repulsion_strength,
46///     config.optimization.learning_rate,
47///     config.optimization.negative_sample_rate,
48///     &euclidean_metric,
49/// );
50///
51/// while opt.remaining_epochs() > 0 {
52///     opt.step_epochs(opt.remaining_epochs().min(10));
53/// }
54///
55/// let fitted = opt.into_fitted(config);
56/// let embedding = fitted.embedding();
57/// ```
58pub struct Umap {
59  config: UmapConfig,
60  metric: Box<dyn Metric>,
61  output_metric: Box<dyn Metric>,
62}
63
64impl Umap {
65  /// Create a new UMAP instance with default Euclidean metrics.
66  ///
67  /// Both the input space metric (for graph construction) and output space
68  /// metric (for optimization) are set to Euclidean distance.
69  ///
70  /// # Arguments
71  ///
72  /// * `config` - UMAP configuration parameters
73  pub fn new(config: UmapConfig) -> Self {
74    Self {
75      config,
76      metric: Box::new(EuclideanMetric),
77      output_metric: Box::new(EuclideanMetric),
78    }
79  }
80
81  /// Create a UMAP instance with custom distance metrics.
82  ///
83  /// # Arguments
84  ///
85  /// * `config` - UMAP configuration parameters
86  /// * `metric` - Distance metric for input space (graph construction)
87  /// * `output_metric` - Distance metric for output embedding space (optimization)
88  ///
89  /// # Example
90  ///
91  /// ```ignore
92  /// let umap = Umap::with_metrics(
93  ///     config,
94  ///     Box::new(MyCustomMetric),
95  ///     Box::new(EuclideanMetric),
96  /// );
97  /// ```
98  pub fn with_metrics(
99    config: UmapConfig,
100    metric: Box<dyn Metric>,
101    output_metric: Box<dyn Metric>,
102  ) -> Self {
103    Self {
104      config,
105      metric,
106      output_metric,
107    }
108  }
109
110  /// Learn the manifold structure from high-dimensional data.
111  ///
112  /// This is the expensive graph construction phase that builds a fuzzy
113  /// topological representation of the data. The result can be cached,
114  /// serialized, and reused for multiple different optimizations.
115  ///
116  /// This phase is deterministic (no randomness) and independent of the
117  /// target embedding dimensionality.
118  ///
119  /// # Arguments
120  ///
121  /// * `data` - Input data matrix (n_samples × n_features). Used for validation.
122  /// * `knn_indices` - Precomputed k-nearest neighbor indices (n_samples × n_neighbors).
123  ///   Each row contains indices of the k nearest neighbors for that sample.
124  /// * `knn_dists` - Precomputed k-nearest neighbor distances (n_samples × n_neighbors).
125  ///   Each row contains distances to the k nearest neighbors.
126  ///
127  /// # Returns
128  ///
129  /// A `LearnedManifold` containing the fuzzy simplicial set and local geometry.
130  ///
131  /// # Panics
132  ///
133  /// Panics if:
134  /// - Parameter validation fails (invalid ranges, incompatible sizes)
135  /// - Array shapes are incompatible
136  /// - Number of samples <= n_neighbors
137  ///
138  /// # Example
139  ///
140  /// ```ignore
141  /// let manifold = umap.learn_manifold(
142  ///     data.view(),
143  ///     knn_indices.view(),
144  ///     knn_dists.view(),
145  /// );
146  /// // Save for later use
147  /// save_manifold(&manifold)?;
148  /// ```
149  pub fn learn_manifold(
150    &self,
151    data: ArrayView2<f32>,
152    knn_indices: ArrayView2<u32>,
153    knn_dists: ArrayView2<f32>,
154  ) -> LearnedManifold {
155    let n_samples = data.shape()[0];
156
157    // Validate parameters
158    self.validate_parameters(n_samples, &knn_indices, &knn_dists);
159
160    // Determine a and b parameters
161    let (a, b) =
162      if let (Some(a_val), Some(b_val)) = (self.config.manifold.a, self.config.manifold.b) {
163        (a_val, b_val)
164      } else {
165        find_ab_params(self.config.manifold.spread, self.config.manifold.min_dist)
166      };
167
168    // Determine disconnection distance
169    let disconnection_distance = self
170      .config
171      .graph
172      .disconnection_distance
173      .unwrap_or_else(|| self.metric.disconnection_threshold());
174
175    // Find and mark disconnected edges
176    let started = Instant::now();
177    let knn_disconnections = DashSet::new();
178    (0..n_samples).into_par_iter().for_each(|row_no| {
179      let row = knn_dists.row(row_no);
180      for (col_no, &dist) in row.iter().enumerate() {
181        if dist >= disconnection_distance {
182          knn_disconnections.insert((row_no, col_no));
183        }
184      }
185    });
186    let edges_removed = knn_disconnections.len();
187    info!(
188      duration_ms = started.elapsed().as_millis(),
189      edges_removed, "disconnection detection complete"
190    );
191
192    // Build fuzzy simplicial set (the graph)
193    info!(
194      n_samples,
195      n_neighbors = self.config.graph.n_neighbors,
196      "starting fuzzy simplicial set"
197    );
198    let started = Instant::now();
199    let (graph, sigmas, rhos) = FuzzySimplicialSet::builder()
200      .n_samples(n_samples)
201      .n_neighbors(self.config.graph.n_neighbors)
202      .knn_indices(knn_indices)
203      .knn_dists(knn_dists)
204      .knn_disconnections(&knn_disconnections)
205      .local_connectivity(self.config.graph.local_connectivity)
206      .set_op_mix_ratio(self.config.graph.set_op_mix_ratio)
207      .apply_set_operations(self.config.graph.symmetrize)
208      .build()
209      .exec();
210    info!(
211      duration_ms = started.elapsed().as_millis(),
212      "fuzzy simplicial set complete"
213    );
214
215    // Check for disconnected vertices
216    let vertices_disconnected = graph
217      .outer_iterator()
218      .filter(|row| {
219        let sum: f32 = row.data().iter().sum();
220        sum == 0.0
221      })
222      .count();
223
224    raise_disconnected_warning(
225      edges_removed,
226      vertices_disconnected,
227      disconnection_distance,
228      n_samples,
229      0.1,
230    );
231
232    LearnedManifold {
233      graph,
234      sigmas,
235      rhos,
236      n_vertices: n_samples,
237      a,
238      b,
239    }
240  }
241
242  /// High-level convenience method that learns and optimizes in one call.
243  ///
244  /// This is equivalent to:
245  /// 1. `learn_manifold()` - build the graph
246  /// 2. `Optimizer::new()` - set up optimization
247  /// 3. Run all epochs
248  /// 4. `into_fitted()` - extract final model
249  ///
250  /// For checkpointing or more control, use the lower-level API instead.
251  ///
252  /// # Arguments
253  ///
254  /// * `data` - Input data matrix (n_samples × n_features)
255  /// * `knn_indices` - Precomputed k-nearest neighbor indices
256  /// * `knn_dists` - Precomputed k-nearest neighbor distances
257  /// * `init` - Initial embedding coordinates (n_samples × n_components)
258  ///
259  /// # Returns
260  ///
261  /// A `FittedUmap` containing the optimized embedding and learned manifold.
262  ///
263  /// # Example
264  ///
265  /// ```ignore
266  /// let fitted = umap.fit(
267  ///     data.view(),
268  ///     knn_indices.view(),
269  ///     knn_dists.view(),
270  ///     init.view(),
271  /// );
272  /// let embedding = fitted.embedding();
273  /// ```
274  pub fn fit(
275    &self,
276    data: ArrayView2<f32>,
277    knn_indices: ArrayView2<u32>,
278    knn_dists: ArrayView2<f32>,
279    init: ArrayView2<f32>,
280  ) -> FittedUmap {
281    let n_samples = data.shape()[0];
282
283    // Validate init array
284    if init.shape()[1] != self.config.n_components {
285      panic!(
286        "init has {} components but n_components is {}",
287        init.shape()[1],
288        self.config.n_components
289      );
290    }
291
292    if init.shape()[0] != n_samples {
293      panic!(
294        "init has {} samples but data has {} samples",
295        init.shape()[0],
296        n_samples
297      );
298    }
299
300    // Learn the manifold
301    let manifold = self.learn_manifold(data, knn_indices, knn_dists);
302
303    // Determine total epochs
304    let total_epochs = self
305      .config
306      .optimization
307      .n_epochs
308      .unwrap_or_else(|| if n_samples <= 10000 { 500 } else { 200 });
309
310    // Create optimizer
311    let metric_type = self.output_metric.metric_type();
312    let mut optimizer = Optimizer::new(
313      manifold,
314      init.to_owned(),
315      total_epochs,
316      &self.config,
317      metric_type,
318    );
319
320    // Run all epochs
321    optimizer.step_epochs(total_epochs, self.output_metric.as_ref());
322
323    // Extract final model
324    let mut fitted = optimizer.into_fitted(self.config.clone());
325
326    // Set disconnected vertices to NaN
327    for (i, row) in fitted.manifold.graph.outer_iterator().enumerate() {
328      let sum: f32 = row.data().iter().sum();
329      if sum == 0.0 {
330        for j in 0..fitted.embedding.shape()[1] {
331          fitted.embedding[(i, j)] = f32::NAN;
332        }
333      }
334    }
335
336    fitted
337  }
338
339  fn validate_parameters(
340    &self,
341    n_samples: usize,
342    knn_indices: &ArrayView2<u32>,
343    knn_dists: &ArrayView2<f32>,
344  ) {
345    // Validate graph parameters
346    if self.config.graph.set_op_mix_ratio < 0.0 || self.config.graph.set_op_mix_ratio > 1.0 {
347      panic!(
348        "set_op_mix_ratio must be between 0.0 and 1.0, got {}",
349        self.config.graph.set_op_mix_ratio
350      );
351    }
352
353    if self.config.graph.n_neighbors < 2 {
354      panic!(
355        "n_neighbors must be >= 2, got {}",
356        self.config.graph.n_neighbors
357      );
358    }
359
360    // Validate optimization parameters
361    if self.config.optimization.repulsion_strength < 0.0 {
362      panic!(
363        "repulsion_strength cannot be negative, got {}",
364        self.config.optimization.repulsion_strength
365      );
366    }
367
368    if self.config.manifold.min_dist > self.config.manifold.spread {
369      panic!(
370        "min_dist ({}) must be <= spread ({})",
371        self.config.manifold.min_dist, self.config.manifold.spread
372      );
373    }
374
375    if self.config.manifold.min_dist < 0.0 {
376      panic!(
377        "min_dist cannot be negative, got {}",
378        self.config.manifold.min_dist
379      );
380    }
381
382    // Validate optimization parameters
383    if self.config.optimization.learning_rate < 0.0 {
384      panic!(
385        "learning_rate must be positive, got {}",
386        self.config.optimization.learning_rate
387      );
388    }
389
390    if self.config.n_components < 1 {
391      panic!(
392        "n_components must be >= 1, got {}",
393        self.config.n_components
394      );
395    }
396
397    // Validate array shapes
398    if knn_dists.shape() != knn_indices.shape() {
399      panic!(
400        "knn_dists and knn_indices must have the same shape, got {:?} vs {:?}",
401        knn_dists.shape(),
402        knn_indices.shape()
403      );
404    }
405
406    if knn_dists.shape()[1] != self.config.graph.n_neighbors {
407      panic!(
408        "knn_dists has {} neighbors but n_neighbors is {}",
409        knn_dists.shape()[1],
410        self.config.graph.n_neighbors
411      );
412    }
413
414    if knn_dists.shape()[0] != n_samples {
415      panic!(
416        "knn_dists has {} samples but data has {} samples",
417        knn_dists.shape()[0],
418        n_samples
419      );
420    }
421
422    // Validate dataset size
423    if n_samples <= self.config.graph.n_neighbors {
424      panic!(
425        "Number of samples ({}) must be > n_neighbors ({})",
426        n_samples, self.config.graph.n_neighbors
427      );
428    }
429  }
430}
431
432/// A fitted UMAP model containing the learned manifold and embedding.
433///
434/// This is a lightweight struct that holds only the final results, without
435/// the heavy optimization state (epoch counters, preprocessed arrays, etc.).
436///
437/// The manifold can be serialized and reused for future work like transform().
438#[derive(Debug, Clone, Serialize, Deserialize)]
439pub struct FittedUmap {
440  pub(crate) embedding: Array2<f32>,
441  pub(crate) manifold: LearnedManifold,
442  pub(crate) config: UmapConfig,
443}
444
445impl FittedUmap {
446  /// Get a view of the computed embedding.
447  ///
448  /// Returns a zero-copy view of the embedding coordinates. Each row
449  /// represents one input sample in the low-dimensional space.
450  ///
451  /// # Returns
452  ///
453  /// An array view of shape (n_samples, n_components) containing the
454  /// embedded coordinates.
455  ///
456  /// # Example
457  ///
458  /// ```ignore
459  /// let embedding = fitted.embedding();
460  /// println!("Embedding shape: {:?}", embedding.shape());
461  /// ```
462  pub fn embedding(&self) -> ArrayView2<'_, f32> {
463    self.embedding.view()
464  }
465
466  /// Consume the model and return the embedding, avoiding a copy.
467  ///
468  /// This method takes ownership of the model and returns the embedding
469  /// array directly, which is useful if you don't need the model anymore.
470  ///
471  /// # Returns
472  ///
473  /// The embedding array of shape (n_samples, n_components).
474  ///
475  /// # Example
476  ///
477  /// ```ignore
478  /// let embedding = fitted.into_embedding();
479  /// // fitted is now consumed
480  /// ```
481  pub fn into_embedding(self) -> Array2<f32> {
482    self.embedding
483  }
484
485  /// Get a reference to the learned manifold.
486  pub fn manifold(&self) -> &LearnedManifold {
487    &self.manifold
488  }
489
490  /// Get a reference to the configuration used for this fit.
491  pub fn config(&self) -> &UmapConfig {
492    &self.config
493  }
494
495  /// Transform new data points into the embedding space.
496  ///
497  /// **Status: Not yet implemented**
498  ///
499  /// This method will project new data points into the learned embedding space
500  /// using the manifold structure learned during fitting.
501  ///
502  /// # Arguments
503  ///
504  /// * `new_data` - New data points to transform (n_new_samples × n_features)
505  /// * `new_knn_indices` - KNN indices of new points to training points
506  /// * `new_knn_dists` - KNN distances of new points to training points
507  ///
508  /// # Returns
509  ///
510  /// Embeddings for the new data points (n_new_samples × n_components)
511  ///
512  /// # Panics
513  ///
514  /// Currently panics with "not yet implemented" message.
515  #[allow(unused_variables)]
516  pub fn transform(
517    &self,
518    new_data: ArrayView2<f32>,
519    new_knn_indices: ArrayView2<u32>,
520    new_knn_dists: ArrayView2<f32>,
521  ) -> Array2<f32> {
522    todo!("Transform not yet implemented - contributions welcome!")
523  }
524}