scirs2_transform/reduction/
lle.rs1use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix2};
8use scirs2_core::numeric::{Float, NumCast};
9use scirs2_core::validation::{check_positive, checkshape};
10use scirs2_linalg::{eigh, solve};
11use std::collections::BinaryHeap;
12
13use crate::error::{Result, TransformError};
14
15#[derive(Debug, Clone)]
21pub struct LLE {
22 n_neighbors: usize,
24 n_components: usize,
26 reg: f64,
28 method: String,
30 embedding: Option<Array2<f64>>,
32 weights: Option<Array2<f64>>,
34 training_data: Option<Array2<f64>>,
36}
37
38impl LLE {
39 pub fn new(n_neighbors: usize, ncomponents: usize) -> Self {
45 LLE {
46 n_neighbors,
47 n_components: ncomponents,
48 reg: 1e-3,
49 method: "standard".to_string(),
50 embedding: None,
51 weights: None,
52 training_data: None,
53 }
54 }
55
56 pub fn with_regularization(mut self, reg: f64) -> Self {
58 self.reg = reg;
59 self
60 }
61
62 pub fn with_method(mut self, method: &str) -> Self {
64 self.method = method.to_string();
65 self
66 }
67
68 fn find_neighbors<S>(&self, x: &ArrayBase<S, Ix2>) -> (Array2<usize>, Array2<f64>)
70 where
71 S: Data,
72 S::Elem: Float + NumCast,
73 {
74 let n_samples = x.shape()[0];
75 let mut indices = Array2::zeros((n_samples, self.n_neighbors));
76 let mut distances = Array2::zeros((n_samples, self.n_neighbors));
77
78 for i in 0..n_samples {
79 let mut heap: BinaryHeap<(std::cmp::Reverse<i64>, usize)> = BinaryHeap::new();
81
82 for j in 0..n_samples {
83 if i != j {
84 let mut dist = 0.0;
85 for k in 0..x.shape()[1] {
86 let diff = NumCast::from(x[[i, k]]).unwrap_or(0.0)
87 - NumCast::from(x[[j, k]]).unwrap_or(0.0);
88 dist += diff * diff;
89 }
90 dist = dist.sqrt();
91
92 let dist_fixed = (dist * 1e9) as i64;
93 heap.push((std::cmp::Reverse(dist_fixed), j));
94 }
95 }
96
97 for j in 0..self.n_neighbors {
99 if let Some((std::cmp::Reverse(dist_fixed), idx)) = heap.pop() {
100 indices[[i, j]] = idx;
101 distances[[i, j]] = dist_fixed as f64 / 1e9;
102 }
103 }
104 }
105
106 (indices, distances)
107 }
108
109 fn compute_weights<S>(
111 &self,
112 x: &ArrayBase<S, Ix2>,
113 neighbors: &Array2<usize>,
114 ) -> Result<Array2<f64>>
115 where
116 S: Data,
117 S::Elem: Float + NumCast,
118 {
119 let n_samples = x.shape()[0];
120 let n_features = x.shape()[1];
121 let k = self.n_neighbors;
122
123 let mut weights = Array2::zeros((n_samples, n_samples));
124
125 for i in 0..n_samples {
126 let mut c = Array2::zeros((k, k));
128 let xi = x.index_axis(Axis(0), i);
129
130 for j in 0..k {
132 let neighbor_j = neighbors[[i, j]];
133 let xj = x.index_axis(Axis(0), neighbor_j);
134
135 for l in 0..k {
136 let neighbor_l = neighbors[[i, l]];
137 let xl = x.index_axis(Axis(0), neighbor_l);
138
139 let mut dot = 0.0;
140 for m in 0..n_features {
141 let diff_j = NumCast::from(xi[m] - xj[m]).unwrap_or(0.0);
142 let diff_l = NumCast::from(xi[m] - xl[m]).unwrap_or(0.0);
143 dot += diff_j * diff_l;
144 }
145 c[[j, l]] = dot;
146 }
147 }
148
149 let trace = (0..k).map(|j| c[[j, j]]).sum::<f64>();
151 let reg_value = self.reg * trace / k as f64;
152 for j in 0..k {
153 c[[j, j]] += reg_value;
154 }
155
156 let ones = Array1::ones(k);
158 let w = match solve(&c.view(), &ones.view(), None) {
159 Ok(solution) => solution,
160 Err(_) => {
161 Array1::from_elem(k, 1.0 / k as f64)
163 }
164 };
165
166 let w_sum = w.sum();
168 let w_normalized = if w_sum.abs() > 1e-10 {
169 w / w_sum
170 } else {
171 Array1::from_elem(k, 1.0 / k as f64)
172 };
173
174 for j in 0..k {
176 let neighbor = neighbors[[i, j]];
177 weights[[i, neighbor]] = w_normalized[j];
178 }
179 }
180
181 Ok(weights)
182 }
183
184 fn compute_embedding(&self, weights: &Array2<f64>) -> Result<Array2<f64>> {
186 let n_samples = weights.shape()[0];
187
188 let mut m = Array2::zeros((n_samples, n_samples));
190
191 for i in 0..n_samples {
192 for j in 0..n_samples {
193 let mut sum = 0.0;
194
195 if i == j {
196 sum += 1.0 - 2.0 * weights[[i, j]] + weights.column(j).dot(&weights.column(j));
198 } else {
199 sum += -weights[[i, j]] - weights[[j, i]]
201 + weights.column(i).dot(&weights.column(j));
202 }
203
204 m[[i, j]] = sum;
205 }
206 }
207
208 let (eigenvalues, eigenvectors) = match eigh(&m.view(), None) {
210 Ok(result) => result,
211 Err(e) => return Err(TransformError::LinalgError(e)),
212 };
213
214 let mut indices: Vec<usize> = (0..n_samples).collect();
216 indices.sort_by(|&i, &j| eigenvalues[i].partial_cmp(&eigenvalues[j]).unwrap());
217
218 let mut embedding = Array2::zeros((n_samples, self.n_components));
221 for j in 0..self.n_components {
222 let idx = indices[j + 1]; for i in 0..n_samples {
224 embedding[[i, j]] = eigenvectors[[i, idx]];
225 }
226 }
227
228 Ok(embedding)
229 }
230
231 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
239 where
240 S: Data,
241 S::Elem: Float + NumCast,
242 {
243 let (n_samples, n_features) = x.dim();
244
245 check_positive(self.n_neighbors, "n_neighbors")?;
247 check_positive(self.n_components, "n_components")?;
248 checkshape(x, &[n_samples, n_features], "x")?;
249
250 if n_samples <= self.n_neighbors {
251 return Err(TransformError::InvalidInput(format!(
252 "n_neighbors={} must be < n_samples={}",
253 self.n_neighbors, n_samples
254 )));
255 }
256
257 if self.n_components >= n_samples {
258 return Err(TransformError::InvalidInput(format!(
259 "n_components={} must be < n_samples={}",
260 self.n_components, n_samples
261 )));
262 }
263
264 let x_f64 = x.mapv(|v| NumCast::from(v).unwrap_or(0.0));
266
267 let (neighbors, distances) = self.find_neighbors(&x_f64.view());
269
270 let weights = self.compute_weights(&x_f64.view(), &neighbors)?;
272
273 let embedding = self.compute_embedding(&weights)?;
275
276 self.embedding = Some(embedding);
277 self.weights = Some(weights);
278 self.training_data = Some(x_f64);
279
280 Ok(())
281 }
282
283 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
291 where
292 S: Data,
293 S::Elem: Float + NumCast,
294 {
295 if self.embedding.is_none() {
296 return Err(TransformError::NotFitted(
297 "LLE model has not been fitted".to_string(),
298 ));
299 }
300
301 let training_data = self
302 .training_data
303 .as_ref()
304 .ok_or_else(|| TransformError::NotFitted("Training data not available".to_string()))?;
305
306 let x_f64 = x.mapv(|v| NumCast::from(v).unwrap_or(0.0));
307
308 if self.is_same_data(&x_f64, training_data) {
310 return Ok(self.embedding.as_ref().unwrap().clone());
311 }
312
313 self.transform_new_data(&x_f64)
315 }
316
317 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
325 where
326 S: Data,
327 S::Elem: Float + NumCast,
328 {
329 self.fit(x)?;
330 self.transform(x)
331 }
332
333 pub fn embedding(&self) -> Option<&Array2<f64>> {
335 self.embedding.as_ref()
336 }
337
338 pub fn reconstruction_weights(&self) -> Option<&Array2<f64>> {
340 self.weights.as_ref()
341 }
342
343 fn is_same_data(&self, x: &Array2<f64>, trainingdata: &Array2<f64>) -> bool {
345 if x.dim() != trainingdata.dim() {
346 return false;
347 }
348
349 let (n_samples, n_features) = x.dim();
350 for i in 0..n_samples {
351 for j in 0..n_features {
352 if (x[[i, j]] - trainingdata[[i, j]]).abs() > 1e-10 {
353 return false;
354 }
355 }
356 }
357 true
358 }
359
360 fn transform_new_data(&self, xnew: &Array2<f64>) -> Result<Array2<f64>> {
362 let training_data = self.training_data.as_ref().unwrap();
363 let training_embedding = self.embedding.as_ref().unwrap();
364
365 let (n_new, n_features) = xnew.dim();
366 let (_n_training_, _) = training_data.dim();
367
368 if n_features != training_data.ncols() {
369 return Err(TransformError::InvalidInput(format!(
370 "Input features {} must match training features {}",
371 n_features,
372 training_data.ncols()
373 )));
374 }
375
376 let mut new_embedding = Array2::zeros((n_new, self.n_components));
377
378 for i in 0..n_new {
381 let new_coords =
382 self.compute_new_point_embedding(&xnew.row(i), training_data, training_embedding)?;
383
384 for j in 0..self.n_components {
385 new_embedding[[i, j]] = new_coords[j];
386 }
387 }
388
389 Ok(new_embedding)
390 }
391
392 fn compute_new_point_embedding(
394 &self,
395 x_new: &scirs2_core::ndarray::ArrayView1<f64>,
396 training_data: &Array2<f64>,
397 training_embedding: &Array2<f64>,
398 ) -> Result<Array1<f64>> {
399 let n_training = training_data.nrows();
400 let n_features = training_data.ncols();
401
402 let mut distances: Vec<(f64, usize)> = Vec::new();
404 for j in 0..n_training {
405 let mut dist_sq = 0.0;
406 for k in 0..n_features {
407 let diff = x_new[k] - training_data[[j, k]];
408 dist_sq += diff * diff;
409 }
410 distances.push((dist_sq.sqrt(), j));
411 }
412
413 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
415 let k = self.n_neighbors.min(n_training);
416 let neighbor_indices: Vec<usize> =
417 distances.into_iter().take(k).map(|(_, idx)| idx).collect();
418
419 let weights =
421 self.compute_reconstruction_weights_for_point(x_new, training_data, &neighbor_indices)?;
422
423 let mut new_coords = Array1::zeros(self.n_components);
425 for (i, &neighbor_idx) in neighbor_indices.iter().enumerate() {
426 for dim in 0..self.n_components {
427 new_coords[dim] += weights[i] * training_embedding[[neighbor_idx, dim]];
428 }
429 }
430
431 Ok(new_coords)
432 }
433
434 fn compute_reconstruction_weights_for_point(
436 &self,
437 x_point: &scirs2_core::ndarray::ArrayView1<f64>,
438 training_data: &Array2<f64>,
439 neighbor_indices: &[usize],
440 ) -> Result<Array1<f64>> {
441 let k = neighbor_indices.len();
442 let n_features = training_data.ncols();
443
444 let mut c = Array2::zeros((k, k));
446
447 for i in 0..k {
448 let neighbor_i = neighbor_indices[i];
449 for j in 0..k {
450 let neighbor_j = neighbor_indices[j];
451
452 let mut dot = 0.0;
453 for m in 0..n_features {
454 let diff_i = x_point[m] - training_data[[neighbor_i, m]];
455 let diff_j = x_point[m] - training_data[[neighbor_j, m]];
456 dot += diff_i * diff_j;
457 }
458 c[[i, j]] = dot;
459 }
460 }
461
462 let trace = (0..k).map(|i| c[[i, i]]).sum::<f64>();
464 let reg_value = self.reg * trace / k as f64;
465 for i in 0..k {
466 c[[i, i]] += reg_value;
467 }
468
469 let ones = Array1::ones(k);
471 let w = match solve(&c.view(), &ones.view(), None) {
472 Ok(solution) => solution,
473 Err(_) => {
474 Array1::from_elem(k, 1.0 / k as f64)
476 }
477 };
478
479 let w_sum = w.sum();
481 let w_normalized = if w_sum.abs() > 1e-10 {
482 w / w_sum
483 } else {
484 Array1::from_elem(k, 1.0 / k as f64)
485 };
486
487 Ok(w_normalized)
488 }
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494 use scirs2_core::ndarray::Array;
495
496 #[test]
497 fn test_lle_basic() {
498 let n_points = 30;
500 let mut data = Vec::new();
501
502 for i in 0..n_points {
503 let t = 1.5 * std::f64::consts::PI * (1.0 + 2.0 * i as f64 / n_points as f64);
504 let x = t * t.cos();
505 let y = 10.0 * i as f64 / n_points as f64;
506 let z = t * t.sin();
507 data.extend_from_slice(&[x, y, z]);
508 }
509
510 let x = Array::from_shape_vec((n_points, 3), data).unwrap();
511
512 let mut lle = LLE::new(10, 2);
514 let embedding = lle.fit_transform(&x).unwrap();
515
516 assert_eq!(embedding.shape(), &[n_points, 2]);
518
519 for val in embedding.iter() {
521 assert!(val.is_finite());
522 }
523 }
524
525 #[test]
526 fn test_lle_regularization() {
527 let x: Array2<f64> = Array::eye(10) * 2.0; let mut lle = LLE::new(3, 2).with_regularization(0.01);
530 let result = lle.fit_transform(&x);
531
532 assert!(result.is_ok());
533 let embedding = result.unwrap();
534 assert_eq!(embedding.shape(), &[10, 2]);
535 }
536
537 #[test]
538 fn test_lle_invalid_params() {
539 let x: Array2<f64> = Array::eye(5);
540
541 let mut lle = LLE::new(10, 2);
543 assert!(lle.fit(&x).is_err());
544
545 let mut lle = LLE::new(2, 10);
547 assert!(lle.fit(&x).is_err());
548 }
549}