1use scirs2_core::ndarray::{Array2, ArrayBase, Data, Ix2};
7use scirs2_core::numeric::{Float, NumCast};
8use scirs2_core::random::Rng;
9use scirs2_core::validation::{check_positive, checkshape};
10use std::collections::BinaryHeap;
11
12use crate::error::{Result, TransformError};
13
14#[derive(Debug, Clone)]
19pub struct UMAP {
20 n_neighbors: usize,
22 n_components: usize,
24 #[allow(dead_code)]
26 mindist: f64,
27 #[allow(dead_code)]
29 spread: f64,
30 learning_rate: f64,
32 n_epochs: usize,
34 random_state: Option<u64>,
36 training_data: Option<Array2<f64>>,
38 training_graph: Option<Array2<f64>>,
40 metric: String,
42 embedding: Option<Array2<f64>>,
44 a: f64,
46 b: f64,
47}
48
49impl UMAP {
50 pub fn new(
59 n_neighbors: usize,
60 n_components: usize,
61 mindist: f64,
62 learning_rate: f64,
63 n_epochs: usize,
64 ) -> Self {
65 let spread = 1.0;
67 let (a, b) = Self::find_ab_params(spread, mindist);
68
69 UMAP {
70 n_neighbors,
71 n_components,
72 mindist,
73 spread,
74 learning_rate,
75 n_epochs,
76 random_state: None,
77 metric: "euclidean".to_string(),
78 embedding: None,
79 training_data: None,
80 training_graph: None,
81 a,
82 b,
83 }
84 }
85
86 pub fn with_random_state(mut self, seed: u64) -> Self {
88 self.random_state = Some(seed);
89 self
90 }
91
92 pub fn with_metric(mut self, metric: &str) -> Self {
94 self.metric = metric.to_string();
95 self
96 }
97
98 fn find_ab_params(_spread: f64, mindist: f64) -> (f64, f64) {
100 let mut a = 1.0;
102 let mut b = 1.0;
103
104 if mindist > 0.0 {
106 b = mindist.ln() / (1.0 - mindist).ln();
107 }
108
109 for _ in 0..64 {
111 let val = 1.0 / (1.0 + a * mindist.powf(2.0 * b));
112 let grad_a = -mindist.powf(2.0 * b) / (1.0 + a * mindist.powf(2.0 * b)).powi(2);
113 let grad_b = -2.0 * a * mindist.powf(2.0 * b) * mindist.ln()
114 / (1.0 + a * mindist.powf(2.0 * b)).powi(2);
115
116 if (val - 0.5).abs() < 1e-5 {
117 break;
118 }
119
120 a -= (val - 0.5) / grad_a;
121 b -= (val - 0.5) / grad_b;
122
123 a = a.max(0.001);
124 b = b.max(0.001);
125 }
126
127 (a, b)
128 }
129
130 fn compute_distances<S>(&self, x: &ArrayBase<S, Ix2>) -> Array2<f64>
132 where
133 S: Data,
134 S::Elem: Float + NumCast,
135 {
136 let nsamples = x.shape()[0];
137 let mut distances = Array2::zeros((nsamples, nsamples));
138
139 for i in 0..nsamples {
141 for j in i + 1..nsamples {
142 let mut dist = 0.0;
143 for k in 0..x.shape()[1] {
144 let diff = NumCast::from(x[[i, k]]).unwrap_or(0.0)
145 - NumCast::from(x[[j, k]]).unwrap_or(0.0);
146 dist += diff * diff;
147 }
148 dist = dist.sqrt();
149 distances[[i, j]] = dist;
150 distances[[j, i]] = dist;
151 }
152 }
153
154 distances
155 }
156
157 fn find_neighbors(&self, distances: &Array2<f64>) -> (Array2<usize>, Array2<f64>) {
159 let nsamples = distances.shape()[0];
160 let k = self.n_neighbors;
161
162 let mut indices = Array2::zeros((nsamples, k));
163 let mut neighbor_distances = Array2::zeros((nsamples, k));
164
165 for i in 0..nsamples {
166 let mut heap: BinaryHeap<(std::cmp::Reverse<i64>, usize)> = BinaryHeap::new();
168
169 for j in 0..nsamples {
170 if i != j {
171 let dist_fixed = (distances[[i, j]] * 1e9) as i64;
173 heap.push((std::cmp::Reverse(dist_fixed), j));
174 }
175 }
176
177 for j in 0..k {
179 if let Some((std::cmp::Reverse(dist_fixed), idx)) = heap.pop() {
180 indices[[i, j]] = idx;
181 neighbor_distances[[i, j]] = dist_fixed as f64 / 1e9;
182 }
183 }
184 }
185
186 (indices, neighbor_distances)
187 }
188
189 fn compute_graph(
191 &self,
192 knn_indices: &Array2<usize>,
193 knn_distances: &Array2<f64>,
194 ) -> Array2<f64> {
195 let nsamples = knn_indices.shape()[0];
196 let mut graph = Array2::zeros((nsamples, nsamples));
197
198 for i in 0..nsamples {
200 let rho = knn_distances[[i, 0]];
202
203 let mut sigma = 1.0;
205 let target = self.n_neighbors as f64;
206
207 for _ in 0..64 {
208 let mut sum = 0.0;
209 for j in 1..self.n_neighbors {
210 let d = (knn_distances[[i, j]] - rho).max(0.0);
211 sum += (-d / sigma).exp();
212 }
213
214 if (sum - target).abs() < 1e-5 {
215 break;
216 }
217
218 if sum > target {
219 sigma *= 2.0;
220 } else {
221 sigma /= 2.0;
222 }
223 }
224
225 for j in 0..self.n_neighbors {
227 let neighbor_idx = knn_indices[[i, j]];
228 let d = (knn_distances[[i, j]] - rho).max(0.0);
229 let strength = (-d / sigma).exp();
230 graph[[i, neighbor_idx]] = strength;
231 }
232 }
233
234 let graph_transpose = graph.t().to_owned();
236 &graph + &graph_transpose - &graph * &graph_transpose
237 }
238
239 fn initialize_embedding(&self, nsamples: usize) -> Array2<f64> {
241 let mut rng = scirs2_core::random::rng();
242
243 let mut embedding = Array2::zeros((nsamples, self.n_components));
245 for i in 0..nsamples {
246 for j in 0..self.n_components {
247 embedding[[i, j]] = rng.gen_range(0.0..1.0) * 10.0 - 5.0;
248 }
249 }
250
251 embedding
252 }
253
254 fn optimize_embedding(
256 &self,
257 embedding: &mut Array2<f64>,
258 graph: &Array2<f64>,
259 n_epochs: usize,
260 ) {
261 let nsamples = embedding.shape()[0];
262 let mut rng = scirs2_core::random::rng();
263
264 let mut edges = Vec::new();
266 let mut weights = Vec::new();
267 for i in 0..nsamples {
268 for j in 0..nsamples {
269 if graph[[i, j]] > 0.0 {
270 edges.push((i, j));
271 weights.push(graph[[i, j]]);
272 }
273 }
274 }
275
276 let n_edges = edges.len();
277
278 for epoch in 0..n_epochs {
280 let alpha = self.learning_rate * (1.0 - epoch as f64 / n_epochs as f64);
282
283 for _ in 0..n_edges {
285 let edge_idx = rng.gen_range(0..n_edges);
287 let (i, j) = edges[edge_idx];
288
289 let mut dist_sq = 0.0;
291 for d in 0..self.n_components {
292 let diff = embedding[[i, d]] - embedding[[j, d]];
293 dist_sq += diff * diff;
294 }
295 let dist = dist_sq.sqrt();
296
297 if dist > 0.0 {
299 let attraction = -2.0 * self.a * self.b * dist.powf(2.0 * self.b - 2.0)
300 / (1.0 + self.a * dist.powf(2.0 * self.b));
301
302 for d in 0..self.n_components {
303 let grad = attraction * (embedding[[i, d]] - embedding[[j, d]]) / dist;
304 embedding[[i, d]] += alpha * grad * weights[edge_idx];
305 embedding[[j, d]] -= alpha * grad * weights[edge_idx];
306 }
307 }
308
309 let k = rng.gen_range(0..nsamples);
311 if k != i && k != j {
312 let mut neg_dist_sq = 0.0;
313 for d in 0..self.n_components {
314 let diff = embedding[[i, d]] - embedding[[k, d]];
315 neg_dist_sq += diff * diff;
316 }
317 let neg_dist = neg_dist_sq.sqrt();
318
319 if neg_dist > 0.0 {
320 let repulsion = 2.0 * self.b
321 / (1.0 + self.a * neg_dist.powf(2.0 * self.b))
322 / (1.0 + neg_dist * neg_dist);
323
324 for d in 0..self.n_components {
325 let grad =
326 repulsion * (embedding[[i, d]] - embedding[[k, d]]) / neg_dist;
327 embedding[[i, d]] += alpha * grad;
328 embedding[[k, d]] -= alpha * grad;
329 }
330 }
331 }
332 }
333 }
334 }
335
336 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
344 where
345 S: Data,
346 S::Elem: Float + NumCast + Send + Sync,
347 {
348 let (nsamples, n_features) = x.dim();
349
350 check_positive(self.n_neighbors, "n_neighbors")?;
352 check_positive(self.n_components, "n_components")?;
353 check_positive(self.n_epochs, "n_epochs")?;
354 checkshape(x, &[nsamples, n_features], "x")?;
355
356 if nsamples < self.n_neighbors {
357 return Err(TransformError::InvalidInput(format!(
358 "n_neighbors={} must be <= nsamples={}",
359 self.n_neighbors, nsamples
360 )));
361 }
362
363 let training_data = Array2::from_shape_fn((nsamples, n_features), |(i, j)| {
365 NumCast::from(x[[i, j]]).unwrap_or(0.0)
366 });
367 self.training_data = Some(training_data);
368
369 let distances = self.compute_distances(x);
371
372 let (knn_indices, knn_distances) = self.find_neighbors(&distances);
374
375 let graph = self.compute_graph(&knn_indices, &knn_distances);
377 self.training_graph = Some(graph.clone());
378
379 let mut embedding = self.initialize_embedding(nsamples);
381
382 self.optimize_embedding(&mut embedding, &graph, self.n_epochs);
384
385 self.embedding = Some(embedding);
386
387 Ok(())
388 }
389
390 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
398 where
399 S: Data,
400 S::Elem: Float + NumCast,
401 {
402 if self.embedding.is_none() {
403 return Err(TransformError::NotFitted(
404 "UMAP model has not been fitted".to_string(),
405 ));
406 }
407
408 let training_data = self
409 .training_data
410 .as_ref()
411 .ok_or_else(|| TransformError::NotFitted("Training data not available".to_string()))?;
412
413 let (_n_new_samples, n_features) = x.dim();
414 let (_, n_training_features) = training_data.dim();
415
416 if n_features != n_training_features {
417 return Err(TransformError::InvalidInput(format!(
418 "Input features {n_features} must match training features {n_training_features}"
419 )));
420 }
421
422 if self.is_same_data(x, training_data) {
424 return Ok(self.embedding.as_ref().unwrap().clone());
425 }
426
427 self.transform_new_data(x)
429 }
430
431 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
439 where
440 S: Data,
441 S::Elem: Float + NumCast + Send + Sync,
442 {
443 self.fit(x)?;
444 self.transform(x)
445 }
446
447 pub fn embedding(&self) -> Option<&Array2<f64>> {
449 self.embedding.as_ref()
450 }
451
452 fn is_same_data<S>(&self, x: &ArrayBase<S, Ix2>, trainingdata: &Array2<f64>) -> bool
454 where
455 S: Data,
456 S::Elem: Float + NumCast,
457 {
458 if x.dim() != trainingdata.dim() {
459 return false;
460 }
461
462 let (nsamples, n_features) = x.dim();
463 for i in 0..nsamples {
464 for j in 0..n_features {
465 let x_val = NumCast::from(x[[i, j]]).unwrap_or(0.0);
466 if (x_val - trainingdata[[i, j]]).abs() > 1e-10 {
467 return false;
468 }
469 }
470 }
471 true
472 }
473
474 fn transform_new_data<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
476 where
477 S: Data,
478 S::Elem: Float + NumCast,
479 {
480 let training_data = self.training_data.as_ref().unwrap();
481 let training_embedding = self.embedding.as_ref().unwrap();
482
483 let (n_new_samples_, _) = x.dim();
484 let (n_training_samples_, _) = training_data.dim();
485
486 let mut new_embedding = Array2::zeros((n_new_samples_, self.n_components));
488
489 for i in 0..n_new_samples_ {
490 let mut distances = Vec::new();
492 for j in 0..n_training_samples_ {
493 let mut dist_sq = 0.0;
494 for k in 0..x.ncols() {
495 let x_val = NumCast::from(x[[i, k]]).unwrap_or(0.0);
496 let train_val = training_data[[j, k]];
497 let diff = x_val - train_val;
498 dist_sq += diff * diff;
499 }
500 distances.push((dist_sq.sqrt(), j));
501 }
502
503 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
505 let k = self.n_neighbors.min(n_training_samples_);
506
507 let mut total_weight = 0.0;
509 let mut weighted_coords = vec![0.0; self.n_components];
510
511 for (dist, train_idx) in distances.iter().take(k) {
512 let weight = if *dist > 1e-10 {
513 1.0 / (*dist + 1e-10)
514 } else {
515 1e10
516 };
517 total_weight += weight;
518
519 for dim in 0..self.n_components {
520 weighted_coords[dim] += weight * training_embedding[[*train_idx, dim]];
521 }
522 }
523
524 if total_weight > 0.0 {
526 for dim in 0..self.n_components {
527 new_embedding[[i, dim]] = weighted_coords[dim] / total_weight;
528 }
529 }
530 }
531
532 Ok(new_embedding)
533 }
534}
535
536#[cfg(test)]
537mod tests {
538 use super::*;
539 use scirs2_core::ndarray::Array;
540
541 #[test]
542 fn test_umap_basic() {
543 let x = Array::from_shape_vec(
545 (10, 3),
546 vec![
547 1.0, 2.0, 3.0, 1.1, 2.1, 3.1, 1.2, 2.2, 3.2, 5.0, 6.0, 7.0, 5.1, 6.1, 7.1, 5.2,
548 6.2, 7.2, 9.0, 10.0, 11.0, 9.1, 10.1, 11.1, 9.2, 10.2, 11.2, 9.3, 10.3, 11.3,
549 ],
550 )
551 .unwrap();
552
553 let mut umap = UMAP::new(3, 2, 0.1, 1.0, 50);
555 let embedding = umap.fit_transform(&x).unwrap();
556
557 assert_eq!(embedding.shape(), &[10, 2]);
559
560 for val in embedding.iter() {
562 assert!(val.is_finite());
563 }
564 }
565
566 #[test]
567 fn test_umap_parameters() {
568 let x: Array2<f64> = Array::eye(5);
569
570 let mut umap = UMAP::new(2, 3, 0.5, 0.5, 100)
572 .with_random_state(42)
573 .with_metric("euclidean");
574
575 let embedding = umap.fit_transform(&x).unwrap();
576 assert_eq!(embedding.shape(), &[5, 3]);
577 }
578}