1use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
11use scirs2_core::numeric::{Float, NumCast};
12use scirs2_core::parallel_ops::*;
13use scirs2_core::random::Normal;
14use scirs2_core::random::RandomExt;
15
16use crate::error::{Result, TransformError};
17use crate::reduction::PCA;
18
19const MACHINE_EPSILON: f64 = 1e-14;
21const EPSILON: f64 = 1e-7;
22
23#[derive(Debug, Clone)]
25enum SpatialTree {
26 QuadTree(QuadTreeNode),
27 OctTree(OctTreeNode),
28}
29
30#[derive(Debug, Clone)]
32struct QuadTreeNode {
33 x_min: f64,
35 x_max: f64,
36 y_min: f64,
37 y_max: f64,
38 center_of_mass: Option<Array1<f64>>,
40 total_mass: f64,
42 point_indices: Vec<usize>,
44 children: Option<[Box<QuadTreeNode>; 4]>,
46 is_leaf: bool,
48}
49
50#[derive(Debug, Clone)]
52struct OctTreeNode {
53 x_min: f64,
55 x_max: f64,
56 y_min: f64,
57 y_max: f64,
58 z_min: f64,
59 z_max: f64,
60 center_of_mass: Option<Array1<f64>>,
62 total_mass: f64,
64 point_indices: Vec<usize>,
66 children: Option<[Box<OctTreeNode>; 8]>,
68 is_leaf: bool,
70}
71
72impl SpatialTree {
73 fn new_quadtree(embedding: &Array2<f64>) -> Result<Self> {
75 let n_samples = embedding.shape()[0];
76
77 if embedding.shape()[1] != 2 {
78 return Err(TransformError::InvalidInput(
79 "QuadTree requires 2D _embedding".to_string(),
80 ));
81 }
82
83 let mut x_min = f64::INFINITY;
85 let mut x_max = f64::NEG_INFINITY;
86 let mut y_min = f64::INFINITY;
87 let mut y_max = f64::NEG_INFINITY;
88
89 for i in 0..n_samples {
90 let x = embedding[[i, 0]];
91 let y = embedding[[i, 1]];
92 x_min = x_min.min(x);
93 x_max = x_max.max(x);
94 y_min = y_min.min(y);
95 y_max = y_max.max(y);
96 }
97
98 let margin = 0.01 * ((x_max - x_min) + (y_max - y_min));
100 x_min -= margin;
101 x_max += margin;
102 y_min -= margin;
103 y_max += margin;
104
105 let point_indices: Vec<usize> = (0..n_samples).collect();
107
108 let mut root = QuadTreeNode {
110 x_min,
111 x_max,
112 y_min,
113 y_max,
114 center_of_mass: None,
115 total_mass: 0.0,
116 point_indices,
117 children: None,
118 is_leaf: true,
119 };
120
121 root.build_tree(embedding)?;
123
124 Ok(SpatialTree::QuadTree(root))
125 }
126
127 fn new_octree(embedding: &Array2<f64>) -> Result<Self> {
129 let n_samples = embedding.shape()[0];
130
131 if embedding.shape()[1] != 3 {
132 return Err(TransformError::InvalidInput(
133 "OctTree requires 3D _embedding".to_string(),
134 ));
135 }
136
137 let mut x_min = f64::INFINITY;
139 let mut x_max = f64::NEG_INFINITY;
140 let mut y_min = f64::INFINITY;
141 let mut y_max = f64::NEG_INFINITY;
142 let mut z_min = f64::INFINITY;
143 let mut z_max = f64::NEG_INFINITY;
144
145 for i in 0..n_samples {
146 let x = embedding[[i, 0]];
147 let y = embedding[[i, 1]];
148 let z = embedding[[i, 2]];
149 x_min = x_min.min(x);
150 x_max = x_max.max(x);
151 y_min = y_min.min(y);
152 y_max = y_max.max(y);
153 z_min = z_min.min(z);
154 z_max = z_max.max(z);
155 }
156
157 let margin = 0.01 * ((x_max - x_min) + (y_max - y_min) + (z_max - z_min));
159 x_min -= margin;
160 x_max += margin;
161 y_min -= margin;
162 y_max += margin;
163 z_min -= margin;
164 z_max += margin;
165
166 let point_indices: Vec<usize> = (0..n_samples).collect();
168
169 let mut root = OctTreeNode {
171 x_min,
172 x_max,
173 y_min,
174 y_max,
175 z_min,
176 z_max,
177 center_of_mass: None,
178 total_mass: 0.0,
179 point_indices,
180 children: None,
181 is_leaf: true,
182 };
183
184 root.build_tree(embedding)?;
186
187 Ok(SpatialTree::OctTree(root))
188 }
189
190 #[allow(clippy::too_many_arguments)]
192 fn compute_forces(
193 &self,
194 point: &Array1<f64>,
195 point_idx: usize,
196 angle: f64,
197 degrees_of_freedom: f64,
198 ) -> Result<(Array1<f64>, f64)> {
199 match self {
200 SpatialTree::QuadTree(root) => {
201 root.compute_forces_quad(point, point_idx, angle, degrees_of_freedom)
202 }
203 SpatialTree::OctTree(root) => {
204 root.compute_forces_oct(point, point_idx, angle, degrees_of_freedom)
205 }
206 }
207 }
208}
209
210impl QuadTreeNode {
211 fn build_tree(&mut self, embedding: &Array2<f64>) -> Result<()> {
213 if self.point_indices.len() <= 1 {
214 self.update_center_of_mass(embedding)?;
216 return Ok(());
217 }
218
219 let x_mid = (self.x_min + self.x_max) / 2.0;
221 let y_mid = (self.y_min + self.y_max) / 2.0;
222
223 let mut quadrants: [Vec<usize>; 4] = [Vec::new(), Vec::new(), Vec::new(), Vec::new()];
224
225 for &idx in &self.point_indices {
227 let x = embedding[[idx, 0]];
228 let y = embedding[[idx, 1]];
229
230 let quadrant = match (x >= x_mid, y >= y_mid) {
231 (false, false) => 0, (true, false) => 1, (false, true) => 2, (true, true) => 3, };
236
237 quadrants[quadrant].push(idx);
238 }
239
240 let mut children = [
242 Box::new(QuadTreeNode {
243 x_min: self.x_min,
244 x_max: x_mid,
245 y_min: self.y_min,
246 y_max: y_mid,
247 center_of_mass: None,
248 total_mass: 0.0,
249 point_indices: quadrants[0].clone(),
250 children: None,
251 is_leaf: true,
252 }),
253 Box::new(QuadTreeNode {
254 x_min: x_mid,
255 x_max: self.x_max,
256 y_min: self.y_min,
257 y_max: y_mid,
258 center_of_mass: None,
259 total_mass: 0.0,
260 point_indices: quadrants[1].clone(),
261 children: None,
262 is_leaf: true,
263 }),
264 Box::new(QuadTreeNode {
265 x_min: self.x_min,
266 x_max: x_mid,
267 y_min: y_mid,
268 y_max: self.y_max,
269 center_of_mass: None,
270 total_mass: 0.0,
271 point_indices: quadrants[2].clone(),
272 children: None,
273 is_leaf: true,
274 }),
275 Box::new(QuadTreeNode {
276 x_min: x_mid,
277 x_max: self.x_max,
278 y_min: y_mid,
279 y_max: self.y_max,
280 center_of_mass: None,
281 total_mass: 0.0,
282 point_indices: quadrants[3].clone(),
283 children: None,
284 is_leaf: true,
285 }),
286 ];
287
288 for child in &mut children {
290 child.build_tree(embedding)?;
291 }
292
293 self.children = Some(children);
294 self.is_leaf = false;
295 self.point_indices.clear(); self.update_center_of_mass(embedding)?;
297
298 Ok(())
299 }
300
301 fn update_center_of_mass(&mut self, embedding: &Array2<f64>) -> Result<()> {
303 if self.is_leaf {
304 if self.point_indices.is_empty() {
306 self.total_mass = 0.0;
307 self.center_of_mass = None;
308 return Ok(());
309 }
310
311 let mut com = Array1::zeros(2);
312 for &idx in &self.point_indices {
313 com[0] += embedding[[idx, 0]];
314 com[1] += embedding[[idx, 1]];
315 }
316
317 self.total_mass = self.point_indices.len() as f64;
318 com.mapv_inplace(|x| x / self.total_mass);
319 self.center_of_mass = Some(com);
320 } else {
321 if let Some(ref children) = self.children {
323 let mut com = Array1::zeros(2);
324 let mut total_mass = 0.0;
325
326 for child in children.iter() {
327 if let Some(ref child_com) = child.center_of_mass {
328 total_mass += child.total_mass;
329 for i in 0..2 {
330 com[i] += child_com[i] * child.total_mass;
331 }
332 }
333 }
334
335 if total_mass > 0.0 {
336 com.mapv_inplace(|x| x / total_mass);
337 self.center_of_mass = Some(com);
338 self.total_mass = total_mass;
339 } else {
340 self.center_of_mass = None;
341 self.total_mass = 0.0;
342 }
343 }
344 }
345
346 Ok(())
347 }
348
349 #[allow(clippy::too_many_arguments)]
351 fn compute_forces_quad(
352 &self,
353 point: &Array1<f64>,
354 point_idx: usize,
355 angle: f64,
356 degrees_of_freedom: f64,
357 ) -> Result<(Array1<f64>, f64)> {
358 let mut force = Array1::zeros(2);
359 let mut sum_q = 0.0;
360
361 self.compute_forces_recursive_quad(
362 point,
363 point_idx,
364 angle,
365 degrees_of_freedom,
366 &mut force,
367 &mut sum_q,
368 )?;
369
370 Ok((force, sum_q))
371 }
372
373 #[allow(clippy::too_many_arguments)]
375 fn compute_forces_recursive_quad(
376 &self,
377 point: &Array1<f64>,
378 point_idx: usize,
379 angle: f64,
380 degrees_of_freedom: f64,
381 force: &mut Array1<f64>,
382 sum_q: &mut f64,
383 ) -> Result<()> {
384 if let Some(ref com) = self.center_of_mass {
385 if self.total_mass == 0.0 {
386 return Ok(());
387 }
388
389 let dx = point[0] - com[0];
391 let dy = point[1] - com[1];
392 let dist_squared = dx * dx + dy * dy;
393
394 if dist_squared < MACHINE_EPSILON {
395 return Ok(());
396 }
397
398 let node_size = (self.x_max - self.x_min).max(self.y_max - self.y_min);
400 let distance = dist_squared.sqrt();
401
402 if self.is_leaf || (node_size / distance) < angle {
403 let q_factor = (1.0 + dist_squared / degrees_of_freedom)
405 .powf(-(degrees_of_freedom + 1.0) / 2.0);
406
407 *sum_q += self.total_mass * q_factor;
408
409 let force_factor =
410 (degrees_of_freedom + 1.0) * self.total_mass * q_factor / degrees_of_freedom;
411 force[0] += force_factor * dx;
412 force[1] += force_factor * dy;
413 } else {
414 if let Some(ref children) = self.children {
416 for child in children.iter() {
417 child.compute_forces_recursive_quad(
418 point,
419 point_idx,
420 angle,
421 degrees_of_freedom,
422 force,
423 sum_q,
424 )?;
425 }
426 }
427 }
428 } else if self.is_leaf {
429 for &_idx in &self.point_indices {
431 if _idx != point_idx {
432 }
435 }
436 }
437
438 Ok(())
439 }
440}
441
442impl OctTreeNode {
443 fn build_tree(&mut self, embedding: &Array2<f64>) -> Result<()> {
445 if self.point_indices.len() <= 1 {
446 self.update_center_of_mass(embedding)?;
448 return Ok(());
449 }
450
451 let x_mid = (self.x_min + self.x_max) / 2.0;
453 let y_mid = (self.y_min + self.y_max) / 2.0;
454 let z_mid = (self.z_min + self.z_max) / 2.0;
455
456 let mut octants: [Vec<usize>; 8] = [
457 Vec::new(),
458 Vec::new(),
459 Vec::new(),
460 Vec::new(),
461 Vec::new(),
462 Vec::new(),
463 Vec::new(),
464 Vec::new(),
465 ];
466
467 for &idx in &self.point_indices {
469 let x = embedding[[idx, 0]];
470 let y = embedding[[idx, 1]];
471 let z = embedding[[idx, 2]];
472
473 let octant = match (x >= x_mid, y >= y_mid, z >= z_mid) {
474 (false, false, false) => 0,
475 (true, false, false) => 1,
476 (false, true, false) => 2,
477 (true, true, false) => 3,
478 (false, false, true) => 4,
479 (true, false, true) => 5,
480 (false, true, true) => 6,
481 (true, true, true) => 7,
482 };
483
484 octants[octant].push(idx);
485 }
486
487 let mut children = [
489 Box::new(OctTreeNode {
490 x_min: self.x_min,
491 x_max: x_mid,
492 y_min: self.y_min,
493 y_max: y_mid,
494 z_min: self.z_min,
495 z_max: z_mid,
496 center_of_mass: None,
497 total_mass: 0.0,
498 point_indices: octants[0].clone(),
499 children: None,
500 is_leaf: true,
501 }),
502 Box::new(OctTreeNode {
503 x_min: x_mid,
504 x_max: self.x_max,
505 y_min: self.y_min,
506 y_max: y_mid,
507 z_min: self.z_min,
508 z_max: z_mid,
509 center_of_mass: None,
510 total_mass: 0.0,
511 point_indices: octants[1].clone(),
512 children: None,
513 is_leaf: true,
514 }),
515 Box::new(OctTreeNode {
516 x_min: self.x_min,
517 x_max: x_mid,
518 y_min: y_mid,
519 y_max: self.y_max,
520 z_min: self.z_min,
521 z_max: z_mid,
522 center_of_mass: None,
523 total_mass: 0.0,
524 point_indices: octants[2].clone(),
525 children: None,
526 is_leaf: true,
527 }),
528 Box::new(OctTreeNode {
529 x_min: x_mid,
530 x_max: self.x_max,
531 y_min: y_mid,
532 y_max: self.y_max,
533 z_min: self.z_min,
534 z_max: z_mid,
535 center_of_mass: None,
536 total_mass: 0.0,
537 point_indices: octants[3].clone(),
538 children: None,
539 is_leaf: true,
540 }),
541 Box::new(OctTreeNode {
542 x_min: self.x_min,
543 x_max: x_mid,
544 y_min: self.y_min,
545 y_max: y_mid,
546 z_min: z_mid,
547 z_max: self.z_max,
548 center_of_mass: None,
549 total_mass: 0.0,
550 point_indices: octants[4].clone(),
551 children: None,
552 is_leaf: true,
553 }),
554 Box::new(OctTreeNode {
555 x_min: x_mid,
556 x_max: self.x_max,
557 y_min: self.y_min,
558 y_max: y_mid,
559 z_min: z_mid,
560 z_max: self.z_max,
561 center_of_mass: None,
562 total_mass: 0.0,
563 point_indices: octants[5].clone(),
564 children: None,
565 is_leaf: true,
566 }),
567 Box::new(OctTreeNode {
568 x_min: self.x_min,
569 x_max: x_mid,
570 y_min: y_mid,
571 y_max: self.y_max,
572 z_min: z_mid,
573 z_max: self.z_max,
574 center_of_mass: None,
575 total_mass: 0.0,
576 point_indices: octants[6].clone(),
577 children: None,
578 is_leaf: true,
579 }),
580 Box::new(OctTreeNode {
581 x_min: x_mid,
582 x_max: self.x_max,
583 y_min: y_mid,
584 y_max: self.y_max,
585 z_min: z_mid,
586 z_max: self.z_max,
587 center_of_mass: None,
588 total_mass: 0.0,
589 point_indices: octants[7].clone(),
590 children: None,
591 is_leaf: true,
592 }),
593 ];
594
595 for child in &mut children {
597 child.build_tree(embedding)?;
598 }
599
600 self.children = Some(children);
601 self.is_leaf = false;
602 self.point_indices.clear();
603 self.update_center_of_mass(embedding)?;
604
605 Ok(())
606 }
607
608 fn update_center_of_mass(&mut self, embedding: &Array2<f64>) -> Result<()> {
610 if self.is_leaf {
611 if self.point_indices.is_empty() {
612 self.total_mass = 0.0;
613 self.center_of_mass = None;
614 return Ok(());
615 }
616
617 let mut com = Array1::zeros(3);
618 for &idx in &self.point_indices {
619 com[0] += embedding[[idx, 0]];
620 com[1] += embedding[[idx, 1]];
621 com[2] += embedding[[idx, 2]];
622 }
623
624 self.total_mass = self.point_indices.len() as f64;
625 com.mapv_inplace(|x| x / self.total_mass);
626 self.center_of_mass = Some(com);
627 } else if let Some(ref children) = self.children {
628 let mut com = Array1::zeros(3);
629 let mut total_mass = 0.0;
630
631 for child in children.iter() {
632 if let Some(ref child_com) = child.center_of_mass {
633 total_mass += child.total_mass;
634 for i in 0..3 {
635 com[i] += child_com[i] * child.total_mass;
636 }
637 }
638 }
639
640 if total_mass > 0.0 {
641 com.mapv_inplace(|x| x / total_mass);
642 self.center_of_mass = Some(com);
643 self.total_mass = total_mass;
644 } else {
645 self.center_of_mass = None;
646 self.total_mass = 0.0;
647 }
648 }
649
650 Ok(())
651 }
652
653 #[allow(clippy::too_many_arguments)]
655 fn compute_forces_oct(
656 &self,
657 point: &Array1<f64>,
658 point_idx: usize,
659 angle: f64,
660 degrees_of_freedom: f64,
661 ) -> Result<(Array1<f64>, f64)> {
662 let mut force = Array1::zeros(3);
663 let mut sum_q = 0.0;
664
665 self.compute_forces_recursive_oct(
666 point,
667 point_idx,
668 angle,
669 degrees_of_freedom,
670 &mut force,
671 &mut sum_q,
672 )?;
673
674 Ok((force, sum_q))
675 }
676
677 #[allow(clippy::too_many_arguments)]
679 fn compute_forces_recursive_oct(
680 &self,
681 point: &Array1<f64>,
682 _point_idx: usize,
683 angle: f64,
684 degrees_of_freedom: f64,
685 force: &mut Array1<f64>,
686 sum_q: &mut f64,
687 ) -> Result<()> {
688 if let Some(ref com) = self.center_of_mass {
689 if self.total_mass == 0.0 {
690 return Ok(());
691 }
692
693 let dx = point[0] - com[0];
694 let dy = point[1] - com[1];
695 let dz = point[2] - com[2];
696 let dist_squared = dx * dx + dy * dy + dz * dz;
697
698 if dist_squared < MACHINE_EPSILON {
699 return Ok(());
700 }
701
702 let node_size = (self.x_max - self.x_min)
703 .max(self.y_max - self.y_min)
704 .max(self.z_max - self.z_min);
705 let distance = dist_squared.sqrt();
706
707 if self.is_leaf || (node_size / distance) < angle {
708 let q_factor = (1.0 + dist_squared / degrees_of_freedom)
709 .powf(-(degrees_of_freedom + 1.0) / 2.0);
710
711 *sum_q += self.total_mass * q_factor;
712
713 let force_factor =
714 (degrees_of_freedom + 1.0) * self.total_mass * q_factor / degrees_of_freedom;
715 force[0] += force_factor * dx;
716 force[1] += force_factor * dy;
717 force[2] += force_factor * dz;
718 } else if let Some(ref children) = self.children {
719 for child in children.iter() {
720 child.compute_forces_recursive_oct(
721 point,
722 _point_idx,
723 angle,
724 degrees_of_freedom,
725 force,
726 sum_q,
727 )?;
728 }
729 }
730 }
731
732 Ok(())
733 }
734}
735
736pub struct TSNE {
745 n_components: usize,
747 perplexity: f64,
749 early_exaggeration: f64,
751 learning_rate: f64,
753 max_iter: usize,
755 n_iter_without_progress: usize,
757 min_grad_norm: f64,
759 metric: String,
761 method: String,
763 init: String,
765 angle: f64,
767 n_jobs: i32,
769 verbose: bool,
771 random_state: Option<u64>,
773 embedding_: Option<Array2<f64>>,
775 kl_divergence_: Option<f64>,
777 n_iter_: Option<usize>,
779 learning_rate_: Option<f64>,
781}
782
783impl Default for TSNE {
784 fn default() -> Self {
785 Self::new()
786 }
787}
788
789impl TSNE {
790 pub fn new() -> Self {
792 TSNE {
793 n_components: 2,
794 perplexity: 30.0,
795 early_exaggeration: 12.0,
796 learning_rate: 200.0,
797 max_iter: 1000,
798 n_iter_without_progress: 300,
799 min_grad_norm: 1e-7,
800 metric: "euclidean".to_string(),
801 method: "barnes_hut".to_string(),
802 init: "pca".to_string(),
803 angle: 0.5,
804 n_jobs: -1, verbose: false,
806 random_state: None,
807 embedding_: None,
808 kl_divergence_: None,
809 n_iter_: None,
810 learning_rate_: None,
811 }
812 }
813
814 pub fn with_n_components(mut self, ncomponents: usize) -> Self {
816 self.n_components = ncomponents;
817 self
818 }
819
820 pub fn with_perplexity(mut self, perplexity: f64) -> Self {
822 self.perplexity = perplexity;
823 self
824 }
825
826 pub fn with_early_exaggeration(mut self, earlyexaggeration: f64) -> Self {
828 self.early_exaggeration = earlyexaggeration;
829 self
830 }
831
832 pub fn with_learning_rate(mut self, learningrate: f64) -> Self {
834 self.learning_rate = learningrate;
835 self
836 }
837
838 pub fn with_max_iter(mut self, maxiter: usize) -> Self {
840 self.max_iter = maxiter;
841 self
842 }
843
844 pub fn with_n_iter_without_progress(mut self, n_iter_withoutprogress: usize) -> Self {
846 self.n_iter_without_progress = n_iter_withoutprogress;
847 self
848 }
849
850 pub fn with_min_grad_norm(mut self, min_gradnorm: f64) -> Self {
852 self.min_grad_norm = min_gradnorm;
853 self
854 }
855
856 pub fn with_metric(mut self, metric: &str) -> Self {
864 self.metric = metric.to_string();
865 self
866 }
867
868 pub fn with_method(mut self, method: &str) -> Self {
870 self.method = method.to_string();
871 self
872 }
873
874 pub fn with_init(mut self, init: &str) -> Self {
876 self.init = init.to_string();
877 self
878 }
879
880 pub fn with_angle(mut self, angle: f64) -> Self {
882 self.angle = angle;
883 self
884 }
885
886 pub fn with_n_jobs(mut self, njobs: i32) -> Self {
891 self.n_jobs = njobs;
892 self
893 }
894
895 pub fn with_verbose(mut self, verbose: bool) -> Self {
897 self.verbose = verbose;
898 self
899 }
900
901 pub fn with_random_state(mut self, randomstate: u64) -> Self {
903 self.random_state = Some(randomstate);
904 self
905 }
906
907 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
915 where
916 S: Data,
917 S::Elem: Float + NumCast,
918 {
919 let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
920
921 let n_samples = x_f64.shape()[0];
922 let n_features = x_f64.shape()[1];
923
924 if n_samples == 0 || n_features == 0 {
926 return Err(TransformError::InvalidInput("Empty input data".to_string()));
927 }
928
929 if self.perplexity >= n_samples as f64 {
930 return Err(TransformError::InvalidInput(format!(
931 "perplexity ({}) must be less than n_samples ({})",
932 self.perplexity, n_samples
933 )));
934 }
935
936 if self.method == "barnes_hut" && self.n_components > 3 {
937 return Err(TransformError::InvalidInput(
938 "'n_components' should be less than or equal to 3 for barnes_hut algorithm"
939 .to_string(),
940 ));
941 }
942
943 self.learning_rate_ = Some(self.learning_rate);
945
946 let x_embedded = self.initialize_embedding(&x_f64)?;
948
949 let p = self.compute_pairwise_affinities(&x_f64)?;
951
952 let (embedding, kl_divergence, n_iter) =
954 self.tsne_optimization(p, x_embedded, n_samples)?;
955
956 self.embedding_ = Some(embedding.clone());
957 self.kl_divergence_ = Some(kl_divergence);
958 self.n_iter_ = Some(n_iter);
959
960 Ok(embedding)
961 }
962
963 fn initialize_embedding(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
965 let n_samples = x.shape()[0];
966
967 if self.init == "pca" {
968 let n_components = self.n_components.min(x.shape()[1]);
969 let mut pca = PCA::new(n_components, true, false);
970 let mut x_embedded = pca.fit_transform(x)?;
971
972 let std_dev = (x_embedded.column(0).map(|&x| x * x).sum() / (n_samples as f64)).sqrt();
974 if std_dev > 0.0 {
975 x_embedded.mapv_inplace(|x| x / std_dev * 1e-4);
976 }
977
978 Ok(x_embedded)
979 } else if self.init == "random" {
980 use scirs2_core::random::{thread_rng, Distribution};
983 let normal = Normal::new(0.0, 1e-4).unwrap();
984 let mut rng = thread_rng();
985
986 let data: Vec<f64> = (0..(n_samples * self.n_components))
988 .map(|_| normal.sample(&mut rng))
989 .collect();
990 Ok(Array2::from_shape_vec((n_samples, self.n_components), data).unwrap())
991 } else {
992 Err(TransformError::InvalidInput(format!(
993 "Initialization method '{}' not recognized",
994 self.init
995 )))
996 }
997 }
998
999 fn compute_pairwise_affinities(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
1001 let _n_samples = x.shape()[0];
1002
1003 let distances = self.compute_pairwise_distances(x)?;
1005
1006 let p = self.distances_to_affinities(&distances)?;
1008
1009 let mut p_symmetric = &p + &p.t();
1011
1012 let p_sum = p_symmetric.sum();
1014 if p_sum > 0.0 {
1015 p_symmetric.mapv_inplace(|x| x.max(MACHINE_EPSILON) / p_sum);
1016 }
1017
1018 Ok(p_symmetric)
1019 }
1020
1021 fn compute_pairwise_distances(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
1023 let n_samples = x.shape()[0];
1024 let mut distances = Array2::zeros((n_samples, n_samples));
1025
1026 match self.metric.as_str() {
1027 "euclidean" => {
1028 if self.n_jobs == 1 {
1029 for i in 0..n_samples {
1031 for j in i + 1..n_samples {
1032 let mut dist_squared = 0.0;
1033 for k in 0..x.shape()[1] {
1034 let diff = x[[i, k]] - x[[j, k]];
1035 dist_squared += diff * diff;
1036 }
1037 distances[[i, j]] = dist_squared;
1038 distances[[j, i]] = dist_squared;
1039 }
1040 }
1041 } else {
1042 let upper_triangle_indices: Vec<(usize, usize)> = (0..n_samples)
1044 .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
1045 .collect();
1046
1047 let n_features = x.shape()[1];
1048 let squared_distances: Vec<f64> = upper_triangle_indices
1049 .par_iter()
1050 .map(|&(i, j)| {
1051 let mut dist_squared = 0.0;
1052 for k in 0..n_features {
1053 let diff = x[[i, k]] - x[[j, k]];
1054 dist_squared += diff * diff;
1055 }
1056 dist_squared
1057 })
1058 .collect();
1059
1060 for (idx, &(i, j)) in upper_triangle_indices.iter().enumerate() {
1062 distances[[i, j]] = squared_distances[idx];
1063 distances[[j, i]] = squared_distances[idx];
1064 }
1065 }
1066 }
1067 "manhattan" => {
1068 if self.n_jobs == 1 {
1069 for i in 0..n_samples {
1071 for j in i + 1..n_samples {
1072 let mut dist = 0.0;
1073 for k in 0..x.shape()[1] {
1074 dist += (x[[i, k]] - x[[j, k]]).abs();
1075 }
1076 distances[[i, j]] = dist;
1077 distances[[j, i]] = dist;
1078 }
1079 }
1080 } else {
1081 let upper_triangle_indices: Vec<(usize, usize)> = (0..n_samples)
1083 .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
1084 .collect();
1085
1086 let n_features = x.shape()[1];
1087 let manhattan_distances: Vec<f64> = upper_triangle_indices
1088 .par_iter()
1089 .map(|&(i, j)| {
1090 let mut dist = 0.0;
1091 for k in 0..n_features {
1092 dist += (x[[i, k]] - x[[j, k]]).abs();
1093 }
1094 dist
1095 })
1096 .collect();
1097
1098 for (idx, &(i, j)) in upper_triangle_indices.iter().enumerate() {
1100 distances[[i, j]] = manhattan_distances[idx];
1101 distances[[j, i]] = manhattan_distances[idx];
1102 }
1103 }
1104 }
1105 "cosine" => {
1106 let mut normalized_x = Array2::zeros((n_samples, x.shape()[1]));
1108 for i in 0..n_samples {
1109 let row = x.row(i);
1110 let norm = row.iter().map(|v| v * v).sum::<f64>().sqrt();
1111 if norm > EPSILON {
1112 for j in 0..x.shape()[1] {
1113 normalized_x[[i, j]] = x[[i, j]] / norm;
1114 }
1115 } else {
1116 for j in 0..x.shape()[1] {
1118 normalized_x[[i, j]] = 0.0;
1119 }
1120 }
1121 }
1122
1123 if self.n_jobs == 1 {
1124 for i in 0..n_samples {
1126 for j in i + 1..n_samples {
1127 let mut dot_product = 0.0;
1128 for k in 0..x.shape()[1] {
1129 dot_product += normalized_x[[i, k]] * normalized_x[[j, k]];
1130 }
1131 let cosine_dist = 1.0 - dot_product.clamp(-1.0, 1.0);
1133 distances[[i, j]] = cosine_dist;
1134 distances[[j, i]] = cosine_dist;
1135 }
1136 }
1137 } else {
1138 let upper_triangle_indices: Vec<(usize, usize)> = (0..n_samples)
1140 .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
1141 .collect();
1142
1143 let n_features = x.shape()[1];
1144 let cosine_distances: Vec<f64> = upper_triangle_indices
1145 .par_iter()
1146 .map(|&(i, j)| {
1147 let mut dot_product = 0.0;
1148 for k in 0..n_features {
1149 dot_product += normalized_x[[i, k]] * normalized_x[[j, k]];
1150 }
1151 1.0 - dot_product.clamp(-1.0, 1.0)
1153 })
1154 .collect();
1155
1156 for (idx, &(i, j)) in upper_triangle_indices.iter().enumerate() {
1158 distances[[i, j]] = cosine_distances[idx];
1159 distances[[j, i]] = cosine_distances[idx];
1160 }
1161 }
1162 }
1163 "chebyshev" => {
1164 if self.n_jobs == 1 {
1165 for i in 0..n_samples {
1167 for j in i + 1..n_samples {
1168 let mut max_dist = 0.0;
1169 for k in 0..x.shape()[1] {
1170 let diff = (x[[i, k]] - x[[j, k]]).abs();
1171 max_dist = max_dist.max(diff);
1172 }
1173 distances[[i, j]] = max_dist;
1174 distances[[j, i]] = max_dist;
1175 }
1176 }
1177 } else {
1178 let upper_triangle_indices: Vec<(usize, usize)> = (0..n_samples)
1180 .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
1181 .collect();
1182
1183 let n_features = x.shape()[1];
1184 let chebyshev_distances: Vec<f64> = upper_triangle_indices
1185 .par_iter()
1186 .map(|&(i, j)| {
1187 let mut max_dist = 0.0;
1188 for k in 0..n_features {
1189 let diff = (x[[i, k]] - x[[j, k]]).abs();
1190 max_dist = max_dist.max(diff);
1191 }
1192 max_dist
1193 })
1194 .collect();
1195
1196 for (idx, &(i, j)) in upper_triangle_indices.iter().enumerate() {
1198 distances[[i, j]] = chebyshev_distances[idx];
1199 distances[[j, i]] = chebyshev_distances[idx];
1200 }
1201 }
1202 }
1203 _ => {
1204 return Err(TransformError::InvalidInput(format!(
1205 "Metric '{}' not implemented. Supported metrics are: 'euclidean', 'manhattan', 'cosine', 'chebyshev'",
1206 self.metric
1207 )));
1208 }
1209 }
1210
1211 Ok(distances)
1212 }
1213
1214 fn distances_to_affinities(&self, distances: &Array2<f64>) -> Result<Array2<f64>> {
1216 let n_samples = distances.shape()[0];
1217 let mut p = Array2::zeros((n_samples, n_samples));
1218 let target = (2.0f64).ln() * self.perplexity;
1219
1220 if self.n_jobs == 1 {
1221 for i in 0..n_samples {
1223 let mut beta_min = -f64::INFINITY;
1224 let mut beta_max = f64::INFINITY;
1225 let mut beta = 1.0;
1226
1227 let distances_i = distances.row(i).to_owned();
1229
1230 for _ in 0..50 {
1232 let mut sum_pi = 0.0;
1235 let mut h = 0.0;
1236
1237 for j in 0..n_samples {
1238 if i == j {
1239 p[[i, j]] = 0.0;
1240 continue;
1241 }
1242
1243 let p_ij = (-beta * distances_i[j]).exp();
1244 p[[i, j]] = p_ij;
1245 sum_pi += p_ij;
1246 }
1247
1248 if sum_pi > 0.0 {
1250 for j in 0..n_samples {
1251 if i == j {
1252 continue;
1253 }
1254
1255 p[[i, j]] /= sum_pi;
1256
1257 if p[[i, j]] > MACHINE_EPSILON {
1259 h -= p[[i, j]] * p[[i, j]].ln();
1260 }
1261 }
1262 }
1263
1264 let h_diff = h - target;
1266
1267 if h_diff.abs() < EPSILON {
1268 break; }
1270
1271 if h_diff > 0.0 {
1273 beta_min = beta;
1274 if beta_max == f64::INFINITY {
1275 beta *= 2.0;
1276 } else {
1277 beta = (beta + beta_max) / 2.0;
1278 }
1279 } else {
1280 beta_max = beta;
1281 if beta_min == -f64::INFINITY {
1282 beta /= 2.0;
1283 } else {
1284 beta = (beta + beta_min) / 2.0;
1285 }
1286 }
1287 }
1288 }
1289 } else {
1290 let prob_rows: Vec<Vec<f64>> = (0..n_samples)
1292 .into_par_iter()
1293 .map(|i| {
1294 let mut beta_min = -f64::INFINITY;
1295 let mut beta_max = f64::INFINITY;
1296 let mut beta = 1.0;
1297
1298 let distances_i: Vec<f64> = (0..n_samples).map(|j| distances[[i, j]]).collect();
1300 let mut p_row = vec![0.0; n_samples];
1301
1302 for _ in 0..50 {
1304 let mut sum_pi = 0.0;
1307 let mut h = 0.0;
1308
1309 for j in 0..n_samples {
1310 if i == j {
1311 p_row[j] = 0.0;
1312 continue;
1313 }
1314
1315 let p_ij = (-beta * distances_i[j]).exp();
1316 p_row[j] = p_ij;
1317 sum_pi += p_ij;
1318 }
1319
1320 if sum_pi > 0.0 {
1322 for (j, prob) in p_row.iter_mut().enumerate().take(n_samples) {
1323 if i == j {
1324 continue;
1325 }
1326
1327 *prob /= sum_pi;
1328
1329 if *prob > MACHINE_EPSILON {
1331 h -= *prob * prob.ln();
1332 }
1333 }
1334 }
1335
1336 let h_diff = h - target;
1338
1339 if h_diff.abs() < EPSILON {
1340 break; }
1342
1343 if h_diff > 0.0 {
1345 beta_min = beta;
1346 if beta_max == f64::INFINITY {
1347 beta *= 2.0;
1348 } else {
1349 beta = (beta + beta_max) / 2.0;
1350 }
1351 } else {
1352 beta_max = beta;
1353 if beta_min == -f64::INFINITY {
1354 beta /= 2.0;
1355 } else {
1356 beta = (beta + beta_min) / 2.0;
1357 }
1358 }
1359 }
1360
1361 p_row
1362 })
1363 .collect();
1364
1365 for (i, row) in prob_rows.iter().enumerate() {
1367 for (j, &val) in row.iter().enumerate() {
1368 p[[i, j]] = val;
1369 }
1370 }
1371 }
1372
1373 Ok(p)
1374 }
1375
1376 #[allow(clippy::too_many_arguments)]
1378 fn tsne_optimization(
1379 &self,
1380 p: Array2<f64>,
1381 initial_embedding: Array2<f64>,
1382 n_samples: usize,
1383 ) -> Result<(Array2<f64>, f64, usize)> {
1384 let n_components = self.n_components;
1385 let degrees_of_freedom = (n_components - 1).max(1) as f64;
1386
1387 let mut embedding = initial_embedding;
1389 let mut update = Array2::zeros((n_samples, n_components));
1390 let mut gains = Array2::ones((n_samples, n_components));
1391 let mut error = f64::INFINITY;
1392 let mut best_error = f64::INFINITY;
1393 let mut best_iter = 0;
1394 let mut iter = 0;
1395
1396 let exploration_n_iter = 250;
1398 let n_iter_check = 50;
1399
1400 let p_early = &p * self.early_exaggeration;
1402
1403 if self.verbose {
1404 println!("[t-SNE] Starting optimization with early exaggeration phase...");
1405 }
1406
1407 for i in 0..exploration_n_iter {
1409 let (curr_error, grad) = if self.method == "barnes_hut" {
1411 self.compute_gradient_barnes_hut(&embedding, &p_early, degrees_of_freedom)?
1412 } else {
1413 self.compute_gradient_exact(&embedding, &p_early, degrees_of_freedom)?
1414 };
1415
1416 self.gradient_update(
1418 &mut embedding,
1419 &mut update,
1420 &mut gains,
1421 &grad,
1422 0.5,
1423 self.learning_rate_,
1424 )?;
1425
1426 if (i + 1) % n_iter_check == 0 {
1428 if self.verbose {
1429 println!("[t-SNE] Iteration {}: error = {:.7}", i + 1, curr_error);
1430 }
1431
1432 if curr_error < best_error {
1433 best_error = curr_error;
1434 best_iter = i;
1435 } else if i - best_iter > self.n_iter_without_progress {
1436 if self.verbose {
1437 println!("[t-SNE] Early convergence at iteration {}", i + 1);
1438 }
1439 break;
1440 }
1441
1442 let grad_norm = grad.mapv(|x| x * x).sum().sqrt();
1444 if grad_norm < self.min_grad_norm {
1445 if self.verbose {
1446 println!("[t-SNE] Gradient norm {} below threshold, stopping optimization at iteration {}",
1447 grad_norm, i + 1);
1448 }
1449 break;
1450 }
1451 }
1452
1453 iter = i;
1454 }
1455
1456 if self.verbose {
1457 println!("[t-SNE] Completed early exaggeration phase, starting final optimization...");
1458 }
1459
1460 for i in iter + 1..self.max_iter {
1462 let (curr_error, grad) = if self.method == "barnes_hut" {
1464 self.compute_gradient_barnes_hut(&embedding, &p, degrees_of_freedom)?
1465 } else {
1466 self.compute_gradient_exact(&embedding, &p, degrees_of_freedom)?
1467 };
1468 error = curr_error;
1469
1470 self.gradient_update(
1472 &mut embedding,
1473 &mut update,
1474 &mut gains,
1475 &grad,
1476 0.8,
1477 self.learning_rate_,
1478 )?;
1479
1480 if (i + 1) % n_iter_check == 0 {
1482 if self.verbose {
1483 println!("[t-SNE] Iteration {}: error = {:.7}", i + 1, curr_error);
1484 }
1485
1486 if curr_error < best_error {
1487 best_error = curr_error;
1488 best_iter = i;
1489 } else if i - best_iter > self.n_iter_without_progress {
1490 if self.verbose {
1491 println!("[t-SNE] Stopping optimization at iteration {}", i + 1);
1492 }
1493 break;
1494 }
1495
1496 let grad_norm = grad.mapv(|x| x * x).sum().sqrt();
1498 if grad_norm < self.min_grad_norm {
1499 if self.verbose {
1500 println!("[t-SNE] Gradient norm {} below threshold, stopping optimization at iteration {}",
1501 grad_norm, i + 1);
1502 }
1503 break;
1504 }
1505 }
1506
1507 iter = i;
1508 }
1509
1510 if self.verbose {
1511 println!(
1512 "[t-SNE] Optimization finished after {} iterations with error {:.7}",
1513 iter + 1,
1514 error
1515 );
1516 }
1517
1518 Ok((embedding, error, iter + 1))
1519 }
1520
1521 #[allow(clippy::too_many_arguments)]
1523 fn compute_gradient_exact(
1524 &self,
1525 embedding: &Array2<f64>,
1526 p: &Array2<f64>,
1527 degrees_of_freedom: f64,
1528 ) -> Result<(f64, Array2<f64>)> {
1529 let n_samples = embedding.shape()[0];
1530 let n_components = embedding.shape()[1];
1531
1532 if self.n_jobs == 1 {
1533 let mut dist = Array2::zeros((n_samples, n_samples));
1535 for i in 0..n_samples {
1536 for j in i + 1..n_samples {
1537 let mut d_squared = 0.0;
1538 for k in 0..n_components {
1539 let diff = embedding[[i, k]] - embedding[[j, k]];
1540 d_squared += diff * diff;
1541 }
1542
1543 let q_ij = (1.0 + d_squared / degrees_of_freedom)
1545 .powf(-(degrees_of_freedom + 1.0) / 2.0);
1546 dist[[i, j]] = q_ij;
1547 dist[[j, i]] = q_ij;
1548 }
1549 }
1550
1551 for i in 0..n_samples {
1553 dist[[i, i]] = 0.0;
1554 }
1555
1556 let sum_q = dist.sum().max(MACHINE_EPSILON);
1558 let q = &dist / sum_q;
1559
1560 let mut kl_divergence = 0.0;
1562 for i in 0..n_samples {
1563 for j in 0..n_samples {
1564 if p[[i, j]] > MACHINE_EPSILON && q[[i, j]] > MACHINE_EPSILON {
1565 kl_divergence += p[[i, j]] * (p[[i, j]] / q[[i, j]]).ln();
1566 }
1567 }
1568 }
1569
1570 let mut grad = Array2::zeros((n_samples, n_components));
1572 let factor =
1573 4.0 * (degrees_of_freedom + 1.0) / (degrees_of_freedom * (sum_q.powf(2.0)));
1574
1575 for i in 0..n_samples {
1576 for j in 0..n_samples {
1577 if i != j {
1578 let p_q_diff = p[[i, j]] - q[[i, j]];
1579 for k in 0..n_components {
1580 grad[[i, k]] += factor
1581 * p_q_diff
1582 * dist[[i, j]]
1583 * (embedding[[i, k]] - embedding[[j, k]]);
1584 }
1585 }
1586 }
1587 }
1588
1589 Ok((kl_divergence, grad))
1590 } else {
1591 let upper_triangle_indices: Vec<(usize, usize)> = (0..n_samples)
1593 .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
1594 .collect();
1595
1596 let q_values: Vec<f64> = upper_triangle_indices
1597 .par_iter()
1598 .map(|&(i, j)| {
1599 let mut d_squared = 0.0;
1600 for k in 0..n_components {
1601 let diff = embedding[[i, k]] - embedding[[j, k]];
1602 d_squared += diff * diff;
1603 }
1604
1605 (1.0 + d_squared / degrees_of_freedom).powf(-(degrees_of_freedom + 1.0) / 2.0)
1607 })
1608 .collect();
1609
1610 let mut dist = Array2::zeros((n_samples, n_samples));
1612 for (idx, &(i, j)) in upper_triangle_indices.iter().enumerate() {
1613 let q_val = q_values[idx];
1614 dist[[i, j]] = q_val;
1615 dist[[j, i]] = q_val;
1616 }
1617
1618 for i in 0..n_samples {
1620 dist[[i, i]] = 0.0;
1621 }
1622
1623 let sum_q = dist.sum().max(MACHINE_EPSILON);
1625 let q = &dist / sum_q;
1626
1627 let kl_divergence: f64 = (0..n_samples)
1629 .into_par_iter()
1630 .map(|i| {
1631 let mut local_kl = 0.0;
1632 for j in 0..n_samples {
1633 if p[[i, j]] > MACHINE_EPSILON && q[[i, j]] > MACHINE_EPSILON {
1634 local_kl += p[[i, j]] * (p[[i, j]] / q[[i, j]]).ln();
1635 }
1636 }
1637 local_kl
1638 })
1639 .sum();
1640
1641 let factor =
1643 4.0 * (degrees_of_freedom + 1.0) / (degrees_of_freedom * (sum_q.powf(2.0)));
1644
1645 let grad_rows: Vec<Vec<f64>> = (0..n_samples)
1646 .into_par_iter()
1647 .map(|i| {
1648 let mut grad_row = vec![0.0; n_components];
1649 for j in 0..n_samples {
1650 if i != j {
1651 let p_q_diff = p[[i, j]] - q[[i, j]];
1652 for k in 0..n_components {
1653 grad_row[k] += factor
1654 * p_q_diff
1655 * dist[[i, j]]
1656 * (embedding[[i, k]] - embedding[[j, k]]);
1657 }
1658 }
1659 }
1660 grad_row
1661 })
1662 .collect();
1663
1664 let mut grad = Array2::zeros((n_samples, n_components));
1666 for (i, row) in grad_rows.iter().enumerate() {
1667 for (k, &val) in row.iter().enumerate() {
1668 grad[[i, k]] = val;
1669 }
1670 }
1671
1672 Ok((kl_divergence, grad))
1673 }
1674 }
1675
1676 #[allow(clippy::too_many_arguments)]
1678 fn compute_gradient_barnes_hut(
1679 &self,
1680 embedding: &Array2<f64>,
1681 p: &Array2<f64>,
1682 degrees_of_freedom: f64,
1683 ) -> Result<(f64, Array2<f64>)> {
1684 let n_samples = embedding.shape()[0];
1685 let n_components = embedding.shape()[1];
1686
1687 let tree = if n_components == 2 {
1689 SpatialTree::new_quadtree(embedding)?
1690 } else if n_components == 3 {
1691 SpatialTree::new_octree(embedding)?
1692 } else {
1693 return Err(TransformError::InvalidInput(
1694 "Barnes-Hut approximation only supports 2D and 3D embeddings".to_string(),
1695 ));
1696 };
1697
1698 let mut q = Array2::zeros((n_samples, n_samples));
1700 let mut grad = Array2::zeros((n_samples, n_components));
1701 let mut sum_q = 0.0;
1702
1703 for i in 0..n_samples {
1705 let point = embedding.row(i).to_owned();
1706 let (repulsive_force, q_sum) =
1707 tree.compute_forces(&point, i, self.angle, degrees_of_freedom)?;
1708
1709 sum_q += q_sum;
1710
1711 for j in 0..n_components {
1713 grad[[i, j]] += repulsive_force[j];
1714 }
1715
1716 for j in 0..n_samples {
1718 if i != j {
1719 let mut dist_squared = 0.0;
1720 for k in 0..n_components {
1721 let diff = embedding[[i, k]] - embedding[[j, k]];
1722 dist_squared += diff * diff;
1723 }
1724 let q_ij = (1.0 + dist_squared / degrees_of_freedom)
1725 .powf(-(degrees_of_freedom + 1.0) / 2.0);
1726 q[[i, j]] = q_ij;
1727 }
1728 }
1729 }
1730
1731 sum_q = sum_q.max(MACHINE_EPSILON);
1733 q.mapv_inplace(|x| x / sum_q);
1734
1735 for i in 0..n_samples {
1737 for j in 0..n_samples {
1738 if i != j && p[[i, j]] > MACHINE_EPSILON {
1739 let mut dist_squared = 0.0;
1740 for k in 0..n_components {
1741 let diff = embedding[[i, k]] - embedding[[j, k]];
1742 dist_squared += diff * diff;
1743 }
1744
1745 let q_ij = (1.0 + dist_squared / degrees_of_freedom)
1746 .powf(-(degrees_of_freedom + 1.0) / 2.0);
1747 let factor = 4.0 * p[[i, j]] * q_ij;
1748
1749 for k in 0..n_components {
1750 grad[[i, k]] -= factor * (embedding[[i, k]] - embedding[[j, k]]);
1751 }
1752 }
1753 }
1754 }
1755
1756 let mut kl_divergence = 0.0;
1758 for i in 0..n_samples {
1759 for j in 0..n_samples {
1760 if p[[i, j]] > MACHINE_EPSILON && q[[i, j]] > MACHINE_EPSILON {
1761 kl_divergence += p[[i, j]] * (p[[i, j]] / q[[i, j]]).ln();
1762 }
1763 }
1764 }
1765
1766 Ok((kl_divergence, grad))
1767 }
1768
1769 #[allow(clippy::too_many_arguments)]
1771 fn gradient_update(
1772 &self,
1773 embedding: &mut Array2<f64>,
1774 update: &mut Array2<f64>,
1775 gains: &mut Array2<f64>,
1776 grad: &Array2<f64>,
1777 momentum: f64,
1778 learning_rate: Option<f64>,
1779 ) -> Result<()> {
1780 let n_samples = embedding.shape()[0];
1781 let n_components = embedding.shape()[1];
1782 let eta = learning_rate.unwrap_or(self.learning_rate);
1783
1784 for i in 0..n_samples {
1786 for j in 0..n_components {
1787 let same_sign = update[[i, j]] * grad[[i, j]] > 0.0;
1788
1789 if same_sign {
1790 gains[[i, j]] *= 0.8;
1791 } else {
1792 gains[[i, j]] += 0.2;
1793 }
1794
1795 gains[[i, j]] = gains[[i, j]].max(0.01);
1797
1798 update[[i, j]] = momentum * update[[i, j]] - eta * gains[[i, j]] * grad[[i, j]];
1800 embedding[[i, j]] += update[[i, j]];
1801 }
1802 }
1803
1804 Ok(())
1805 }
1806
1807 pub fn embedding(&self) -> Option<&Array2<f64>> {
1809 self.embedding_.as_ref()
1810 }
1811
1812 pub fn kl_divergence(&self) -> Option<f64> {
1814 self.kl_divergence_
1815 }
1816
1817 pub fn n_iter(&self) -> Option<usize> {
1819 self.n_iter_
1820 }
1821}
1822
1823#[allow(dead_code)]
1837#[allow(clippy::too_many_arguments)]
1838pub fn trustworthiness<S1, S2>(
1839 x: &ArrayBase<S1, Ix2>,
1840 x_embedded: &ArrayBase<S2, Ix2>,
1841 n_neighbors: usize,
1842 metric: &str,
1843) -> Result<f64>
1844where
1845 S1: Data,
1846 S2: Data,
1847 S1::Elem: Float + NumCast,
1848 S2::Elem: Float + NumCast,
1849{
1850 let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
1851 let x_embedded_f64 = x_embedded.mapv(|x| NumCast::from(x).unwrap_or(0.0));
1852
1853 let n_samples = x_f64.shape()[0];
1854
1855 if n_neighbors >= n_samples / 2 {
1856 return Err(TransformError::InvalidInput(format!(
1857 "n_neighbors ({}) should be less than n_samples / 2 ({})",
1858 n_neighbors,
1859 n_samples / 2
1860 )));
1861 }
1862
1863 if metric != "euclidean" {
1864 return Err(TransformError::InvalidInput(format!(
1865 "Metric '{metric}' not implemented. Currently only 'euclidean' is supported"
1866 )));
1867 }
1868
1869 let mut dist_x = Array2::zeros((n_samples, n_samples));
1871 for i in 0..n_samples {
1872 for j in 0..n_samples {
1873 if i == j {
1874 dist_x[[i, j]] = f64::INFINITY; continue;
1876 }
1877
1878 let mut d_squared = 0.0;
1879 for k in 0..x_f64.shape()[1] {
1880 let diff = x_f64[[i, k]] - x_f64[[j, k]];
1881 d_squared += diff * diff;
1882 }
1883 dist_x[[i, j]] = d_squared.sqrt();
1884 }
1885 }
1886
1887 let mut dist_embedded = Array2::zeros((n_samples, n_samples));
1889 for i in 0..n_samples {
1890 for j in 0..n_samples {
1891 if i == j {
1892 dist_embedded[[i, j]] = f64::INFINITY; continue;
1894 }
1895
1896 let mut d_squared = 0.0;
1897 for k in 0..x_embedded_f64.shape()[1] {
1898 let diff = x_embedded_f64[[i, k]] - x_embedded_f64[[j, k]];
1899 d_squared += diff * diff;
1900 }
1901 dist_embedded[[i, j]] = d_squared.sqrt();
1902 }
1903 }
1904
1905 let mut nn_orig = Array2::<usize>::zeros((n_samples, n_neighbors));
1907 for i in 0..n_samples {
1908 let row = dist_x.row(i).to_owned();
1910 let mut pairs: Vec<(usize, f64)> = row.iter().enumerate().map(|(j, &d)| (j, d)).collect();
1911 pairs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1912
1913 for (j, &(idx_, _)) in pairs.iter().enumerate().take(n_neighbors) {
1915 nn_orig[[i, j]] = idx_;
1916 }
1917 }
1918
1919 let mut nn_embedded = Array2::<usize>::zeros((n_samples, n_neighbors));
1921 for i in 0..n_samples {
1922 let row = dist_embedded.row(i).to_owned();
1924 let mut pairs: Vec<(usize, f64)> = row.iter().enumerate().map(|(j, &d)| (j, d)).collect();
1925 pairs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1926
1927 for (j, &(idx, _)) in pairs.iter().skip(1).take(n_neighbors).enumerate() {
1929 nn_embedded[[i, j]] = idx;
1930 }
1931 }
1932
1933 let mut t = 0.0;
1935 for i in 0..n_samples {
1936 for &j in nn_embedded.row(i).iter() {
1937 let is_not_neighbor = !nn_orig.row(i).iter().any(|&nn| nn == j);
1939
1940 if is_not_neighbor {
1941 let row = dist_x.row(i).to_owned();
1943 let mut pairs: Vec<(usize, f64)> =
1944 row.iter().enumerate().map(|(idx, &d)| (idx, d)).collect();
1945 pairs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1946
1947 let rank = pairs.iter().position(|&(idx_, _)| idx_ == j).unwrap_or(0) - n_neighbors;
1948
1949 t += rank as f64;
1950 }
1951 }
1952 }
1953
1954 let n = n_samples as f64;
1956 let k = n_neighbors as f64;
1957 let normalizer = 2.0 / (n * k * (2.0 * n - 3.0 * k - 1.0));
1958 let trustworthiness = 1.0 - normalizer * t;
1959
1960 Ok(trustworthiness)
1961}
1962
1963#[cfg(test)]
1964mod tests {
1965 use super::*;
1966 use approx::assert_abs_diff_eq;
1967 use scirs2_core::ndarray::arr2;
1968
1969 #[test]
1970 fn test_tsne_simple() {
1971 let x = arr2(&[
1973 [0.0, 0.0],
1974 [0.0, 1.0],
1975 [1.0, 0.0],
1976 [1.0, 1.0],
1977 [5.0, 5.0],
1978 [6.0, 5.0],
1979 [5.0, 6.0],
1980 [6.0, 6.0],
1981 ]);
1982
1983 let mut tsne_exact = TSNE::new()
1985 .with_n_components(2)
1986 .with_perplexity(2.0)
1987 .with_method("exact")
1988 .with_random_state(42)
1989 .with_max_iter(250)
1990 .with_verbose(false);
1991
1992 let embedding_exact = tsne_exact.fit_transform(&x).unwrap();
1993
1994 assert_eq!(embedding_exact.shape(), &[8, 2]);
1996
1997 let dist_group1 =
2000 average_pairwise_distance(&embedding_exact.slice(scirs2_core::ndarray::s![0..4, ..]));
2001 let dist_group2 =
2002 average_pairwise_distance(&embedding_exact.slice(scirs2_core::ndarray::s![4..8, ..]));
2003
2004 let dist_between = average_intergroup_distance(
2006 &embedding_exact.slice(scirs2_core::ndarray::s![0..4, ..]),
2007 &embedding_exact.slice(scirs2_core::ndarray::s![4..8, ..]),
2008 );
2009
2010 assert!(dist_between > dist_group1);
2012 assert!(dist_between > dist_group2);
2013 }
2014
2015 #[test]
2016 fn test_tsne_barnes_hut() {
2017 let x = arr2(&[
2019 [0.0, 0.0],
2020 [0.0, 1.0],
2021 [1.0, 0.0],
2022 [1.0, 1.0],
2023 [5.0, 5.0],
2024 [6.0, 5.0],
2025 [5.0, 6.0],
2026 [6.0, 6.0],
2027 ]);
2028
2029 let mut tsne_bh = TSNE::new()
2031 .with_n_components(2)
2032 .with_perplexity(2.0)
2033 .with_method("barnes_hut")
2034 .with_angle(0.5)
2035 .with_random_state(42)
2036 .with_max_iter(250)
2037 .with_verbose(false);
2038
2039 let embedding_bh = tsne_bh.fit_transform(&x).unwrap();
2040
2041 assert_eq!(embedding_bh.shape(), &[8, 2]);
2043
2044 assert!(embedding_bh.iter().all(|&x| x.is_finite()));
2046
2047 let min_val = embedding_bh.iter().cloned().fold(f64::INFINITY, f64::min);
2049 let max_val = embedding_bh
2050 .iter()
2051 .cloned()
2052 .fold(f64::NEG_INFINITY, f64::max);
2053 assert!(
2054 max_val - min_val > 1e-6,
2055 "Embedding should have some spread"
2056 );
2057
2058 assert!(tsne_bh.kl_divergence().is_some());
2060
2061 let kl_div = tsne_bh.kl_divergence().unwrap();
2064 if !kl_div.is_finite() {
2065 println!(
2067 "Barnes-Hut KL divergence: {} (non-finite, which is acceptable for approximation)",
2068 kl_div
2069 );
2070 } else {
2071 println!("Barnes-Hut KL divergence: {} (finite)", kl_div);
2072 }
2073 }
2074
2075 #[test]
2076 fn test_tsne_multicore() {
2077 let x = arr2(&[
2079 [0.0, 0.0],
2080 [0.0, 1.0],
2081 [1.0, 0.0],
2082 [1.0, 1.0],
2083 [5.0, 5.0],
2084 [6.0, 5.0],
2085 [5.0, 6.0],
2086 [6.0, 6.0],
2087 ]);
2088
2089 let mut tsne_multicore = TSNE::new()
2091 .with_n_components(2)
2092 .with_perplexity(2.0)
2093 .with_method("exact")
2094 .with_n_jobs(-1) .with_random_state(42)
2096 .with_max_iter(100) .with_verbose(false);
2098
2099 let embedding_multicore = tsne_multicore.fit_transform(&x).unwrap();
2100
2101 assert_eq!(embedding_multicore.shape(), &[8, 2]);
2103
2104 assert!(embedding_multicore.iter().all(|&x| x.is_finite()));
2106
2107 let min_val = embedding_multicore
2109 .iter()
2110 .cloned()
2111 .fold(f64::INFINITY, f64::min);
2112 let max_val = embedding_multicore
2113 .iter()
2114 .cloned()
2115 .fold(f64::NEG_INFINITY, f64::max);
2116 assert!(
2117 max_val - min_val > 1e-12,
2118 "Embedding should have some spread, got range: {}",
2119 max_val - min_val
2120 );
2121
2122 let mut tsne_singlecore = TSNE::new()
2124 .with_n_components(2)
2125 .with_perplexity(2.0)
2126 .with_method("exact")
2127 .with_n_jobs(1) .with_random_state(42)
2129 .with_max_iter(100)
2130 .with_verbose(false);
2131
2132 let embedding_singlecore = tsne_singlecore.fit_transform(&x).unwrap();
2133
2134 assert!(embedding_multicore.iter().all(|&x| x.is_finite()));
2136 assert!(embedding_singlecore.iter().all(|&x| x.is_finite()));
2137 }
2138
2139 #[test]
2140 fn test_tsne_3d_barnes_hut() {
2141 let x = arr2(&[
2143 [0.0, 0.0, 0.0],
2144 [0.0, 1.0, 0.0],
2145 [1.0, 0.0, 0.0],
2146 [1.0, 1.0, 0.0],
2147 [5.0, 5.0, 5.0],
2148 [6.0, 5.0, 5.0],
2149 [5.0, 6.0, 5.0],
2150 [6.0, 6.0, 5.0],
2151 ]);
2152
2153 let mut tsne_3d = TSNE::new()
2155 .with_n_components(3)
2156 .with_perplexity(2.0)
2157 .with_method("barnes_hut")
2158 .with_angle(0.5)
2159 .with_random_state(42)
2160 .with_max_iter(250)
2161 .with_verbose(false);
2162
2163 let embedding_3d = tsne_3d.fit_transform(&x).unwrap();
2164
2165 assert_eq!(embedding_3d.shape(), &[8, 3]);
2167
2168 assert!(embedding_3d.iter().all(|&x| x.is_finite()));
2170 }
2171
2172 fn average_pairwise_distance(
2174 points: &ArrayBase<scirs2_core::ndarray::ViewRepr<&f64>, Ix2>,
2175 ) -> f64 {
2176 let n = points.shape()[0];
2177 let mut total_dist = 0.0;
2178 let mut count = 0;
2179
2180 for i in 0..n {
2181 for j in i + 1..n {
2182 let mut dist_squared = 0.0;
2183 for k in 0..points.shape()[1] {
2184 let diff = points[[i, k]] - points[[j, k]];
2185 dist_squared += diff * diff;
2186 }
2187 total_dist += dist_squared.sqrt();
2188 count += 1;
2189 }
2190 }
2191
2192 if count > 0 {
2193 total_dist / count as f64
2194 } else {
2195 0.0
2196 }
2197 }
2198
2199 fn average_intergroup_distance(
2201 group1: &ArrayBase<scirs2_core::ndarray::ViewRepr<&f64>, Ix2>,
2202 group2: &ArrayBase<scirs2_core::ndarray::ViewRepr<&f64>, Ix2>,
2203 ) -> f64 {
2204 let n1 = group1.shape()[0];
2205 let n2 = group2.shape()[0];
2206 let mut total_dist = 0.0;
2207 let mut count = 0;
2208
2209 for i in 0..n1 {
2210 for j in 0..n2 {
2211 let mut dist_squared = 0.0;
2212 for k in 0..group1.shape()[1] {
2213 let diff = group1[[i, k]] - group2[[j, k]];
2214 dist_squared += diff * diff;
2215 }
2216 total_dist += dist_squared.sqrt();
2217 count += 1;
2218 }
2219 }
2220
2221 if count > 0 {
2222 total_dist / count as f64
2223 } else {
2224 0.0
2225 }
2226 }
2227
2228 #[test]
2229 fn test_trustworthiness() {
2230 let x = arr2(&[
2232 [0.0, 0.0],
2233 [0.0, 1.0],
2234 [1.0, 0.0],
2235 [1.0, 1.0],
2236 [5.0, 5.0],
2237 [5.0, 6.0],
2238 [6.0, 5.0],
2239 [6.0, 6.0],
2240 ]);
2241
2242 let perfect_embedding = x.clone();
2244 let t_perfect = trustworthiness(&x, &perfect_embedding, 3, "euclidean").unwrap();
2245 assert_abs_diff_eq!(t_perfect, 1.0, epsilon = 1e-10);
2246
2247 let random_embedding = arr2(&[
2249 [0.9, 0.1],
2250 [0.8, 0.2],
2251 [0.7, 0.3],
2252 [0.6, 0.4],
2253 [0.5, 0.5],
2254 [0.4, 0.6],
2255 [0.3, 0.7],
2256 [0.2, 0.8],
2257 ]);
2258
2259 let t_random = trustworthiness(&x, &random_embedding, 3, "euclidean").unwrap();
2260 assert!(t_random < 1.0);
2261 }
2262}