1use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix2};
23use scirs2_core::numeric::{Float, NumCast};
24use scirs2_core::validation::{check_positive, checkshape};
25use scirs2_linalg::eigh;
26use std::collections::BinaryHeap;
27use std::f64;
28
29use crate::error::{Result, TransformError};
30
31#[derive(Debug, Clone, PartialEq)]
33pub enum ShortestPathAlgorithm {
34 Dijkstra,
36 FloydWarshall,
38}
39
40#[derive(Debug, Clone)]
58pub struct Isomap {
59 n_neighbors: usize,
61 n_components: usize,
63 neighbor_mode: String,
65 epsilon: f64,
67 path_algorithm: ShortestPathAlgorithm,
69 embedding: Option<Array2<f64>>,
71 training_data: Option<Array2<f64>>,
73 geodesic_distances: Option<Array2<f64>>,
75 residual_variance: Option<f64>,
77}
78
79impl Isomap {
80 pub fn new(n_neighbors: usize, n_components: usize) -> Self {
86 Isomap {
87 n_neighbors,
88 n_components,
89 neighbor_mode: "knn".to_string(),
90 epsilon: 0.0,
91 path_algorithm: ShortestPathAlgorithm::Dijkstra,
92 embedding: None,
93 training_data: None,
94 geodesic_distances: None,
95 residual_variance: None,
96 }
97 }
98
99 pub fn with_epsilon(mut self, epsilon: f64) -> Self {
101 self.neighbor_mode = "epsilon".to_string();
102 self.epsilon = epsilon;
103 self
104 }
105
106 pub fn with_path_algorithm(mut self, algorithm: ShortestPathAlgorithm) -> Self {
108 self.path_algorithm = algorithm;
109 self
110 }
111
112 fn compute_distances<S>(&self, x: &ArrayBase<S, Ix2>) -> Array2<f64>
114 where
115 S: Data,
116 S::Elem: Float + NumCast,
117 {
118 let n_samples = x.shape()[0];
119 let mut distances = Array2::zeros((n_samples, n_samples));
120
121 for i in 0..n_samples {
122 for j in i + 1..n_samples {
123 let mut dist = 0.0;
124 for k in 0..x.shape()[1] {
125 let diff: f64 = NumCast::from(x[[i, k]]).unwrap_or(0.0)
126 - NumCast::from(x[[j, k]]).unwrap_or(0.0);
127 dist += diff * diff;
128 }
129 dist = dist.sqrt();
130 distances[[i, j]] = dist;
131 distances[[j, i]] = dist;
132 }
133 }
134
135 distances
136 }
137
138 fn construct_graph(&self, distances: &Array2<f64>) -> Array2<f64> {
140 let n_samples = distances.shape()[0];
141 let mut graph = Array2::from_elem((n_samples, n_samples), f64::INFINITY);
142
143 for i in 0..n_samples {
145 graph[[i, i]] = 0.0;
146 }
147
148 if self.neighbor_mode == "knn" {
149 for i in 0..n_samples {
150 let mut heap: BinaryHeap<(std::cmp::Reverse<i64>, usize)> = BinaryHeap::new();
152
153 for j in 0..n_samples {
154 if i != j {
155 let dist_fixed = (distances[[i, j]] * 1e9) as i64;
156 heap.push((std::cmp::Reverse(dist_fixed), j));
157 }
158 }
159
160 for _ in 0..self.n_neighbors {
161 if let Some((_, j)) = heap.pop() {
162 graph[[i, j]] = distances[[i, j]];
163 graph[[j, i]] = distances[[j, i]]; }
165 }
166 }
167 } else {
168 for i in 0..n_samples {
170 for j in i + 1..n_samples {
171 if distances[[i, j]] <= self.epsilon {
172 graph[[i, j]] = distances[[i, j]];
173 graph[[j, i]] = distances[[j, i]];
174 }
175 }
176 }
177 }
178
179 graph
180 }
181
182 fn compute_shortest_paths_dijkstra(&self, graph: &Array2<f64>) -> Result<Array2<f64>> {
187 let n = graph.shape()[0];
188 let mut dist = Array2::from_elem((n, n), f64::INFINITY);
189
190 for i in 0..n {
192 dist[[i, i]] = 0.0;
193 }
194
195 let mut adjacency: Vec<Vec<(usize, f64)>> = vec![Vec::new(); n];
197 for i in 0..n {
198 for j in 0..n {
199 if i != j && graph[[i, j]] < f64::INFINITY {
200 adjacency[i].push((j, graph[[i, j]]));
201 }
202 }
203 }
204
205 for source in 0..n {
207 let mut heap: BinaryHeap<std::cmp::Reverse<(i64, usize)>> = BinaryHeap::new();
209 let mut visited = vec![false; n];
210
211 dist[[source, source]] = 0.0;
212 heap.push(std::cmp::Reverse((0, source)));
213
214 while let Some(std::cmp::Reverse((d_fixed, u))) = heap.pop() {
215 if visited[u] {
216 continue;
217 }
218 visited[u] = true;
219
220 let d_u = d_fixed as f64 / 1e9;
221
222 for &(v, weight) in &adjacency[u] {
223 let new_dist = d_u + weight;
224 if new_dist < dist[[source, v]] {
225 dist[[source, v]] = new_dist;
226 let d_fixed_new = (new_dist * 1e9) as i64;
227 heap.push(std::cmp::Reverse((d_fixed_new, v)));
228 }
229 }
230 }
231 }
232
233 for i in 0..n {
235 for j in 0..n {
236 if dist[[i, j]].is_infinite() {
237 return Err(TransformError::InvalidInput(
238 "Graph is not connected. Try increasing n_neighbors or epsilon."
239 .to_string(),
240 ));
241 }
242 }
243 }
244
245 Ok(dist)
246 }
247
248 fn compute_shortest_paths_floyd_warshall(&self, graph: &Array2<f64>) -> Result<Array2<f64>> {
250 let n = graph.shape()[0];
251 let mut dist = graph.clone();
252
253 for k in 0..n {
254 for i in 0..n {
255 for j in 0..n {
256 if dist[[i, k]] + dist[[k, j]] < dist[[i, j]] {
257 dist[[i, j]] = dist[[i, k]] + dist[[k, j]];
258 }
259 }
260 }
261 }
262
263 for i in 0..n {
265 for j in 0..n {
266 if dist[[i, j]].is_infinite() {
267 return Err(TransformError::InvalidInput(
268 "Graph is not connected. Try increasing n_neighbors or epsilon."
269 .to_string(),
270 ));
271 }
272 }
273 }
274
275 Ok(dist)
276 }
277
278 fn compute_shortest_paths(&self, graph: &Array2<f64>) -> Result<Array2<f64>> {
280 match self.path_algorithm {
281 ShortestPathAlgorithm::Dijkstra => self.compute_shortest_paths_dijkstra(graph),
282 ShortestPathAlgorithm::FloydWarshall => {
283 self.compute_shortest_paths_floyd_warshall(graph)
284 }
285 }
286 }
287
288 fn classical_mds(&self, distances: &Array2<f64>) -> Result<Array2<f64>> {
290 let n = distances.shape()[0];
291
292 let squared_distances = distances.mapv(|d| d * d);
294
295 let row_means = squared_distances.mean_axis(Axis(1)).ok_or_else(|| {
296 TransformError::ComputationError("Failed to compute row means".to_string())
297 })?;
298
299 let col_means = squared_distances.mean_axis(Axis(0)).ok_or_else(|| {
300 TransformError::ComputationError("Failed to compute column means".to_string())
301 })?;
302
303 let grand_mean = row_means.mean().ok_or_else(|| {
304 TransformError::ComputationError("Failed to compute grand mean".to_string())
305 })?;
306
307 let mut gram = Array2::zeros((n, n));
309 for i in 0..n {
310 for j in 0..n {
311 gram[[i, j]] =
312 -0.5 * (squared_distances[[i, j]] - row_means[i] - col_means[j] + grand_mean);
313 }
314 }
315
316 let gram_symmetric = 0.5 * (&gram + &gram.t());
318
319 let (eigenvalues, eigenvectors) =
321 eigh(&gram_symmetric.view(), None).map_err(|e| TransformError::LinalgError(e))?;
322
323 let mut indices: Vec<usize> = (0..n).collect();
325 indices.sort_by(|&i, &j| {
326 eigenvalues[j]
327 .partial_cmp(&eigenvalues[i])
328 .unwrap_or(std::cmp::Ordering::Equal)
329 });
330
331 let mut embedding = Array2::zeros((n, self.n_components));
333 for j in 0..self.n_components {
334 let idx = indices[j];
335 let scale = eigenvalues[idx].max(0.0).sqrt();
336
337 for i in 0..n {
338 embedding[[i, j]] = eigenvectors[[i, idx]] * scale;
339 }
340 }
341
342 Ok(embedding)
343 }
344
345 fn compute_residual_variance(
347 &self,
348 geodesic_distances: &Array2<f64>,
349 embedding: &Array2<f64>,
350 ) -> f64 {
351 let n = embedding.shape()[0];
352
353 let mut embedding_distances = Array2::zeros((n, n));
355 for i in 0..n {
356 for j in i + 1..n {
357 let mut dist_sq = 0.0;
358 for k in 0..embedding.shape()[1] {
359 let diff = embedding[[i, k]] - embedding[[j, k]];
360 dist_sq += diff * diff;
361 }
362 let dist = dist_sq.sqrt();
363 embedding_distances[[i, j]] = dist;
364 embedding_distances[[j, i]] = dist;
365 }
366 }
367
368 let mut sum_geodesic = 0.0;
370 let mut sum_embedding = 0.0;
371 let mut sum_geo_sq = 0.0;
372 let mut sum_emb_sq = 0.0;
373 let mut sum_product = 0.0;
374 let mut count = 0.0;
375
376 for i in 0..n {
377 for j in i + 1..n {
378 let g = geodesic_distances[[i, j]];
379 let e = embedding_distances[[i, j]];
380 sum_geodesic += g;
381 sum_embedding += e;
382 sum_geo_sq += g * g;
383 sum_emb_sq += e * e;
384 sum_product += g * e;
385 count += 1.0;
386 }
387 }
388
389 if count == 0.0 {
390 return 1.0;
391 }
392
393 let mean_geo = sum_geodesic / count;
394 let mean_emb = sum_embedding / count;
395
396 let var_geo = sum_geo_sq / count - mean_geo * mean_geo;
397 let var_emb = sum_emb_sq / count - mean_emb * mean_emb;
398 let cov = sum_product / count - mean_geo * mean_emb;
399
400 let denom = (var_geo * var_emb).sqrt();
401 if denom > 1e-10 {
402 let r = (cov / denom).clamp(-1.0, 1.0);
403 (1.0 - r * r).max(0.0)
406 } else {
407 1.0
408 }
409 }
410
411 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
413 where
414 S: Data,
415 S::Elem: Float + NumCast,
416 {
417 let (n_samples, n_features) = x.dim();
418
419 check_positive(self.n_neighbors, "n_neighbors")?;
420 check_positive(self.n_components, "n_components")?;
421 checkshape(x, &[n_samples, n_features], "x")?;
422
423 if n_samples < self.n_neighbors {
424 return Err(TransformError::InvalidInput(format!(
425 "n_neighbors={} must be <= n_samples={}",
426 self.n_neighbors, n_samples
427 )));
428 }
429
430 if self.n_components >= n_samples {
431 return Err(TransformError::InvalidInput(format!(
432 "n_components={} must be < n_samples={}",
433 self.n_components, n_samples
434 )));
435 }
436
437 let x_f64 = x.mapv(|v| NumCast::from(v).unwrap_or(0.0));
438
439 let distances = self.compute_distances(&x_f64.view());
441
442 let graph = self.construct_graph(&distances);
444
445 let geodesic_distances = self.compute_shortest_paths(&graph)?;
447
448 let embedding = self.classical_mds(&geodesic_distances)?;
450
451 let residual_var = self.compute_residual_variance(&geodesic_distances, &embedding);
453
454 self.embedding = Some(embedding);
455 self.training_data = Some(x_f64);
456 self.geodesic_distances = Some(geodesic_distances);
457 self.residual_variance = Some(residual_var);
458
459 Ok(())
460 }
461
462 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
464 where
465 S: Data,
466 S::Elem: Float + NumCast,
467 {
468 if self.embedding.is_none() {
469 return Err(TransformError::NotFitted(
470 "Isomap model has not been fitted".to_string(),
471 ));
472 }
473
474 let training_data = self
475 .training_data
476 .as_ref()
477 .ok_or_else(|| TransformError::NotFitted("Training data not available".to_string()))?;
478
479 let x_f64 = x.mapv(|v| NumCast::from(v).unwrap_or(0.0));
480
481 if self.is_same_data(&x_f64, training_data) {
482 return self
483 .embedding
484 .as_ref()
485 .cloned()
486 .ok_or_else(|| TransformError::NotFitted("Embedding not available".to_string()));
487 }
488
489 self.landmark_mds(&x_f64)
490 }
491
492 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
494 where
495 S: Data,
496 S::Elem: Float + NumCast,
497 {
498 self.fit(x)?;
499 self.transform(x)
500 }
501
502 pub fn embedding(&self) -> Option<&Array2<f64>> {
504 self.embedding.as_ref()
505 }
506
507 pub fn geodesic_distances(&self) -> Option<&Array2<f64>> {
509 self.geodesic_distances.as_ref()
510 }
511
512 pub fn residual_variance(&self) -> Option<f64> {
514 self.residual_variance
515 }
516
517 fn is_same_data(&self, x: &Array2<f64>, training_data: &Array2<f64>) -> bool {
519 if x.dim() != training_data.dim() {
520 return false;
521 }
522 let (n_samples, n_features) = x.dim();
523 for i in 0..n_samples {
524 for j in 0..n_features {
525 if (x[[i, j]] - training_data[[i, j]]).abs() > 1e-10 {
526 return false;
527 }
528 }
529 }
530 true
531 }
532
533 fn landmark_mds(&self, x_new: &Array2<f64>) -> Result<Array2<f64>> {
535 let training_data = self
536 .training_data
537 .as_ref()
538 .ok_or_else(|| TransformError::NotFitted("Training data not available".to_string()))?;
539 let training_embedding = self
540 .embedding
541 .as_ref()
542 .ok_or_else(|| TransformError::NotFitted("Embedding not available".to_string()))?;
543
544 let (n_new, n_features) = x_new.dim();
545 let (n_training, _) = training_data.dim();
546
547 if n_features != training_data.ncols() {
548 return Err(TransformError::InvalidInput(format!(
549 "Input features {} must match training features {}",
550 n_features,
551 training_data.ncols()
552 )));
553 }
554
555 let mut distances_to_training = Array2::zeros((n_new, n_training));
557 for i in 0..n_new {
558 for j in 0..n_training {
559 let mut dist_sq = 0.0;
560 for k in 0..n_features {
561 let diff = x_new[[i, k]] - training_data[[j, k]];
562 dist_sq += diff * diff;
563 }
564 distances_to_training[[i, j]] = dist_sq.sqrt();
565 }
566 }
567
568 let mut new_embedding = Array2::zeros((n_new, self.n_components));
570
571 for i in 0..n_new {
572 let mut landmark_dists: Vec<(f64, usize)> = (0..n_training)
574 .map(|j| (distances_to_training[[i, j]], j))
575 .collect();
576 landmark_dists
577 .sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
578
579 let k_landmarks = (n_training / 2).max(self.n_components + 1).min(n_training);
580
581 let mut total_weight = 0.0;
582 let mut weighted_coords = vec![0.0; self.n_components];
583
584 for &(dist, landmark_idx) in landmark_dists.iter().take(k_landmarks) {
585 let weight = if dist > 1e-10 {
586 1.0 / (dist * dist + 1e-10)
587 } else {
588 1e10
589 };
590 total_weight += weight;
591
592 for dim in 0..self.n_components {
593 weighted_coords[dim] += weight * training_embedding[[landmark_idx, dim]];
594 }
595 }
596
597 if total_weight > 0.0 {
598 for dim in 0..self.n_components {
599 new_embedding[[i, dim]] = weighted_coords[dim] / total_weight;
600 }
601 }
602 }
603
604 Ok(new_embedding)
605 }
606}
607
608#[cfg(test)]
609mod tests {
610 use super::*;
611 use scirs2_core::ndarray::Array;
612
613 #[test]
614 fn test_isomap_basic() {
615 let n_points = 20;
616 let mut data = Vec::new();
617
618 for i in 0..n_points {
619 let t = i as f64 / n_points as f64 * 3.0 * std::f64::consts::PI;
620 let x = t.sin();
621 let y = 2.0 * (i as f64 / n_points as f64);
622 let z = t.cos();
623 data.extend_from_slice(&[x, y, z]);
624 }
625
626 let x = Array::from_shape_vec((n_points, 3), data).expect("Failed to create array");
627
628 let mut isomap = Isomap::new(5, 2);
629 let embedding = isomap
630 .fit_transform(&x)
631 .expect("Isomap fit_transform failed");
632
633 assert_eq!(embedding.shape(), &[n_points, 2]);
634 for val in embedding.iter() {
635 assert!(val.is_finite());
636 }
637 }
638
639 #[test]
640 fn test_isomap_dijkstra() {
641 let n_points = 15;
642 let mut data = Vec::new();
643 for i in 0..n_points {
644 let t = i as f64 / n_points as f64 * 2.0 * std::f64::consts::PI;
645 data.extend_from_slice(&[t.cos(), t.sin(), i as f64 * 0.1]);
646 }
647
648 let x = Array::from_shape_vec((n_points, 3), data).expect("Failed to create array");
649
650 let mut isomap = Isomap::new(4, 2).with_path_algorithm(ShortestPathAlgorithm::Dijkstra);
651 let embedding = isomap
652 .fit_transform(&x)
653 .expect("Isomap fit_transform failed");
654
655 assert_eq!(embedding.shape(), &[n_points, 2]);
656 for val in embedding.iter() {
657 assert!(val.is_finite());
658 }
659 }
660
661 #[test]
662 fn test_isomap_floyd_warshall() {
663 let n_points = 15;
664 let mut data = Vec::new();
665 for i in 0..n_points {
666 let t = i as f64 / n_points as f64 * 2.0 * std::f64::consts::PI;
667 data.extend_from_slice(&[t.cos(), t.sin(), i as f64 * 0.1]);
668 }
669
670 let x = Array::from_shape_vec((n_points, 3), data).expect("Failed to create array");
671
672 let mut isomap =
673 Isomap::new(4, 2).with_path_algorithm(ShortestPathAlgorithm::FloydWarshall);
674 let embedding = isomap
675 .fit_transform(&x)
676 .expect("Isomap fit_transform failed");
677
678 assert_eq!(embedding.shape(), &[n_points, 2]);
679 for val in embedding.iter() {
680 assert!(val.is_finite());
681 }
682 }
683
684 #[test]
685 fn test_isomap_epsilon_ball() {
686 let x: Array2<f64> = Array::eye(5);
687
688 let mut isomap = Isomap::new(3, 2).with_epsilon(1.5);
689 let result = isomap.fit_transform(&x);
690
691 assert!(result.is_ok());
692 let embedding = result.expect("Isomap fit_transform failed");
693 assert_eq!(embedding.shape(), &[5, 2]);
694 }
695
696 #[test]
697 fn test_isomap_disconnected_graph() {
698 let x = scirs2_core::ndarray::array![[0.0, 0.0], [0.1, 0.1], [10.0, 10.0], [10.1, 10.1],];
699
700 let mut isomap = Isomap::new(1, 2);
701 let result = isomap.fit(&x);
702
703 assert!(result.is_err());
704 if let Err(e) = result {
705 match e {
706 TransformError::InvalidInput(msg) => {
707 assert!(msg.contains("Graph is not connected"));
708 }
709 _ => panic!("Expected InvalidInput error for disconnected graph"),
710 }
711 }
712 }
713
714 #[test]
715 fn test_isomap_residual_variance() {
716 let n_points = 20;
717 let mut data = Vec::new();
718 for i in 0..n_points {
719 let t = i as f64 / n_points as f64;
720 data.extend_from_slice(&[t, t * 2.0, t * 3.0]);
721 }
722
723 let x = Array::from_shape_vec((n_points, 3), data).expect("Failed to create array");
724
725 let mut isomap = Isomap::new(5, 2);
726 let _ = isomap
727 .fit_transform(&x)
728 .expect("Isomap fit_transform failed");
729
730 let rv = isomap.residual_variance();
732 assert!(rv.is_some());
733 let rv_val = rv.expect("Residual variance should exist");
734 assert!(rv_val >= 0.0);
735 assert!(rv_val <= 1.0);
736 }
737
738 #[test]
739 fn test_isomap_out_of_sample() {
740 let n_points = 20;
741 let mut data = Vec::new();
742 for i in 0..n_points {
743 let t = i as f64 / n_points as f64;
744 data.extend_from_slice(&[t, t * 2.0, t * 3.0]);
745 }
746
747 let x = Array::from_shape_vec((n_points, 3), data).expect("Failed to create array");
748
749 let mut isomap = Isomap::new(5, 2);
750 isomap.fit(&x).expect("Isomap fit failed");
751
752 let x_new = Array::from_shape_vec((2, 3), vec![0.25, 0.5, 0.75, 0.75, 1.5, 2.25])
753 .expect("Failed to create test array");
754
755 let new_embedding = isomap.transform(&x_new).expect("Isomap transform failed");
756 assert_eq!(new_embedding.shape(), &[2, 2]);
757 for val in new_embedding.iter() {
758 assert!(val.is_finite());
759 }
760 }
761}