1use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
22use scirs2_core::numeric::{Float, NumCast};
23use scirs2_core::random::Rng;
24use scirs2_core::validation::{check_positive, checkshape};
25use scirs2_linalg::eigh;
26use std::collections::BinaryHeap;
27
28use crate::error::{Result, TransformError};
29
30#[derive(Debug, Clone)]
47pub struct UMAP {
48 n_neighbors: usize,
50 n_components: usize,
52 min_dist: f64,
54 spread: f64,
56 learning_rate: f64,
58 n_epochs: usize,
60 random_state: Option<u64>,
62 training_data: Option<Array2<f64>>,
64 training_graph: Option<Array2<f64>>,
66 metric: String,
68 embedding: Option<Array2<f64>>,
70 negative_sample_rate: usize,
72 spectral_init: bool,
74 a: f64,
76 b: f64,
77 local_connectivity: f64,
79 set_op_mix_ratio: f64,
81}
82
83impl UMAP {
84 pub fn new(
93 n_neighbors: usize,
94 n_components: usize,
95 min_dist: f64,
96 learning_rate: f64,
97 n_epochs: usize,
98 ) -> Self {
99 let spread = 1.0;
100 let (a, b) = Self::find_ab_params(spread, min_dist);
101
102 UMAP {
103 n_neighbors,
104 n_components,
105 min_dist,
106 spread,
107 learning_rate,
108 n_epochs,
109 random_state: None,
110 metric: "euclidean".to_string(),
111 embedding: None,
112 training_data: None,
113 training_graph: None,
114 negative_sample_rate: 5,
115 spectral_init: true,
116 a,
117 b,
118 local_connectivity: 1.0,
119 set_op_mix_ratio: 1.0,
120 }
121 }
122
123 pub fn with_random_state(mut self, seed: u64) -> Self {
125 self.random_state = Some(seed);
126 self
127 }
128
129 pub fn with_metric(mut self, metric: &str) -> Self {
131 self.metric = metric.to_string();
132 self
133 }
134
135 pub fn with_negative_sample_rate(mut self, rate: usize) -> Self {
137 self.negative_sample_rate = rate;
138 self
139 }
140
141 pub fn with_spectral_init(mut self, use_spectral: bool) -> Self {
143 self.spectral_init = use_spectral;
144 self
145 }
146
147 pub fn with_local_connectivity(mut self, local_connectivity: f64) -> Self {
149 self.local_connectivity = local_connectivity.max(1.0);
150 self
151 }
152
153 pub fn with_set_op_mix_ratio(mut self, ratio: f64) -> Self {
155 self.set_op_mix_ratio = ratio.clamp(0.0, 1.0);
156 self
157 }
158
159 pub fn with_spread(mut self, spread: f64) -> Self {
161 self.spread = spread;
162 let (a, b) = Self::find_ab_params(spread, self.min_dist);
163 self.a = a;
164 self.b = b;
165 self
166 }
167
168 fn find_ab_params(spread: f64, min_dist: f64) -> (f64, f64) {
174 if min_dist <= 0.0 || spread <= 0.0 {
175 return (1.0, 1.0);
176 }
177
178 let mut a = 1.0;
179 let mut b = 1.0;
180
181 if min_dist < spread {
188 b = min_dist.ln().abs() / (1.0 - min_dist).ln().abs().max(1e-10);
189 b = b.clamp(0.1, 10.0);
190 }
191
192 for _ in 0..100 {
194 let mut residual_a = 0.0;
195 let mut residual_b = 0.0;
196 let mut jacobian_aa = 0.0;
197 let mut jacobian_bb = 0.0;
198
199 let n_samples = 50;
200 for k in 0..n_samples {
201 let d = min_dist + (3.0 * spread) * (k as f64 / n_samples as f64);
202 if d < 1e-10 {
203 continue;
204 }
205
206 let target = if d <= min_dist {
207 1.0
208 } else {
209 (-(d - min_dist) / spread).exp()
210 };
211
212 let d2b = d.powf(2.0 * b);
213 let denom = 1.0 + a * d2b;
214 let model = 1.0 / denom;
215 let diff = model - target;
216
217 let da = -d2b / (denom * denom);
219 let db = -2.0 * a * d2b * d.ln() / (denom * denom);
221
222 residual_a += diff * da;
223 residual_b += diff * db;
224 jacobian_aa += da * da;
225 jacobian_bb += db * db;
226 }
227
228 if jacobian_aa.abs() > 1e-15 {
229 a -= 0.5 * residual_a / jacobian_aa;
230 }
231 if jacobian_bb.abs() > 1e-15 {
232 b -= 0.5 * residual_b / jacobian_bb;
233 }
234
235 a = a.max(0.001);
236 b = b.max(0.001);
237
238 if residual_a.abs() < 1e-8 && residual_b.abs() < 1e-8 {
239 break;
240 }
241 }
242
243 (a, b)
244 }
245
246 fn compute_distances<S>(&self, x: &ArrayBase<S, Ix2>) -> Array2<f64>
248 where
249 S: Data,
250 S::Elem: Float + NumCast,
251 {
252 let n_samples = x.shape()[0];
253 let n_features = x.shape()[1];
254 let mut distances = Array2::zeros((n_samples, n_samples));
255
256 for i in 0..n_samples {
257 for j in i + 1..n_samples {
258 let dist = match self.metric.as_str() {
259 "manhattan" => {
260 let mut d = 0.0;
261 for k in 0..n_features {
262 let diff = NumCast::from(x[[i, k]]).unwrap_or(0.0)
263 - NumCast::from(x[[j, k]]).unwrap_or(0.0);
264 d += diff.abs();
265 }
266 d
267 }
268 "cosine" => {
269 let mut dot = 0.0;
270 let mut norm_i = 0.0;
271 let mut norm_j = 0.0;
272 for k in 0..n_features {
273 let vi: f64 = NumCast::from(x[[i, k]]).unwrap_or(0.0);
274 let vj: f64 = NumCast::from(x[[j, k]]).unwrap_or(0.0);
275 dot += vi * vj;
276 norm_i += vi * vi;
277 norm_j += vj * vj;
278 }
279 let denom = (norm_i * norm_j).sqrt();
280 if denom > 1e-10 {
281 1.0 - (dot / denom).clamp(-1.0, 1.0)
282 } else {
283 1.0
284 }
285 }
286 _ => {
287 let mut d = 0.0;
289 for k in 0..n_features {
290 let diff = NumCast::from(x[[i, k]]).unwrap_or(0.0)
291 - NumCast::from(x[[j, k]]).unwrap_or(0.0);
292 d += diff * diff;
293 }
294 d.sqrt()
295 }
296 };
297
298 distances[[i, j]] = dist;
299 distances[[j, i]] = dist;
300 }
301 }
302
303 distances
304 }
305
306 fn find_neighbors(&self, distances: &Array2<f64>) -> (Array2<usize>, Array2<f64>) {
308 let n_samples = distances.shape()[0];
309 let k = self.n_neighbors;
310
311 let mut indices = Array2::zeros((n_samples, k));
312 let mut neighbor_distances = Array2::zeros((n_samples, k));
313
314 for i in 0..n_samples {
315 let mut heap: BinaryHeap<(std::cmp::Reverse<i64>, usize)> = BinaryHeap::new();
316
317 for j in 0..n_samples {
318 if i != j {
319 let dist_fixed = (distances[[i, j]] * 1e9) as i64;
320 heap.push((std::cmp::Reverse(dist_fixed), j));
321 }
322 }
323
324 for j in 0..k {
325 if let Some((std::cmp::Reverse(dist_fixed), idx)) = heap.pop() {
326 indices[[i, j]] = idx;
327 neighbor_distances[[i, j]] = dist_fixed as f64 / 1e9;
328 }
329 }
330 }
331
332 (indices, neighbor_distances)
333 }
334
335 fn compute_graph(
341 &self,
342 knn_indices: &Array2<usize>,
343 knn_distances: &Array2<f64>,
344 ) -> Array2<f64> {
345 let n_samples = knn_indices.shape()[0];
346 let k = self.n_neighbors;
347 let mut graph = Array2::zeros((n_samples, n_samples));
348
349 for i in 0..n_samples {
350 let local_idx = (self.local_connectivity as usize)
353 .saturating_sub(1)
354 .min(k - 1);
355 let rho = knn_distances[[i, local_idx]];
356
357 let target = (k as f64).ln() / (2.0f64).ln();
359 let mut sigma_lo = 0.0;
360 let mut sigma_hi = f64::INFINITY;
361 let mut sigma = 1.0;
362
363 for _ in 0..64 {
364 let mut membership_sum = 0.0;
365 for j in 0..k {
366 let d = (knn_distances[[i, j]] - rho).max(0.0);
367 if sigma > 1e-15 {
368 membership_sum += (-d / sigma).exp();
369 }
370 }
371
372 if (membership_sum - target).abs() < 1e-5 {
373 break;
374 }
375
376 if membership_sum > target {
377 sigma_hi = sigma;
378 sigma = (sigma_lo + sigma_hi) / 2.0;
379 } else {
380 sigma_lo = sigma;
381 if sigma_hi == f64::INFINITY {
382 sigma *= 2.0;
383 } else {
384 sigma = (sigma_lo + sigma_hi) / 2.0;
385 }
386 }
387 }
388
389 for j in 0..k {
391 let neighbor_idx = knn_indices[[i, j]];
392 let d = (knn_distances[[i, j]] - rho).max(0.0);
393 let strength = if sigma > 1e-15 {
394 (-d / sigma).exp()
395 } else if d < 1e-15 {
396 1.0
397 } else {
398 0.0
399 };
400 graph[[i, neighbor_idx]] = strength;
401 }
402 }
403
404 let graph_t = graph.t().to_owned();
408
409 if (self.set_op_mix_ratio - 1.0).abs() < 1e-10 {
410 &graph + &graph_t - &graph * &graph_t
412 } else if self.set_op_mix_ratio.abs() < 1e-10 {
413 &graph * &graph_t
415 } else {
416 let union = &graph + &graph_t - &graph * &graph_t;
418 let intersection = &graph * &graph_t;
419 &intersection * (1.0 - self.set_op_mix_ratio) + &union * self.set_op_mix_ratio
420 }
421 }
422
423 fn initialize_embedding(&self, n_samples: usize, graph: &Array2<f64>) -> Result<Array2<f64>> {
425 if self.spectral_init && n_samples > self.n_components + 1 {
426 match self.spectral_init_from_graph(n_samples, graph) {
428 Ok(embedding) => return Ok(embedding),
429 Err(_) => {
430 }
432 }
433 }
434
435 let mut rng = scirs2_core::random::rng();
437 let mut embedding = Array2::zeros((n_samples, self.n_components));
438 for i in 0..n_samples {
439 for j in 0..self.n_components {
440 embedding[[i, j]] = rng.random_range(0.0..1.0) * 10.0 - 5.0;
441 }
442 }
443
444 Ok(embedding)
445 }
446
447 fn spectral_init_from_graph(
449 &self,
450 n_samples: usize,
451 graph: &Array2<f64>,
452 ) -> Result<Array2<f64>> {
453 let mut degree = Array1::zeros(n_samples);
455 for i in 0..n_samples {
456 degree[i] = graph.row(i).sum();
457 }
458
459 for i in 0..n_samples {
461 if degree[i] < 1e-10 {
462 return Err(TransformError::ComputationError(
463 "Graph has isolated nodes, cannot use spectral initialization".to_string(),
464 ));
465 }
466 }
467
468 let mut laplacian = Array2::zeros((n_samples, n_samples));
470 for i in 0..n_samples {
471 for j in 0..n_samples {
472 if i == j {
473 laplacian[[i, j]] = 1.0;
474 } else {
475 let norm_weight = graph[[i, j]] / (degree[i] * degree[j]).sqrt();
476 laplacian[[i, j]] = -norm_weight;
477 }
478 }
479 }
480
481 let (eigenvalues, eigenvectors) =
483 eigh(&laplacian.view(), None).map_err(|e| TransformError::LinalgError(e))?;
484
485 let mut indices: Vec<usize> = (0..n_samples).collect();
487 indices.sort_by(|&a, &b| {
488 eigenvalues[a]
489 .partial_cmp(&eigenvalues[b])
490 .unwrap_or(std::cmp::Ordering::Equal)
491 });
492
493 let mut embedding = Array2::zeros((n_samples, self.n_components));
494 for j in 0..self.n_components {
495 let idx = indices[j + 1]; let scale = 10.0; for i in 0..n_samples {
498 embedding[[i, j]] = eigenvectors[[i, idx]] * scale;
499 }
500 }
501
502 Ok(embedding)
503 }
504
505 fn optimize_embedding(
507 &self,
508 embedding: &mut Array2<f64>,
509 graph: &Array2<f64>,
510 n_epochs: usize,
511 ) {
512 let n_samples = embedding.shape()[0];
513 let mut rng = scirs2_core::random::rng();
514
515 let mut edges = Vec::new();
517 let mut weights = Vec::new();
518 for i in 0..n_samples {
519 for j in 0..n_samples {
520 if graph[[i, j]] > 0.0 {
521 edges.push((i, j));
522 weights.push(graph[[i, j]]);
523 }
524 }
525 }
526
527 let n_edges = edges.len();
528 if n_edges == 0 {
529 return;
530 }
531
532 let max_weight = weights.iter().cloned().fold(0.0f64, f64::max);
534 let epochs_per_sample: Vec<f64> = if max_weight > 0.0 {
535 weights
536 .iter()
537 .map(|&w| {
538 let epoch_ratio = max_weight / w.max(1e-10);
539 epoch_ratio.min(n_epochs as f64)
540 })
541 .collect()
542 } else {
543 vec![1.0; n_edges]
544 };
545
546 let mut epochs_per_negative_sample: Vec<f64> = epochs_per_sample
547 .iter()
548 .map(|&e| e / self.negative_sample_rate as f64)
549 .collect();
550
551 let mut epoch_of_next_sample: Vec<f64> = epochs_per_sample.clone();
552 let mut epoch_of_next_negative_sample: Vec<f64> = epochs_per_negative_sample.clone();
553
554 let clip_val = 4.0;
556
557 for epoch in 0..n_epochs {
559 let alpha = self.learning_rate * (1.0 - epoch as f64 / n_epochs as f64);
560
561 for edge_idx in 0..n_edges {
562 if epoch_of_next_sample[edge_idx] > epoch as f64 {
563 continue;
564 }
565
566 let (i, j) = edges[edge_idx];
567
568 let mut dist_sq = 0.0;
570 for d in 0..self.n_components {
571 let diff = embedding[[i, d]] - embedding[[j, d]];
572 dist_sq += diff * diff;
573 }
574 dist_sq = dist_sq.max(1e-10);
575
576 let grad_coeff = -2.0 * self.a * self.b * dist_sq.powf(self.b - 1.0)
578 / (1.0 + self.a * dist_sq.powf(self.b));
579
580 for d in 0..self.n_components {
581 let grad = (grad_coeff * (embedding[[i, d]] - embedding[[j, d]]))
582 .clamp(-clip_val, clip_val);
583 embedding[[i, d]] += alpha * grad;
584 embedding[[j, d]] -= alpha * grad;
585 }
586
587 epoch_of_next_sample[edge_idx] += epochs_per_sample[edge_idx];
589
590 let n_neg = self.negative_sample_rate;
592 for _ in 0..n_neg {
593 if epoch_of_next_negative_sample[edge_idx] > epoch as f64 {
594 break;
595 }
596
597 let k = rng.random_range(0..n_samples);
598 if k == i {
599 continue;
600 }
601
602 let mut neg_dist_sq = 0.0;
603 for d in 0..self.n_components {
604 let diff = embedding[[i, d]] - embedding[[k, d]];
605 neg_dist_sq += diff * diff;
606 }
607 neg_dist_sq = neg_dist_sq.max(1e-10);
608
609 let grad_coeff = 2.0 * self.b
611 / ((0.001 + neg_dist_sq) * (1.0 + self.a * neg_dist_sq.powf(self.b)));
612
613 for d in 0..self.n_components {
614 let grad = (grad_coeff * (embedding[[i, d]] - embedding[[k, d]]))
615 .clamp(-clip_val, clip_val);
616 embedding[[i, d]] += alpha * grad;
617 }
618
619 epoch_of_next_negative_sample[edge_idx] += epochs_per_negative_sample[edge_idx];
620 }
621 }
622 }
623 }
624
625 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
633 where
634 S: Data,
635 S::Elem: Float + NumCast + Send + Sync,
636 {
637 let (n_samples, n_features) = x.dim();
638
639 check_positive(self.n_neighbors, "n_neighbors")?;
640 check_positive(self.n_components, "n_components")?;
641 check_positive(self.n_epochs, "n_epochs")?;
642 checkshape(x, &[n_samples, n_features], "x")?;
643
644 if n_samples < self.n_neighbors {
645 return Err(TransformError::InvalidInput(format!(
646 "n_neighbors={} must be <= n_samples={}",
647 self.n_neighbors, n_samples
648 )));
649 }
650
651 let training_data = Array2::from_shape_fn((n_samples, n_features), |(i, j)| {
653 NumCast::from(x[[i, j]]).unwrap_or(0.0)
654 });
655 self.training_data = Some(training_data);
656
657 let distances = self.compute_distances(x);
659
660 let (knn_indices, knn_distances) = self.find_neighbors(&distances);
662
663 let graph = self.compute_graph(&knn_indices, &knn_distances);
665 self.training_graph = Some(graph.clone());
666
667 let mut embedding = self.initialize_embedding(n_samples, &graph)?;
669
670 self.optimize_embedding(&mut embedding, &graph, self.n_epochs);
672
673 self.embedding = Some(embedding);
674
675 Ok(())
676 }
677
678 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
686 where
687 S: Data,
688 S::Elem: Float + NumCast,
689 {
690 if self.embedding.is_none() {
691 return Err(TransformError::NotFitted(
692 "UMAP model has not been fitted".to_string(),
693 ));
694 }
695
696 let training_data = self
697 .training_data
698 .as_ref()
699 .ok_or_else(|| TransformError::NotFitted("Training data not available".to_string()))?;
700
701 let (_, n_features) = x.dim();
702 let (_, n_training_features) = training_data.dim();
703
704 if n_features != n_training_features {
705 return Err(TransformError::InvalidInput(format!(
706 "Input features {n_features} must match training features {n_training_features}"
707 )));
708 }
709
710 if self.is_same_data(x, training_data) {
712 return self
713 .embedding
714 .as_ref()
715 .cloned()
716 .ok_or_else(|| TransformError::NotFitted("Embedding not available".to_string()));
717 }
718
719 self.transform_new_data(x)
721 }
722
723 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
725 where
726 S: Data,
727 S::Elem: Float + NumCast + Send + Sync,
728 {
729 self.fit(x)?;
730 self.transform(x)
731 }
732
733 pub fn embedding(&self) -> Option<&Array2<f64>> {
735 self.embedding.as_ref()
736 }
737
738 pub fn graph(&self) -> Option<&Array2<f64>> {
740 self.training_graph.as_ref()
741 }
742
743 fn is_same_data<S>(&self, x: &ArrayBase<S, Ix2>, training_data: &Array2<f64>) -> bool
745 where
746 S: Data,
747 S::Elem: Float + NumCast,
748 {
749 if x.dim() != training_data.dim() {
750 return false;
751 }
752
753 let (n_samples, n_features) = x.dim();
754 for i in 0..n_samples {
755 for j in 0..n_features {
756 let x_val: f64 = NumCast::from(x[[i, j]]).unwrap_or(0.0);
757 if (x_val - training_data[[i, j]]).abs() > 1e-10 {
758 return false;
759 }
760 }
761 }
762 true
763 }
764
765 fn transform_new_data<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
767 where
768 S: Data,
769 S::Elem: Float + NumCast,
770 {
771 let training_data = self
772 .training_data
773 .as_ref()
774 .ok_or_else(|| TransformError::NotFitted("Training data not available".to_string()))?;
775 let training_embedding = self
776 .embedding
777 .as_ref()
778 .ok_or_else(|| TransformError::NotFitted("Embedding not available".to_string()))?;
779
780 let (n_new_samples, _) = x.dim();
781 let (n_training_samples, _) = training_data.dim();
782
783 let mut new_embedding = Array2::zeros((n_new_samples, self.n_components));
784
785 for i in 0..n_new_samples {
786 let mut distances: Vec<(f64, usize)> = Vec::with_capacity(n_training_samples);
788 for j in 0..n_training_samples {
789 let mut dist_sq = 0.0;
790 for k in 0..x.ncols() {
791 let x_val: f64 = NumCast::from(x[[i, k]]).unwrap_or(0.0);
792 let train_val = training_data[[j, k]];
793 let diff = x_val - train_val;
794 dist_sq += diff * diff;
795 }
796 distances.push((dist_sq.sqrt(), j));
797 }
798
799 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
801 let k = self.n_neighbors.min(n_training_samples);
802
803 let mut total_weight = 0.0;
805 let mut weighted_coords = vec![0.0; self.n_components];
806
807 for &(dist, train_idx) in distances.iter().take(k) {
808 let weight = if dist > 1e-10 {
809 1.0 / (dist + 1e-10)
810 } else {
811 1e10
812 };
813 total_weight += weight;
814
815 for dim in 0..self.n_components {
816 weighted_coords[dim] += weight * training_embedding[[train_idx, dim]];
817 }
818 }
819
820 if total_weight > 0.0 {
821 for dim in 0..self.n_components {
822 new_embedding[[i, dim]] = weighted_coords[dim] / total_weight;
823 }
824 }
825 }
826
827 Ok(new_embedding)
828 }
829}
830
831#[cfg(test)]
832mod tests {
833 use super::*;
834 use scirs2_core::ndarray::Array;
835
836 #[test]
837 fn test_umap_basic() {
838 let x = Array::from_shape_vec(
839 (10, 3),
840 vec![
841 1.0, 2.0, 3.0, 1.1, 2.1, 3.1, 1.2, 2.2, 3.2, 5.0, 6.0, 7.0, 5.1, 6.1, 7.1, 5.2,
842 6.2, 7.2, 9.0, 10.0, 11.0, 9.1, 10.1, 11.1, 9.2, 10.2, 11.2, 9.3, 10.3, 11.3,
843 ],
844 )
845 .expect("Failed to create test array");
846
847 let mut umap = UMAP::new(3, 2, 0.1, 1.0, 50);
848 let embedding = umap.fit_transform(&x).expect("UMAP fit_transform failed");
849
850 assert_eq!(embedding.shape(), &[10, 2]);
851 for val in embedding.iter() {
852 assert!(val.is_finite());
853 }
854 }
855
856 #[test]
857 fn test_umap_parameters() {
858 let x: Array2<f64> = Array::eye(5);
859
860 let mut umap = UMAP::new(2, 3, 0.5, 0.5, 100)
861 .with_random_state(42)
862 .with_metric("euclidean");
863
864 let embedding = umap.fit_transform(&x).expect("UMAP fit_transform failed");
865 assert_eq!(embedding.shape(), &[5, 3]);
866 }
867
868 #[test]
869 fn test_umap_spectral_init() {
870 let x = Array::from_shape_vec(
871 (8, 2),
872 vec![
873 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 5.0, 5.0, 5.0, 6.0, 6.0, 5.0, 6.0, 6.0,
874 ],
875 )
876 .expect("Failed to create test array");
877
878 let mut umap = UMAP::new(3, 2, 0.1, 1.0, 50).with_spectral_init(true);
879 let embedding = umap.fit_transform(&x).expect("UMAP fit_transform failed");
880
881 assert_eq!(embedding.shape(), &[8, 2]);
882 for val in embedding.iter() {
883 assert!(val.is_finite());
884 }
885 }
886
887 #[test]
888 fn test_umap_random_init() {
889 let x = Array::from_shape_vec(
890 (8, 2),
891 vec![
892 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 5.0, 5.0, 5.0, 6.0, 6.0, 5.0, 6.0, 6.0,
893 ],
894 )
895 .expect("Failed to create test array");
896
897 let mut umap = UMAP::new(3, 2, 0.1, 1.0, 50).with_spectral_init(false);
898 let embedding = umap.fit_transform(&x).expect("UMAP fit_transform failed");
899
900 assert_eq!(embedding.shape(), &[8, 2]);
901 for val in embedding.iter() {
902 assert!(val.is_finite());
903 }
904 }
905
906 #[test]
907 fn test_umap_negative_sampling() {
908 let x = Array::from_shape_vec(
909 (8, 2),
910 vec![
911 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 5.0, 5.0, 5.0, 6.0, 6.0, 5.0, 6.0, 6.0,
912 ],
913 )
914 .expect("Failed to create test array");
915
916 let mut umap = UMAP::new(3, 2, 0.1, 1.0, 50).with_negative_sample_rate(10);
917 let embedding = umap.fit_transform(&x).expect("UMAP fit_transform failed");
918
919 assert_eq!(embedding.shape(), &[8, 2]);
920 for val in embedding.iter() {
921 assert!(val.is_finite());
922 }
923 }
924
925 #[test]
926 fn test_umap_out_of_sample() {
927 let x_train = Array::from_shape_vec(
928 (10, 3),
929 vec![
930 1.0, 2.0, 3.0, 1.1, 2.1, 3.1, 1.2, 2.2, 3.2, 5.0, 6.0, 7.0, 5.1, 6.1, 7.1, 5.2,
931 6.2, 7.2, 9.0, 10.0, 11.0, 9.1, 10.1, 11.1, 9.2, 10.2, 11.2, 9.3, 10.3, 11.3,
932 ],
933 )
934 .expect("Failed to create test array");
935
936 let mut umap = UMAP::new(3, 2, 0.1, 1.0, 50);
937 umap.fit(&x_train).expect("UMAP fit failed");
938
939 let x_test = Array::from_shape_vec((2, 3), vec![1.05, 2.05, 3.05, 9.05, 10.05, 11.05])
940 .expect("Failed to create test array");
941
942 let test_embedding = umap.transform(&x_test).expect("UMAP transform failed");
943 assert_eq!(test_embedding.shape(), &[2, 2]);
944 for val in test_embedding.iter() {
945 assert!(val.is_finite());
946 }
947 }
948
949 #[test]
950 fn test_umap_find_ab_params() {
951 let (a, b) = UMAP::find_ab_params(1.0, 0.1);
952 assert!(a > 0.0);
953 assert!(b > 0.0);
954
955 let val_at_zero = 1.0 / (1.0 + a * 0.0f64.powf(2.0 * b));
957 assert!((val_at_zero - 1.0).abs() < 1e-5);
958 }
959}