1use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix2};
8use scirs2_core::numeric::{Float, NumCast};
9use scirs2_core::validation::{check_positive, checkshape};
10use scirs2_linalg::eigh;
11use std::collections::BinaryHeap;
12use std::f64;
13
14use crate::error::{Result, TransformError};
15
16#[derive(Debug, Clone)]
22pub struct Isomap {
23 n_neighbors: usize,
25 n_components: usize,
27 neighbor_mode: String,
29 epsilon: f64,
31 embedding: Option<Array2<f64>>,
33 training_data: Option<Array2<f64>>,
35 geodesic_distances: Option<Array2<f64>>,
37}
38
39impl Isomap {
40 pub fn new(n_neighbors: usize, ncomponents: usize) -> Self {
46 Isomap {
47 n_neighbors,
48 n_components: ncomponents,
49 neighbor_mode: "knn".to_string(),
50 epsilon: 0.0,
51 embedding: None,
52 training_data: None,
53 geodesic_distances: None,
54 }
55 }
56
57 pub fn with_epsilon(mut self, epsilon: f64) -> Self {
59 self.neighbor_mode = "epsilon".to_string();
60 self.epsilon = epsilon;
61 self
62 }
63
64 fn compute_distances<S>(&self, x: &ArrayBase<S, Ix2>) -> Array2<f64>
66 where
67 S: Data,
68 S::Elem: Float + NumCast,
69 {
70 let n_samples = x.shape()[0];
71 let mut distances = Array2::zeros((n_samples, n_samples));
72
73 for i in 0..n_samples {
74 for j in i + 1..n_samples {
75 let mut dist = 0.0;
76 for k in 0..x.shape()[1] {
77 let diff = NumCast::from(x[[i, k]]).unwrap_or(0.0)
78 - NumCast::from(x[[j, k]]).unwrap_or(0.0);
79 dist += diff * diff;
80 }
81 dist = dist.sqrt();
82 distances[[i, j]] = dist;
83 distances[[j, i]] = dist;
84 }
85 }
86
87 distances
88 }
89
90 fn construct_graph(&self, distances: &Array2<f64>) -> Array2<f64> {
92 let n_samples = distances.shape()[0];
93 let mut graph = Array2::from_elem((n_samples, n_samples), f64::INFINITY);
94
95 for i in 0..n_samples {
97 graph[[i, i]] = 0.0;
98 }
99
100 if self.neighbor_mode == "knn" {
101 for i in 0..n_samples {
103 let mut heap: BinaryHeap<(std::cmp::Reverse<i64>, usize)> = BinaryHeap::new();
105
106 for j in 0..n_samples {
107 if i != j {
108 let dist_fixed = (distances[[i, j]] * 1e9) as i64;
109 heap.push((std::cmp::Reverse(dist_fixed), j));
110 }
111 }
112
113 for _ in 0..self.n_neighbors {
115 if let Some((_, j)) = heap.pop() {
116 graph[[i, j]] = distances[[i, j]];
117 graph[[j, i]] = distances[[j, i]]; }
119 }
120 }
121 } else {
122 for i in 0..n_samples {
124 for j in i + 1..n_samples {
125 if distances[[i, j]] <= self.epsilon {
126 graph[[i, j]] = distances[[i, j]];
127 graph[[j, i]] = distances[[j, i]];
128 }
129 }
130 }
131 }
132
133 graph
134 }
135
136 fn compute_shortest_paths(&self, graph: &Array2<f64>) -> Result<Array2<f64>> {
138 let n = graph.shape()[0];
139 let mut dist = graph.clone();
140
141 for k in 0..n {
143 for i in 0..n {
144 for j in 0..n {
145 if dist[[i, k]] + dist[[k, j]] < dist[[i, j]] {
146 dist[[i, j]] = dist[[i, k]] + dist[[k, j]];
147 }
148 }
149 }
150 }
151
152 for i in 0..n {
154 for j in 0..n {
155 if dist[[i, j]].is_infinite() {
156 return Err(TransformError::InvalidInput(
157 "Graph is not connected. Try increasing n_neighbors or epsilon."
158 .to_string(),
159 ));
160 }
161 }
162 }
163
164 Ok(dist)
165 }
166
167 fn classical_mds(&self, distances: &Array2<f64>) -> Result<Array2<f64>> {
169 let n = distances.shape()[0];
170
171 let squared_distances = distances.mapv(|d| d * d);
173
174 let row_means = squared_distances.mean_axis(Axis(1)).unwrap();
176
177 let col_means = squared_distances.mean_axis(Axis(0)).unwrap();
179
180 let grand_mean = row_means.mean().unwrap();
182
183 let mut gram = Array2::zeros((n, n));
185 for i in 0..n {
186 for j in 0..n {
187 gram[[i, j]] =
188 -0.5 * (squared_distances[[i, j]] - row_means[i] - col_means[j] + grand_mean);
189 }
190 }
191
192 let gram_symmetric = 0.5 * (&gram + &gram.t());
194
195 let (eigenvalues, eigenvectors) = match eigh(&gram_symmetric.view(), None) {
197 Ok(result) => result,
198 Err(e) => return Err(TransformError::LinalgError(e)),
199 };
200
201 let mut indices: Vec<usize> = (0..n).collect();
203 indices.sort_by(|&i, &j| eigenvalues[j].partial_cmp(&eigenvalues[i]).unwrap());
204
205 let mut embedding = Array2::zeros((n, self.n_components));
207 for j in 0..self.n_components {
208 let idx = indices[j];
209 let scale = eigenvalues[idx].max(0.0).sqrt();
210
211 for i in 0..n {
212 embedding[[i, j]] = eigenvectors[[i, idx]] * scale;
213 }
214 }
215
216 Ok(embedding)
217 }
218
219 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
227 where
228 S: Data,
229 S::Elem: Float + NumCast,
230 {
231 let (n_samples, n_features) = x.dim();
232
233 check_positive(self.n_neighbors, "n_neighbors")?;
235 check_positive(self.n_components, "n_components")?;
236 checkshape(x, &[n_samples, n_features], "x")?;
237
238 if n_samples < self.n_neighbors {
239 return Err(TransformError::InvalidInput(format!(
240 "n_neighbors={} must be <= n_samples={}",
241 self.n_neighbors, n_samples
242 )));
243 }
244
245 if self.n_components >= n_samples {
246 return Err(TransformError::InvalidInput(format!(
247 "n_components={} must be < n_samples={}",
248 self.n_components, n_samples
249 )));
250 }
251
252 let x_f64 = x.mapv(|v| NumCast::from(v).unwrap_or(0.0));
254
255 let distances = self.compute_distances(&x_f64.view());
257
258 let graph = self.construct_graph(&distances);
260
261 let geodesic_distances = self.compute_shortest_paths(&graph)?;
263
264 let embedding = self.classical_mds(&geodesic_distances)?;
266
267 self.embedding = Some(embedding);
268 self.training_data = Some(x_f64);
269 self.geodesic_distances = Some(geodesic_distances);
270
271 Ok(())
272 }
273
274 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
284 where
285 S: Data,
286 S::Elem: Float + NumCast,
287 {
288 if self.embedding.is_none() {
289 return Err(TransformError::NotFitted(
290 "Isomap model has not been fitted".to_string(),
291 ));
292 }
293
294 let training_data = self
295 .training_data
296 .as_ref()
297 .ok_or_else(|| TransformError::NotFitted("Training data not available".to_string()))?;
298
299 let x_f64 = x.mapv(|v| NumCast::from(v).unwrap_or(0.0));
300
301 if self.is_same_data(&x_f64, training_data) {
303 return Ok(self.embedding.as_ref().unwrap().clone());
304 }
305
306 self.landmark_mds(&x_f64)
308 }
309
310 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
318 where
319 S: Data,
320 S::Elem: Float + NumCast,
321 {
322 self.fit(x)?;
323 self.transform(x)
324 }
325
326 pub fn embedding(&self) -> Option<&Array2<f64>> {
328 self.embedding.as_ref()
329 }
330
331 pub fn geodesic_distances(&self) -> Option<&Array2<f64>> {
333 self.geodesic_distances.as_ref()
334 }
335
336 fn is_same_data(&self, x: &Array2<f64>, trainingdata: &Array2<f64>) -> bool {
338 if x.dim() != trainingdata.dim() {
339 return false;
340 }
341
342 let (n_samples, n_features) = x.dim();
343 for i in 0..n_samples {
344 for j in 0..n_features {
345 if (x[[i, j]] - trainingdata[[i, j]]).abs() > 1e-10 {
346 return false;
347 }
348 }
349 }
350 true
351 }
352
353 fn landmark_mds(&self, xnew: &Array2<f64>) -> Result<Array2<f64>> {
355 let training_data = self.training_data.as_ref().unwrap();
356 let training_embedding = self.embedding.as_ref().unwrap();
357 let geodesic_distances = self.geodesic_distances.as_ref().unwrap();
358
359 let (n_new, n_features) = xnew.dim();
360 let (n_training_, _) = training_data.dim();
361
362 if n_features != training_data.ncols() {
363 return Err(TransformError::InvalidInput(format!(
364 "Input features {} must match training features {}",
365 n_features,
366 training_data.ncols()
367 )));
368 }
369
370 let mut distances_to_training = Array2::zeros((n_new, n_training_));
372 for i in 0..n_new {
373 for j in 0..n_training_ {
374 let mut dist_sq = 0.0;
375 for k in 0..n_features {
376 let diff = xnew[[i, k]] - training_data[[j, k]];
377 dist_sq += diff * diff;
378 }
379 distances_to_training[[i, j]] = dist_sq.sqrt();
380 }
381 }
382
383 let mut new_embedding = Array2::zeros((n_new, self.n_components));
387
388 for i in 0..n_new {
389 let coords = self.solve_landmark_coordinates(
391 &distances_to_training.row(i),
392 training_embedding,
393 geodesic_distances,
394 )?;
395
396 for j in 0..self.n_components {
397 new_embedding[[i, j]] = coords[j];
398 }
399 }
400
401 Ok(new_embedding)
402 }
403
404 fn solve_landmark_coordinates(
406 &self,
407 distances_to_landmarks: &scirs2_core::ndarray::ArrayView1<f64>,
408 landmark_embedding: &Array2<f64>,
409 _geodesic_distances: &Array2<f64>,
410 ) -> Result<Array1<f64>> {
411 let n_landmarks = landmark_embedding.nrows();
412
413 let k_landmarks = (n_landmarks / 2)
415 .max(self.n_components + 1)
416 .min(n_landmarks);
417
418 let mut landmark_dists: Vec<(f64, usize)> = distances_to_landmarks
420 .indexed_iter()
421 .map(|(idx, &dist)| (dist, idx))
422 .collect();
423 landmark_dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
424
425 let selected_landmarks: Vec<usize> = landmark_dists
427 .into_iter()
428 .take(k_landmarks)
429 .map(|(_, idx)| idx)
430 .collect();
431
432 let mut a = Array2::zeros((k_landmarks, self.n_components));
436 let mut b = Array1::zeros(k_landmarks);
437 let mut weights = Array1::zeros(k_landmarks);
438
439 for (row_idx, &landmark_idx) in selected_landmarks.iter().enumerate() {
441 let dist_to_landmark = distances_to_landmarks[landmark_idx];
442 let weight = if dist_to_landmark > 1e-10 {
443 1.0 / (dist_to_landmark + 1e-10)
444 } else {
445 1e10
446 };
447 weights[row_idx] = weight;
448
449 b[row_idx] = dist_to_landmark * weight;
451
452 for dim in 0..self.n_components {
454 a[[row_idx, dim]] = landmark_embedding[[landmark_idx, dim]] * weight;
455 }
456 }
457
458 let mut at_wa = Array2::zeros((self.n_components, self.n_components));
461 let mut at_wb = Array1::zeros(self.n_components);
462
463 for i in 0..self.n_components {
464 for j in 0..self.n_components {
465 for k in 0..k_landmarks {
466 at_wa[[i, j]] += a[[k, i]] * weights[k] * a[[k, j]];
467 }
468 }
469 for k in 0..k_landmarks {
470 at_wb[i] += a[[k, i]] * weights[k] * b[k];
471 }
472 }
473
474 for i in 0..self.n_components {
476 at_wa[[i, i]] += 1e-10;
477 }
478
479 self.solve_linear_system(&at_wa, &at_wb)
481 }
482
483 fn solve_linear_system(&self, a: &Array2<f64>, b: &Array1<f64>) -> Result<Array1<f64>> {
485 let n = a.nrows();
486 let mut a_copy = a.clone();
487 let mut b_copy = b.clone();
488
489 for i in 0..n {
491 let mut max_row = i;
493 for k in i + 1..n {
494 if a_copy[[k, i]].abs() > a_copy[[max_row, i]].abs() {
495 max_row = k;
496 }
497 }
498
499 if max_row != i {
501 for j in 0..n {
502 let temp = a_copy[[i, j]];
503 a_copy[[i, j]] = a_copy[[max_row, j]];
504 a_copy[[max_row, j]] = temp;
505 }
506 let temp = b_copy[i];
507 b_copy[i] = b_copy[max_row];
508 b_copy[max_row] = temp;
509 }
510
511 if a_copy[[i, i]].abs() < 1e-12 {
513 return Err(TransformError::ComputationError(
514 "Singular matrix in landmark MDS".to_string(),
515 ));
516 }
517
518 for k in i + 1..n {
520 let factor = a_copy[[k, i]] / a_copy[[i, i]];
521 for j in i..n {
522 a_copy[[k, j]] -= factor * a_copy[[i, j]];
523 }
524 b_copy[k] -= factor * b_copy[i];
525 }
526 }
527
528 let mut x = Array1::zeros(n);
530 for i in (0..n).rev() {
531 x[i] = b_copy[i];
532 for j in i + 1..n {
533 x[i] -= a_copy[[i, j]] * x[j];
534 }
535 x[i] /= a_copy[[i, i]];
536 }
537
538 Ok(x)
539 }
540}
541
542#[cfg(test)]
543mod tests {
544 use super::*;
545 use scirs2_core::ndarray::Array;
546
547 #[test]
548 fn test_isomap_basic() {
549 let n_points = 20;
551 let mut data = Vec::new();
552
553 for i in 0..n_points {
554 let t = i as f64 / n_points as f64 * 3.0 * std::f64::consts::PI;
555 let x = t.sin();
556 let y = 2.0 * (i as f64 / n_points as f64);
557 let z = t.cos();
558 data.extend_from_slice(&[x, y, z]);
559 }
560
561 let x = Array::from_shape_vec((n_points, 3), data).unwrap();
562
563 let mut isomap = Isomap::new(5, 2);
565 let embedding = isomap.fit_transform(&x).unwrap();
566
567 assert_eq!(embedding.shape(), &[n_points, 2]);
569
570 for val in embedding.iter() {
572 assert!(val.is_finite());
573 }
574 }
575
576 #[test]
577 fn test_isomap_epsilon_ball() {
578 let x: Array2<f64> = Array::eye(5);
579
580 let mut isomap = Isomap::new(3, 2).with_epsilon(1.5);
581 let result = isomap.fit_transform(&x);
582
583 assert!(result.is_ok());
585
586 let embedding = result.unwrap();
587 assert_eq!(embedding.shape(), &[5, 2]);
588 }
589
590 #[test]
591 fn test_isomap_disconnected_graph() {
592 let x = scirs2_core::ndarray::array![
594 [0.0, 0.0], [0.1, 0.1], [10.0, 10.0], [10.1, 10.1], ];
599
600 let mut isomap = Isomap::new(1, 2);
602 let result = isomap.fit(&x);
603
604 assert!(result.is_err());
606 if let Err(e) = result {
607 match e {
609 TransformError::InvalidInput(msg) => {
610 assert!(msg.contains("Graph is not connected"));
611 }
612 _ => panic!("Expected InvalidInput error for disconnected graph"),
613 }
614 }
615 }
616}