umap_rs/embedding.rs
1use crate::config::UmapConfig;
2use crate::distances::EuclideanMetric;
3use crate::manifold::LearnedManifold;
4use crate::metric::Metric;
5use crate::optimizer::Optimizer;
6use crate::umap::find_ab_params::find_ab_params;
7use crate::umap::fuzzy_simplicial_set::FuzzySimplicialSet;
8use crate::umap::raise_disconnected_warning::raise_disconnected_warning;
9use dashmap::DashSet;
10use ndarray::Array2;
11use ndarray::ArrayView2;
12use rayon::iter::IntoParallelIterator;
13use rayon::iter::ParallelIterator;
14use serde::Deserialize;
15use serde::Serialize;
16use std::time::Instant;
17use tracing::info;
18
19/// UMAP dimensionality reduction algorithm.
20///
21/// This struct holds the configuration and metrics for UMAP. It can be reused
22/// to learn manifolds from multiple datasets with the same parameters.
23///
24/// # Example
25///
26/// ```ignore
27/// use umap::{Umap, UmapConfig};
28/// use ndarray::Array2;
29///
30/// let config = UmapConfig::default();
31/// let umap = Umap::new(config);
32///
33/// // Learn the manifold structure
34/// let manifold = umap.learn_manifold(
35/// data.view(),
36/// knn_indices.view(),
37/// knn_dists.view(),
38/// );
39///
40/// // Create an optimizer and run training
41/// let mut opt = Optimizer::new(
42/// manifold,
43/// init,
44/// 500, // total epochs
45/// config.optimization.repulsion_strength,
46/// config.optimization.learning_rate,
47/// config.optimization.negative_sample_rate,
48/// &euclidean_metric,
49/// );
50///
51/// while opt.remaining_epochs() > 0 {
52/// opt.step_epochs(opt.remaining_epochs().min(10));
53/// }
54///
55/// let fitted = opt.into_fitted(config);
56/// let embedding = fitted.embedding();
57/// ```
58pub struct Umap {
59 config: UmapConfig,
60 metric: Box<dyn Metric>,
61 output_metric: Box<dyn Metric>,
62}
63
64impl Umap {
65 /// Create a new UMAP instance with default Euclidean metrics.
66 ///
67 /// Both the input space metric (for graph construction) and output space
68 /// metric (for optimization) are set to Euclidean distance.
69 ///
70 /// # Arguments
71 ///
72 /// * `config` - UMAP configuration parameters
73 pub fn new(config: UmapConfig) -> Self {
74 Self {
75 config,
76 metric: Box::new(EuclideanMetric),
77 output_metric: Box::new(EuclideanMetric),
78 }
79 }
80
81 /// Create a UMAP instance with custom distance metrics.
82 ///
83 /// # Arguments
84 ///
85 /// * `config` - UMAP configuration parameters
86 /// * `metric` - Distance metric for input space (graph construction)
87 /// * `output_metric` - Distance metric for output embedding space (optimization)
88 ///
89 /// # Example
90 ///
91 /// ```ignore
92 /// let umap = Umap::with_metrics(
93 /// config,
94 /// Box::new(MyCustomMetric),
95 /// Box::new(EuclideanMetric),
96 /// );
97 /// ```
98 pub fn with_metrics(
99 config: UmapConfig,
100 metric: Box<dyn Metric>,
101 output_metric: Box<dyn Metric>,
102 ) -> Self {
103 Self {
104 config,
105 metric,
106 output_metric,
107 }
108 }
109
110 /// Learn the manifold structure from high-dimensional data.
111 ///
112 /// This is the expensive graph construction phase that builds a fuzzy
113 /// topological representation of the data. The result can be cached,
114 /// serialized, and reused for multiple different optimizations.
115 ///
116 /// This phase is deterministic (no randomness) and independent of the
117 /// target embedding dimensionality.
118 ///
119 /// # Arguments
120 ///
121 /// * `data` - Input data matrix (n_samples × n_features). Used for validation.
122 /// * `knn_indices` - Precomputed k-nearest neighbor indices (n_samples × n_neighbors).
123 /// Each row contains indices of the k nearest neighbors for that sample.
124 /// * `knn_dists` - Precomputed k-nearest neighbor distances (n_samples × n_neighbors).
125 /// Each row contains distances to the k nearest neighbors.
126 ///
127 /// # Returns
128 ///
129 /// A `LearnedManifold` containing the fuzzy simplicial set and local geometry.
130 ///
131 /// # Panics
132 ///
133 /// Panics if:
134 /// - Parameter validation fails (invalid ranges, incompatible sizes)
135 /// - Array shapes are incompatible
136 /// - Number of samples <= n_neighbors
137 ///
138 /// # Example
139 ///
140 /// ```ignore
141 /// let manifold = umap.learn_manifold(
142 /// data.view(),
143 /// knn_indices.view(),
144 /// knn_dists.view(),
145 /// );
146 /// // Save for later use
147 /// save_manifold(&manifold)?;
148 /// ```
149 pub fn learn_manifold(
150 &self,
151 data: ArrayView2<f32>,
152 knn_indices: ArrayView2<u32>,
153 knn_dists: ArrayView2<f32>,
154 ) -> LearnedManifold {
155 let n_samples = data.shape()[0];
156
157 // Validate parameters
158 self.validate_parameters(n_samples, &knn_indices, &knn_dists);
159
160 // Determine a and b parameters
161 let (a, b) =
162 if let (Some(a_val), Some(b_val)) = (self.config.manifold.a, self.config.manifold.b) {
163 (a_val, b_val)
164 } else {
165 find_ab_params(self.config.manifold.spread, self.config.manifold.min_dist)
166 };
167
168 // Determine disconnection distance
169 let disconnection_distance = self
170 .config
171 .graph
172 .disconnection_distance
173 .unwrap_or_else(|| self.metric.disconnection_threshold());
174
175 // Find and mark disconnected edges
176 let started = Instant::now();
177 let knn_disconnections = DashSet::new();
178 (0..n_samples).into_par_iter().for_each(|row_no| {
179 let row = knn_dists.row(row_no);
180 for (col_no, &dist) in row.iter().enumerate() {
181 if dist >= disconnection_distance {
182 knn_disconnections.insert((row_no, col_no));
183 }
184 }
185 });
186 let edges_removed = knn_disconnections.len();
187 info!(
188 duration_ms = started.elapsed().as_millis(),
189 edges_removed, "disconnection detection complete"
190 );
191
192 // Build fuzzy simplicial set (the graph)
193 info!(
194 n_samples,
195 n_neighbors = self.config.graph.n_neighbors,
196 "starting fuzzy simplicial set"
197 );
198 let started = Instant::now();
199 let (graph, sigmas, rhos) = FuzzySimplicialSet::builder()
200 .n_samples(n_samples)
201 .n_neighbors(self.config.graph.n_neighbors)
202 .knn_indices(knn_indices)
203 .knn_dists(knn_dists)
204 .knn_disconnections(&knn_disconnections)
205 .local_connectivity(self.config.graph.local_connectivity)
206 .set_op_mix_ratio(self.config.graph.set_op_mix_ratio)
207 .apply_set_operations(self.config.graph.symmetrize)
208 .build()
209 .exec();
210 info!(
211 duration_ms = started.elapsed().as_millis(),
212 "fuzzy simplicial set complete"
213 );
214
215 // Check for disconnected vertices
216 let vertices_disconnected = graph
217 .outer_iterator()
218 .filter(|row| {
219 let sum: f32 = row.data().iter().sum();
220 sum == 0.0
221 })
222 .count();
223
224 raise_disconnected_warning(
225 edges_removed,
226 vertices_disconnected,
227 disconnection_distance,
228 n_samples,
229 0.1,
230 );
231
232 LearnedManifold {
233 graph,
234 sigmas,
235 rhos,
236 n_vertices: n_samples,
237 a,
238 b,
239 }
240 }
241
242 /// High-level convenience method that learns and optimizes in one call.
243 ///
244 /// This is equivalent to:
245 /// 1. `learn_manifold()` - build the graph
246 /// 2. `Optimizer::new()` - set up optimization
247 /// 3. Run all epochs
248 /// 4. `into_fitted()` - extract final model
249 ///
250 /// For checkpointing or more control, use the lower-level API instead.
251 ///
252 /// # Arguments
253 ///
254 /// * `data` - Input data matrix (n_samples × n_features)
255 /// * `knn_indices` - Precomputed k-nearest neighbor indices
256 /// * `knn_dists` - Precomputed k-nearest neighbor distances
257 /// * `init` - Initial embedding coordinates (n_samples × n_components)
258 ///
259 /// # Returns
260 ///
261 /// A `FittedUmap` containing the optimized embedding and learned manifold.
262 ///
263 /// # Example
264 ///
265 /// ```ignore
266 /// let fitted = umap.fit(
267 /// data.view(),
268 /// knn_indices.view(),
269 /// knn_dists.view(),
270 /// init.view(),
271 /// );
272 /// let embedding = fitted.embedding();
273 /// ```
274 pub fn fit(
275 &self,
276 data: ArrayView2<f32>,
277 knn_indices: ArrayView2<u32>,
278 knn_dists: ArrayView2<f32>,
279 init: ArrayView2<f32>,
280 ) -> FittedUmap {
281 let n_samples = data.shape()[0];
282
283 // Validate init array
284 if init.shape()[1] != self.config.n_components {
285 panic!(
286 "init has {} components but n_components is {}",
287 init.shape()[1],
288 self.config.n_components
289 );
290 }
291
292 if init.shape()[0] != n_samples {
293 panic!(
294 "init has {} samples but data has {} samples",
295 init.shape()[0],
296 n_samples
297 );
298 }
299
300 // Learn the manifold
301 let manifold = self.learn_manifold(data, knn_indices, knn_dists);
302
303 // Determine total epochs
304 let total_epochs = self
305 .config
306 .optimization
307 .n_epochs
308 .unwrap_or_else(|| if n_samples <= 10000 { 500 } else { 200 });
309
310 // Create optimizer
311 let metric_type = self.output_metric.metric_type();
312 let mut optimizer = Optimizer::new(
313 manifold,
314 init.to_owned(),
315 total_epochs,
316 &self.config,
317 metric_type,
318 );
319
320 // Run all epochs
321 optimizer.step_epochs(total_epochs, self.output_metric.as_ref());
322
323 // Extract final model
324 let mut fitted = optimizer.into_fitted(self.config.clone());
325
326 // Set disconnected vertices to NaN
327 for (i, row) in fitted.manifold.graph.outer_iterator().enumerate() {
328 let sum: f32 = row.data().iter().sum();
329 if sum == 0.0 {
330 for j in 0..fitted.embedding.shape()[1] {
331 fitted.embedding[(i, j)] = f32::NAN;
332 }
333 }
334 }
335
336 fitted
337 }
338
339 fn validate_parameters(
340 &self,
341 n_samples: usize,
342 knn_indices: &ArrayView2<u32>,
343 knn_dists: &ArrayView2<f32>,
344 ) {
345 // Validate graph parameters
346 if self.config.graph.set_op_mix_ratio < 0.0 || self.config.graph.set_op_mix_ratio > 1.0 {
347 panic!(
348 "set_op_mix_ratio must be between 0.0 and 1.0, got {}",
349 self.config.graph.set_op_mix_ratio
350 );
351 }
352
353 if self.config.graph.n_neighbors < 2 {
354 panic!(
355 "n_neighbors must be >= 2, got {}",
356 self.config.graph.n_neighbors
357 );
358 }
359
360 // Validate optimization parameters
361 if self.config.optimization.repulsion_strength < 0.0 {
362 panic!(
363 "repulsion_strength cannot be negative, got {}",
364 self.config.optimization.repulsion_strength
365 );
366 }
367
368 if self.config.manifold.min_dist > self.config.manifold.spread {
369 panic!(
370 "min_dist ({}) must be <= spread ({})",
371 self.config.manifold.min_dist, self.config.manifold.spread
372 );
373 }
374
375 if self.config.manifold.min_dist < 0.0 {
376 panic!(
377 "min_dist cannot be negative, got {}",
378 self.config.manifold.min_dist
379 );
380 }
381
382 // Validate optimization parameters
383 if self.config.optimization.learning_rate < 0.0 {
384 panic!(
385 "learning_rate must be positive, got {}",
386 self.config.optimization.learning_rate
387 );
388 }
389
390 if self.config.n_components < 1 {
391 panic!(
392 "n_components must be >= 1, got {}",
393 self.config.n_components
394 );
395 }
396
397 // Validate array shapes
398 if knn_dists.shape() != knn_indices.shape() {
399 panic!(
400 "knn_dists and knn_indices must have the same shape, got {:?} vs {:?}",
401 knn_dists.shape(),
402 knn_indices.shape()
403 );
404 }
405
406 if knn_dists.shape()[1] != self.config.graph.n_neighbors {
407 panic!(
408 "knn_dists has {} neighbors but n_neighbors is {}",
409 knn_dists.shape()[1],
410 self.config.graph.n_neighbors
411 );
412 }
413
414 if knn_dists.shape()[0] != n_samples {
415 panic!(
416 "knn_dists has {} samples but data has {} samples",
417 knn_dists.shape()[0],
418 n_samples
419 );
420 }
421
422 // Validate dataset size
423 if n_samples <= self.config.graph.n_neighbors {
424 panic!(
425 "Number of samples ({}) must be > n_neighbors ({})",
426 n_samples, self.config.graph.n_neighbors
427 );
428 }
429 }
430}
431
432/// A fitted UMAP model containing the learned manifold and embedding.
433///
434/// This is a lightweight struct that holds only the final results, without
435/// the heavy optimization state (epoch counters, preprocessed arrays, etc.).
436///
437/// The manifold can be serialized and reused for future work like transform().
438#[derive(Debug, Clone, Serialize, Deserialize)]
439pub struct FittedUmap {
440 pub(crate) embedding: Array2<f32>,
441 pub(crate) manifold: LearnedManifold,
442 pub(crate) config: UmapConfig,
443}
444
445impl FittedUmap {
446 /// Get a view of the computed embedding.
447 ///
448 /// Returns a zero-copy view of the embedding coordinates. Each row
449 /// represents one input sample in the low-dimensional space.
450 ///
451 /// # Returns
452 ///
453 /// An array view of shape (n_samples, n_components) containing the
454 /// embedded coordinates.
455 ///
456 /// # Example
457 ///
458 /// ```ignore
459 /// let embedding = fitted.embedding();
460 /// println!("Embedding shape: {:?}", embedding.shape());
461 /// ```
462 pub fn embedding(&self) -> ArrayView2<'_, f32> {
463 self.embedding.view()
464 }
465
466 /// Consume the model and return the embedding, avoiding a copy.
467 ///
468 /// This method takes ownership of the model and returns the embedding
469 /// array directly, which is useful if you don't need the model anymore.
470 ///
471 /// # Returns
472 ///
473 /// The embedding array of shape (n_samples, n_components).
474 ///
475 /// # Example
476 ///
477 /// ```ignore
478 /// let embedding = fitted.into_embedding();
479 /// // fitted is now consumed
480 /// ```
481 pub fn into_embedding(self) -> Array2<f32> {
482 self.embedding
483 }
484
485 /// Get a reference to the learned manifold.
486 pub fn manifold(&self) -> &LearnedManifold {
487 &self.manifold
488 }
489
490 /// Get a reference to the configuration used for this fit.
491 pub fn config(&self) -> &UmapConfig {
492 &self.config
493 }
494
495 /// Transform new data points into the embedding space.
496 ///
497 /// **Status: Not yet implemented**
498 ///
499 /// This method will project new data points into the learned embedding space
500 /// using the manifold structure learned during fitting.
501 ///
502 /// # Arguments
503 ///
504 /// * `new_data` - New data points to transform (n_new_samples × n_features)
505 /// * `new_knn_indices` - KNN indices of new points to training points
506 /// * `new_knn_dists` - KNN distances of new points to training points
507 ///
508 /// # Returns
509 ///
510 /// Embeddings for the new data points (n_new_samples × n_components)
511 ///
512 /// # Panics
513 ///
514 /// Currently panics with "not yet implemented" message.
515 #[allow(unused_variables)]
516 pub fn transform(
517 &self,
518 new_data: ArrayView2<f32>,
519 new_knn_indices: ArrayView2<u32>,
520 new_knn_dists: ArrayView2<f32>,
521 ) -> Array2<f32> {
522 todo!("Transform not yet implemented - contributions welcome!")
523 }
524}