1use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix2};
20use scirs2_core::numeric::{Float, NumCast};
21use scirs2_core::validation::{check_positive, checkshape};
22use scirs2_linalg::{eigh, solve, svd};
23use std::collections::BinaryHeap;
24
25use crate::error::{Result, TransformError};
26
27#[derive(Debug, Clone, PartialEq)]
29pub enum LLEMethod {
30 Standard,
32 Modified,
34 Hessian,
36}
37
38#[derive(Debug, Clone)]
56pub struct LLE {
57 n_neighbors: usize,
59 n_components: usize,
61 reg: f64,
63 method: LLEMethod,
65 embedding: Option<Array2<f64>>,
67 weights: Option<Array2<f64>>,
69 training_data: Option<Array2<f64>>,
71 reconstruction_error: Option<f64>,
73}
74
75impl LLE {
76 pub fn new(n_neighbors: usize, n_components: usize) -> Self {
82 LLE {
83 n_neighbors,
84 n_components,
85 reg: 1e-3,
86 method: LLEMethod::Standard,
87 embedding: None,
88 weights: None,
89 training_data: None,
90 reconstruction_error: None,
91 }
92 }
93
94 pub fn with_regularization(mut self, reg: f64) -> Self {
96 self.reg = reg;
97 self
98 }
99
100 pub fn with_method(mut self, method: &str) -> Self {
102 self.method = match method {
103 "modified" | "mlle" => LLEMethod::Modified,
104 "hessian" | "hlle" => LLEMethod::Hessian,
105 _ => LLEMethod::Standard,
106 };
107 self
108 }
109
110 pub fn with_method_type(mut self, method: LLEMethod) -> Self {
112 self.method = method;
113 self
114 }
115
116 fn find_neighbors<S>(&self, x: &ArrayBase<S, Ix2>) -> (Array2<usize>, Array2<f64>)
118 where
119 S: Data,
120 S::Elem: Float + NumCast,
121 {
122 let n_samples = x.shape()[0];
123 let mut indices = Array2::zeros((n_samples, self.n_neighbors));
124 let mut distances = Array2::zeros((n_samples, self.n_neighbors));
125
126 for i in 0..n_samples {
127 let mut heap: BinaryHeap<(std::cmp::Reverse<i64>, usize)> = BinaryHeap::new();
128
129 for j in 0..n_samples {
130 if i != j {
131 let mut dist = 0.0;
132 for k in 0..x.shape()[1] {
133 let diff: f64 = NumCast::from(x[[i, k]]).unwrap_or(0.0)
134 - NumCast::from(x[[j, k]]).unwrap_or(0.0);
135 dist += diff * diff;
136 }
137 dist = dist.sqrt();
138
139 let dist_fixed = (dist * 1e9) as i64;
140 heap.push((std::cmp::Reverse(dist_fixed), j));
141 }
142 }
143
144 for j in 0..self.n_neighbors {
145 if let Some((std::cmp::Reverse(dist_fixed), idx)) = heap.pop() {
146 indices[[i, j]] = idx;
147 distances[[i, j]] = dist_fixed as f64 / 1e9;
148 }
149 }
150 }
151
152 (indices, distances)
153 }
154
155 fn compute_weights<S>(
157 &self,
158 x: &ArrayBase<S, Ix2>,
159 neighbors: &Array2<usize>,
160 ) -> Result<Array2<f64>>
161 where
162 S: Data,
163 S::Elem: Float + NumCast,
164 {
165 let n_samples = x.shape()[0];
166 let n_features = x.shape()[1];
167 let k = self.n_neighbors;
168
169 let mut weights = Array2::zeros((n_samples, n_samples));
170
171 for i in 0..n_samples {
172 let mut c = Array2::zeros((k, k));
174 let xi = x.index_axis(Axis(0), i);
175
176 for j in 0..k {
177 let neighbor_j = neighbors[[i, j]];
178 let xj = x.index_axis(Axis(0), neighbor_j);
179
180 for l in 0..k {
181 let neighbor_l = neighbors[[i, l]];
182 let xl = x.index_axis(Axis(0), neighbor_l);
183
184 let mut dot = 0.0;
185 for m in 0..n_features {
186 let diff_j: f64 = NumCast::from(xi[m]).unwrap_or(0.0)
187 - NumCast::from(xj[m]).unwrap_or(0.0);
188 let diff_l: f64 = NumCast::from(xi[m]).unwrap_or(0.0)
189 - NumCast::from(xl[m]).unwrap_or(0.0);
190 dot += diff_j * diff_l;
191 }
192 c[[j, l]] = dot;
193 }
194 }
195
196 let trace: f64 = (0..k).map(|j| c[[j, j]]).sum();
198 let reg_value = self.reg * trace / k as f64;
199 for j in 0..k {
200 c[[j, j]] += reg_value;
201 }
202
203 let ones = Array1::ones(k);
205 let w = match solve(&c.view(), &ones.view(), None) {
206 Ok(solution) => solution,
207 Err(_) => Array1::from_elem(k, 1.0 / k as f64),
208 };
209
210 let w_sum = w.sum();
212 let w_normalized = if w_sum.abs() > 1e-10 {
213 w / w_sum
214 } else {
215 Array1::from_elem(k, 1.0 / k as f64)
216 };
217
218 for j in 0..k {
219 let neighbor = neighbors[[i, j]];
220 weights[[i, neighbor]] = w_normalized[j];
221 }
222 }
223
224 Ok(weights)
225 }
226
227 fn compute_weights_modified<S>(
233 &self,
234 x: &ArrayBase<S, Ix2>,
235 neighbors: &Array2<usize>,
236 ) -> Result<Array2<f64>>
237 where
238 S: Data,
239 S::Elem: Float + NumCast,
240 {
241 let n_samples = x.shape()[0];
242 let n_features = x.shape()[1];
243 let k = self.n_neighbors;
244 let d = self.n_components;
245
246 let n_extra = k.saturating_sub(d + 1).min(d);
248 if n_extra == 0 {
249 return self.compute_weights(x, neighbors);
251 }
252
253 let mut weights = Array2::zeros((n_samples, n_samples));
254
255 for i in 0..n_samples {
256 let mut local_data = Array2::zeros((k, n_features));
258 let xi = x.index_axis(Axis(0), i);
259
260 for j in 0..k {
261 let neighbor_j = neighbors[[i, j]];
262 for m in 0..n_features {
263 let val_i: f64 = NumCast::from(xi[m]).unwrap_or(0.0);
264 let val_j: f64 = NumCast::from(x[[neighbor_j, m]]).unwrap_or(0.0);
265 local_data[[j, m]] = val_j - val_i;
266 }
267 }
268
269 let svd_result = svd::<f64>(&local_data.view(), true, None);
271 let (u, _s, _vt) = match svd_result {
272 Ok(result) => result,
273 Err(_) => {
274 let ones = Array1::from_elem(k, 1.0 / k as f64);
276 for j in 0..k {
277 weights[[i, neighbors[[i, j]]]] = ones[j];
278 }
279 continue;
280 }
281 };
282
283 let start_col = d.min(u.shape()[1].saturating_sub(1));
286 let n_weight_vecs = (u.shape()[1] - start_col).min(n_extra + 1).max(1);
287
288 let mut w = Array1::zeros(k);
290 for j in 0..n_weight_vecs {
291 let col_idx = start_col + j;
292 if col_idx < u.shape()[1] {
293 for r in 0..k {
294 w[r] += u[[r, col_idx]] * u[[r, col_idx]];
295 }
296 }
297 }
298
299 let w_sum = w.sum();
301 if w_sum > 1e-10 {
302 w.mapv_inplace(|v| v / w_sum);
303 } else {
304 w = Array1::from_elem(k, 1.0 / k as f64);
305 }
306
307 for j in 0..k {
308 weights[[i, neighbors[[i, j]]]] = w[j];
309 }
310 }
311
312 Ok(weights)
313 }
314
315 fn compute_weights_hessian<S>(
320 &self,
321 x: &ArrayBase<S, Ix2>,
322 neighbors: &Array2<usize>,
323 ) -> Result<Array2<f64>>
324 where
325 S: Data,
326 S::Elem: Float + NumCast,
327 {
328 let n_samples = x.shape()[0];
329 let n_features = x.shape()[1];
330 let k = self.n_neighbors;
331 let d = self.n_components;
332
333 let min_k = d * (d + 3) / 2 + 1;
335 if k < min_k {
336 return Err(TransformError::InvalidInput(format!(
337 "Hessian LLE requires n_neighbors >= {} for n_components={}, got {}",
338 min_k, d, k
339 )));
340 }
341
342 let dp = d * (d + 1) / 2;
344
345 let mut weights = Array2::zeros((n_samples, n_samples));
346
347 for i in 0..n_samples {
348 let mut local_data = Array2::zeros((k, n_features));
350 let xi = x.index_axis(Axis(0), i);
351
352 for j in 0..k {
353 let neighbor_j = neighbors[[i, j]];
354 for m in 0..n_features {
355 let val_i: f64 = NumCast::from(xi[m]).unwrap_or(0.0);
356 let val_j: f64 = NumCast::from(x[[neighbor_j, m]]).unwrap_or(0.0);
357 local_data[[j, m]] = val_j - val_i;
358 }
359 }
360
361 let (u, _s, _vt) = match svd::<f64>(&local_data.view(), true, None) {
363 Ok(result) => result,
364 Err(_) => {
365 let ones = Array1::from_elem(k, 1.0 / k as f64);
366 for j in 0..k {
367 weights[[i, neighbors[[i, j]]]] = ones[j];
368 }
369 continue;
370 }
371 };
372
373 let mut tangent = Array2::zeros((k, d));
375 let max_d = d.min(u.shape()[1]);
376 for j in 0..max_d {
377 for r in 0..k {
378 tangent[[r, j]] = u[[r, j]];
379 }
380 }
381
382 let n_cols = 1 + d + dp;
385 let mut h_mat = Array2::zeros((k, n_cols));
386
387 for r in 0..k {
388 h_mat[[r, 0]] = 1.0; for j in 0..max_d {
390 h_mat[[r, 1 + j]] = tangent[[r, j]]; }
392
393 let mut col = 1 + d;
395 for j in 0..max_d {
396 for l in j..max_d {
397 h_mat[[r, col]] = tangent[[r, j]] * tangent[[r, l]];
398 col += 1;
399 }
400 }
401 }
402
403 let (q, _r) = self.qr_decomposition(&h_mat)?;
406
407 let mut w = Array1::zeros(k);
409 let start_col = n_cols.min(q.shape()[1]);
410 let mut count = 0;
411 for col in start_col..q.shape()[1] {
412 for r in 0..k {
413 w[r] += q[[r, col]] * q[[r, col]];
414 }
415 count += 1;
416 }
417
418 if count == 0 {
419 w = Array1::from_elem(k, 1.0 / k as f64);
421 } else {
422 let w_sum = w.sum();
423 if w_sum > 1e-10 {
424 w.mapv_inplace(|v| v / w_sum);
425 } else {
426 w = Array1::from_elem(k, 1.0 / k as f64);
427 }
428 }
429
430 for j in 0..k {
431 weights[[i, neighbors[[i, j]]]] = w[j];
432 }
433 }
434
435 Ok(weights)
436 }
437
438 fn qr_decomposition(&self, a: &Array2<f64>) -> Result<(Array2<f64>, Array2<f64>)> {
440 let (m, n) = a.dim();
441 let mut q = a.clone();
442 let mut r = Array2::zeros((n, n));
443
444 for j in 0..n {
445 let mut norm = 0.0;
447 for i in 0..m {
448 norm += q[[i, j]] * q[[i, j]];
449 }
450 norm = norm.sqrt();
451
452 r[[j, j]] = norm;
453 if norm > 1e-14 {
454 for i in 0..m {
455 q[[i, j]] /= norm;
456 }
457 }
458
459 for k in (j + 1)..n {
461 let mut dot = 0.0;
462 for i in 0..m {
463 dot += q[[i, j]] * q[[i, k]];
464 }
465 r[[j, k]] = dot;
466 for i in 0..m {
467 q[[i, k]] -= dot * q[[i, j]];
468 }
469 }
470 }
471
472 Ok((q, r))
473 }
474
475 fn compute_embedding(&self, weights: &Array2<f64>) -> Result<Array2<f64>> {
477 let n_samples = weights.shape()[0];
478
479 let mut m = Array2::zeros((n_samples, n_samples));
481
482 for i in 0..n_samples {
483 for j in 0..n_samples {
484 let mut sum = 0.0;
485
486 if i == j {
487 sum += 1.0 - 2.0 * weights[[i, j]] + weights.column(j).dot(&weights.column(j));
488 } else {
489 sum += -weights[[i, j]] - weights[[j, i]]
490 + weights.column(i).dot(&weights.column(j));
491 }
492
493 m[[i, j]] = sum;
494 }
495 }
496
497 let (eigenvalues, eigenvectors) =
499 eigh(&m.view(), None).map_err(|e| TransformError::LinalgError(e))?;
500
501 let mut indices: Vec<usize> = (0..n_samples).collect();
503 indices.sort_by(|&i, &j| {
504 eigenvalues[i]
505 .partial_cmp(&eigenvalues[j])
506 .unwrap_or(std::cmp::Ordering::Equal)
507 });
508
509 let mut embedding = Array2::zeros((n_samples, self.n_components));
512 for j in 0..self.n_components {
513 let idx = indices[j + 1]; for i in 0..n_samples {
515 embedding[[i, j]] = eigenvectors[[i, idx]];
516 }
517 }
518
519 let recon_error: f64 = (0..self.n_components)
521 .map(|j| {
522 let idx = indices[j + 1];
523 eigenvalues[idx].max(0.0)
524 })
525 .sum();
526
527 let _ = recon_error;
530
531 Ok(embedding)
532 }
533
534 fn compute_reconstruction_error(&self, weights: &Array2<f64>, embedding: &Array2<f64>) -> f64 {
536 let n_samples = weights.shape()[0];
537 let n_components = embedding.shape()[1];
538
539 let mut total_error = 0.0;
540
541 for i in 0..n_samples {
542 for d in 0..n_components {
543 let mut reconstructed = 0.0;
544 for j in 0..n_samples {
545 reconstructed += weights[[i, j]] * embedding[[j, d]];
546 }
547 let diff = embedding[[i, d]] - reconstructed;
548 total_error += diff * diff;
549 }
550 }
551
552 total_error / n_samples as f64
553 }
554
555 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
557 where
558 S: Data,
559 S::Elem: Float + NumCast,
560 {
561 let (n_samples, n_features) = x.dim();
562
563 check_positive(self.n_neighbors, "n_neighbors")?;
564 check_positive(self.n_components, "n_components")?;
565 checkshape(x, &[n_samples, n_features], "x")?;
566
567 if n_samples <= self.n_neighbors {
568 return Err(TransformError::InvalidInput(format!(
569 "n_neighbors={} must be < n_samples={}",
570 self.n_neighbors, n_samples
571 )));
572 }
573
574 if self.n_components >= n_samples {
575 return Err(TransformError::InvalidInput(format!(
576 "n_components={} must be < n_samples={}",
577 self.n_components, n_samples
578 )));
579 }
580
581 let x_f64 = x.mapv(|v| NumCast::from(v).unwrap_or(0.0));
582
583 let (neighbors, _distances) = self.find_neighbors(&x_f64.view());
585
586 let weights = match &self.method {
588 LLEMethod::Standard => self.compute_weights(&x_f64.view(), &neighbors)?,
589 LLEMethod::Modified => self.compute_weights_modified(&x_f64.view(), &neighbors)?,
590 LLEMethod::Hessian => self.compute_weights_hessian(&x_f64.view(), &neighbors)?,
591 };
592
593 let embedding = self.compute_embedding(&weights)?;
595
596 let recon_error = self.compute_reconstruction_error(&weights, &embedding);
598
599 self.embedding = Some(embedding);
600 self.weights = Some(weights);
601 self.training_data = Some(x_f64);
602 self.reconstruction_error = Some(recon_error);
603
604 Ok(())
605 }
606
607 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
609 where
610 S: Data,
611 S::Elem: Float + NumCast,
612 {
613 if self.embedding.is_none() {
614 return Err(TransformError::NotFitted(
615 "LLE model has not been fitted".to_string(),
616 ));
617 }
618
619 let training_data = self
620 .training_data
621 .as_ref()
622 .ok_or_else(|| TransformError::NotFitted("Training data not available".to_string()))?;
623
624 let x_f64 = x.mapv(|v| NumCast::from(v).unwrap_or(0.0));
625
626 if self.is_same_data(&x_f64, training_data) {
627 return self
628 .embedding
629 .as_ref()
630 .cloned()
631 .ok_or_else(|| TransformError::NotFitted("Embedding not available".to_string()));
632 }
633
634 self.transform_new_data(&x_f64)
635 }
636
637 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
639 where
640 S: Data,
641 S::Elem: Float + NumCast,
642 {
643 self.fit(x)?;
644 self.transform(x)
645 }
646
647 pub fn embedding(&self) -> Option<&Array2<f64>> {
649 self.embedding.as_ref()
650 }
651
652 pub fn reconstruction_weights(&self) -> Option<&Array2<f64>> {
654 self.weights.as_ref()
655 }
656
657 pub fn reconstruction_error(&self) -> Option<f64> {
659 self.reconstruction_error
660 }
661
662 fn is_same_data(&self, x: &Array2<f64>, training_data: &Array2<f64>) -> bool {
664 if x.dim() != training_data.dim() {
665 return false;
666 }
667 let (n_samples, n_features) = x.dim();
668 for i in 0..n_samples {
669 for j in 0..n_features {
670 if (x[[i, j]] - training_data[[i, j]]).abs() > 1e-10 {
671 return false;
672 }
673 }
674 }
675 true
676 }
677
678 fn transform_new_data(&self, x_new: &Array2<f64>) -> Result<Array2<f64>> {
680 let training_data = self
681 .training_data
682 .as_ref()
683 .ok_or_else(|| TransformError::NotFitted("Training data not available".to_string()))?;
684 let training_embedding = self
685 .embedding
686 .as_ref()
687 .ok_or_else(|| TransformError::NotFitted("Embedding not available".to_string()))?;
688
689 let (n_new, n_features) = x_new.dim();
690
691 if n_features != training_data.ncols() {
692 return Err(TransformError::InvalidInput(format!(
693 "Input features {} must match training features {}",
694 n_features,
695 training_data.ncols()
696 )));
697 }
698
699 let mut new_embedding = Array2::zeros((n_new, self.n_components));
700
701 for i in 0..n_new {
702 let new_coords =
703 self.compute_new_point_embedding(&x_new.row(i), training_data, training_embedding)?;
704
705 for j in 0..self.n_components {
706 new_embedding[[i, j]] = new_coords[j];
707 }
708 }
709
710 Ok(new_embedding)
711 }
712
713 fn compute_new_point_embedding(
715 &self,
716 x_new: &scirs2_core::ndarray::ArrayView1<f64>,
717 training_data: &Array2<f64>,
718 training_embedding: &Array2<f64>,
719 ) -> Result<Array1<f64>> {
720 let n_training = training_data.nrows();
721 let n_features = training_data.ncols();
722
723 let mut distances: Vec<(f64, usize)> = Vec::with_capacity(n_training);
725 for j in 0..n_training {
726 let mut dist_sq = 0.0;
727 for k in 0..n_features {
728 let diff = x_new[k] - training_data[[j, k]];
729 dist_sq += diff * diff;
730 }
731 distances.push((dist_sq.sqrt(), j));
732 }
733
734 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
735 let k = self.n_neighbors.min(n_training);
736 let neighbor_indices: Vec<usize> =
737 distances.into_iter().take(k).map(|(_, idx)| idx).collect();
738
739 let weights =
741 self.compute_reconstruction_weights_for_point(x_new, training_data, &neighbor_indices)?;
742
743 let mut new_coords = Array1::zeros(self.n_components);
745 for (i, &neighbor_idx) in neighbor_indices.iter().enumerate() {
746 for dim in 0..self.n_components {
747 new_coords[dim] += weights[i] * training_embedding[[neighbor_idx, dim]];
748 }
749 }
750
751 Ok(new_coords)
752 }
753
754 fn compute_reconstruction_weights_for_point(
756 &self,
757 x_point: &scirs2_core::ndarray::ArrayView1<f64>,
758 training_data: &Array2<f64>,
759 neighbor_indices: &[usize],
760 ) -> Result<Array1<f64>> {
761 let k = neighbor_indices.len();
762 let n_features = training_data.ncols();
763
764 let mut c = Array2::zeros((k, k));
765
766 for i in 0..k {
767 let neighbor_i = neighbor_indices[i];
768 for j in 0..k {
769 let neighbor_j = neighbor_indices[j];
770
771 let mut dot = 0.0;
772 for m in 0..n_features {
773 let diff_i = x_point[m] - training_data[[neighbor_i, m]];
774 let diff_j = x_point[m] - training_data[[neighbor_j, m]];
775 dot += diff_i * diff_j;
776 }
777 c[[i, j]] = dot;
778 }
779 }
780
781 let trace: f64 = (0..k).map(|i| c[[i, i]]).sum();
782 let reg_value = self.reg * trace / k as f64;
783 for i in 0..k {
784 c[[i, i]] += reg_value;
785 }
786
787 let ones = Array1::ones(k);
788 let w = match solve(&c.view(), &ones.view(), None) {
789 Ok(solution) => solution,
790 Err(_) => Array1::from_elem(k, 1.0 / k as f64),
791 };
792
793 let w_sum = w.sum();
794 let w_normalized = if w_sum.abs() > 1e-10 {
795 w / w_sum
796 } else {
797 Array1::from_elem(k, 1.0 / k as f64)
798 };
799
800 Ok(w_normalized)
801 }
802}
803
804#[cfg(test)]
805mod tests {
806 use super::*;
807 use scirs2_core::ndarray::Array;
808
809 #[test]
810 fn test_lle_basic() {
811 let n_points = 20;
812 let mut data = Vec::new();
813
814 for i in 0..n_points {
815 let t = 1.5 * std::f64::consts::PI * (1.0 + 2.0 * i as f64 / n_points as f64);
816 let x = t * t.cos();
817 let y = 10.0 * i as f64 / n_points as f64;
818 let z = t * t.sin();
819 data.extend_from_slice(&[x, y, z]);
820 }
821
822 let x = Array::from_shape_vec((n_points, 3), data).expect("Failed to create array");
823
824 let mut lle = LLE::new(5, 2);
825 let embedding = lle.fit_transform(&x).expect("LLE fit_transform failed");
826
827 assert_eq!(embedding.shape(), &[n_points, 2]);
828 for val in embedding.iter() {
829 assert!(val.is_finite());
830 }
831 }
832
833 #[test]
834 fn test_lle_regularization() {
835 let x: Array2<f64> = Array::eye(10) * 2.0;
836
837 let mut lle = LLE::new(3, 2).with_regularization(0.01);
838 let result = lle.fit_transform(&x);
839
840 assert!(result.is_ok());
841 let embedding = result.expect("LLE fit_transform failed");
842 assert_eq!(embedding.shape(), &[10, 2]);
843 }
844
845 #[test]
846 fn test_lle_modified() {
847 let n_points = 20;
848 let mut data = Vec::new();
849 for i in 0..n_points {
850 let t = i as f64 / n_points as f64 * 2.0 * std::f64::consts::PI;
851 data.extend_from_slice(&[t.cos(), t.sin(), t * 0.1]);
852 }
853
854 let x = Array::from_shape_vec((n_points, 3), data).expect("Failed to create array");
855
856 let mut lle = LLE::new(5, 2).with_method("modified");
857 let embedding = lle.fit_transform(&x).expect("MLLE fit_transform failed");
858
859 assert_eq!(embedding.shape(), &[n_points, 2]);
860 for val in embedding.iter() {
861 assert!(val.is_finite());
862 }
863 }
864
865 #[test]
866 fn test_lle_hessian() {
867 let n_points = 25;
868 let mut data = Vec::new();
869 for i in 0..n_points {
870 let t = i as f64 / n_points as f64;
871 data.extend_from_slice(&[t, t * 2.0, t * 3.0, t * t]);
872 }
873
874 let x = Array::from_shape_vec((n_points, 4), data).expect("Failed to create array");
875
876 let mut lle = LLE::new(7, 2).with_method("hessian");
878 let embedding = lle.fit_transform(&x).expect("HLLE fit_transform failed");
879
880 assert_eq!(embedding.shape(), &[n_points, 2]);
881 for val in embedding.iter() {
882 assert!(val.is_finite());
883 }
884 }
885
886 #[test]
887 fn test_lle_invalid_params() {
888 let x: Array2<f64> = Array::eye(5);
889
890 let mut lle = LLE::new(10, 2);
891 assert!(lle.fit(&x).is_err());
892
893 let mut lle = LLE::new(2, 10);
894 assert!(lle.fit(&x).is_err());
895 }
896
897 #[test]
898 fn test_lle_reconstruction_error() {
899 let n_points = 20;
900 let mut data = Vec::new();
901 for i in 0..n_points {
902 let t = i as f64 / n_points as f64;
903 data.extend_from_slice(&[t, t * 2.0, t * 3.0]);
904 }
905
906 let x = Array::from_shape_vec((n_points, 3), data).expect("Failed to create array");
907
908 let mut lle = LLE::new(5, 2);
909 let _ = lle.fit_transform(&x).expect("LLE fit_transform failed");
910
911 let error = lle.reconstruction_error();
912 assert!(error.is_some());
913 let error_val = error.expect("Error should exist");
914 assert!(error_val >= 0.0);
915 assert!(error_val.is_finite());
916 }
917
918 #[test]
919 fn test_lle_out_of_sample() {
920 let n_points = 20;
921 let mut data = Vec::new();
922 for i in 0..n_points {
923 let t = i as f64 / n_points as f64;
924 data.extend_from_slice(&[t, t * 2.0, t * 3.0]);
925 }
926
927 let x = Array::from_shape_vec((n_points, 3), data).expect("Failed to create array");
928
929 let mut lle = LLE::new(5, 2);
930 lle.fit(&x).expect("LLE fit failed");
931
932 let x_new = Array::from_shape_vec((2, 3), vec![0.25, 0.5, 0.75, 0.75, 1.5, 2.25])
933 .expect("Failed to create test array");
934
935 let new_embedding = lle.transform(&x_new).expect("LLE transform failed");
936 assert_eq!(new_embedding.shape(), &[2, 2]);
937 for val in new_embedding.iter() {
938 assert!(val.is_finite());
939 }
940 }
941}