scirs2_transform/reduction/
isomap.rs

1//! Isomap (Isometric Feature Mapping) for non-linear dimensionality reduction
2//!
3//! Isomap is a non-linear dimensionality reduction method that preserves geodesic
4//! distances between all points. It extends MDS by using geodesic distances instead
5//! of Euclidean distances.
6
7use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix2};
8use scirs2_core::numeric::{Float, NumCast};
9use scirs2_core::validation::{check_positive, checkshape};
10use scirs2_linalg::eigh;
11use std::collections::BinaryHeap;
12use std::f64;
13
14use crate::error::{Result, TransformError};
15// use statrs::statistics::Statistics; // TODO: Add statrs dependency - needs generic type fixes
16
17/// Isomap (Isometric Feature Mapping) dimensionality reduction
18///
19/// Isomap seeks a lower-dimensional embedding that maintains geodesic distances
20/// between all points. It uses graph distances to approximate geodesic distances
21/// on the manifold.
22#[derive(Debug, Clone)]
23pub struct Isomap {
24    /// Number of neighbors to use for graph construction
25    n_neighbors: usize,
26    /// Number of components for dimensionality reduction
27    n_components: usize,
28    /// Whether to use k-neighbors or epsilon-ball for graph construction
29    neighbor_mode: String,
30    /// Epsilon for epsilon-ball graph construction
31    epsilon: f64,
32    /// The embedding vectors
33    embedding: Option<Array2<f64>>,
34    /// Training data for out-of-sample extension
35    training_data: Option<Array2<f64>>,
36    /// Geodesic distances from training data
37    geodesic_distances: Option<Array2<f64>>,
38}
39
40impl Isomap {
41    /// Creates a new Isomap instance
42    ///
43    /// # Arguments
44    /// * `n_neighbors` - Number of neighbors for graph construction
45    /// * `n_components` - Number of dimensions in the embedding space
46    pub fn new(n_neighbors: usize, ncomponents: usize) -> Self {
47        Isomap {
48            n_neighbors,
49            n_components: ncomponents,
50            neighbor_mode: "knn".to_string(),
51            epsilon: 0.0,
52            embedding: None,
53            training_data: None,
54            geodesic_distances: None,
55        }
56    }
57
58    /// Use epsilon-ball instead of k-nearest neighbors
59    pub fn with_epsilon(mut self, epsilon: f64) -> Self {
60        self.neighbor_mode = "epsilon".to_string();
61        self.epsilon = epsilon;
62        self
63    }
64
65    /// Compute pairwise Euclidean distances
66    fn compute_distances<S>(&self, x: &ArrayBase<S, Ix2>) -> Array2<f64>
67    where
68        S: Data,
69        S::Elem: Float + NumCast,
70    {
71        let n_samples = x.shape()[0];
72        let mut distances = Array2::zeros((n_samples, n_samples));
73
74        for i in 0..n_samples {
75            for j in i + 1..n_samples {
76                let mut dist = 0.0;
77                for k in 0..x.shape()[1] {
78                    let diff = NumCast::from(x[[i, k]]).unwrap_or(0.0)
79                        - NumCast::from(x[[j, k]]).unwrap_or(0.0);
80                    dist += diff * diff;
81                }
82                dist = dist.sqrt();
83                distances[[i, j]] = dist;
84                distances[[j, i]] = dist;
85            }
86        }
87
88        distances
89    }
90
91    /// Construct the neighborhood graph
92    fn construct_graph(&self, distances: &Array2<f64>) -> Array2<f64> {
93        let n_samples = distances.shape()[0];
94        let mut graph = Array2::from_elem((n_samples, n_samples), f64::INFINITY);
95
96        // Set diagonal to 0
97        for i in 0..n_samples {
98            graph[[i, i]] = 0.0;
99        }
100
101        if self.neighbor_mode == "knn" {
102            // K-nearest neighbors graph
103            for i in 0..n_samples {
104                // Find k nearest neighbors
105                let mut heap: BinaryHeap<(std::cmp::Reverse<i64>, usize)> = BinaryHeap::new();
106
107                for j in 0..n_samples {
108                    if i != j {
109                        let dist_fixed = (distances[[i, j]] * 1e9) as i64;
110                        heap.push((std::cmp::Reverse(dist_fixed), j));
111                    }
112                }
113
114                // Connect to k nearest neighbors
115                for _ in 0..self.n_neighbors {
116                    if let Some((_, j)) = heap.pop() {
117                        graph[[i, j]] = distances[[i, j]];
118                        graph[[j, i]] = distances[[j, i]]; // Make symmetric
119                    }
120                }
121            }
122        } else {
123            // Epsilon-ball graph
124            for i in 0..n_samples {
125                for j in i + 1..n_samples {
126                    if distances[[i, j]] <= self.epsilon {
127                        graph[[i, j]] = distances[[i, j]];
128                        graph[[j, i]] = distances[[j, i]];
129                    }
130                }
131            }
132        }
133
134        graph
135    }
136
137    /// Compute shortest paths using Floyd-Warshall algorithm
138    fn compute_shortest_paths(&self, graph: &Array2<f64>) -> Result<Array2<f64>> {
139        let n = graph.shape()[0];
140        let mut dist = graph.clone();
141
142        // Floyd-Warshall algorithm
143        for k in 0..n {
144            for i in 0..n {
145                for j in 0..n {
146                    if dist[[i, k]] + dist[[k, j]] < dist[[i, j]] {
147                        dist[[i, j]] = dist[[i, k]] + dist[[k, j]];
148                    }
149                }
150            }
151        }
152
153        // Check if graph is connected
154        for i in 0..n {
155            for j in 0..n {
156                if dist[[i, j]].is_infinite() {
157                    return Err(TransformError::InvalidInput(
158                        "Graph is not connected. Try increasing n_neighbors or epsilon."
159                            .to_string(),
160                    ));
161                }
162            }
163        }
164
165        Ok(dist)
166    }
167
168    /// Apply classical MDS to the geodesic distance matrix
169    fn classical_mds(&self, distances: &Array2<f64>) -> Result<Array2<f64>> {
170        let n = distances.shape()[0];
171
172        // Double center the squared distance matrix
173        let squared_distances = distances.mapv(|d| d * d);
174
175        // Row means
176        let row_means = squared_distances.mean_axis(Axis(1)).unwrap();
177
178        // Column means
179        let col_means = squared_distances.mean_axis(Axis(0)).unwrap();
180
181        // Grand mean
182        let grand_mean = row_means.mean().unwrap();
183
184        // Double centering
185        let mut gram = Array2::zeros((n, n));
186        for i in 0..n {
187            for j in 0..n {
188                gram[[i, j]] =
189                    -0.5 * (squared_distances[[i, j]] - row_means[i] - col_means[j] + grand_mean);
190            }
191        }
192
193        // Ensure symmetry by averaging with transpose (fixes floating point errors)
194        let gram_symmetric = 0.5 * (&gram + &gram.t());
195
196        // Eigendecomposition
197        let (eigenvalues, eigenvectors) = match eigh(&gram_symmetric.view(), None) {
198            Ok(result) => result,
199            Err(e) => return Err(TransformError::LinalgError(e)),
200        };
201
202        // Sort eigenvalues and eigenvectors in descending order
203        let mut indices: Vec<usize> = (0..n).collect();
204        indices.sort_by(|&i, &j| eigenvalues[j].partial_cmp(&eigenvalues[i]).unwrap());
205
206        // Extract the top n_components eigenvectors
207        let mut embedding = Array2::zeros((n, self.n_components));
208        for j in 0..self.n_components {
209            let idx = indices[j];
210            let scale = eigenvalues[idx].max(0.0).sqrt();
211
212            for i in 0..n {
213                embedding[[i, j]] = eigenvectors[[i, idx]] * scale;
214            }
215        }
216
217        Ok(embedding)
218    }
219
220    /// Fits the Isomap model to the input data
221    ///
222    /// # Arguments
223    /// * `x` - The input data, shape (n_samples, n_features)
224    ///
225    /// # Returns
226    /// * `Result<()>` - Ok if successful, Err otherwise
227    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
228    where
229        S: Data,
230        S::Elem: Float + NumCast,
231    {
232        let (n_samples, n_features) = x.dim();
233
234        // Validate inputs
235        check_positive(self.n_neighbors, "n_neighbors")?;
236        check_positive(self.n_components, "n_components")?;
237        checkshape(x, &[n_samples, n_features], "x")?;
238
239        if n_samples < self.n_neighbors {
240            return Err(TransformError::InvalidInput(format!(
241                "n_neighbors={} must be <= n_samples={}",
242                self.n_neighbors, n_samples
243            )));
244        }
245
246        if self.n_components >= n_samples {
247            return Err(TransformError::InvalidInput(format!(
248                "n_components={} must be < n_samples={}",
249                self.n_components, n_samples
250            )));
251        }
252
253        // Convert input to f64
254        let x_f64 = x.mapv(|v| NumCast::from(v).unwrap_or(0.0));
255
256        // Step 1: Compute pairwise distances
257        let distances = self.compute_distances(&x_f64.view());
258
259        // Step 2: Construct neighborhood graph
260        let graph = self.construct_graph(&distances);
261
262        // Step 3: Compute shortest paths (geodesic distances)
263        let geodesic_distances = self.compute_shortest_paths(&graph)?;
264
265        // Step 4: Apply classical MDS
266        let embedding = self.classical_mds(&geodesic_distances)?;
267
268        self.embedding = Some(embedding);
269        self.training_data = Some(x_f64);
270        self.geodesic_distances = Some(geodesic_distances);
271
272        Ok(())
273    }
274
275    /// Transforms the input data using the fitted Isomap model
276    ///
277    /// For new points, this uses the Landmark MDS approach
278    ///
279    /// # Arguments
280    /// * `x` - The input data, shape (n_samples, n_features)
281    ///
282    /// # Returns
283    /// * `Result<Array2<f64>>` - The transformed data, shape (n_samples, n_components)
284    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
285    where
286        S: Data,
287        S::Elem: Float + NumCast,
288    {
289        if self.embedding.is_none() {
290            return Err(TransformError::NotFitted(
291                "Isomap model has not been fitted".to_string(),
292            ));
293        }
294
295        let training_data = self
296            .training_data
297            .as_ref()
298            .ok_or_else(|| TransformError::NotFitted("Training data not available".to_string()))?;
299
300        let x_f64 = x.mapv(|v| NumCast::from(v).unwrap_or(0.0));
301
302        // Check if this is the training data
303        if self.is_same_data(&x_f64, training_data) {
304            return Ok(self.embedding.as_ref().unwrap().clone());
305        }
306
307        // Implement Landmark MDS for out-of-sample extension
308        self.landmark_mds(&x_f64)
309    }
310
311    /// Fits the Isomap model and transforms the data
312    ///
313    /// # Arguments
314    /// * `x` - The input data, shape (n_samples, n_features)
315    ///
316    /// # Returns
317    /// * `Result<Array2<f64>>` - The transformed data, shape (n_samples, n_components)
318    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
319    where
320        S: Data,
321        S::Elem: Float + NumCast,
322    {
323        self.fit(x)?;
324        self.transform(x)
325    }
326
327    /// Returns the embedding
328    pub fn embedding(&self) -> Option<&Array2<f64>> {
329        self.embedding.as_ref()
330    }
331
332    /// Returns the geodesic distances computed during fitting
333    pub fn geodesic_distances(&self) -> Option<&Array2<f64>> {
334        self.geodesic_distances.as_ref()
335    }
336
337    /// Check if the input data is the same as training data
338    fn is_same_data(&self, x: &Array2<f64>, trainingdata: &Array2<f64>) -> bool {
339        if x.dim() != trainingdata.dim() {
340            return false;
341        }
342
343        let (n_samples, n_features) = x.dim();
344        for i in 0..n_samples {
345            for j in 0..n_features {
346                if (x[[i, j]] - trainingdata[[i, j]]).abs() > 1e-10 {
347                    return false;
348                }
349            }
350        }
351        true
352    }
353
354    /// Implement Landmark MDS for out-of-sample extension
355    fn landmark_mds(&self, xnew: &Array2<f64>) -> Result<Array2<f64>> {
356        let training_data = self.training_data.as_ref().unwrap();
357        let training_embedding = self.embedding.as_ref().unwrap();
358        let geodesic_distances = self.geodesic_distances.as_ref().unwrap();
359
360        let (n_new, n_features) = xnew.dim();
361        let (n_training_, _) = training_data.dim();
362
363        if n_features != training_data.ncols() {
364            return Err(TransformError::InvalidInput(format!(
365                "Input features {} must match training features {}",
366                n_features,
367                training_data.ncols()
368            )));
369        }
370
371        // Step 1: Compute distances from _new points to all training points
372        let mut distances_to_training = Array2::zeros((n_new, n_training_));
373        for i in 0..n_new {
374            for j in 0..n_training_ {
375                let mut dist_sq = 0.0;
376                for k in 0..n_features {
377                    let diff = xnew[[i, k]] - training_data[[j, k]];
378                    dist_sq += diff * diff;
379                }
380                distances_to_training[[i, j]] = dist_sq.sqrt();
381            }
382        }
383
384        // Step 2: Apply Landmark MDS algorithm
385        // For each _new point, find its coordinates that minimize stress
386        // with respect to the known training points
387        let mut new_embedding = Array2::zeros((n_new, self.n_components));
388
389        for i in 0..n_new {
390            // Use weighted least squares to find optimal coordinates
391            let coords = self.solve_landmark_coordinates(
392                &distances_to_training.row(i),
393                training_embedding,
394                geodesic_distances,
395            )?;
396
397            for j in 0..self.n_components {
398                new_embedding[[i, j]] = coords[j];
399            }
400        }
401
402        Ok(new_embedding)
403    }
404
405    /// Solve for landmark coordinates using weighted least squares
406    fn solve_landmark_coordinates(
407        &self,
408        distances_to_landmarks: &scirs2_core::ndarray::ArrayView1<f64>,
409        landmark_embedding: &Array2<f64>,
410        _geodesic_distances: &Array2<f64>,
411    ) -> Result<Array1<f64>> {
412        let n_landmarks = landmark_embedding.nrows();
413
414        // Use a subset of _landmarks for efficiency (select k nearest)
415        let k_landmarks = (n_landmarks / 2)
416            .max(self.n_components + 1)
417            .min(n_landmarks);
418
419        // Find k nearest _landmarks
420        let mut landmark_dists: Vec<(f64, usize)> = distances_to_landmarks
421            .indexed_iter()
422            .map(|(idx, &dist)| (dist, idx))
423            .collect();
424        landmark_dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
425
426        // Use the k nearest _landmarks
427        let selected_landmarks: Vec<usize> = landmark_dists
428            .into_iter()
429            .take(k_landmarks)
430            .map(|(_, idx)| idx)
431            .collect();
432
433        // Build system: A * x = b where x are the coordinates
434        // Using the constraint that _distances in _embedding space should
435        // approximate geodesic _distances
436        let mut a = Array2::zeros((k_landmarks, self.n_components));
437        let mut b = Array1::zeros(k_landmarks);
438        let mut weights = Array1::zeros(k_landmarks);
439
440        // For each selected landmark, create a constraint equation
441        for (row_idx, &landmark_idx) in selected_landmarks.iter().enumerate() {
442            let dist_to_landmark = distances_to_landmarks[landmark_idx];
443            let weight = if dist_to_landmark > 1e-10 {
444                1.0 / (dist_to_landmark + 1e-10)
445            } else {
446                1e10
447            };
448            weights[row_idx] = weight;
449
450            // Target distance is the distance from new point to this landmark
451            b[row_idx] = dist_to_landmark * weight;
452
453            // Coefficients are the landmark coordinates (weighted)
454            for dim in 0..self.n_components {
455                a[[row_idx, dim]] = landmark_embedding[[landmark_idx, dim]] * weight;
456            }
457        }
458
459        // Solve weighted least squares: A^T W A x = A^T W b
460        // where W is the diagonal weight matrix
461        let mut at_wa = Array2::zeros((self.n_components, self.n_components));
462        let mut at_wb = Array1::zeros(self.n_components);
463
464        for i in 0..self.n_components {
465            for j in 0..self.n_components {
466                for k in 0..k_landmarks {
467                    at_wa[[i, j]] += a[[k, i]] * weights[k] * a[[k, j]];
468                }
469            }
470            for k in 0..k_landmarks {
471                at_wb[i] += a[[k, i]] * weights[k] * b[k];
472            }
473        }
474
475        // Add regularization to prevent singular matrix
476        for i in 0..self.n_components {
477            at_wa[[i, i]] += 1e-10;
478        }
479
480        // Solve using simple Gaussian elimination for small systems
481        self.solve_linear_system(&at_wa, &at_wb)
482    }
483
484    /// Simple linear system solver for small matrices
485    fn solve_linear_system(&self, a: &Array2<f64>, b: &Array1<f64>) -> Result<Array1<f64>> {
486        let n = a.nrows();
487        let mut a_copy = a.clone();
488        let mut b_copy = b.clone();
489
490        // Gaussian elimination with partial pivoting
491        for i in 0..n {
492            // Find pivot
493            let mut max_row = i;
494            for k in i + 1..n {
495                if a_copy[[k, i]].abs() > a_copy[[max_row, i]].abs() {
496                    max_row = k;
497                }
498            }
499
500            // Swap rows
501            if max_row != i {
502                for j in 0..n {
503                    let temp = a_copy[[i, j]];
504                    a_copy[[i, j]] = a_copy[[max_row, j]];
505                    a_copy[[max_row, j]] = temp;
506                }
507                let temp = b_copy[i];
508                b_copy[i] = b_copy[max_row];
509                b_copy[max_row] = temp;
510            }
511
512            // Check for singular matrix
513            if a_copy[[i, i]].abs() < 1e-12 {
514                return Err(TransformError::ComputationError(
515                    "Singular matrix in landmark MDS".to_string(),
516                ));
517            }
518
519            // Eliminate
520            for k in i + 1..n {
521                let factor = a_copy[[k, i]] / a_copy[[i, i]];
522                for j in i..n {
523                    a_copy[[k, j]] -= factor * a_copy[[i, j]];
524                }
525                b_copy[k] -= factor * b_copy[i];
526            }
527        }
528
529        // Back substitution
530        let mut x = Array1::zeros(n);
531        for i in (0..n).rev() {
532            x[i] = b_copy[i];
533            for j in i + 1..n {
534                x[i] -= a_copy[[i, j]] * x[j];
535            }
536            x[i] /= a_copy[[i, i]];
537        }
538
539        Ok(x)
540    }
541}
542
543#[cfg(test)]
544mod tests {
545    use super::*;
546    use scirs2_core::ndarray::Array;
547
548    #[test]
549    fn test_isomap_basic() {
550        // Create a simple S-curve dataset
551        let n_points = 20;
552        let mut data = Vec::new();
553
554        for i in 0..n_points {
555            let t = i as f64 / n_points as f64 * 3.0 * std::f64::consts::PI;
556            let x = t.sin();
557            let y = 2.0 * (i as f64 / n_points as f64);
558            let z = t.cos();
559            data.extend_from_slice(&[x, y, z]);
560        }
561
562        let x = Array::from_shape_vec((n_points, 3), data).unwrap();
563
564        // Fit Isomap
565        let mut isomap = Isomap::new(5, 2);
566        let embedding = isomap.fit_transform(&x).unwrap();
567
568        // Check shape
569        assert_eq!(embedding.shape(), &[n_points, 2]);
570
571        // Check that values are finite
572        for val in embedding.iter() {
573            assert!(val.is_finite());
574        }
575    }
576
577    #[test]
578    fn test_isomap_epsilon_ball() {
579        let x: Array2<f64> = Array::eye(5);
580
581        let mut isomap = Isomap::new(3, 2).with_epsilon(1.5);
582        let result = isomap.fit_transform(&x);
583
584        // This should work as the identity matrix forms a connected graph with epsilon=1.5
585        assert!(result.is_ok());
586
587        let embedding = result.unwrap();
588        assert_eq!(embedding.shape(), &[5, 2]);
589    }
590
591    #[test]
592    fn test_isomap_disconnected_graph() {
593        // Create clearly disconnected data: two separate clusters
594        let x = scirs2_core::ndarray::array![
595            [0.0, 0.0],   // Cluster 1
596            [0.1, 0.1],   // Cluster 1
597            [10.0, 10.0], // Cluster 2 (far away)
598            [10.1, 10.1], // Cluster 2
599        ];
600
601        // With only 1 neighbor, the two clusters won't connect
602        let mut isomap = Isomap::new(1, 2);
603        let result = isomap.fit(&x);
604
605        // Should fail due to disconnected graph
606        assert!(result.is_err());
607        if let Err(e) = result {
608            // Verify it's specifically a connectivity error
609            match e {
610                TransformError::InvalidInput(msg) => {
611                    assert!(msg.contains("Graph is not connected"));
612                }
613                _ => panic!("Expected InvalidInput error for disconnected graph"),
614            }
615        }
616    }
617}