1use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix2};
8use scirs2_core::numeric::{Float, NumCast};
9use scirs2_core::validation::{check_positive, checkshape};
10use scirs2_linalg::eigh;
11use std::collections::BinaryHeap;
12use std::f64;
13
14use crate::error::{Result, TransformError};
15#[derive(Debug, Clone)]
23pub struct Isomap {
24 n_neighbors: usize,
26 n_components: usize,
28 neighbor_mode: String,
30 epsilon: f64,
32 embedding: Option<Array2<f64>>,
34 training_data: Option<Array2<f64>>,
36 geodesic_distances: Option<Array2<f64>>,
38}
39
40impl Isomap {
41 pub fn new(n_neighbors: usize, ncomponents: usize) -> Self {
47 Isomap {
48 n_neighbors,
49 n_components: ncomponents,
50 neighbor_mode: "knn".to_string(),
51 epsilon: 0.0,
52 embedding: None,
53 training_data: None,
54 geodesic_distances: None,
55 }
56 }
57
58 pub fn with_epsilon(mut self, epsilon: f64) -> Self {
60 self.neighbor_mode = "epsilon".to_string();
61 self.epsilon = epsilon;
62 self
63 }
64
65 fn compute_distances<S>(&self, x: &ArrayBase<S, Ix2>) -> Array2<f64>
67 where
68 S: Data,
69 S::Elem: Float + NumCast,
70 {
71 let n_samples = x.shape()[0];
72 let mut distances = Array2::zeros((n_samples, n_samples));
73
74 for i in 0..n_samples {
75 for j in i + 1..n_samples {
76 let mut dist = 0.0;
77 for k in 0..x.shape()[1] {
78 let diff = NumCast::from(x[[i, k]]).unwrap_or(0.0)
79 - NumCast::from(x[[j, k]]).unwrap_or(0.0);
80 dist += diff * diff;
81 }
82 dist = dist.sqrt();
83 distances[[i, j]] = dist;
84 distances[[j, i]] = dist;
85 }
86 }
87
88 distances
89 }
90
91 fn construct_graph(&self, distances: &Array2<f64>) -> Array2<f64> {
93 let n_samples = distances.shape()[0];
94 let mut graph = Array2::from_elem((n_samples, n_samples), f64::INFINITY);
95
96 for i in 0..n_samples {
98 graph[[i, i]] = 0.0;
99 }
100
101 if self.neighbor_mode == "knn" {
102 for i in 0..n_samples {
104 let mut heap: BinaryHeap<(std::cmp::Reverse<i64>, usize)> = BinaryHeap::new();
106
107 for j in 0..n_samples {
108 if i != j {
109 let dist_fixed = (distances[[i, j]] * 1e9) as i64;
110 heap.push((std::cmp::Reverse(dist_fixed), j));
111 }
112 }
113
114 for _ in 0..self.n_neighbors {
116 if let Some((_, j)) = heap.pop() {
117 graph[[i, j]] = distances[[i, j]];
118 graph[[j, i]] = distances[[j, i]]; }
120 }
121 }
122 } else {
123 for i in 0..n_samples {
125 for j in i + 1..n_samples {
126 if distances[[i, j]] <= self.epsilon {
127 graph[[i, j]] = distances[[i, j]];
128 graph[[j, i]] = distances[[j, i]];
129 }
130 }
131 }
132 }
133
134 graph
135 }
136
137 fn compute_shortest_paths(&self, graph: &Array2<f64>) -> Result<Array2<f64>> {
139 let n = graph.shape()[0];
140 let mut dist = graph.clone();
141
142 for k in 0..n {
144 for i in 0..n {
145 for j in 0..n {
146 if dist[[i, k]] + dist[[k, j]] < dist[[i, j]] {
147 dist[[i, j]] = dist[[i, k]] + dist[[k, j]];
148 }
149 }
150 }
151 }
152
153 for i in 0..n {
155 for j in 0..n {
156 if dist[[i, j]].is_infinite() {
157 return Err(TransformError::InvalidInput(
158 "Graph is not connected. Try increasing n_neighbors or epsilon."
159 .to_string(),
160 ));
161 }
162 }
163 }
164
165 Ok(dist)
166 }
167
168 fn classical_mds(&self, distances: &Array2<f64>) -> Result<Array2<f64>> {
170 let n = distances.shape()[0];
171
172 let squared_distances = distances.mapv(|d| d * d);
174
175 let row_means = squared_distances.mean_axis(Axis(1)).unwrap();
177
178 let col_means = squared_distances.mean_axis(Axis(0)).unwrap();
180
181 let grand_mean = row_means.mean().unwrap();
183
184 let mut gram = Array2::zeros((n, n));
186 for i in 0..n {
187 for j in 0..n {
188 gram[[i, j]] =
189 -0.5 * (squared_distances[[i, j]] - row_means[i] - col_means[j] + grand_mean);
190 }
191 }
192
193 let gram_symmetric = 0.5 * (&gram + &gram.t());
195
196 let (eigenvalues, eigenvectors) = match eigh(&gram_symmetric.view(), None) {
198 Ok(result) => result,
199 Err(e) => return Err(TransformError::LinalgError(e)),
200 };
201
202 let mut indices: Vec<usize> = (0..n).collect();
204 indices.sort_by(|&i, &j| eigenvalues[j].partial_cmp(&eigenvalues[i]).unwrap());
205
206 let mut embedding = Array2::zeros((n, self.n_components));
208 for j in 0..self.n_components {
209 let idx = indices[j];
210 let scale = eigenvalues[idx].max(0.0).sqrt();
211
212 for i in 0..n {
213 embedding[[i, j]] = eigenvectors[[i, idx]] * scale;
214 }
215 }
216
217 Ok(embedding)
218 }
219
220 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
228 where
229 S: Data,
230 S::Elem: Float + NumCast,
231 {
232 let (n_samples, n_features) = x.dim();
233
234 check_positive(self.n_neighbors, "n_neighbors")?;
236 check_positive(self.n_components, "n_components")?;
237 checkshape(x, &[n_samples, n_features], "x")?;
238
239 if n_samples < self.n_neighbors {
240 return Err(TransformError::InvalidInput(format!(
241 "n_neighbors={} must be <= n_samples={}",
242 self.n_neighbors, n_samples
243 )));
244 }
245
246 if self.n_components >= n_samples {
247 return Err(TransformError::InvalidInput(format!(
248 "n_components={} must be < n_samples={}",
249 self.n_components, n_samples
250 )));
251 }
252
253 let x_f64 = x.mapv(|v| NumCast::from(v).unwrap_or(0.0));
255
256 let distances = self.compute_distances(&x_f64.view());
258
259 let graph = self.construct_graph(&distances);
261
262 let geodesic_distances = self.compute_shortest_paths(&graph)?;
264
265 let embedding = self.classical_mds(&geodesic_distances)?;
267
268 self.embedding = Some(embedding);
269 self.training_data = Some(x_f64);
270 self.geodesic_distances = Some(geodesic_distances);
271
272 Ok(())
273 }
274
275 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
285 where
286 S: Data,
287 S::Elem: Float + NumCast,
288 {
289 if self.embedding.is_none() {
290 return Err(TransformError::NotFitted(
291 "Isomap model has not been fitted".to_string(),
292 ));
293 }
294
295 let training_data = self
296 .training_data
297 .as_ref()
298 .ok_or_else(|| TransformError::NotFitted("Training data not available".to_string()))?;
299
300 let x_f64 = x.mapv(|v| NumCast::from(v).unwrap_or(0.0));
301
302 if self.is_same_data(&x_f64, training_data) {
304 return Ok(self.embedding.as_ref().unwrap().clone());
305 }
306
307 self.landmark_mds(&x_f64)
309 }
310
311 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
319 where
320 S: Data,
321 S::Elem: Float + NumCast,
322 {
323 self.fit(x)?;
324 self.transform(x)
325 }
326
327 pub fn embedding(&self) -> Option<&Array2<f64>> {
329 self.embedding.as_ref()
330 }
331
332 pub fn geodesic_distances(&self) -> Option<&Array2<f64>> {
334 self.geodesic_distances.as_ref()
335 }
336
337 fn is_same_data(&self, x: &Array2<f64>, trainingdata: &Array2<f64>) -> bool {
339 if x.dim() != trainingdata.dim() {
340 return false;
341 }
342
343 let (n_samples, n_features) = x.dim();
344 for i in 0..n_samples {
345 for j in 0..n_features {
346 if (x[[i, j]] - trainingdata[[i, j]]).abs() > 1e-10 {
347 return false;
348 }
349 }
350 }
351 true
352 }
353
354 fn landmark_mds(&self, xnew: &Array2<f64>) -> Result<Array2<f64>> {
356 let training_data = self.training_data.as_ref().unwrap();
357 let training_embedding = self.embedding.as_ref().unwrap();
358 let geodesic_distances = self.geodesic_distances.as_ref().unwrap();
359
360 let (n_new, n_features) = xnew.dim();
361 let (n_training_, _) = training_data.dim();
362
363 if n_features != training_data.ncols() {
364 return Err(TransformError::InvalidInput(format!(
365 "Input features {} must match training features {}",
366 n_features,
367 training_data.ncols()
368 )));
369 }
370
371 let mut distances_to_training = Array2::zeros((n_new, n_training_));
373 for i in 0..n_new {
374 for j in 0..n_training_ {
375 let mut dist_sq = 0.0;
376 for k in 0..n_features {
377 let diff = xnew[[i, k]] - training_data[[j, k]];
378 dist_sq += diff * diff;
379 }
380 distances_to_training[[i, j]] = dist_sq.sqrt();
381 }
382 }
383
384 let mut new_embedding = Array2::zeros((n_new, self.n_components));
388
389 for i in 0..n_new {
390 let coords = self.solve_landmark_coordinates(
392 &distances_to_training.row(i),
393 training_embedding,
394 geodesic_distances,
395 )?;
396
397 for j in 0..self.n_components {
398 new_embedding[[i, j]] = coords[j];
399 }
400 }
401
402 Ok(new_embedding)
403 }
404
405 fn solve_landmark_coordinates(
407 &self,
408 distances_to_landmarks: &scirs2_core::ndarray::ArrayView1<f64>,
409 landmark_embedding: &Array2<f64>,
410 _geodesic_distances: &Array2<f64>,
411 ) -> Result<Array1<f64>> {
412 let n_landmarks = landmark_embedding.nrows();
413
414 let k_landmarks = (n_landmarks / 2)
416 .max(self.n_components + 1)
417 .min(n_landmarks);
418
419 let mut landmark_dists: Vec<(f64, usize)> = distances_to_landmarks
421 .indexed_iter()
422 .map(|(idx, &dist)| (dist, idx))
423 .collect();
424 landmark_dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
425
426 let selected_landmarks: Vec<usize> = landmark_dists
428 .into_iter()
429 .take(k_landmarks)
430 .map(|(_, idx)| idx)
431 .collect();
432
433 let mut a = Array2::zeros((k_landmarks, self.n_components));
437 let mut b = Array1::zeros(k_landmarks);
438 let mut weights = Array1::zeros(k_landmarks);
439
440 for (row_idx, &landmark_idx) in selected_landmarks.iter().enumerate() {
442 let dist_to_landmark = distances_to_landmarks[landmark_idx];
443 let weight = if dist_to_landmark > 1e-10 {
444 1.0 / (dist_to_landmark + 1e-10)
445 } else {
446 1e10
447 };
448 weights[row_idx] = weight;
449
450 b[row_idx] = dist_to_landmark * weight;
452
453 for dim in 0..self.n_components {
455 a[[row_idx, dim]] = landmark_embedding[[landmark_idx, dim]] * weight;
456 }
457 }
458
459 let mut at_wa = Array2::zeros((self.n_components, self.n_components));
462 let mut at_wb = Array1::zeros(self.n_components);
463
464 for i in 0..self.n_components {
465 for j in 0..self.n_components {
466 for k in 0..k_landmarks {
467 at_wa[[i, j]] += a[[k, i]] * weights[k] * a[[k, j]];
468 }
469 }
470 for k in 0..k_landmarks {
471 at_wb[i] += a[[k, i]] * weights[k] * b[k];
472 }
473 }
474
475 for i in 0..self.n_components {
477 at_wa[[i, i]] += 1e-10;
478 }
479
480 self.solve_linear_system(&at_wa, &at_wb)
482 }
483
484 fn solve_linear_system(&self, a: &Array2<f64>, b: &Array1<f64>) -> Result<Array1<f64>> {
486 let n = a.nrows();
487 let mut a_copy = a.clone();
488 let mut b_copy = b.clone();
489
490 for i in 0..n {
492 let mut max_row = i;
494 for k in i + 1..n {
495 if a_copy[[k, i]].abs() > a_copy[[max_row, i]].abs() {
496 max_row = k;
497 }
498 }
499
500 if max_row != i {
502 for j in 0..n {
503 let temp = a_copy[[i, j]];
504 a_copy[[i, j]] = a_copy[[max_row, j]];
505 a_copy[[max_row, j]] = temp;
506 }
507 let temp = b_copy[i];
508 b_copy[i] = b_copy[max_row];
509 b_copy[max_row] = temp;
510 }
511
512 if a_copy[[i, i]].abs() < 1e-12 {
514 return Err(TransformError::ComputationError(
515 "Singular matrix in landmark MDS".to_string(),
516 ));
517 }
518
519 for k in i + 1..n {
521 let factor = a_copy[[k, i]] / a_copy[[i, i]];
522 for j in i..n {
523 a_copy[[k, j]] -= factor * a_copy[[i, j]];
524 }
525 b_copy[k] -= factor * b_copy[i];
526 }
527 }
528
529 let mut x = Array1::zeros(n);
531 for i in (0..n).rev() {
532 x[i] = b_copy[i];
533 for j in i + 1..n {
534 x[i] -= a_copy[[i, j]] * x[j];
535 }
536 x[i] /= a_copy[[i, i]];
537 }
538
539 Ok(x)
540 }
541}
542
543#[cfg(test)]
544mod tests {
545 use super::*;
546 use scirs2_core::ndarray::Array;
547
548 #[test]
549 fn test_isomap_basic() {
550 let n_points = 20;
552 let mut data = Vec::new();
553
554 for i in 0..n_points {
555 let t = i as f64 / n_points as f64 * 3.0 * std::f64::consts::PI;
556 let x = t.sin();
557 let y = 2.0 * (i as f64 / n_points as f64);
558 let z = t.cos();
559 data.extend_from_slice(&[x, y, z]);
560 }
561
562 let x = Array::from_shape_vec((n_points, 3), data).unwrap();
563
564 let mut isomap = Isomap::new(5, 2);
566 let embedding = isomap.fit_transform(&x).unwrap();
567
568 assert_eq!(embedding.shape(), &[n_points, 2]);
570
571 for val in embedding.iter() {
573 assert!(val.is_finite());
574 }
575 }
576
577 #[test]
578 fn test_isomap_epsilon_ball() {
579 let x: Array2<f64> = Array::eye(5);
580
581 let mut isomap = Isomap::new(3, 2).with_epsilon(1.5);
582 let result = isomap.fit_transform(&x);
583
584 assert!(result.is_ok());
586
587 let embedding = result.unwrap();
588 assert_eq!(embedding.shape(), &[5, 2]);
589 }
590
591 #[test]
592 fn test_isomap_disconnected_graph() {
593 let x = scirs2_core::ndarray::array![
595 [0.0, 0.0], [0.1, 0.1], [10.0, 10.0], [10.1, 10.1], ];
600
601 let mut isomap = Isomap::new(1, 2);
603 let result = isomap.fit(&x);
604
605 assert!(result.is_err());
607 if let Err(e) = result {
608 match e {
610 TransformError::InvalidInput(msg) => {
611 assert!(msg.contains("Graph is not connected"));
612 }
613 _ => panic!("Expected InvalidInput error for disconnected graph"),
614 }
615 }
616 }
617}