Skip to main content

scirs2_transform/reduction/
lle.rs

1//! Locally Linear Embedding (LLE) for non-linear dimensionality reduction
2//!
3//! LLE is a non-linear dimensionality reduction method that assumes the data lies
4//! on a low-dimensional manifold that is locally linear. It preserves local
5//! neighborhood structure in the embedding space.
6//!
7//! ## Algorithm Overview
8//!
9//! 1. **k-NN computation**: Find k nearest neighbors for each point
10//! 2. **Weight computation**: Solve least-squares for reconstruction weights
11//! 3. **Embedding**: Find eigenvectors of (I-W)^T(I-W) corresponding to smallest eigenvalues
12//!
13//! ## Variants
14//!
15//! - **Standard LLE**: Classic locally linear embedding
16//! - **Modified LLE (MLLE)**: Uses multiple weight vectors for robustness
17//! - **Hessian LLE (HLLE)**: Uses Hessian of the local geometry
18
19use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix2};
20use scirs2_core::numeric::{Float, NumCast};
21use scirs2_core::validation::{check_positive, checkshape};
22use scirs2_linalg::{eigh, solve, svd};
23use std::collections::BinaryHeap;
24
25use crate::error::{Result, TransformError};
26
27/// LLE method variant
28#[derive(Debug, Clone, PartialEq)]
29pub enum LLEMethod {
30    /// Standard LLE
31    Standard,
32    /// Modified LLE with multiple weight vectors
33    Modified,
34    /// Hessian LLE using local Hessian estimation
35    Hessian,
36}
37
38/// Locally Linear Embedding (LLE) dimensionality reduction
39///
40/// LLE finds a low-dimensional embedding that preserves local linear structure.
41/// Each point is reconstructed from its neighbors with fixed weights, and the
42/// embedding preserves these reconstruction weights.
43///
44/// # Example
45///
46/// ```rust,no_run
47/// use scirs2_transform::LLE;
48/// use scirs2_core::ndarray::Array2;
49///
50/// let data = Array2::<f64>::zeros((50, 10));
51/// let mut lle = LLE::new(10, 2);
52/// let embedding = lle.fit_transform(&data).expect("should succeed");
53/// assert_eq!(embedding.shape(), &[50, 2]);
54/// ```
55#[derive(Debug, Clone)]
56pub struct LLE {
57    /// Number of neighbors to use
58    n_neighbors: usize,
59    /// Number of components in the embedding
60    n_components: usize,
61    /// Regularization parameter
62    reg: f64,
63    /// Method variant
64    method: LLEMethod,
65    /// The embedding
66    embedding: Option<Array2<f64>>,
67    /// Reconstruction weights
68    weights: Option<Array2<f64>>,
69    /// Training data for out-of-sample extension
70    training_data: Option<Array2<f64>>,
71    /// Reconstruction error
72    reconstruction_error: Option<f64>,
73}
74
75impl LLE {
76    /// Creates a new LLE instance
77    ///
78    /// # Arguments
79    /// * `n_neighbors` - Number of neighbors to use
80    /// * `n_components` - Number of dimensions in the embedding
81    pub fn new(n_neighbors: usize, n_components: usize) -> Self {
82        LLE {
83            n_neighbors,
84            n_components,
85            reg: 1e-3,
86            method: LLEMethod::Standard,
87            embedding: None,
88            weights: None,
89            training_data: None,
90            reconstruction_error: None,
91        }
92    }
93
94    /// Set the regularization parameter
95    pub fn with_regularization(mut self, reg: f64) -> Self {
96        self.reg = reg;
97        self
98    }
99
100    /// Set the LLE method variant
101    pub fn with_method(mut self, method: &str) -> Self {
102        self.method = match method {
103            "modified" | "mlle" => LLEMethod::Modified,
104            "hessian" | "hlle" => LLEMethod::Hessian,
105            _ => LLEMethod::Standard,
106        };
107        self
108    }
109
110    /// Set the LLE method variant (typed)
111    pub fn with_method_type(mut self, method: LLEMethod) -> Self {
112        self.method = method;
113        self
114    }
115
116    /// Find k nearest neighbors for each point
117    fn find_neighbors<S>(&self, x: &ArrayBase<S, Ix2>) -> (Array2<usize>, Array2<f64>)
118    where
119        S: Data,
120        S::Elem: Float + NumCast,
121    {
122        let n_samples = x.shape()[0];
123        let mut indices = Array2::zeros((n_samples, self.n_neighbors));
124        let mut distances = Array2::zeros((n_samples, self.n_neighbors));
125
126        for i in 0..n_samples {
127            let mut heap: BinaryHeap<(std::cmp::Reverse<i64>, usize)> = BinaryHeap::new();
128
129            for j in 0..n_samples {
130                if i != j {
131                    let mut dist = 0.0;
132                    for k in 0..x.shape()[1] {
133                        let diff: f64 = NumCast::from(x[[i, k]]).unwrap_or(0.0)
134                            - NumCast::from(x[[j, k]]).unwrap_or(0.0);
135                        dist += diff * diff;
136                    }
137                    dist = dist.sqrt();
138
139                    let dist_fixed = (dist * 1e9) as i64;
140                    heap.push((std::cmp::Reverse(dist_fixed), j));
141                }
142            }
143
144            for j in 0..self.n_neighbors {
145                if let Some((std::cmp::Reverse(dist_fixed), idx)) = heap.pop() {
146                    indices[[i, j]] = idx;
147                    distances[[i, j]] = dist_fixed as f64 / 1e9;
148                }
149            }
150        }
151
152        (indices, distances)
153    }
154
155    /// Compute reconstruction weights (standard LLE)
156    fn compute_weights<S>(
157        &self,
158        x: &ArrayBase<S, Ix2>,
159        neighbors: &Array2<usize>,
160    ) -> Result<Array2<f64>>
161    where
162        S: Data,
163        S::Elem: Float + NumCast,
164    {
165        let n_samples = x.shape()[0];
166        let n_features = x.shape()[1];
167        let k = self.n_neighbors;
168
169        let mut weights = Array2::zeros((n_samples, n_samples));
170
171        for i in 0..n_samples {
172            // Create local covariance matrix
173            let mut c = Array2::zeros((k, k));
174            let xi = x.index_axis(Axis(0), i);
175
176            for j in 0..k {
177                let neighbor_j = neighbors[[i, j]];
178                let xj = x.index_axis(Axis(0), neighbor_j);
179
180                for l in 0..k {
181                    let neighbor_l = neighbors[[i, l]];
182                    let xl = x.index_axis(Axis(0), neighbor_l);
183
184                    let mut dot = 0.0;
185                    for m in 0..n_features {
186                        let diff_j: f64 = NumCast::from(xi[m]).unwrap_or(0.0)
187                            - NumCast::from(xj[m]).unwrap_or(0.0);
188                        let diff_l: f64 = NumCast::from(xi[m]).unwrap_or(0.0)
189                            - NumCast::from(xl[m]).unwrap_or(0.0);
190                        dot += diff_j * diff_l;
191                    }
192                    c[[j, l]] = dot;
193                }
194            }
195
196            // Add regularization to diagonal
197            let trace: f64 = (0..k).map(|j| c[[j, j]]).sum();
198            let reg_value = self.reg * trace / k as f64;
199            for j in 0..k {
200                c[[j, j]] += reg_value;
201            }
202
203            // Solve C * w = 1 for weights
204            let ones = Array1::ones(k);
205            let w = match solve(&c.view(), &ones.view(), None) {
206                Ok(solution) => solution,
207                Err(_) => Array1::from_elem(k, 1.0 / k as f64),
208            };
209
210            // Normalize weights to sum to 1
211            let w_sum = w.sum();
212            let w_normalized = if w_sum.abs() > 1e-10 {
213                w / w_sum
214            } else {
215                Array1::from_elem(k, 1.0 / k as f64)
216            };
217
218            for j in 0..k {
219                let neighbor = neighbors[[i, j]];
220                weights[[i, neighbor]] = w_normalized[j];
221            }
222        }
223
224        Ok(weights)
225    }
226
227    /// Compute weights for Modified LLE (MLLE)
228    ///
229    /// Modified LLE uses multiple weight vectors to provide more robust embeddings.
230    /// For each point, it computes the SVD of the local neighborhood and uses
231    /// the bottom singular vectors to form a more stable weight matrix.
232    fn compute_weights_modified<S>(
233        &self,
234        x: &ArrayBase<S, Ix2>,
235        neighbors: &Array2<usize>,
236    ) -> Result<Array2<f64>>
237    where
238        S: Data,
239        S::Elem: Float + NumCast,
240    {
241        let n_samples = x.shape()[0];
242        let n_features = x.shape()[1];
243        let k = self.n_neighbors;
244        let d = self.n_components;
245
246        // Number of extra weight vectors for MLLE
247        let n_extra = k.saturating_sub(d + 1).min(d);
248        if n_extra == 0 {
249            // Fall back to standard LLE
250            return self.compute_weights(x, neighbors);
251        }
252
253        let mut weights = Array2::zeros((n_samples, n_samples));
254
255        for i in 0..n_samples {
256            // Get the local neighborhood centered at xi
257            let mut local_data = Array2::zeros((k, n_features));
258            let xi = x.index_axis(Axis(0), i);
259
260            for j in 0..k {
261                let neighbor_j = neighbors[[i, j]];
262                for m in 0..n_features {
263                    let val_i: f64 = NumCast::from(xi[m]).unwrap_or(0.0);
264                    let val_j: f64 = NumCast::from(x[[neighbor_j, m]]).unwrap_or(0.0);
265                    local_data[[j, m]] = val_j - val_i;
266                }
267            }
268
269            // SVD of the centered local neighborhood
270            let svd_result = svd::<f64>(&local_data.view(), true, None);
271            let (u, _s, _vt) = match svd_result {
272                Ok(result) => result,
273                Err(_) => {
274                    // Fallback to standard weights for this point
275                    let ones = Array1::from_elem(k, 1.0 / k as f64);
276                    for j in 0..k {
277                        weights[[i, neighbors[[i, j]]]] = ones[j];
278                    }
279                    continue;
280                }
281            };
282
283            // Use the last (k - d) singular vectors of U
284            // These span the "null space" and give the weight vectors
285            let start_col = d.min(u.shape()[1].saturating_sub(1));
286            let n_weight_vecs = (u.shape()[1] - start_col).min(n_extra + 1).max(1);
287
288            // Average the squared singular vectors to form the weight vector
289            let mut w = Array1::zeros(k);
290            for j in 0..n_weight_vecs {
291                let col_idx = start_col + j;
292                if col_idx < u.shape()[1] {
293                    for r in 0..k {
294                        w[r] += u[[r, col_idx]] * u[[r, col_idx]];
295                    }
296                }
297            }
298
299            // Normalize
300            let w_sum = w.sum();
301            if w_sum > 1e-10 {
302                w.mapv_inplace(|v| v / w_sum);
303            } else {
304                w = Array1::from_elem(k, 1.0 / k as f64);
305            }
306
307            for j in 0..k {
308                weights[[i, neighbors[[i, j]]]] = w[j];
309            }
310        }
311
312        Ok(weights)
313    }
314
315    /// Compute weights for Hessian LLE (HLLE)
316    ///
317    /// Hessian LLE estimates the local Hessian of the manifold and uses it
318    /// to construct a weight matrix that better captures local curvature.
319    fn compute_weights_hessian<S>(
320        &self,
321        x: &ArrayBase<S, Ix2>,
322        neighbors: &Array2<usize>,
323    ) -> Result<Array2<f64>>
324    where
325        S: Data,
326        S::Elem: Float + NumCast,
327    {
328        let n_samples = x.shape()[0];
329        let n_features = x.shape()[1];
330        let k = self.n_neighbors;
331        let d = self.n_components;
332
333        // HLLE requires k > d * (d + 3) / 2
334        let min_k = d * (d + 3) / 2 + 1;
335        if k < min_k {
336            return Err(TransformError::InvalidInput(format!(
337                "Hessian LLE requires n_neighbors >= {} for n_components={}, got {}",
338                min_k, d, k
339            )));
340        }
341
342        // Number of Hessian components
343        let dp = d * (d + 1) / 2;
344
345        let mut weights = Array2::zeros((n_samples, n_samples));
346
347        for i in 0..n_samples {
348            // Get centered local neighborhood
349            let mut local_data = Array2::zeros((k, n_features));
350            let xi = x.index_axis(Axis(0), i);
351
352            for j in 0..k {
353                let neighbor_j = neighbors[[i, j]];
354                for m in 0..n_features {
355                    let val_i: f64 = NumCast::from(xi[m]).unwrap_or(0.0);
356                    let val_j: f64 = NumCast::from(x[[neighbor_j, m]]).unwrap_or(0.0);
357                    local_data[[j, m]] = val_j - val_i;
358                }
359            }
360
361            // SVD of local neighborhood to get the tangent space
362            let (u, _s, _vt) = match svd::<f64>(&local_data.view(), true, None) {
363                Ok(result) => result,
364                Err(_) => {
365                    let ones = Array1::from_elem(k, 1.0 / k as f64);
366                    for j in 0..k {
367                        weights[[i, neighbors[[i, j]]]] = ones[j];
368                    }
369                    continue;
370                }
371            };
372
373            // Take the first d columns of U as the tangent coordinates
374            let mut tangent = Array2::zeros((k, d));
375            let max_d = d.min(u.shape()[1]);
376            for j in 0..max_d {
377                for r in 0..k {
378                    tangent[[r, j]] = u[[r, j]];
379                }
380            }
381
382            // Build the Hessian estimator matrix
383            // Columns: [1, t1, t2, ..., td, t1*t1, t1*t2, ..., td*td]
384            let n_cols = 1 + d + dp;
385            let mut h_mat = Array2::zeros((k, n_cols));
386
387            for r in 0..k {
388                h_mat[[r, 0]] = 1.0; // constant
389                for j in 0..max_d {
390                    h_mat[[r, 1 + j]] = tangent[[r, j]]; // linear
391                }
392
393                // Quadratic terms
394                let mut col = 1 + d;
395                for j in 0..max_d {
396                    for l in j..max_d {
397                        h_mat[[r, col]] = tangent[[r, j]] * tangent[[r, l]];
398                        col += 1;
399                    }
400                }
401            }
402
403            // QR decomposition via Gram-Schmidt to get the null space
404            // We want the projection onto the null space of H^T
405            let (q, _r) = self.qr_decomposition(&h_mat)?;
406
407            // The weight vector comes from the columns of Q beyond the first n_cols
408            let mut w = Array1::zeros(k);
409            let start_col = n_cols.min(q.shape()[1]);
410            let mut count = 0;
411            for col in start_col..q.shape()[1] {
412                for r in 0..k {
413                    w[r] += q[[r, col]] * q[[r, col]];
414                }
415                count += 1;
416            }
417
418            if count == 0 {
419                // Fallback
420                w = Array1::from_elem(k, 1.0 / k as f64);
421            } else {
422                let w_sum = w.sum();
423                if w_sum > 1e-10 {
424                    w.mapv_inplace(|v| v / w_sum);
425                } else {
426                    w = Array1::from_elem(k, 1.0 / k as f64);
427                }
428            }
429
430            for j in 0..k {
431                weights[[i, neighbors[[i, j]]]] = w[j];
432            }
433        }
434
435        Ok(weights)
436    }
437
438    /// Simple QR decomposition using modified Gram-Schmidt
439    fn qr_decomposition(&self, a: &Array2<f64>) -> Result<(Array2<f64>, Array2<f64>)> {
440        let (m, n) = a.dim();
441        let mut q = a.clone();
442        let mut r = Array2::zeros((n, n));
443
444        for j in 0..n {
445            // Normalize the j-th column
446            let mut norm = 0.0;
447            for i in 0..m {
448                norm += q[[i, j]] * q[[i, j]];
449            }
450            norm = norm.sqrt();
451
452            r[[j, j]] = norm;
453            if norm > 1e-14 {
454                for i in 0..m {
455                    q[[i, j]] /= norm;
456                }
457            }
458
459            // Orthogonalize remaining columns against q_j
460            for k in (j + 1)..n {
461                let mut dot = 0.0;
462                for i in 0..m {
463                    dot += q[[i, j]] * q[[i, k]];
464                }
465                r[[j, k]] = dot;
466                for i in 0..m {
467                    q[[i, k]] -= dot * q[[i, j]];
468                }
469            }
470        }
471
472        Ok((q, r))
473    }
474
475    /// Compute the embedding from reconstruction weights
476    fn compute_embedding(&self, weights: &Array2<f64>) -> Result<Array2<f64>> {
477        let n_samples = weights.shape()[0];
478
479        // Construct the cost matrix M = (I - W)^T (I - W)
480        let mut m = Array2::zeros((n_samples, n_samples));
481
482        for i in 0..n_samples {
483            for j in 0..n_samples {
484                let mut sum = 0.0;
485
486                if i == j {
487                    sum += 1.0 - 2.0 * weights[[i, j]] + weights.column(j).dot(&weights.column(j));
488                } else {
489                    sum += -weights[[i, j]] - weights[[j, i]]
490                        + weights.column(i).dot(&weights.column(j));
491                }
492
493                m[[i, j]] = sum;
494            }
495        }
496
497        // Find the eigenvectors corresponding to the smallest eigenvalues
498        let (eigenvalues, eigenvectors) =
499            eigh(&m.view(), None).map_err(|e| TransformError::LinalgError(e))?;
500
501        // Sort eigenvalues and eigenvectors
502        let mut indices: Vec<usize> = (0..n_samples).collect();
503        indices.sort_by(|&i, &j| {
504            eigenvalues[i]
505                .partial_cmp(&eigenvalues[j])
506                .unwrap_or(std::cmp::Ordering::Equal)
507        });
508
509        // Skip the first eigenvector (corresponding to eigenvalue ~0)
510        // and take the next n_components eigenvectors
511        let mut embedding = Array2::zeros((n_samples, self.n_components));
512        for j in 0..self.n_components {
513            let idx = indices[j + 1]; // Skip first eigenvector
514            for i in 0..n_samples {
515                embedding[[i, j]] = eigenvectors[[i, idx]];
516            }
517        }
518
519        // Compute reconstruction error
520        let recon_error: f64 = (0..self.n_components)
521            .map(|j| {
522                let idx = indices[j + 1];
523                eigenvalues[idx].max(0.0)
524            })
525            .sum();
526
527        // Store as side effect via return
528        // (We'll compute it in fit())
529        let _ = recon_error;
530
531        Ok(embedding)
532    }
533
534    /// Compute reconstruction error
535    fn compute_reconstruction_error(&self, weights: &Array2<f64>, embedding: &Array2<f64>) -> f64 {
536        let n_samples = weights.shape()[0];
537        let n_components = embedding.shape()[1];
538
539        let mut total_error = 0.0;
540
541        for i in 0..n_samples {
542            for d in 0..n_components {
543                let mut reconstructed = 0.0;
544                for j in 0..n_samples {
545                    reconstructed += weights[[i, j]] * embedding[[j, d]];
546                }
547                let diff = embedding[[i, d]] - reconstructed;
548                total_error += diff * diff;
549            }
550        }
551
552        total_error / n_samples as f64
553    }
554
555    /// Fits the LLE model to the input data
556    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
557    where
558        S: Data,
559        S::Elem: Float + NumCast,
560    {
561        let (n_samples, n_features) = x.dim();
562
563        check_positive(self.n_neighbors, "n_neighbors")?;
564        check_positive(self.n_components, "n_components")?;
565        checkshape(x, &[n_samples, n_features], "x")?;
566
567        if n_samples <= self.n_neighbors {
568            return Err(TransformError::InvalidInput(format!(
569                "n_neighbors={} must be < n_samples={}",
570                self.n_neighbors, n_samples
571            )));
572        }
573
574        if self.n_components >= n_samples {
575            return Err(TransformError::InvalidInput(format!(
576                "n_components={} must be < n_samples={}",
577                self.n_components, n_samples
578            )));
579        }
580
581        let x_f64 = x.mapv(|v| NumCast::from(v).unwrap_or(0.0));
582
583        // Step 1: Find k nearest neighbors
584        let (neighbors, _distances) = self.find_neighbors(&x_f64.view());
585
586        // Step 2: Compute reconstruction weights based on method
587        let weights = match &self.method {
588            LLEMethod::Standard => self.compute_weights(&x_f64.view(), &neighbors)?,
589            LLEMethod::Modified => self.compute_weights_modified(&x_f64.view(), &neighbors)?,
590            LLEMethod::Hessian => self.compute_weights_hessian(&x_f64.view(), &neighbors)?,
591        };
592
593        // Step 3: Compute embedding from weights
594        let embedding = self.compute_embedding(&weights)?;
595
596        // Compute reconstruction error
597        let recon_error = self.compute_reconstruction_error(&weights, &embedding);
598
599        self.embedding = Some(embedding);
600        self.weights = Some(weights);
601        self.training_data = Some(x_f64);
602        self.reconstruction_error = Some(recon_error);
603
604        Ok(())
605    }
606
607    /// Transforms the input data using the fitted LLE model
608    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
609    where
610        S: Data,
611        S::Elem: Float + NumCast,
612    {
613        if self.embedding.is_none() {
614            return Err(TransformError::NotFitted(
615                "LLE model has not been fitted".to_string(),
616            ));
617        }
618
619        let training_data = self
620            .training_data
621            .as_ref()
622            .ok_or_else(|| TransformError::NotFitted("Training data not available".to_string()))?;
623
624        let x_f64 = x.mapv(|v| NumCast::from(v).unwrap_or(0.0));
625
626        if self.is_same_data(&x_f64, training_data) {
627            return self
628                .embedding
629                .as_ref()
630                .cloned()
631                .ok_or_else(|| TransformError::NotFitted("Embedding not available".to_string()));
632        }
633
634        self.transform_new_data(&x_f64)
635    }
636
637    /// Fits the LLE model and transforms the data
638    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
639    where
640        S: Data,
641        S::Elem: Float + NumCast,
642    {
643        self.fit(x)?;
644        self.transform(x)
645    }
646
647    /// Returns the embedding
648    pub fn embedding(&self) -> Option<&Array2<f64>> {
649        self.embedding.as_ref()
650    }
651
652    /// Returns the reconstruction weights
653    pub fn reconstruction_weights(&self) -> Option<&Array2<f64>> {
654        self.weights.as_ref()
655    }
656
657    /// Returns the reconstruction error
658    pub fn reconstruction_error(&self) -> Option<f64> {
659        self.reconstruction_error
660    }
661
662    /// Check if the input data is the same as training data
663    fn is_same_data(&self, x: &Array2<f64>, training_data: &Array2<f64>) -> bool {
664        if x.dim() != training_data.dim() {
665            return false;
666        }
667        let (n_samples, n_features) = x.dim();
668        for i in 0..n_samples {
669            for j in 0..n_features {
670                if (x[[i, j]] - training_data[[i, j]]).abs() > 1e-10 {
671                    return false;
672                }
673            }
674        }
675        true
676    }
677
678    /// Transform new data using out-of-sample extension
679    fn transform_new_data(&self, x_new: &Array2<f64>) -> Result<Array2<f64>> {
680        let training_data = self
681            .training_data
682            .as_ref()
683            .ok_or_else(|| TransformError::NotFitted("Training data not available".to_string()))?;
684        let training_embedding = self
685            .embedding
686            .as_ref()
687            .ok_or_else(|| TransformError::NotFitted("Embedding not available".to_string()))?;
688
689        let (n_new, n_features) = x_new.dim();
690
691        if n_features != training_data.ncols() {
692            return Err(TransformError::InvalidInput(format!(
693                "Input features {} must match training features {}",
694                n_features,
695                training_data.ncols()
696            )));
697        }
698
699        let mut new_embedding = Array2::zeros((n_new, self.n_components));
700
701        for i in 0..n_new {
702            let new_coords =
703                self.compute_new_point_embedding(&x_new.row(i), training_data, training_embedding)?;
704
705            for j in 0..self.n_components {
706                new_embedding[[i, j]] = new_coords[j];
707            }
708        }
709
710        Ok(new_embedding)
711    }
712
713    /// Compute embedding coordinates for a single new point
714    fn compute_new_point_embedding(
715        &self,
716        x_new: &scirs2_core::ndarray::ArrayView1<f64>,
717        training_data: &Array2<f64>,
718        training_embedding: &Array2<f64>,
719    ) -> Result<Array1<f64>> {
720        let n_training = training_data.nrows();
721        let n_features = training_data.ncols();
722
723        // Find k nearest neighbors in training data
724        let mut distances: Vec<(f64, usize)> = Vec::with_capacity(n_training);
725        for j in 0..n_training {
726            let mut dist_sq = 0.0;
727            for k in 0..n_features {
728                let diff = x_new[k] - training_data[[j, k]];
729                dist_sq += diff * diff;
730            }
731            distances.push((dist_sq.sqrt(), j));
732        }
733
734        distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
735        let k = self.n_neighbors.min(n_training);
736        let neighbor_indices: Vec<usize> =
737            distances.into_iter().take(k).map(|(_, idx)| idx).collect();
738
739        // Compute reconstruction weights
740        let weights =
741            self.compute_reconstruction_weights_for_point(x_new, training_data, &neighbor_indices)?;
742
743        // Compute embedding as weighted combination
744        let mut new_coords = Array1::zeros(self.n_components);
745        for (i, &neighbor_idx) in neighbor_indices.iter().enumerate() {
746            for dim in 0..self.n_components {
747                new_coords[dim] += weights[i] * training_embedding[[neighbor_idx, dim]];
748            }
749        }
750
751        Ok(new_coords)
752    }
753
754    /// Compute reconstruction weights for a single point given its neighbors
755    fn compute_reconstruction_weights_for_point(
756        &self,
757        x_point: &scirs2_core::ndarray::ArrayView1<f64>,
758        training_data: &Array2<f64>,
759        neighbor_indices: &[usize],
760    ) -> Result<Array1<f64>> {
761        let k = neighbor_indices.len();
762        let n_features = training_data.ncols();
763
764        let mut c = Array2::zeros((k, k));
765
766        for i in 0..k {
767            let neighbor_i = neighbor_indices[i];
768            for j in 0..k {
769                let neighbor_j = neighbor_indices[j];
770
771                let mut dot = 0.0;
772                for m in 0..n_features {
773                    let diff_i = x_point[m] - training_data[[neighbor_i, m]];
774                    let diff_j = x_point[m] - training_data[[neighbor_j, m]];
775                    dot += diff_i * diff_j;
776                }
777                c[[i, j]] = dot;
778            }
779        }
780
781        let trace: f64 = (0..k).map(|i| c[[i, i]]).sum();
782        let reg_value = self.reg * trace / k as f64;
783        for i in 0..k {
784            c[[i, i]] += reg_value;
785        }
786
787        let ones = Array1::ones(k);
788        let w = match solve(&c.view(), &ones.view(), None) {
789            Ok(solution) => solution,
790            Err(_) => Array1::from_elem(k, 1.0 / k as f64),
791        };
792
793        let w_sum = w.sum();
794        let w_normalized = if w_sum.abs() > 1e-10 {
795            w / w_sum
796        } else {
797            Array1::from_elem(k, 1.0 / k as f64)
798        };
799
800        Ok(w_normalized)
801    }
802}
803
804#[cfg(test)]
805mod tests {
806    use super::*;
807    use scirs2_core::ndarray::Array;
808
809    #[test]
810    fn test_lle_basic() {
811        let n_points = 20;
812        let mut data = Vec::new();
813
814        for i in 0..n_points {
815            let t = 1.5 * std::f64::consts::PI * (1.0 + 2.0 * i as f64 / n_points as f64);
816            let x = t * t.cos();
817            let y = 10.0 * i as f64 / n_points as f64;
818            let z = t * t.sin();
819            data.extend_from_slice(&[x, y, z]);
820        }
821
822        let x = Array::from_shape_vec((n_points, 3), data).expect("Failed to create array");
823
824        let mut lle = LLE::new(5, 2);
825        let embedding = lle.fit_transform(&x).expect("LLE fit_transform failed");
826
827        assert_eq!(embedding.shape(), &[n_points, 2]);
828        for val in embedding.iter() {
829            assert!(val.is_finite());
830        }
831    }
832
833    #[test]
834    fn test_lle_regularization() {
835        let x: Array2<f64> = Array::eye(10) * 2.0;
836
837        let mut lle = LLE::new(3, 2).with_regularization(0.01);
838        let result = lle.fit_transform(&x);
839
840        assert!(result.is_ok());
841        let embedding = result.expect("LLE fit_transform failed");
842        assert_eq!(embedding.shape(), &[10, 2]);
843    }
844
845    #[test]
846    fn test_lle_modified() {
847        let n_points = 20;
848        let mut data = Vec::new();
849        for i in 0..n_points {
850            let t = i as f64 / n_points as f64 * 2.0 * std::f64::consts::PI;
851            data.extend_from_slice(&[t.cos(), t.sin(), t * 0.1]);
852        }
853
854        let x = Array::from_shape_vec((n_points, 3), data).expect("Failed to create array");
855
856        let mut lle = LLE::new(5, 2).with_method("modified");
857        let embedding = lle.fit_transform(&x).expect("MLLE fit_transform failed");
858
859        assert_eq!(embedding.shape(), &[n_points, 2]);
860        for val in embedding.iter() {
861            assert!(val.is_finite());
862        }
863    }
864
865    #[test]
866    fn test_lle_hessian() {
867        let n_points = 25;
868        let mut data = Vec::new();
869        for i in 0..n_points {
870            let t = i as f64 / n_points as f64;
871            data.extend_from_slice(&[t, t * 2.0, t * 3.0, t * t]);
872        }
873
874        let x = Array::from_shape_vec((n_points, 4), data).expect("Failed to create array");
875
876        // HLLE requires k >= d*(d+3)/2 + 1 = 2*(2+3)/2 + 1 = 6
877        let mut lle = LLE::new(7, 2).with_method("hessian");
878        let embedding = lle.fit_transform(&x).expect("HLLE fit_transform failed");
879
880        assert_eq!(embedding.shape(), &[n_points, 2]);
881        for val in embedding.iter() {
882            assert!(val.is_finite());
883        }
884    }
885
886    #[test]
887    fn test_lle_invalid_params() {
888        let x: Array2<f64> = Array::eye(5);
889
890        let mut lle = LLE::new(10, 2);
891        assert!(lle.fit(&x).is_err());
892
893        let mut lle = LLE::new(2, 10);
894        assert!(lle.fit(&x).is_err());
895    }
896
897    #[test]
898    fn test_lle_reconstruction_error() {
899        let n_points = 20;
900        let mut data = Vec::new();
901        for i in 0..n_points {
902            let t = i as f64 / n_points as f64;
903            data.extend_from_slice(&[t, t * 2.0, t * 3.0]);
904        }
905
906        let x = Array::from_shape_vec((n_points, 3), data).expect("Failed to create array");
907
908        let mut lle = LLE::new(5, 2);
909        let _ = lle.fit_transform(&x).expect("LLE fit_transform failed");
910
911        let error = lle.reconstruction_error();
912        assert!(error.is_some());
913        let error_val = error.expect("Error should exist");
914        assert!(error_val >= 0.0);
915        assert!(error_val.is_finite());
916    }
917
918    #[test]
919    fn test_lle_out_of_sample() {
920        let n_points = 20;
921        let mut data = Vec::new();
922        for i in 0..n_points {
923            let t = i as f64 / n_points as f64;
924            data.extend_from_slice(&[t, t * 2.0, t * 3.0]);
925        }
926
927        let x = Array::from_shape_vec((n_points, 3), data).expect("Failed to create array");
928
929        let mut lle = LLE::new(5, 2);
930        lle.fit(&x).expect("LLE fit failed");
931
932        let x_new = Array::from_shape_vec((2, 3), vec![0.25, 0.5, 0.75, 0.75, 1.5, 2.25])
933            .expect("Failed to create test array");
934
935        let new_embedding = lle.transform(&x_new).expect("LLE transform failed");
936        assert_eq!(new_embedding.shape(), &[2, 2]);
937        for val in new_embedding.iter() {
938            assert!(val.is_finite());
939        }
940    }
941}