1use ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
11use ndarray_rand::rand_distr::Normal;
12use ndarray_rand::RandomExt;
13use num_traits::{Float, NumCast};
14use scirs2_core::parallel_ops::*;
15
16use crate::error::{Result, TransformError};
17use crate::reduction::PCA;
18
19const MACHINE_EPSILON: f64 = 1e-14;
21const EPSILON: f64 = 1e-7;
22
23#[derive(Debug, Clone)]
25enum SpatialTree {
26 QuadTree(QuadTreeNode),
27 OctTree(OctTreeNode),
28}
29
30#[derive(Debug, Clone)]
32struct QuadTreeNode {
33 x_min: f64,
35 x_max: f64,
36 y_min: f64,
37 y_max: f64,
38 center_of_mass: Option<Array1<f64>>,
40 total_mass: f64,
42 point_indices: Vec<usize>,
44 children: Option<[Box<QuadTreeNode>; 4]>,
46 is_leaf: bool,
48}
49
50#[derive(Debug, Clone)]
52struct OctTreeNode {
53 x_min: f64,
55 x_max: f64,
56 y_min: f64,
57 y_max: f64,
58 z_min: f64,
59 z_max: f64,
60 center_of_mass: Option<Array1<f64>>,
62 total_mass: f64,
64 point_indices: Vec<usize>,
66 children: Option<[Box<OctTreeNode>; 8]>,
68 is_leaf: bool,
70}
71
72impl SpatialTree {
73 fn new_quadtree(embedding: &Array2<f64>) -> Result<Self> {
75 let n_samples = embedding.shape()[0];
76
77 if embedding.shape()[1] != 2 {
78 return Err(TransformError::InvalidInput(
79 "QuadTree requires 2D _embedding".to_string(),
80 ));
81 }
82
83 let mut x_min = f64::INFINITY;
85 let mut x_max = f64::NEG_INFINITY;
86 let mut y_min = f64::INFINITY;
87 let mut y_max = f64::NEG_INFINITY;
88
89 for i in 0..n_samples {
90 let x = embedding[[i, 0]];
91 let y = embedding[[i, 1]];
92 x_min = x_min.min(x);
93 x_max = x_max.max(x);
94 y_min = y_min.min(y);
95 y_max = y_max.max(y);
96 }
97
98 let margin = 0.01 * ((x_max - x_min) + (y_max - y_min));
100 x_min -= margin;
101 x_max += margin;
102 y_min -= margin;
103 y_max += margin;
104
105 let point_indices: Vec<usize> = (0..n_samples).collect();
107
108 let mut root = QuadTreeNode {
110 x_min,
111 x_max,
112 y_min,
113 y_max,
114 center_of_mass: None,
115 total_mass: 0.0,
116 point_indices,
117 children: None,
118 is_leaf: true,
119 };
120
121 root.build_tree(embedding)?;
123
124 Ok(SpatialTree::QuadTree(root))
125 }
126
127 fn new_octree(embedding: &Array2<f64>) -> Result<Self> {
129 let n_samples = embedding.shape()[0];
130
131 if embedding.shape()[1] != 3 {
132 return Err(TransformError::InvalidInput(
133 "OctTree requires 3D _embedding".to_string(),
134 ));
135 }
136
137 let mut x_min = f64::INFINITY;
139 let mut x_max = f64::NEG_INFINITY;
140 let mut y_min = f64::INFINITY;
141 let mut y_max = f64::NEG_INFINITY;
142 let mut z_min = f64::INFINITY;
143 let mut z_max = f64::NEG_INFINITY;
144
145 for i in 0..n_samples {
146 let x = embedding[[i, 0]];
147 let y = embedding[[i, 1]];
148 let z = embedding[[i, 2]];
149 x_min = x_min.min(x);
150 x_max = x_max.max(x);
151 y_min = y_min.min(y);
152 y_max = y_max.max(y);
153 z_min = z_min.min(z);
154 z_max = z_max.max(z);
155 }
156
157 let margin = 0.01 * ((x_max - x_min) + (y_max - y_min) + (z_max - z_min));
159 x_min -= margin;
160 x_max += margin;
161 y_min -= margin;
162 y_max += margin;
163 z_min -= margin;
164 z_max += margin;
165
166 let point_indices: Vec<usize> = (0..n_samples).collect();
168
169 let mut root = OctTreeNode {
171 x_min,
172 x_max,
173 y_min,
174 y_max,
175 z_min,
176 z_max,
177 center_of_mass: None,
178 total_mass: 0.0,
179 point_indices,
180 children: None,
181 is_leaf: true,
182 };
183
184 root.build_tree(embedding)?;
186
187 Ok(SpatialTree::OctTree(root))
188 }
189
190 #[allow(clippy::too_many_arguments)]
192 fn compute_forces(
193 &self,
194 point: &Array1<f64>,
195 point_idx: usize,
196 angle: f64,
197 degrees_of_freedom: f64,
198 ) -> Result<(Array1<f64>, f64)> {
199 match self {
200 SpatialTree::QuadTree(root) => {
201 root.compute_forces_quad(point, point_idx, angle, degrees_of_freedom)
202 }
203 SpatialTree::OctTree(root) => {
204 root.compute_forces_oct(point, point_idx, angle, degrees_of_freedom)
205 }
206 }
207 }
208}
209
210impl QuadTreeNode {
211 fn build_tree(&mut self, embedding: &Array2<f64>) -> Result<()> {
213 if self.point_indices.len() <= 1 {
214 self.update_center_of_mass(embedding)?;
216 return Ok(());
217 }
218
219 let x_mid = (self.x_min + self.x_max) / 2.0;
221 let y_mid = (self.y_min + self.y_max) / 2.0;
222
223 let mut quadrants: [Vec<usize>; 4] = [Vec::new(), Vec::new(), Vec::new(), Vec::new()];
224
225 for &idx in &self.point_indices {
227 let x = embedding[[idx, 0]];
228 let y = embedding[[idx, 1]];
229
230 let quadrant = match (x >= x_mid, y >= y_mid) {
231 (false, false) => 0, (true, false) => 1, (false, true) => 2, (true, true) => 3, };
236
237 quadrants[quadrant].push(idx);
238 }
239
240 let mut children = [
242 Box::new(QuadTreeNode {
243 x_min: self.x_min,
244 x_max: x_mid,
245 y_min: self.y_min,
246 y_max: y_mid,
247 center_of_mass: None,
248 total_mass: 0.0,
249 point_indices: quadrants[0].clone(),
250 children: None,
251 is_leaf: true,
252 }),
253 Box::new(QuadTreeNode {
254 x_min: x_mid,
255 x_max: self.x_max,
256 y_min: self.y_min,
257 y_max: y_mid,
258 center_of_mass: None,
259 total_mass: 0.0,
260 point_indices: quadrants[1].clone(),
261 children: None,
262 is_leaf: true,
263 }),
264 Box::new(QuadTreeNode {
265 x_min: self.x_min,
266 x_max: x_mid,
267 y_min: y_mid,
268 y_max: self.y_max,
269 center_of_mass: None,
270 total_mass: 0.0,
271 point_indices: quadrants[2].clone(),
272 children: None,
273 is_leaf: true,
274 }),
275 Box::new(QuadTreeNode {
276 x_min: x_mid,
277 x_max: self.x_max,
278 y_min: y_mid,
279 y_max: self.y_max,
280 center_of_mass: None,
281 total_mass: 0.0,
282 point_indices: quadrants[3].clone(),
283 children: None,
284 is_leaf: true,
285 }),
286 ];
287
288 for child in &mut children {
290 child.build_tree(embedding)?;
291 }
292
293 self.children = Some(children);
294 self.is_leaf = false;
295 self.point_indices.clear(); self.update_center_of_mass(embedding)?;
297
298 Ok(())
299 }
300
301 fn update_center_of_mass(&mut self, embedding: &Array2<f64>) -> Result<()> {
303 if self.is_leaf {
304 if self.point_indices.is_empty() {
306 self.total_mass = 0.0;
307 self.center_of_mass = None;
308 return Ok(());
309 }
310
311 let mut com = Array1::zeros(2);
312 for &idx in &self.point_indices {
313 com[0] += embedding[[idx, 0]];
314 com[1] += embedding[[idx, 1]];
315 }
316
317 self.total_mass = self.point_indices.len() as f64;
318 com.mapv_inplace(|x| x / self.total_mass);
319 self.center_of_mass = Some(com);
320 } else {
321 if let Some(ref children) = self.children {
323 let mut com = Array1::zeros(2);
324 let mut total_mass = 0.0;
325
326 for child in children.iter() {
327 if let Some(ref child_com) = child.center_of_mass {
328 total_mass += child.total_mass;
329 for i in 0..2 {
330 com[i] += child_com[i] * child.total_mass;
331 }
332 }
333 }
334
335 if total_mass > 0.0 {
336 com.mapv_inplace(|x| x / total_mass);
337 self.center_of_mass = Some(com);
338 self.total_mass = total_mass;
339 } else {
340 self.center_of_mass = None;
341 self.total_mass = 0.0;
342 }
343 }
344 }
345
346 Ok(())
347 }
348
349 #[allow(clippy::too_many_arguments)]
351 fn compute_forces_quad(
352 &self,
353 point: &Array1<f64>,
354 point_idx: usize,
355 angle: f64,
356 degrees_of_freedom: f64,
357 ) -> Result<(Array1<f64>, f64)> {
358 let mut force = Array1::zeros(2);
359 let mut sum_q = 0.0;
360
361 self.compute_forces_recursive_quad(
362 point,
363 point_idx,
364 angle,
365 degrees_of_freedom,
366 &mut force,
367 &mut sum_q,
368 )?;
369
370 Ok((force, sum_q))
371 }
372
373 #[allow(clippy::too_many_arguments)]
375 fn compute_forces_recursive_quad(
376 &self,
377 point: &Array1<f64>,
378 point_idx: usize,
379 angle: f64,
380 degrees_of_freedom: f64,
381 force: &mut Array1<f64>,
382 sum_q: &mut f64,
383 ) -> Result<()> {
384 if let Some(ref com) = self.center_of_mass {
385 if self.total_mass == 0.0 {
386 return Ok(());
387 }
388
389 let dx = point[0] - com[0];
391 let dy = point[1] - com[1];
392 let dist_squared = dx * dx + dy * dy;
393
394 if dist_squared < MACHINE_EPSILON {
395 return Ok(());
396 }
397
398 let node_size = (self.x_max - self.x_min).max(self.y_max - self.y_min);
400 let distance = dist_squared.sqrt();
401
402 if self.is_leaf || (node_size / distance) < angle {
403 let q_factor = (1.0 + dist_squared / degrees_of_freedom)
405 .powf(-(degrees_of_freedom + 1.0) / 2.0);
406
407 *sum_q += self.total_mass * q_factor;
408
409 let force_factor =
410 (degrees_of_freedom + 1.0) * self.total_mass * q_factor / degrees_of_freedom;
411 force[0] += force_factor * dx;
412 force[1] += force_factor * dy;
413 } else {
414 if let Some(ref children) = self.children {
416 for child in children.iter() {
417 child.compute_forces_recursive_quad(
418 point,
419 point_idx,
420 angle,
421 degrees_of_freedom,
422 force,
423 sum_q,
424 )?;
425 }
426 }
427 }
428 } else if self.is_leaf {
429 for &_idx in &self.point_indices {
431 if _idx != point_idx {
432 }
435 }
436 }
437
438 Ok(())
439 }
440}
441
442impl OctTreeNode {
443 fn build_tree(&mut self, embedding: &Array2<f64>) -> Result<()> {
445 if self.point_indices.len() <= 1 {
446 self.update_center_of_mass(embedding)?;
448 return Ok(());
449 }
450
451 let x_mid = (self.x_min + self.x_max) / 2.0;
453 let y_mid = (self.y_min + self.y_max) / 2.0;
454 let z_mid = (self.z_min + self.z_max) / 2.0;
455
456 let mut octants: [Vec<usize>; 8] = [
457 Vec::new(),
458 Vec::new(),
459 Vec::new(),
460 Vec::new(),
461 Vec::new(),
462 Vec::new(),
463 Vec::new(),
464 Vec::new(),
465 ];
466
467 for &idx in &self.point_indices {
469 let x = embedding[[idx, 0]];
470 let y = embedding[[idx, 1]];
471 let z = embedding[[idx, 2]];
472
473 let octant = match (x >= x_mid, y >= y_mid, z >= z_mid) {
474 (false, false, false) => 0,
475 (true, false, false) => 1,
476 (false, true, false) => 2,
477 (true, true, false) => 3,
478 (false, false, true) => 4,
479 (true, false, true) => 5,
480 (false, true, true) => 6,
481 (true, true, true) => 7,
482 };
483
484 octants[octant].push(idx);
485 }
486
487 let mut children = [
489 Box::new(OctTreeNode {
490 x_min: self.x_min,
491 x_max: x_mid,
492 y_min: self.y_min,
493 y_max: y_mid,
494 z_min: self.z_min,
495 z_max: z_mid,
496 center_of_mass: None,
497 total_mass: 0.0,
498 point_indices: octants[0].clone(),
499 children: None,
500 is_leaf: true,
501 }),
502 Box::new(OctTreeNode {
503 x_min: x_mid,
504 x_max: self.x_max,
505 y_min: self.y_min,
506 y_max: y_mid,
507 z_min: self.z_min,
508 z_max: z_mid,
509 center_of_mass: None,
510 total_mass: 0.0,
511 point_indices: octants[1].clone(),
512 children: None,
513 is_leaf: true,
514 }),
515 Box::new(OctTreeNode {
516 x_min: self.x_min,
517 x_max: x_mid,
518 y_min: y_mid,
519 y_max: self.y_max,
520 z_min: self.z_min,
521 z_max: z_mid,
522 center_of_mass: None,
523 total_mass: 0.0,
524 point_indices: octants[2].clone(),
525 children: None,
526 is_leaf: true,
527 }),
528 Box::new(OctTreeNode {
529 x_min: x_mid,
530 x_max: self.x_max,
531 y_min: y_mid,
532 y_max: self.y_max,
533 z_min: self.z_min,
534 z_max: z_mid,
535 center_of_mass: None,
536 total_mass: 0.0,
537 point_indices: octants[3].clone(),
538 children: None,
539 is_leaf: true,
540 }),
541 Box::new(OctTreeNode {
542 x_min: self.x_min,
543 x_max: x_mid,
544 y_min: self.y_min,
545 y_max: y_mid,
546 z_min: z_mid,
547 z_max: self.z_max,
548 center_of_mass: None,
549 total_mass: 0.0,
550 point_indices: octants[4].clone(),
551 children: None,
552 is_leaf: true,
553 }),
554 Box::new(OctTreeNode {
555 x_min: x_mid,
556 x_max: self.x_max,
557 y_min: self.y_min,
558 y_max: y_mid,
559 z_min: z_mid,
560 z_max: self.z_max,
561 center_of_mass: None,
562 total_mass: 0.0,
563 point_indices: octants[5].clone(),
564 children: None,
565 is_leaf: true,
566 }),
567 Box::new(OctTreeNode {
568 x_min: self.x_min,
569 x_max: x_mid,
570 y_min: y_mid,
571 y_max: self.y_max,
572 z_min: z_mid,
573 z_max: self.z_max,
574 center_of_mass: None,
575 total_mass: 0.0,
576 point_indices: octants[6].clone(),
577 children: None,
578 is_leaf: true,
579 }),
580 Box::new(OctTreeNode {
581 x_min: x_mid,
582 x_max: self.x_max,
583 y_min: y_mid,
584 y_max: self.y_max,
585 z_min: z_mid,
586 z_max: self.z_max,
587 center_of_mass: None,
588 total_mass: 0.0,
589 point_indices: octants[7].clone(),
590 children: None,
591 is_leaf: true,
592 }),
593 ];
594
595 for child in &mut children {
597 child.build_tree(embedding)?;
598 }
599
600 self.children = Some(children);
601 self.is_leaf = false;
602 self.point_indices.clear();
603 self.update_center_of_mass(embedding)?;
604
605 Ok(())
606 }
607
608 fn update_center_of_mass(&mut self, embedding: &Array2<f64>) -> Result<()> {
610 if self.is_leaf {
611 if self.point_indices.is_empty() {
612 self.total_mass = 0.0;
613 self.center_of_mass = None;
614 return Ok(());
615 }
616
617 let mut com = Array1::zeros(3);
618 for &idx in &self.point_indices {
619 com[0] += embedding[[idx, 0]];
620 com[1] += embedding[[idx, 1]];
621 com[2] += embedding[[idx, 2]];
622 }
623
624 self.total_mass = self.point_indices.len() as f64;
625 com.mapv_inplace(|x| x / self.total_mass);
626 self.center_of_mass = Some(com);
627 } else if let Some(ref children) = self.children {
628 let mut com = Array1::zeros(3);
629 let mut total_mass = 0.0;
630
631 for child in children.iter() {
632 if let Some(ref child_com) = child.center_of_mass {
633 total_mass += child.total_mass;
634 for i in 0..3 {
635 com[i] += child_com[i] * child.total_mass;
636 }
637 }
638 }
639
640 if total_mass > 0.0 {
641 com.mapv_inplace(|x| x / total_mass);
642 self.center_of_mass = Some(com);
643 self.total_mass = total_mass;
644 } else {
645 self.center_of_mass = None;
646 self.total_mass = 0.0;
647 }
648 }
649
650 Ok(())
651 }
652
653 #[allow(clippy::too_many_arguments)]
655 fn compute_forces_oct(
656 &self,
657 point: &Array1<f64>,
658 point_idx: usize,
659 angle: f64,
660 degrees_of_freedom: f64,
661 ) -> Result<(Array1<f64>, f64)> {
662 let mut force = Array1::zeros(3);
663 let mut sum_q = 0.0;
664
665 self.compute_forces_recursive_oct(
666 point,
667 point_idx,
668 angle,
669 degrees_of_freedom,
670 &mut force,
671 &mut sum_q,
672 )?;
673
674 Ok((force, sum_q))
675 }
676
677 #[allow(clippy::too_many_arguments)]
679 fn compute_forces_recursive_oct(
680 &self,
681 point: &Array1<f64>,
682 _point_idx: usize,
683 angle: f64,
684 degrees_of_freedom: f64,
685 force: &mut Array1<f64>,
686 sum_q: &mut f64,
687 ) -> Result<()> {
688 if let Some(ref com) = self.center_of_mass {
689 if self.total_mass == 0.0 {
690 return Ok(());
691 }
692
693 let dx = point[0] - com[0];
694 let dy = point[1] - com[1];
695 let dz = point[2] - com[2];
696 let dist_squared = dx * dx + dy * dy + dz * dz;
697
698 if dist_squared < MACHINE_EPSILON {
699 return Ok(());
700 }
701
702 let node_size = (self.x_max - self.x_min)
703 .max(self.y_max - self.y_min)
704 .max(self.z_max - self.z_min);
705 let distance = dist_squared.sqrt();
706
707 if self.is_leaf || (node_size / distance) < angle {
708 let q_factor = (1.0 + dist_squared / degrees_of_freedom)
709 .powf(-(degrees_of_freedom + 1.0) / 2.0);
710
711 *sum_q += self.total_mass * q_factor;
712
713 let force_factor =
714 (degrees_of_freedom + 1.0) * self.total_mass * q_factor / degrees_of_freedom;
715 force[0] += force_factor * dx;
716 force[1] += force_factor * dy;
717 force[2] += force_factor * dz;
718 } else if let Some(ref children) = self.children {
719 for child in children.iter() {
720 child.compute_forces_recursive_oct(
721 point,
722 _point_idx,
723 angle,
724 degrees_of_freedom,
725 force,
726 sum_q,
727 )?;
728 }
729 }
730 }
731
732 Ok(())
733 }
734}
735
736pub struct TSNE {
745 n_components: usize,
747 perplexity: f64,
749 early_exaggeration: f64,
751 learning_rate: f64,
753 max_iter: usize,
755 n_iter_without_progress: usize,
757 min_grad_norm: f64,
759 metric: String,
761 method: String,
763 init: String,
765 angle: f64,
767 n_jobs: i32,
769 verbose: bool,
771 random_state: Option<u64>,
773 embedding_: Option<Array2<f64>>,
775 kl_divergence_: Option<f64>,
777 n_iter_: Option<usize>,
779 learning_rate_: Option<f64>,
781}
782
783impl Default for TSNE {
784 fn default() -> Self {
785 Self::new()
786 }
787}
788
789impl TSNE {
790 pub fn new() -> Self {
792 TSNE {
793 n_components: 2,
794 perplexity: 30.0,
795 early_exaggeration: 12.0,
796 learning_rate: 200.0,
797 max_iter: 1000,
798 n_iter_without_progress: 300,
799 min_grad_norm: 1e-7,
800 metric: "euclidean".to_string(),
801 method: "barnes_hut".to_string(),
802 init: "pca".to_string(),
803 angle: 0.5,
804 n_jobs: -1, verbose: false,
806 random_state: None,
807 embedding_: None,
808 kl_divergence_: None,
809 n_iter_: None,
810 learning_rate_: None,
811 }
812 }
813
814 pub fn with_n_components(mut self, ncomponents: usize) -> Self {
816 self.n_components = ncomponents;
817 self
818 }
819
820 pub fn with_perplexity(mut self, perplexity: f64) -> Self {
822 self.perplexity = perplexity;
823 self
824 }
825
826 pub fn with_early_exaggeration(mut self, earlyexaggeration: f64) -> Self {
828 self.early_exaggeration = earlyexaggeration;
829 self
830 }
831
832 pub fn with_learning_rate(mut self, learningrate: f64) -> Self {
834 self.learning_rate = learningrate;
835 self
836 }
837
838 pub fn with_max_iter(mut self, maxiter: usize) -> Self {
840 self.max_iter = maxiter;
841 self
842 }
843
844 pub fn with_n_iter_without_progress(mut self, n_iter_withoutprogress: usize) -> Self {
846 self.n_iter_without_progress = n_iter_withoutprogress;
847 self
848 }
849
850 pub fn with_min_grad_norm(mut self, min_gradnorm: f64) -> Self {
852 self.min_grad_norm = min_gradnorm;
853 self
854 }
855
856 pub fn with_metric(mut self, metric: &str) -> Self {
864 self.metric = metric.to_string();
865 self
866 }
867
868 pub fn with_method(mut self, method: &str) -> Self {
870 self.method = method.to_string();
871 self
872 }
873
874 pub fn with_init(mut self, init: &str) -> Self {
876 self.init = init.to_string();
877 self
878 }
879
880 pub fn with_angle(mut self, angle: f64) -> Self {
882 self.angle = angle;
883 self
884 }
885
886 pub fn with_n_jobs(mut self, njobs: i32) -> Self {
891 self.n_jobs = njobs;
892 self
893 }
894
895 pub fn with_verbose(mut self, verbose: bool) -> Self {
897 self.verbose = verbose;
898 self
899 }
900
901 pub fn with_random_state(mut self, randomstate: u64) -> Self {
903 self.random_state = Some(randomstate);
904 self
905 }
906
907 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
915 where
916 S: Data,
917 S::Elem: Float + NumCast,
918 {
919 let x_f64 = x.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
920
921 let n_samples = x_f64.shape()[0];
922 let n_features = x_f64.shape()[1];
923
924 if n_samples == 0 || n_features == 0 {
926 return Err(TransformError::InvalidInput("Empty input data".to_string()));
927 }
928
929 if self.perplexity >= n_samples as f64 {
930 return Err(TransformError::InvalidInput(format!(
931 "perplexity ({}) must be less than n_samples ({})",
932 self.perplexity, n_samples
933 )));
934 }
935
936 if self.method == "barnes_hut" && self.n_components > 3 {
937 return Err(TransformError::InvalidInput(
938 "'n_components' should be less than or equal to 3 for barnes_hut algorithm"
939 .to_string(),
940 ));
941 }
942
943 self.learning_rate_ = Some(self.learning_rate);
945
946 let x_embedded = self.initialize_embedding(&x_f64)?;
948
949 let p = self.compute_pairwise_affinities(&x_f64)?;
951
952 let (embedding, kl_divergence, n_iter) =
954 self.tsne_optimization(p, x_embedded, n_samples)?;
955
956 self.embedding_ = Some(embedding.clone());
957 self.kl_divergence_ = Some(kl_divergence);
958 self.n_iter_ = Some(n_iter);
959
960 Ok(embedding)
961 }
962
963 fn initialize_embedding(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
965 let n_samples = x.shape()[0];
966
967 if self.init == "pca" {
968 let n_components = self.n_components.min(x.shape()[1]);
969 let mut pca = PCA::new(n_components, true, false);
970 let mut x_embedded = pca.fit_transform(x)?;
971
972 let std_dev = (x_embedded.column(0).map(|&x| x * x).sum() / (n_samples as f64)).sqrt();
974 if std_dev > 0.0 {
975 x_embedded.mapv_inplace(|x| x / std_dev * 1e-4);
976 }
977
978 Ok(x_embedded)
979 } else if self.init == "random" {
980 let normal = Normal::new(0.0, 1e-4).unwrap();
983
984 Ok(Array2::random((n_samples, self.n_components), normal))
986 } else {
987 Err(TransformError::InvalidInput(format!(
988 "Initialization method '{}' not recognized",
989 self.init
990 )))
991 }
992 }
993
994 fn compute_pairwise_affinities(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
996 let _n_samples = x.shape()[0];
997
998 let distances = self.compute_pairwise_distances(x)?;
1000
1001 let p = self.distances_to_affinities(&distances)?;
1003
1004 let mut p_symmetric = &p + &p.t();
1006
1007 let p_sum = p_symmetric.sum();
1009 if p_sum > 0.0 {
1010 p_symmetric.mapv_inplace(|x| x.max(MACHINE_EPSILON) / p_sum);
1011 }
1012
1013 Ok(p_symmetric)
1014 }
1015
1016 fn compute_pairwise_distances(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
1018 let n_samples = x.shape()[0];
1019 let mut distances = Array2::zeros((n_samples, n_samples));
1020
1021 match self.metric.as_str() {
1022 "euclidean" => {
1023 if self.n_jobs == 1 {
1024 for i in 0..n_samples {
1026 for j in i + 1..n_samples {
1027 let mut dist_squared = 0.0;
1028 for k in 0..x.shape()[1] {
1029 let diff = x[[i, k]] - x[[j, k]];
1030 dist_squared += diff * diff;
1031 }
1032 distances[[i, j]] = dist_squared;
1033 distances[[j, i]] = dist_squared;
1034 }
1035 }
1036 } else {
1037 let upper_triangle_indices: Vec<(usize, usize)> = (0..n_samples)
1039 .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
1040 .collect();
1041
1042 let n_features = x.shape()[1];
1043 let squared_distances: Vec<f64> = upper_triangle_indices
1044 .par_iter()
1045 .map(|&(i, j)| {
1046 let mut dist_squared = 0.0;
1047 for k in 0..n_features {
1048 let diff = x[[i, k]] - x[[j, k]];
1049 dist_squared += diff * diff;
1050 }
1051 dist_squared
1052 })
1053 .collect();
1054
1055 for (idx, &(i, j)) in upper_triangle_indices.iter().enumerate() {
1057 distances[[i, j]] = squared_distances[idx];
1058 distances[[j, i]] = squared_distances[idx];
1059 }
1060 }
1061 }
1062 "manhattan" => {
1063 if self.n_jobs == 1 {
1064 for i in 0..n_samples {
1066 for j in i + 1..n_samples {
1067 let mut dist = 0.0;
1068 for k in 0..x.shape()[1] {
1069 dist += (x[[i, k]] - x[[j, k]]).abs();
1070 }
1071 distances[[i, j]] = dist;
1072 distances[[j, i]] = dist;
1073 }
1074 }
1075 } else {
1076 let upper_triangle_indices: Vec<(usize, usize)> = (0..n_samples)
1078 .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
1079 .collect();
1080
1081 let n_features = x.shape()[1];
1082 let manhattan_distances: Vec<f64> = upper_triangle_indices
1083 .par_iter()
1084 .map(|&(i, j)| {
1085 let mut dist = 0.0;
1086 for k in 0..n_features {
1087 dist += (x[[i, k]] - x[[j, k]]).abs();
1088 }
1089 dist
1090 })
1091 .collect();
1092
1093 for (idx, &(i, j)) in upper_triangle_indices.iter().enumerate() {
1095 distances[[i, j]] = manhattan_distances[idx];
1096 distances[[j, i]] = manhattan_distances[idx];
1097 }
1098 }
1099 }
1100 "cosine" => {
1101 let mut normalized_x = Array2::zeros((n_samples, x.shape()[1]));
1103 for i in 0..n_samples {
1104 let row = x.row(i);
1105 let norm = row.iter().map(|v| v * v).sum::<f64>().sqrt();
1106 if norm > EPSILON {
1107 for j in 0..x.shape()[1] {
1108 normalized_x[[i, j]] = x[[i, j]] / norm;
1109 }
1110 } else {
1111 for j in 0..x.shape()[1] {
1113 normalized_x[[i, j]] = 0.0;
1114 }
1115 }
1116 }
1117
1118 if self.n_jobs == 1 {
1119 for i in 0..n_samples {
1121 for j in i + 1..n_samples {
1122 let mut dot_product = 0.0;
1123 for k in 0..x.shape()[1] {
1124 dot_product += normalized_x[[i, k]] * normalized_x[[j, k]];
1125 }
1126 let cosine_dist = 1.0 - dot_product.clamp(-1.0, 1.0);
1128 distances[[i, j]] = cosine_dist;
1129 distances[[j, i]] = cosine_dist;
1130 }
1131 }
1132 } else {
1133 let upper_triangle_indices: Vec<(usize, usize)> = (0..n_samples)
1135 .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
1136 .collect();
1137
1138 let n_features = x.shape()[1];
1139 let cosine_distances: Vec<f64> = upper_triangle_indices
1140 .par_iter()
1141 .map(|&(i, j)| {
1142 let mut dot_product = 0.0;
1143 for k in 0..n_features {
1144 dot_product += normalized_x[[i, k]] * normalized_x[[j, k]];
1145 }
1146 1.0 - dot_product.clamp(-1.0, 1.0)
1148 })
1149 .collect();
1150
1151 for (idx, &(i, j)) in upper_triangle_indices.iter().enumerate() {
1153 distances[[i, j]] = cosine_distances[idx];
1154 distances[[j, i]] = cosine_distances[idx];
1155 }
1156 }
1157 }
1158 "chebyshev" => {
1159 if self.n_jobs == 1 {
1160 for i in 0..n_samples {
1162 for j in i + 1..n_samples {
1163 let mut max_dist = 0.0;
1164 for k in 0..x.shape()[1] {
1165 let diff = (x[[i, k]] - x[[j, k]]).abs();
1166 max_dist = max_dist.max(diff);
1167 }
1168 distances[[i, j]] = max_dist;
1169 distances[[j, i]] = max_dist;
1170 }
1171 }
1172 } else {
1173 let upper_triangle_indices: Vec<(usize, usize)> = (0..n_samples)
1175 .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
1176 .collect();
1177
1178 let n_features = x.shape()[1];
1179 let chebyshev_distances: Vec<f64> = upper_triangle_indices
1180 .par_iter()
1181 .map(|&(i, j)| {
1182 let mut max_dist = 0.0;
1183 for k in 0..n_features {
1184 let diff = (x[[i, k]] - x[[j, k]]).abs();
1185 max_dist = max_dist.max(diff);
1186 }
1187 max_dist
1188 })
1189 .collect();
1190
1191 for (idx, &(i, j)) in upper_triangle_indices.iter().enumerate() {
1193 distances[[i, j]] = chebyshev_distances[idx];
1194 distances[[j, i]] = chebyshev_distances[idx];
1195 }
1196 }
1197 }
1198 _ => {
1199 return Err(TransformError::InvalidInput(format!(
1200 "Metric '{}' not implemented. Supported metrics are: 'euclidean', 'manhattan', 'cosine', 'chebyshev'",
1201 self.metric
1202 )));
1203 }
1204 }
1205
1206 Ok(distances)
1207 }
1208
1209 fn distances_to_affinities(&self, distances: &Array2<f64>) -> Result<Array2<f64>> {
1211 let n_samples = distances.shape()[0];
1212 let mut p = Array2::zeros((n_samples, n_samples));
1213 let target = (2.0f64).ln() * self.perplexity;
1214
1215 if self.n_jobs == 1 {
1216 for i in 0..n_samples {
1218 let mut beta_min = -f64::INFINITY;
1219 let mut beta_max = f64::INFINITY;
1220 let mut beta = 1.0;
1221
1222 let distances_i = distances.row(i).to_owned();
1224
1225 for _ in 0..50 {
1227 let mut sum_pi = 0.0;
1230 let mut h = 0.0;
1231
1232 for j in 0..n_samples {
1233 if i == j {
1234 p[[i, j]] = 0.0;
1235 continue;
1236 }
1237
1238 let p_ij = (-beta * distances_i[j]).exp();
1239 p[[i, j]] = p_ij;
1240 sum_pi += p_ij;
1241 }
1242
1243 if sum_pi > 0.0 {
1245 for j in 0..n_samples {
1246 if i == j {
1247 continue;
1248 }
1249
1250 p[[i, j]] /= sum_pi;
1251
1252 if p[[i, j]] > MACHINE_EPSILON {
1254 h -= p[[i, j]] * p[[i, j]].ln();
1255 }
1256 }
1257 }
1258
1259 let h_diff = h - target;
1261
1262 if h_diff.abs() < EPSILON {
1263 break; }
1265
1266 if h_diff > 0.0 {
1268 beta_min = beta;
1269 if beta_max == f64::INFINITY {
1270 beta *= 2.0;
1271 } else {
1272 beta = (beta + beta_max) / 2.0;
1273 }
1274 } else {
1275 beta_max = beta;
1276 if beta_min == -f64::INFINITY {
1277 beta /= 2.0;
1278 } else {
1279 beta = (beta + beta_min) / 2.0;
1280 }
1281 }
1282 }
1283 }
1284 } else {
1285 let prob_rows: Vec<Vec<f64>> = (0..n_samples)
1287 .into_par_iter()
1288 .map(|i| {
1289 let mut beta_min = -f64::INFINITY;
1290 let mut beta_max = f64::INFINITY;
1291 let mut beta = 1.0;
1292
1293 let distances_i: Vec<f64> = (0..n_samples).map(|j| distances[[i, j]]).collect();
1295 let mut p_row = vec![0.0; n_samples];
1296
1297 for _ in 0..50 {
1299 let mut sum_pi = 0.0;
1302 let mut h = 0.0;
1303
1304 for j in 0..n_samples {
1305 if i == j {
1306 p_row[j] = 0.0;
1307 continue;
1308 }
1309
1310 let p_ij = (-beta * distances_i[j]).exp();
1311 p_row[j] = p_ij;
1312 sum_pi += p_ij;
1313 }
1314
1315 if sum_pi > 0.0 {
1317 for (j, prob) in p_row.iter_mut().enumerate().take(n_samples) {
1318 if i == j {
1319 continue;
1320 }
1321
1322 *prob /= sum_pi;
1323
1324 if *prob > MACHINE_EPSILON {
1326 h -= *prob * prob.ln();
1327 }
1328 }
1329 }
1330
1331 let h_diff = h - target;
1333
1334 if h_diff.abs() < EPSILON {
1335 break; }
1337
1338 if h_diff > 0.0 {
1340 beta_min = beta;
1341 if beta_max == f64::INFINITY {
1342 beta *= 2.0;
1343 } else {
1344 beta = (beta + beta_max) / 2.0;
1345 }
1346 } else {
1347 beta_max = beta;
1348 if beta_min == -f64::INFINITY {
1349 beta /= 2.0;
1350 } else {
1351 beta = (beta + beta_min) / 2.0;
1352 }
1353 }
1354 }
1355
1356 p_row
1357 })
1358 .collect();
1359
1360 for (i, row) in prob_rows.iter().enumerate() {
1362 for (j, &val) in row.iter().enumerate() {
1363 p[[i, j]] = val;
1364 }
1365 }
1366 }
1367
1368 Ok(p)
1369 }
1370
1371 #[allow(clippy::too_many_arguments)]
1373 fn tsne_optimization(
1374 &self,
1375 p: Array2<f64>,
1376 initial_embedding: Array2<f64>,
1377 n_samples: usize,
1378 ) -> Result<(Array2<f64>, f64, usize)> {
1379 let n_components = self.n_components;
1380 let degrees_of_freedom = (n_components - 1).max(1) as f64;
1381
1382 let mut embedding = initial_embedding;
1384 let mut update = Array2::zeros((n_samples, n_components));
1385 let mut gains = Array2::ones((n_samples, n_components));
1386 let mut error = f64::INFINITY;
1387 let mut best_error = f64::INFINITY;
1388 let mut best_iter = 0;
1389 let mut iter = 0;
1390
1391 let exploration_n_iter = 250;
1393 let n_iter_check = 50;
1394
1395 let p_early = &p * self.early_exaggeration;
1397
1398 if self.verbose {
1399 println!("[t-SNE] Starting optimization with early exaggeration phase...");
1400 }
1401
1402 for i in 0..exploration_n_iter {
1404 let (curr_error, grad) = if self.method == "barnes_hut" {
1406 self.compute_gradient_barnes_hut(&embedding, &p_early, degrees_of_freedom)?
1407 } else {
1408 self.compute_gradient_exact(&embedding, &p_early, degrees_of_freedom)?
1409 };
1410
1411 self.gradient_update(
1413 &mut embedding,
1414 &mut update,
1415 &mut gains,
1416 &grad,
1417 0.5,
1418 self.learning_rate_,
1419 )?;
1420
1421 if (i + 1) % n_iter_check == 0 {
1423 if self.verbose {
1424 println!("[t-SNE] Iteration {}: error = {:.7}", i + 1, curr_error);
1425 }
1426
1427 if curr_error < best_error {
1428 best_error = curr_error;
1429 best_iter = i;
1430 } else if i - best_iter > self.n_iter_without_progress {
1431 if self.verbose {
1432 println!("[t-SNE] Early convergence at iteration {}", i + 1);
1433 }
1434 break;
1435 }
1436
1437 let grad_norm = grad.mapv(|x| x * x).sum().sqrt();
1439 if grad_norm < self.min_grad_norm {
1440 if self.verbose {
1441 println!("[t-SNE] Gradient norm {} below threshold, stopping optimization at iteration {}",
1442 grad_norm, i + 1);
1443 }
1444 break;
1445 }
1446 }
1447
1448 iter = i;
1449 }
1450
1451 if self.verbose {
1452 println!("[t-SNE] Completed early exaggeration phase, starting final optimization...");
1453 }
1454
1455 for i in iter + 1..self.max_iter {
1457 let (curr_error, grad) = if self.method == "barnes_hut" {
1459 self.compute_gradient_barnes_hut(&embedding, &p, degrees_of_freedom)?
1460 } else {
1461 self.compute_gradient_exact(&embedding, &p, degrees_of_freedom)?
1462 };
1463 error = curr_error;
1464
1465 self.gradient_update(
1467 &mut embedding,
1468 &mut update,
1469 &mut gains,
1470 &grad,
1471 0.8,
1472 self.learning_rate_,
1473 )?;
1474
1475 if (i + 1) % n_iter_check == 0 {
1477 if self.verbose {
1478 println!("[t-SNE] Iteration {}: error = {:.7}", i + 1, curr_error);
1479 }
1480
1481 if curr_error < best_error {
1482 best_error = curr_error;
1483 best_iter = i;
1484 } else if i - best_iter > self.n_iter_without_progress {
1485 if self.verbose {
1486 println!("[t-SNE] Stopping optimization at iteration {}", i + 1);
1487 }
1488 break;
1489 }
1490
1491 let grad_norm = grad.mapv(|x| x * x).sum().sqrt();
1493 if grad_norm < self.min_grad_norm {
1494 if self.verbose {
1495 println!("[t-SNE] Gradient norm {} below threshold, stopping optimization at iteration {}",
1496 grad_norm, i + 1);
1497 }
1498 break;
1499 }
1500 }
1501
1502 iter = i;
1503 }
1504
1505 if self.verbose {
1506 println!(
1507 "[t-SNE] Optimization finished after {} iterations with error {:.7}",
1508 iter + 1,
1509 error
1510 );
1511 }
1512
1513 Ok((embedding, error, iter + 1))
1514 }
1515
1516 #[allow(clippy::too_many_arguments)]
1518 fn compute_gradient_exact(
1519 &self,
1520 embedding: &Array2<f64>,
1521 p: &Array2<f64>,
1522 degrees_of_freedom: f64,
1523 ) -> Result<(f64, Array2<f64>)> {
1524 let n_samples = embedding.shape()[0];
1525 let n_components = embedding.shape()[1];
1526
1527 if self.n_jobs == 1 {
1528 let mut dist = Array2::zeros((n_samples, n_samples));
1530 for i in 0..n_samples {
1531 for j in i + 1..n_samples {
1532 let mut d_squared = 0.0;
1533 for k in 0..n_components {
1534 let diff = embedding[[i, k]] - embedding[[j, k]];
1535 d_squared += diff * diff;
1536 }
1537
1538 let q_ij = (1.0 + d_squared / degrees_of_freedom)
1540 .powf(-(degrees_of_freedom + 1.0) / 2.0);
1541 dist[[i, j]] = q_ij;
1542 dist[[j, i]] = q_ij;
1543 }
1544 }
1545
1546 for i in 0..n_samples {
1548 dist[[i, i]] = 0.0;
1549 }
1550
1551 let sum_q = dist.sum().max(MACHINE_EPSILON);
1553 let q = &dist / sum_q;
1554
1555 let mut kl_divergence = 0.0;
1557 for i in 0..n_samples {
1558 for j in 0..n_samples {
1559 if p[[i, j]] > MACHINE_EPSILON && q[[i, j]] > MACHINE_EPSILON {
1560 kl_divergence += p[[i, j]] * (p[[i, j]] / q[[i, j]]).ln();
1561 }
1562 }
1563 }
1564
1565 let mut grad = Array2::zeros((n_samples, n_components));
1567 let factor =
1568 4.0 * (degrees_of_freedom + 1.0) / (degrees_of_freedom * (sum_q.powf(2.0)));
1569
1570 for i in 0..n_samples {
1571 for j in 0..n_samples {
1572 if i != j {
1573 let p_q_diff = p[[i, j]] - q[[i, j]];
1574 for k in 0..n_components {
1575 grad[[i, k]] += factor
1576 * p_q_diff
1577 * dist[[i, j]]
1578 * (embedding[[i, k]] - embedding[[j, k]]);
1579 }
1580 }
1581 }
1582 }
1583
1584 Ok((kl_divergence, grad))
1585 } else {
1586 let upper_triangle_indices: Vec<(usize, usize)> = (0..n_samples)
1588 .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
1589 .collect();
1590
1591 let q_values: Vec<f64> = upper_triangle_indices
1592 .par_iter()
1593 .map(|&(i, j)| {
1594 let mut d_squared = 0.0;
1595 for k in 0..n_components {
1596 let diff = embedding[[i, k]] - embedding[[j, k]];
1597 d_squared += diff * diff;
1598 }
1599
1600 (1.0 + d_squared / degrees_of_freedom).powf(-(degrees_of_freedom + 1.0) / 2.0)
1602 })
1603 .collect();
1604
1605 let mut dist = Array2::zeros((n_samples, n_samples));
1607 for (idx, &(i, j)) in upper_triangle_indices.iter().enumerate() {
1608 let q_val = q_values[idx];
1609 dist[[i, j]] = q_val;
1610 dist[[j, i]] = q_val;
1611 }
1612
1613 for i in 0..n_samples {
1615 dist[[i, i]] = 0.0;
1616 }
1617
1618 let sum_q = dist.sum().max(MACHINE_EPSILON);
1620 let q = &dist / sum_q;
1621
1622 let kl_divergence: f64 = (0..n_samples)
1624 .into_par_iter()
1625 .map(|i| {
1626 let mut local_kl = 0.0;
1627 for j in 0..n_samples {
1628 if p[[i, j]] > MACHINE_EPSILON && q[[i, j]] > MACHINE_EPSILON {
1629 local_kl += p[[i, j]] * (p[[i, j]] / q[[i, j]]).ln();
1630 }
1631 }
1632 local_kl
1633 })
1634 .sum();
1635
1636 let factor =
1638 4.0 * (degrees_of_freedom + 1.0) / (degrees_of_freedom * (sum_q.powf(2.0)));
1639
1640 let grad_rows: Vec<Vec<f64>> = (0..n_samples)
1641 .into_par_iter()
1642 .map(|i| {
1643 let mut grad_row = vec![0.0; n_components];
1644 for j in 0..n_samples {
1645 if i != j {
1646 let p_q_diff = p[[i, j]] - q[[i, j]];
1647 for k in 0..n_components {
1648 grad_row[k] += factor
1649 * p_q_diff
1650 * dist[[i, j]]
1651 * (embedding[[i, k]] - embedding[[j, k]]);
1652 }
1653 }
1654 }
1655 grad_row
1656 })
1657 .collect();
1658
1659 let mut grad = Array2::zeros((n_samples, n_components));
1661 for (i, row) in grad_rows.iter().enumerate() {
1662 for (k, &val) in row.iter().enumerate() {
1663 grad[[i, k]] = val;
1664 }
1665 }
1666
1667 Ok((kl_divergence, grad))
1668 }
1669 }
1670
1671 #[allow(clippy::too_many_arguments)]
1673 fn compute_gradient_barnes_hut(
1674 &self,
1675 embedding: &Array2<f64>,
1676 p: &Array2<f64>,
1677 degrees_of_freedom: f64,
1678 ) -> Result<(f64, Array2<f64>)> {
1679 let n_samples = embedding.shape()[0];
1680 let n_components = embedding.shape()[1];
1681
1682 let tree = if n_components == 2 {
1684 SpatialTree::new_quadtree(embedding)?
1685 } else if n_components == 3 {
1686 SpatialTree::new_octree(embedding)?
1687 } else {
1688 return Err(TransformError::InvalidInput(
1689 "Barnes-Hut approximation only supports 2D and 3D embeddings".to_string(),
1690 ));
1691 };
1692
1693 let mut q = Array2::zeros((n_samples, n_samples));
1695 let mut grad = Array2::zeros((n_samples, n_components));
1696 let mut sum_q = 0.0;
1697
1698 for i in 0..n_samples {
1700 let point = embedding.row(i).to_owned();
1701 let (repulsive_force, q_sum) =
1702 tree.compute_forces(&point, i, self.angle, degrees_of_freedom)?;
1703
1704 sum_q += q_sum;
1705
1706 for j in 0..n_components {
1708 grad[[i, j]] += repulsive_force[j];
1709 }
1710
1711 for j in 0..n_samples {
1713 if i != j {
1714 let mut dist_squared = 0.0;
1715 for k in 0..n_components {
1716 let diff = embedding[[i, k]] - embedding[[j, k]];
1717 dist_squared += diff * diff;
1718 }
1719 let q_ij = (1.0 + dist_squared / degrees_of_freedom)
1720 .powf(-(degrees_of_freedom + 1.0) / 2.0);
1721 q[[i, j]] = q_ij;
1722 }
1723 }
1724 }
1725
1726 sum_q = sum_q.max(MACHINE_EPSILON);
1728 q.mapv_inplace(|x| x / sum_q);
1729
1730 for i in 0..n_samples {
1732 for j in 0..n_samples {
1733 if i != j && p[[i, j]] > MACHINE_EPSILON {
1734 let mut dist_squared = 0.0;
1735 for k in 0..n_components {
1736 let diff = embedding[[i, k]] - embedding[[j, k]];
1737 dist_squared += diff * diff;
1738 }
1739
1740 let q_ij = (1.0 + dist_squared / degrees_of_freedom)
1741 .powf(-(degrees_of_freedom + 1.0) / 2.0);
1742 let factor = 4.0 * p[[i, j]] * q_ij;
1743
1744 for k in 0..n_components {
1745 grad[[i, k]] -= factor * (embedding[[i, k]] - embedding[[j, k]]);
1746 }
1747 }
1748 }
1749 }
1750
1751 let mut kl_divergence = 0.0;
1753 for i in 0..n_samples {
1754 for j in 0..n_samples {
1755 if p[[i, j]] > MACHINE_EPSILON && q[[i, j]] > MACHINE_EPSILON {
1756 kl_divergence += p[[i, j]] * (p[[i, j]] / q[[i, j]]).ln();
1757 }
1758 }
1759 }
1760
1761 Ok((kl_divergence, grad))
1762 }
1763
1764 #[allow(clippy::too_many_arguments)]
1766 fn gradient_update(
1767 &self,
1768 embedding: &mut Array2<f64>,
1769 update: &mut Array2<f64>,
1770 gains: &mut Array2<f64>,
1771 grad: &Array2<f64>,
1772 momentum: f64,
1773 learning_rate: Option<f64>,
1774 ) -> Result<()> {
1775 let n_samples = embedding.shape()[0];
1776 let n_components = embedding.shape()[1];
1777 let eta = learning_rate.unwrap_or(self.learning_rate);
1778
1779 for i in 0..n_samples {
1781 for j in 0..n_components {
1782 let same_sign = update[[i, j]] * grad[[i, j]] > 0.0;
1783
1784 if same_sign {
1785 gains[[i, j]] *= 0.8;
1786 } else {
1787 gains[[i, j]] += 0.2;
1788 }
1789
1790 gains[[i, j]] = gains[[i, j]].max(0.01);
1792
1793 update[[i, j]] = momentum * update[[i, j]] - eta * gains[[i, j]] * grad[[i, j]];
1795 embedding[[i, j]] += update[[i, j]];
1796 }
1797 }
1798
1799 Ok(())
1800 }
1801
1802 pub fn embedding(&self) -> Option<&Array2<f64>> {
1804 self.embedding_.as_ref()
1805 }
1806
1807 pub fn kl_divergence(&self) -> Option<f64> {
1809 self.kl_divergence_
1810 }
1811
1812 pub fn n_iter(&self) -> Option<usize> {
1814 self.n_iter_
1815 }
1816}
1817
1818#[allow(dead_code)]
1832#[allow(clippy::too_many_arguments)]
1833pub fn trustworthiness<S1, S2>(
1834 x: &ArrayBase<S1, Ix2>,
1835 x_embedded: &ArrayBase<S2, Ix2>,
1836 n_neighbors: usize,
1837 metric: &str,
1838) -> Result<f64>
1839where
1840 S1: Data,
1841 S2: Data,
1842 S1::Elem: Float + NumCast,
1843 S2::Elem: Float + NumCast,
1844{
1845 let x_f64 = x.mapv(|x| num_traits::cast::<S1::Elem, f64>(x).unwrap_or(0.0));
1846 let x_embedded_f64 = x_embedded.mapv(|x| num_traits::cast::<S2::Elem, f64>(x).unwrap_or(0.0));
1847
1848 let n_samples = x_f64.shape()[0];
1849
1850 if n_neighbors >= n_samples / 2 {
1851 return Err(TransformError::InvalidInput(format!(
1852 "n_neighbors ({}) should be less than n_samples / 2 ({})",
1853 n_neighbors,
1854 n_samples / 2
1855 )));
1856 }
1857
1858 if metric != "euclidean" {
1859 return Err(TransformError::InvalidInput(format!(
1860 "Metric '{metric}' not implemented. Currently only 'euclidean' is supported"
1861 )));
1862 }
1863
1864 let mut dist_x = Array2::zeros((n_samples, n_samples));
1866 for i in 0..n_samples {
1867 for j in 0..n_samples {
1868 if i == j {
1869 dist_x[[i, j]] = f64::INFINITY; continue;
1871 }
1872
1873 let mut d_squared = 0.0;
1874 for k in 0..x_f64.shape()[1] {
1875 let diff = x_f64[[i, k]] - x_f64[[j, k]];
1876 d_squared += diff * diff;
1877 }
1878 dist_x[[i, j]] = d_squared.sqrt();
1879 }
1880 }
1881
1882 let mut dist_embedded = Array2::zeros((n_samples, n_samples));
1884 for i in 0..n_samples {
1885 for j in 0..n_samples {
1886 if i == j {
1887 dist_embedded[[i, j]] = f64::INFINITY; continue;
1889 }
1890
1891 let mut d_squared = 0.0;
1892 for k in 0..x_embedded_f64.shape()[1] {
1893 let diff = x_embedded_f64[[i, k]] - x_embedded_f64[[j, k]];
1894 d_squared += diff * diff;
1895 }
1896 dist_embedded[[i, j]] = d_squared.sqrt();
1897 }
1898 }
1899
1900 let mut nn_orig = Array2::<usize>::zeros((n_samples, n_neighbors));
1902 for i in 0..n_samples {
1903 let row = dist_x.row(i).to_owned();
1905 let mut pairs: Vec<(usize, f64)> = row.iter().enumerate().map(|(j, &d)| (j, d)).collect();
1906 pairs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1907
1908 for (j, &(idx_, _)) in pairs.iter().enumerate().take(n_neighbors) {
1910 nn_orig[[i, j]] = idx_;
1911 }
1912 }
1913
1914 let mut nn_embedded = Array2::<usize>::zeros((n_samples, n_neighbors));
1916 for i in 0..n_samples {
1917 let row = dist_embedded.row(i).to_owned();
1919 let mut pairs: Vec<(usize, f64)> = row.iter().enumerate().map(|(j, &d)| (j, d)).collect();
1920 pairs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1921
1922 for (j, &(idx, _)) in pairs.iter().skip(1).take(n_neighbors).enumerate() {
1924 nn_embedded[[i, j]] = idx;
1925 }
1926 }
1927
1928 let mut t = 0.0;
1930 for i in 0..n_samples {
1931 for &j in nn_embedded.row(i).iter() {
1932 let is_not_neighbor = !nn_orig.row(i).iter().any(|&nn| nn == j);
1934
1935 if is_not_neighbor {
1936 let row = dist_x.row(i).to_owned();
1938 let mut pairs: Vec<(usize, f64)> =
1939 row.iter().enumerate().map(|(idx, &d)| (idx, d)).collect();
1940 pairs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1941
1942 let rank = pairs.iter().position(|&(idx_, _)| idx_ == j).unwrap_or(0) - n_neighbors;
1943
1944 t += rank as f64;
1945 }
1946 }
1947 }
1948
1949 let n = n_samples as f64;
1951 let k = n_neighbors as f64;
1952 let normalizer = 2.0 / (n * k * (2.0 * n - 3.0 * k - 1.0));
1953 let trustworthiness = 1.0 - normalizer * t;
1954
1955 Ok(trustworthiness)
1956}
1957
1958#[cfg(test)]
1959mod tests {
1960 use super::*;
1961 use approx::assert_abs_diff_eq;
1962 use ndarray::arr2;
1963
1964 #[test]
1965 fn test_tsne_simple() {
1966 let x = arr2(&[
1968 [0.0, 0.0],
1969 [0.0, 1.0],
1970 [1.0, 0.0],
1971 [1.0, 1.0],
1972 [5.0, 5.0],
1973 [6.0, 5.0],
1974 [5.0, 6.0],
1975 [6.0, 6.0],
1976 ]);
1977
1978 let mut tsne_exact = TSNE::new()
1980 .with_n_components(2)
1981 .with_perplexity(2.0)
1982 .with_method("exact")
1983 .with_random_state(42)
1984 .with_max_iter(250)
1985 .with_verbose(false);
1986
1987 let embedding_exact = tsne_exact.fit_transform(&x).unwrap();
1988
1989 assert_eq!(embedding_exact.shape(), &[8, 2]);
1991
1992 let dist_group1 = average_pairwise_distance(&embedding_exact.slice(ndarray::s![0..4, ..]));
1995 let dist_group2 = average_pairwise_distance(&embedding_exact.slice(ndarray::s![4..8, ..]));
1996
1997 let dist_between = average_intergroup_distance(
1999 &embedding_exact.slice(ndarray::s![0..4, ..]),
2000 &embedding_exact.slice(ndarray::s![4..8, ..]),
2001 );
2002
2003 assert!(dist_between > dist_group1);
2005 assert!(dist_between > dist_group2);
2006 }
2007
2008 #[test]
2009 fn test_tsne_barnes_hut() {
2010 let x = arr2(&[
2012 [0.0, 0.0],
2013 [0.0, 1.0],
2014 [1.0, 0.0],
2015 [1.0, 1.0],
2016 [5.0, 5.0],
2017 [6.0, 5.0],
2018 [5.0, 6.0],
2019 [6.0, 6.0],
2020 ]);
2021
2022 let mut tsne_bh = TSNE::new()
2024 .with_n_components(2)
2025 .with_perplexity(2.0)
2026 .with_method("barnes_hut")
2027 .with_angle(0.5)
2028 .with_random_state(42)
2029 .with_max_iter(250)
2030 .with_verbose(false);
2031
2032 let embedding_bh = tsne_bh.fit_transform(&x).unwrap();
2033
2034 assert_eq!(embedding_bh.shape(), &[8, 2]);
2036
2037 assert!(embedding_bh.iter().all(|&x| x.is_finite()));
2039
2040 let min_val = embedding_bh.iter().cloned().fold(f64::INFINITY, f64::min);
2042 let max_val = embedding_bh
2043 .iter()
2044 .cloned()
2045 .fold(f64::NEG_INFINITY, f64::max);
2046 assert!(
2047 max_val - min_val > 1e-6,
2048 "Embedding should have some spread"
2049 );
2050
2051 assert!(tsne_bh.kl_divergence().is_some());
2053
2054 let kl_div = tsne_bh.kl_divergence().unwrap();
2057 if !kl_div.is_finite() {
2058 println!(
2060 "Barnes-Hut KL divergence: {} (non-finite, which is acceptable for approximation)",
2061 kl_div
2062 );
2063 } else {
2064 println!("Barnes-Hut KL divergence: {} (finite)", kl_div);
2065 }
2066 }
2067
2068 #[test]
2069 fn test_tsne_multicore() {
2070 let x = arr2(&[
2072 [0.0, 0.0],
2073 [0.0, 1.0],
2074 [1.0, 0.0],
2075 [1.0, 1.0],
2076 [5.0, 5.0],
2077 [6.0, 5.0],
2078 [5.0, 6.0],
2079 [6.0, 6.0],
2080 ]);
2081
2082 let mut tsne_multicore = TSNE::new()
2084 .with_n_components(2)
2085 .with_perplexity(2.0)
2086 .with_method("exact")
2087 .with_n_jobs(-1) .with_random_state(42)
2089 .with_max_iter(100) .with_verbose(false);
2091
2092 let embedding_multicore = tsne_multicore.fit_transform(&x).unwrap();
2093
2094 assert_eq!(embedding_multicore.shape(), &[8, 2]);
2096
2097 assert!(embedding_multicore.iter().all(|&x| x.is_finite()));
2099
2100 let min_val = embedding_multicore
2102 .iter()
2103 .cloned()
2104 .fold(f64::INFINITY, f64::min);
2105 let max_val = embedding_multicore
2106 .iter()
2107 .cloned()
2108 .fold(f64::NEG_INFINITY, f64::max);
2109 assert!(
2110 max_val - min_val > 1e-12,
2111 "Embedding should have some spread, got range: {}",
2112 max_val - min_val
2113 );
2114
2115 let mut tsne_singlecore = TSNE::new()
2117 .with_n_components(2)
2118 .with_perplexity(2.0)
2119 .with_method("exact")
2120 .with_n_jobs(1) .with_random_state(42)
2122 .with_max_iter(100)
2123 .with_verbose(false);
2124
2125 let embedding_singlecore = tsne_singlecore.fit_transform(&x).unwrap();
2126
2127 assert!(embedding_multicore.iter().all(|&x| x.is_finite()));
2129 assert!(embedding_singlecore.iter().all(|&x| x.is_finite()));
2130 }
2131
2132 #[test]
2133 fn test_tsne_3d_barnes_hut() {
2134 let x = arr2(&[
2136 [0.0, 0.0, 0.0],
2137 [0.0, 1.0, 0.0],
2138 [1.0, 0.0, 0.0],
2139 [1.0, 1.0, 0.0],
2140 [5.0, 5.0, 5.0],
2141 [6.0, 5.0, 5.0],
2142 [5.0, 6.0, 5.0],
2143 [6.0, 6.0, 5.0],
2144 ]);
2145
2146 let mut tsne_3d = TSNE::new()
2148 .with_n_components(3)
2149 .with_perplexity(2.0)
2150 .with_method("barnes_hut")
2151 .with_angle(0.5)
2152 .with_random_state(42)
2153 .with_max_iter(250)
2154 .with_verbose(false);
2155
2156 let embedding_3d = tsne_3d.fit_transform(&x).unwrap();
2157
2158 assert_eq!(embedding_3d.shape(), &[8, 3]);
2160
2161 assert!(embedding_3d.iter().all(|&x| x.is_finite()));
2163 }
2164
2165 fn average_pairwise_distance(points: &ArrayBase<ndarray::ViewRepr<&f64>, Ix2>) -> f64 {
2167 let n = points.shape()[0];
2168 let mut total_dist = 0.0;
2169 let mut count = 0;
2170
2171 for i in 0..n {
2172 for j in i + 1..n {
2173 let mut dist_squared = 0.0;
2174 for k in 0..points.shape()[1] {
2175 let diff = points[[i, k]] - points[[j, k]];
2176 dist_squared += diff * diff;
2177 }
2178 total_dist += dist_squared.sqrt();
2179 count += 1;
2180 }
2181 }
2182
2183 if count > 0 {
2184 total_dist / count as f64
2185 } else {
2186 0.0
2187 }
2188 }
2189
2190 fn average_intergroup_distance(
2192 group1: &ArrayBase<ndarray::ViewRepr<&f64>, Ix2>,
2193 group2: &ArrayBase<ndarray::ViewRepr<&f64>, Ix2>,
2194 ) -> f64 {
2195 let n1 = group1.shape()[0];
2196 let n2 = group2.shape()[0];
2197 let mut total_dist = 0.0;
2198 let mut count = 0;
2199
2200 for i in 0..n1 {
2201 for j in 0..n2 {
2202 let mut dist_squared = 0.0;
2203 for k in 0..group1.shape()[1] {
2204 let diff = group1[[i, k]] - group2[[j, k]];
2205 dist_squared += diff * diff;
2206 }
2207 total_dist += dist_squared.sqrt();
2208 count += 1;
2209 }
2210 }
2211
2212 if count > 0 {
2213 total_dist / count as f64
2214 } else {
2215 0.0
2216 }
2217 }
2218
2219 #[test]
2220 fn test_trustworthiness() {
2221 let x = arr2(&[
2223 [0.0, 0.0],
2224 [0.0, 1.0],
2225 [1.0, 0.0],
2226 [1.0, 1.0],
2227 [5.0, 5.0],
2228 [5.0, 6.0],
2229 [6.0, 5.0],
2230 [6.0, 6.0],
2231 ]);
2232
2233 let perfect_embedding = x.clone();
2235 let t_perfect = trustworthiness(&x, &perfect_embedding, 3, "euclidean").unwrap();
2236 assert_abs_diff_eq!(t_perfect, 1.0, epsilon = 1e-10);
2237
2238 let random_embedding = arr2(&[
2240 [0.9, 0.1],
2241 [0.8, 0.2],
2242 [0.7, 0.3],
2243 [0.6, 0.4],
2244 [0.5, 0.5],
2245 [0.4, 0.6],
2246 [0.3, 0.7],
2247 [0.2, 0.8],
2248 ]);
2249
2250 let t_random = trustworthiness(&x, &random_embedding, 3, "euclidean").unwrap();
2251 assert!(t_random < 1.0);
2252 }
2253}