scirs2_transform/reduction/
tsne.rs

1//! t-SNE (t-distributed Stochastic Neighbor Embedding) implementation
2//!
3//! This module provides an implementation of t-SNE, a technique for dimensionality
4//! reduction particularly well-suited for visualization of high-dimensional data.
5//!
6//! t-SNE converts similarities between data points to joint probabilities and tries
7//! to minimize the Kullback-Leibler divergence between the joint probabilities of
8//! the low-dimensional embedding and the high-dimensional data.
9
10use ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
11use ndarray_rand::rand_distr::Normal;
12use ndarray_rand::RandomExt;
13use num_traits::{Float, NumCast};
14use scirs2_core::parallel_ops::*;
15
16use crate::error::{Result, TransformError};
17use crate::reduction::PCA;
18
19// Constants for numerical stability
20const MACHINE_EPSILON: f64 = 1e-14;
21const EPSILON: f64 = 1e-7;
22
23/// Spatial tree data structure for Barnes-Hut approximation
24#[derive(Debug, Clone)]
25enum SpatialTree {
26    QuadTree(QuadTreeNode),
27    OctTree(OctTreeNode),
28}
29
30/// Node in a quadtree (for 2D embeddings)
31#[derive(Debug, Clone)]
32struct QuadTreeNode {
33    /// Bounding box of this node
34    x_min: f64,
35    x_max: f64,
36    y_min: f64,
37    y_max: f64,
38    /// Center of mass
39    center_of_mass: Option<Array1<f64>>,
40    /// Total mass (number of points)
41    total_mass: f64,
42    /// Point indices in this node (for leaf nodes)
43    point_indices: Vec<usize>,
44    /// Children nodes (NW, NE, SW, SE)
45    children: Option<[Box<QuadTreeNode>; 4]>,
46    /// Whether this is a leaf node
47    is_leaf: bool,
48}
49
50/// Node in an octree (for 3D embeddings)
51#[derive(Debug, Clone)]
52struct OctTreeNode {
53    /// Bounding box of this node
54    x_min: f64,
55    x_max: f64,
56    y_min: f64,
57    y_max: f64,
58    z_min: f64,
59    z_max: f64,
60    /// Center of mass
61    center_of_mass: Option<Array1<f64>>,
62    /// Total mass (number of points)
63    total_mass: f64,
64    /// Point indices in this node (for leaf nodes)
65    point_indices: Vec<usize>,
66    /// Children nodes (8 octants)
67    children: Option<[Box<OctTreeNode>; 8]>,
68    /// Whether this is a leaf node
69    is_leaf: bool,
70}
71
72impl SpatialTree {
73    /// Create a new quadtree for 2D embeddings
74    fn new_quadtree(embedding: &Array2<f64>) -> Result<Self> {
75        let n_samples = embedding.shape()[0];
76
77        if embedding.shape()[1] != 2 {
78            return Err(TransformError::InvalidInput(
79                "QuadTree requires 2D _embedding".to_string(),
80            ));
81        }
82
83        // Find bounding box
84        let mut x_min = f64::INFINITY;
85        let mut x_max = f64::NEG_INFINITY;
86        let mut y_min = f64::INFINITY;
87        let mut y_max = f64::NEG_INFINITY;
88
89        for i in 0..n_samples {
90            let x = embedding[[i, 0]];
91            let y = embedding[[i, 1]];
92            x_min = x_min.min(x);
93            x_max = x_max.max(x);
94            y_min = y_min.min(y);
95            y_max = y_max.max(y);
96        }
97
98        // Add small margin to avoid edge cases
99        let margin = 0.01 * ((x_max - x_min) + (y_max - y_min));
100        x_min -= margin;
101        x_max += margin;
102        y_min -= margin;
103        y_max += margin;
104
105        // Collect all point indices
106        let point_indices: Vec<usize> = (0..n_samples).collect();
107
108        // Create root node
109        let mut root = QuadTreeNode {
110            x_min,
111            x_max,
112            y_min,
113            y_max,
114            center_of_mass: None,
115            total_mass: 0.0,
116            point_indices,
117            children: None,
118            is_leaf: true,
119        };
120
121        // Build the tree
122        root.build_tree(embedding)?;
123
124        Ok(SpatialTree::QuadTree(root))
125    }
126
127    /// Create a new octree for 3D embeddings
128    fn new_octree(embedding: &Array2<f64>) -> Result<Self> {
129        let n_samples = embedding.shape()[0];
130
131        if embedding.shape()[1] != 3 {
132            return Err(TransformError::InvalidInput(
133                "OctTree requires 3D _embedding".to_string(),
134            ));
135        }
136
137        // Find bounding box
138        let mut x_min = f64::INFINITY;
139        let mut x_max = f64::NEG_INFINITY;
140        let mut y_min = f64::INFINITY;
141        let mut y_max = f64::NEG_INFINITY;
142        let mut z_min = f64::INFINITY;
143        let mut z_max = f64::NEG_INFINITY;
144
145        for i in 0..n_samples {
146            let x = embedding[[i, 0]];
147            let y = embedding[[i, 1]];
148            let z = embedding[[i, 2]];
149            x_min = x_min.min(x);
150            x_max = x_max.max(x);
151            y_min = y_min.min(y);
152            y_max = y_max.max(y);
153            z_min = z_min.min(z);
154            z_max = z_max.max(z);
155        }
156
157        // Add small margin to avoid edge cases
158        let margin = 0.01 * ((x_max - x_min) + (y_max - y_min) + (z_max - z_min));
159        x_min -= margin;
160        x_max += margin;
161        y_min -= margin;
162        y_max += margin;
163        z_min -= margin;
164        z_max += margin;
165
166        // Collect all point indices
167        let point_indices: Vec<usize> = (0..n_samples).collect();
168
169        // Create root node
170        let mut root = OctTreeNode {
171            x_min,
172            x_max,
173            y_min,
174            y_max,
175            z_min,
176            z_max,
177            center_of_mass: None,
178            total_mass: 0.0,
179            point_indices,
180            children: None,
181            is_leaf: true,
182        };
183
184        // Build the tree
185        root.build_tree(embedding)?;
186
187        Ok(SpatialTree::OctTree(root))
188    }
189
190    /// Compute forces on a point using Barnes-Hut approximation
191    #[allow(clippy::too_many_arguments)]
192    fn compute_forces(
193        &self,
194        point: &Array1<f64>,
195        point_idx: usize,
196        angle: f64,
197        degrees_of_freedom: f64,
198    ) -> Result<(Array1<f64>, f64)> {
199        match self {
200            SpatialTree::QuadTree(root) => {
201                root.compute_forces_quad(point, point_idx, angle, degrees_of_freedom)
202            }
203            SpatialTree::OctTree(root) => {
204                root.compute_forces_oct(point, point_idx, angle, degrees_of_freedom)
205            }
206        }
207    }
208}
209
210impl QuadTreeNode {
211    /// Build the quadtree recursively
212    fn build_tree(&mut self, embedding: &Array2<f64>) -> Result<()> {
213        if self.point_indices.len() <= 1 {
214            // Leaf node with 0 or 1 points
215            self.update_center_of_mass(embedding)?;
216            return Ok(());
217        }
218
219        // Split into 4 quadrants
220        let x_mid = (self.x_min + self.x_max) / 2.0;
221        let y_mid = (self.y_min + self.y_max) / 2.0;
222
223        let mut quadrants: [Vec<usize>; 4] = [Vec::new(), Vec::new(), Vec::new(), Vec::new()];
224
225        // Distribute points to quadrants
226        for &idx in &self.point_indices {
227            let x = embedding[[idx, 0]];
228            let y = embedding[[idx, 1]];
229
230            let quadrant = match (x >= x_mid, y >= y_mid) {
231                (false, false) => 0, // SW
232                (true, false) => 1,  // SE
233                (false, true) => 2,  // NW
234                (true, true) => 3,   // NE
235            };
236
237            quadrants[quadrant].push(idx);
238        }
239
240        // Create child nodes
241        let mut children = [
242            Box::new(QuadTreeNode {
243                x_min: self.x_min,
244                x_max: x_mid,
245                y_min: self.y_min,
246                y_max: y_mid,
247                center_of_mass: None,
248                total_mass: 0.0,
249                point_indices: quadrants[0].clone(),
250                children: None,
251                is_leaf: true,
252            }),
253            Box::new(QuadTreeNode {
254                x_min: x_mid,
255                x_max: self.x_max,
256                y_min: self.y_min,
257                y_max: y_mid,
258                center_of_mass: None,
259                total_mass: 0.0,
260                point_indices: quadrants[1].clone(),
261                children: None,
262                is_leaf: true,
263            }),
264            Box::new(QuadTreeNode {
265                x_min: self.x_min,
266                x_max: x_mid,
267                y_min: y_mid,
268                y_max: self.y_max,
269                center_of_mass: None,
270                total_mass: 0.0,
271                point_indices: quadrants[2].clone(),
272                children: None,
273                is_leaf: true,
274            }),
275            Box::new(QuadTreeNode {
276                x_min: x_mid,
277                x_max: self.x_max,
278                y_min: y_mid,
279                y_max: self.y_max,
280                center_of_mass: None,
281                total_mass: 0.0,
282                point_indices: quadrants[3].clone(),
283                children: None,
284                is_leaf: true,
285            }),
286        ];
287
288        // Recursively build children
289        for child in &mut children {
290            child.build_tree(embedding)?;
291        }
292
293        self.children = Some(children);
294        self.is_leaf = false;
295        self.point_indices.clear(); // Clear points as they are now in children
296        self.update_center_of_mass(embedding)?;
297
298        Ok(())
299    }
300
301    /// Update center of mass for this node
302    fn update_center_of_mass(&mut self, embedding: &Array2<f64>) -> Result<()> {
303        if self.is_leaf {
304            // Leaf node: compute center of mass from points
305            if self.point_indices.is_empty() {
306                self.total_mass = 0.0;
307                self.center_of_mass = None;
308                return Ok(());
309            }
310
311            let mut com = Array1::zeros(2);
312            for &idx in &self.point_indices {
313                com[0] += embedding[[idx, 0]];
314                com[1] += embedding[[idx, 1]];
315            }
316
317            self.total_mass = self.point_indices.len() as f64;
318            com.mapv_inplace(|x| x / self.total_mass);
319            self.center_of_mass = Some(com);
320        } else {
321            // Internal node: compute center of mass from children
322            if let Some(ref children) = self.children {
323                let mut com = Array1::zeros(2);
324                let mut total_mass = 0.0;
325
326                for child in children.iter() {
327                    if let Some(ref child_com) = child.center_of_mass {
328                        total_mass += child.total_mass;
329                        for i in 0..2 {
330                            com[i] += child_com[i] * child.total_mass;
331                        }
332                    }
333                }
334
335                if total_mass > 0.0 {
336                    com.mapv_inplace(|x| x / total_mass);
337                    self.center_of_mass = Some(com);
338                    self.total_mass = total_mass;
339                } else {
340                    self.center_of_mass = None;
341                    self.total_mass = 0.0;
342                }
343            }
344        }
345
346        Ok(())
347    }
348
349    /// Compute forces using Barnes-Hut approximation for quadtree
350    #[allow(clippy::too_many_arguments)]
351    fn compute_forces_quad(
352        &self,
353        point: &Array1<f64>,
354        point_idx: usize,
355        angle: f64,
356        degrees_of_freedom: f64,
357    ) -> Result<(Array1<f64>, f64)> {
358        let mut force = Array1::zeros(2);
359        let mut sum_q = 0.0;
360
361        self.compute_forces_recursive_quad(
362            point,
363            point_idx,
364            angle,
365            degrees_of_freedom,
366            &mut force,
367            &mut sum_q,
368        )?;
369
370        Ok((force, sum_q))
371    }
372
373    /// Recursive force computation for quadtree
374    #[allow(clippy::too_many_arguments)]
375    fn compute_forces_recursive_quad(
376        &self,
377        point: &Array1<f64>,
378        point_idx: usize,
379        angle: f64,
380        degrees_of_freedom: f64,
381        force: &mut Array1<f64>,
382        sum_q: &mut f64,
383    ) -> Result<()> {
384        if let Some(ref com) = self.center_of_mass {
385            if self.total_mass == 0.0 {
386                return Ok(());
387            }
388
389            // Compute distance to center of mass
390            let dx = point[0] - com[0];
391            let dy = point[1] - com[1];
392            let dist_squared = dx * dx + dy * dy;
393
394            if dist_squared < MACHINE_EPSILON {
395                return Ok(());
396            }
397
398            // Check if we can use this node's center of mass (Barnes-Hut criterion)
399            let node_size = (self.x_max - self.x_min).max(self.y_max - self.y_min);
400            let distance = dist_squared.sqrt();
401
402            if self.is_leaf || (node_size / distance) < angle {
403                // Use center of mass approximation
404                let q_factor = (1.0 + dist_squared / degrees_of_freedom)
405                    .powf(-(degrees_of_freedom + 1.0) / 2.0);
406
407                *sum_q += self.total_mass * q_factor;
408
409                let force_factor =
410                    (degrees_of_freedom + 1.0) * self.total_mass * q_factor / degrees_of_freedom;
411                force[0] += force_factor * dx;
412                force[1] += force_factor * dy;
413            } else {
414                // Recursively compute forces from children
415                if let Some(ref children) = self.children {
416                    for child in children.iter() {
417                        child.compute_forces_recursive_quad(
418                            point,
419                            point_idx,
420                            angle,
421                            degrees_of_freedom,
422                            force,
423                            sum_q,
424                        )?;
425                    }
426                }
427            }
428        } else if self.is_leaf {
429            // Leaf node without center of mass (empty node)
430            for &_idx in &self.point_indices {
431                if _idx != point_idx {
432                    // Compute exact force for this point
433                    // This will be handled by attractive forces in the main gradient computation
434                }
435            }
436        }
437
438        Ok(())
439    }
440}
441
442impl OctTreeNode {
443    /// Build the octree recursively
444    fn build_tree(&mut self, embedding: &Array2<f64>) -> Result<()> {
445        if self.point_indices.len() <= 1 {
446            // Leaf node with 0 or 1 points
447            self.update_center_of_mass(embedding)?;
448            return Ok(());
449        }
450
451        // Split into 8 octants
452        let x_mid = (self.x_min + self.x_max) / 2.0;
453        let y_mid = (self.y_min + self.y_max) / 2.0;
454        let z_mid = (self.z_min + self.z_max) / 2.0;
455
456        let mut octants: [Vec<usize>; 8] = [
457            Vec::new(),
458            Vec::new(),
459            Vec::new(),
460            Vec::new(),
461            Vec::new(),
462            Vec::new(),
463            Vec::new(),
464            Vec::new(),
465        ];
466
467        // Distribute points to octants
468        for &idx in &self.point_indices {
469            let x = embedding[[idx, 0]];
470            let y = embedding[[idx, 1]];
471            let z = embedding[[idx, 2]];
472
473            let octant = match (x >= x_mid, y >= y_mid, z >= z_mid) {
474                (false, false, false) => 0,
475                (true, false, false) => 1,
476                (false, true, false) => 2,
477                (true, true, false) => 3,
478                (false, false, true) => 4,
479                (true, false, true) => 5,
480                (false, true, true) => 6,
481                (true, true, true) => 7,
482            };
483
484            octants[octant].push(idx);
485        }
486
487        // Create child nodes
488        let mut children = [
489            Box::new(OctTreeNode {
490                x_min: self.x_min,
491                x_max: x_mid,
492                y_min: self.y_min,
493                y_max: y_mid,
494                z_min: self.z_min,
495                z_max: z_mid,
496                center_of_mass: None,
497                total_mass: 0.0,
498                point_indices: octants[0].clone(),
499                children: None,
500                is_leaf: true,
501            }),
502            Box::new(OctTreeNode {
503                x_min: x_mid,
504                x_max: self.x_max,
505                y_min: self.y_min,
506                y_max: y_mid,
507                z_min: self.z_min,
508                z_max: z_mid,
509                center_of_mass: None,
510                total_mass: 0.0,
511                point_indices: octants[1].clone(),
512                children: None,
513                is_leaf: true,
514            }),
515            Box::new(OctTreeNode {
516                x_min: self.x_min,
517                x_max: x_mid,
518                y_min: y_mid,
519                y_max: self.y_max,
520                z_min: self.z_min,
521                z_max: z_mid,
522                center_of_mass: None,
523                total_mass: 0.0,
524                point_indices: octants[2].clone(),
525                children: None,
526                is_leaf: true,
527            }),
528            Box::new(OctTreeNode {
529                x_min: x_mid,
530                x_max: self.x_max,
531                y_min: y_mid,
532                y_max: self.y_max,
533                z_min: self.z_min,
534                z_max: z_mid,
535                center_of_mass: None,
536                total_mass: 0.0,
537                point_indices: octants[3].clone(),
538                children: None,
539                is_leaf: true,
540            }),
541            Box::new(OctTreeNode {
542                x_min: self.x_min,
543                x_max: x_mid,
544                y_min: self.y_min,
545                y_max: y_mid,
546                z_min: z_mid,
547                z_max: self.z_max,
548                center_of_mass: None,
549                total_mass: 0.0,
550                point_indices: octants[4].clone(),
551                children: None,
552                is_leaf: true,
553            }),
554            Box::new(OctTreeNode {
555                x_min: x_mid,
556                x_max: self.x_max,
557                y_min: self.y_min,
558                y_max: y_mid,
559                z_min: z_mid,
560                z_max: self.z_max,
561                center_of_mass: None,
562                total_mass: 0.0,
563                point_indices: octants[5].clone(),
564                children: None,
565                is_leaf: true,
566            }),
567            Box::new(OctTreeNode {
568                x_min: self.x_min,
569                x_max: x_mid,
570                y_min: y_mid,
571                y_max: self.y_max,
572                z_min: z_mid,
573                z_max: self.z_max,
574                center_of_mass: None,
575                total_mass: 0.0,
576                point_indices: octants[6].clone(),
577                children: None,
578                is_leaf: true,
579            }),
580            Box::new(OctTreeNode {
581                x_min: x_mid,
582                x_max: self.x_max,
583                y_min: y_mid,
584                y_max: self.y_max,
585                z_min: z_mid,
586                z_max: self.z_max,
587                center_of_mass: None,
588                total_mass: 0.0,
589                point_indices: octants[7].clone(),
590                children: None,
591                is_leaf: true,
592            }),
593        ];
594
595        // Recursively build children
596        for child in &mut children {
597            child.build_tree(embedding)?;
598        }
599
600        self.children = Some(children);
601        self.is_leaf = false;
602        self.point_indices.clear();
603        self.update_center_of_mass(embedding)?;
604
605        Ok(())
606    }
607
608    /// Update center of mass for this octree node
609    fn update_center_of_mass(&mut self, embedding: &Array2<f64>) -> Result<()> {
610        if self.is_leaf {
611            if self.point_indices.is_empty() {
612                self.total_mass = 0.0;
613                self.center_of_mass = None;
614                return Ok(());
615            }
616
617            let mut com = Array1::zeros(3);
618            for &idx in &self.point_indices {
619                com[0] += embedding[[idx, 0]];
620                com[1] += embedding[[idx, 1]];
621                com[2] += embedding[[idx, 2]];
622            }
623
624            self.total_mass = self.point_indices.len() as f64;
625            com.mapv_inplace(|x| x / self.total_mass);
626            self.center_of_mass = Some(com);
627        } else if let Some(ref children) = self.children {
628            let mut com = Array1::zeros(3);
629            let mut total_mass = 0.0;
630
631            for child in children.iter() {
632                if let Some(ref child_com) = child.center_of_mass {
633                    total_mass += child.total_mass;
634                    for i in 0..3 {
635                        com[i] += child_com[i] * child.total_mass;
636                    }
637                }
638            }
639
640            if total_mass > 0.0 {
641                com.mapv_inplace(|x| x / total_mass);
642                self.center_of_mass = Some(com);
643                self.total_mass = total_mass;
644            } else {
645                self.center_of_mass = None;
646                self.total_mass = 0.0;
647            }
648        }
649
650        Ok(())
651    }
652
653    /// Compute forces using Barnes-Hut approximation for octree
654    #[allow(clippy::too_many_arguments)]
655    fn compute_forces_oct(
656        &self,
657        point: &Array1<f64>,
658        point_idx: usize,
659        angle: f64,
660        degrees_of_freedom: f64,
661    ) -> Result<(Array1<f64>, f64)> {
662        let mut force = Array1::zeros(3);
663        let mut sum_q = 0.0;
664
665        self.compute_forces_recursive_oct(
666            point,
667            point_idx,
668            angle,
669            degrees_of_freedom,
670            &mut force,
671            &mut sum_q,
672        )?;
673
674        Ok((force, sum_q))
675    }
676
677    /// Recursive force computation for octree
678    #[allow(clippy::too_many_arguments)]
679    fn compute_forces_recursive_oct(
680        &self,
681        point: &Array1<f64>,
682        _point_idx: usize,
683        angle: f64,
684        degrees_of_freedom: f64,
685        force: &mut Array1<f64>,
686        sum_q: &mut f64,
687    ) -> Result<()> {
688        if let Some(ref com) = self.center_of_mass {
689            if self.total_mass == 0.0 {
690                return Ok(());
691            }
692
693            let dx = point[0] - com[0];
694            let dy = point[1] - com[1];
695            let dz = point[2] - com[2];
696            let dist_squared = dx * dx + dy * dy + dz * dz;
697
698            if dist_squared < MACHINE_EPSILON {
699                return Ok(());
700            }
701
702            let node_size = (self.x_max - self.x_min)
703                .max(self.y_max - self.y_min)
704                .max(self.z_max - self.z_min);
705            let distance = dist_squared.sqrt();
706
707            if self.is_leaf || (node_size / distance) < angle {
708                let q_factor = (1.0 + dist_squared / degrees_of_freedom)
709                    .powf(-(degrees_of_freedom + 1.0) / 2.0);
710
711                *sum_q += self.total_mass * q_factor;
712
713                let force_factor =
714                    (degrees_of_freedom + 1.0) * self.total_mass * q_factor / degrees_of_freedom;
715                force[0] += force_factor * dx;
716                force[1] += force_factor * dy;
717                force[2] += force_factor * dz;
718            } else if let Some(ref children) = self.children {
719                for child in children.iter() {
720                    child.compute_forces_recursive_oct(
721                        point,
722                        _point_idx,
723                        angle,
724                        degrees_of_freedom,
725                        force,
726                        sum_q,
727                    )?;
728                }
729            }
730        }
731
732        Ok(())
733    }
734}
735
736/// t-SNE (t-distributed Stochastic Neighbor Embedding) for dimensionality reduction
737///
738/// t-SNE is a nonlinear dimensionality reduction technique well-suited for
739/// embedding high-dimensional data for visualization in a low-dimensional space
740/// (typically 2D or 3D). It models each high-dimensional object by a two- or
741/// three-dimensional point in such a way that similar objects are modeled by
742/// nearby points and dissimilar objects are modeled by distant points with
743/// high probability.
744pub struct TSNE {
745    /// Number of components in the embedded space
746    n_components: usize,
747    /// Perplexity parameter that balances attention between local and global structure
748    perplexity: f64,
749    /// Weight of early exaggeration phase
750    early_exaggeration: f64,
751    /// Learning rate for optimization
752    learning_rate: f64,
753    /// Maximum number of iterations
754    max_iter: usize,
755    /// Maximum iterations without progress before early stopping
756    n_iter_without_progress: usize,
757    /// Minimum gradient norm for convergence
758    min_grad_norm: f64,
759    /// Method to compute pairwise distances
760    metric: String,
761    /// Method to perform dimensionality reduction
762    method: String,
763    /// Initialization method
764    init: String,
765    /// Angle for Barnes-Hut approximation
766    angle: f64,
767    /// Whether to use multicore processing
768    n_jobs: i32,
769    /// Verbosity level
770    verbose: bool,
771    /// Random state for reproducibility
772    random_state: Option<u64>,
773    /// The embedding vectors
774    embedding_: Option<Array2<f64>>,
775    /// KL divergence after optimization
776    kl_divergence_: Option<f64>,
777    /// Total number of iterations run
778    n_iter_: Option<usize>,
779    /// Effective learning rate used
780    learning_rate_: Option<f64>,
781}
782
783impl Default for TSNE {
784    fn default() -> Self {
785        Self::new()
786    }
787}
788
789impl TSNE {
790    /// Creates a new t-SNE instance with default parameters
791    pub fn new() -> Self {
792        TSNE {
793            n_components: 2,
794            perplexity: 30.0,
795            early_exaggeration: 12.0,
796            learning_rate: 200.0,
797            max_iter: 1000,
798            n_iter_without_progress: 300,
799            min_grad_norm: 1e-7,
800            metric: "euclidean".to_string(),
801            method: "barnes_hut".to_string(),
802            init: "pca".to_string(),
803            angle: 0.5,
804            n_jobs: -1, // Use all available cores by default
805            verbose: false,
806            random_state: None,
807            embedding_: None,
808            kl_divergence_: None,
809            n_iter_: None,
810            learning_rate_: None,
811        }
812    }
813
814    /// Sets the number of components in the embedded space
815    pub fn with_n_components(mut self, ncomponents: usize) -> Self {
816        self.n_components = ncomponents;
817        self
818    }
819
820    /// Sets the perplexity parameter
821    pub fn with_perplexity(mut self, perplexity: f64) -> Self {
822        self.perplexity = perplexity;
823        self
824    }
825
826    /// Sets the early exaggeration factor
827    pub fn with_early_exaggeration(mut self, earlyexaggeration: f64) -> Self {
828        self.early_exaggeration = earlyexaggeration;
829        self
830    }
831
832    /// Sets the learning rate for gradient descent
833    pub fn with_learning_rate(mut self, learningrate: f64) -> Self {
834        self.learning_rate = learningrate;
835        self
836    }
837
838    /// Sets the maximum number of iterations
839    pub fn with_max_iter(mut self, maxiter: usize) -> Self {
840        self.max_iter = maxiter;
841        self
842    }
843
844    /// Sets the number of iterations without progress before early stopping
845    pub fn with_n_iter_without_progress(mut self, n_iter_withoutprogress: usize) -> Self {
846        self.n_iter_without_progress = n_iter_withoutprogress;
847        self
848    }
849
850    /// Sets the minimum gradient norm for convergence
851    pub fn with_min_grad_norm(mut self, min_gradnorm: f64) -> Self {
852        self.min_grad_norm = min_gradnorm;
853        self
854    }
855
856    /// Sets the metric for pairwise distance computation
857    ///
858    /// Supported metrics:
859    /// - "euclidean": Euclidean distance (L2 norm) - default
860    /// - "manhattan": Manhattan distance (L1 norm)
861    /// - "cosine": Cosine distance (1 - cosine similarity)
862    /// - "chebyshev": Chebyshev distance (maximum coordinate difference)
863    pub fn with_metric(mut self, metric: &str) -> Self {
864        self.metric = metric.to_string();
865        self
866    }
867
868    /// Sets the method for dimensionality reduction
869    pub fn with_method(mut self, method: &str) -> Self {
870        self.method = method.to_string();
871        self
872    }
873
874    /// Sets the initialization method
875    pub fn with_init(mut self, init: &str) -> Self {
876        self.init = init.to_string();
877        self
878    }
879
880    /// Sets the angle for Barnes-Hut approximation
881    pub fn with_angle(mut self, angle: f64) -> Self {
882        self.angle = angle;
883        self
884    }
885
886    /// Sets the number of parallel jobs to run
887    /// * n_jobs = -1: Use all available cores
888    /// * n_jobs = 1: Use single-core (disable multicore)
889    /// * n_jobs > 1: Use specific number of cores
890    pub fn with_n_jobs(mut self, njobs: i32) -> Self {
891        self.n_jobs = njobs;
892        self
893    }
894
895    /// Sets the verbosity level
896    pub fn with_verbose(mut self, verbose: bool) -> Self {
897        self.verbose = verbose;
898        self
899    }
900
901    /// Sets the random state for reproducibility
902    pub fn with_random_state(mut self, randomstate: u64) -> Self {
903        self.random_state = Some(randomstate);
904        self
905    }
906
907    /// Fit t-SNE to input data and transform it to the embedded space
908    ///
909    /// # Arguments
910    /// * `x` - Input data, shape (n_samples, n_features)
911    ///
912    /// # Returns
913    /// * `Result<Array2<f64>>` - Embedding of the training data, shape (n_samples, n_components)
914    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
915    where
916        S: Data,
917        S::Elem: Float + NumCast,
918    {
919        let x_f64 = x.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
920
921        let n_samples = x_f64.shape()[0];
922        let n_features = x_f64.shape()[1];
923
924        // Input validation
925        if n_samples == 0 || n_features == 0 {
926            return Err(TransformError::InvalidInput("Empty input data".to_string()));
927        }
928
929        if self.perplexity >= n_samples as f64 {
930            return Err(TransformError::InvalidInput(format!(
931                "perplexity ({}) must be less than n_samples ({})",
932                self.perplexity, n_samples
933            )));
934        }
935
936        if self.method == "barnes_hut" && self.n_components > 3 {
937            return Err(TransformError::InvalidInput(
938                "'n_components' should be less than or equal to 3 for barnes_hut algorithm"
939                    .to_string(),
940            ));
941        }
942
943        // Set learning rate if auto
944        self.learning_rate_ = Some(self.learning_rate);
945
946        // Initialize embedding
947        let x_embedded = self.initialize_embedding(&x_f64)?;
948
949        // Compute pairwise affinities (P)
950        let p = self.compute_pairwise_affinities(&x_f64)?;
951
952        // Run t-SNE optimization
953        let (embedding, kl_divergence, n_iter) =
954            self.tsne_optimization(p, x_embedded, n_samples)?;
955
956        self.embedding_ = Some(embedding.clone());
957        self.kl_divergence_ = Some(kl_divergence);
958        self.n_iter_ = Some(n_iter);
959
960        Ok(embedding)
961    }
962
963    /// Initialize embedding either with PCA or random
964    fn initialize_embedding(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
965        let n_samples = x.shape()[0];
966
967        if self.init == "pca" {
968            let n_components = self.n_components.min(x.shape()[1]);
969            let mut pca = PCA::new(n_components, true, false);
970            let mut x_embedded = pca.fit_transform(x)?;
971
972            // Scale PCA initialization
973            let std_dev = (x_embedded.column(0).map(|&x| x * x).sum() / (n_samples as f64)).sqrt();
974            if std_dev > 0.0 {
975                x_embedded.mapv_inplace(|x| x / std_dev * 1e-4);
976            }
977
978            Ok(x_embedded)
979        } else if self.init == "random" {
980            // Random initialization from standard normal distribution
981            // Ignoring random_state as it's not needed for basic random functionality
982            let normal = Normal::new(0.0, 1e-4).unwrap();
983
984            // Use simple random initialization
985            Ok(Array2::random((n_samples, self.n_components), normal))
986        } else {
987            Err(TransformError::InvalidInput(format!(
988                "Initialization method '{}' not recognized",
989                self.init
990            )))
991        }
992    }
993
994    /// Compute pairwise affinities with perplexity-based normalization
995    fn compute_pairwise_affinities(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
996        let _n_samples = x.shape()[0];
997
998        // Compute pairwise distances
999        let distances = self.compute_pairwise_distances(x)?;
1000
1001        // Convert distances to affinities using binary search for sigma
1002        let p = self.distances_to_affinities(&distances)?;
1003
1004        // Symmetrize and normalize the affinity matrix
1005        let mut p_symmetric = &p + &p.t();
1006
1007        // Normalize
1008        let p_sum = p_symmetric.sum();
1009        if p_sum > 0.0 {
1010            p_symmetric.mapv_inplace(|x| x.max(MACHINE_EPSILON) / p_sum);
1011        }
1012
1013        Ok(p_symmetric)
1014    }
1015
1016    /// Compute pairwise distances with optional multicore support
1017    fn compute_pairwise_distances(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
1018        let n_samples = x.shape()[0];
1019        let mut distances = Array2::zeros((n_samples, n_samples));
1020
1021        match self.metric.as_str() {
1022            "euclidean" => {
1023                if self.n_jobs == 1 {
1024                    // Single-core computation
1025                    for i in 0..n_samples {
1026                        for j in i + 1..n_samples {
1027                            let mut dist_squared = 0.0;
1028                            for k in 0..x.shape()[1] {
1029                                let diff = x[[i, k]] - x[[j, k]];
1030                                dist_squared += diff * diff;
1031                            }
1032                            distances[[i, j]] = dist_squared;
1033                            distances[[j, i]] = dist_squared;
1034                        }
1035                    }
1036                } else {
1037                    // Multi-core computation
1038                    let upper_triangle_indices: Vec<(usize, usize)> = (0..n_samples)
1039                        .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
1040                        .collect();
1041
1042                    let n_features = x.shape()[1];
1043                    let squared_distances: Vec<f64> = upper_triangle_indices
1044                        .par_iter()
1045                        .map(|&(i, j)| {
1046                            let mut dist_squared = 0.0;
1047                            for k in 0..n_features {
1048                                let diff = x[[i, k]] - x[[j, k]];
1049                                dist_squared += diff * diff;
1050                            }
1051                            dist_squared
1052                        })
1053                        .collect();
1054
1055                    // Fill the distance matrix
1056                    for (idx, &(i, j)) in upper_triangle_indices.iter().enumerate() {
1057                        distances[[i, j]] = squared_distances[idx];
1058                        distances[[j, i]] = squared_distances[idx];
1059                    }
1060                }
1061            }
1062            "manhattan" => {
1063                if self.n_jobs == 1 {
1064                    // Single-core Manhattan distance computation
1065                    for i in 0..n_samples {
1066                        for j in i + 1..n_samples {
1067                            let mut dist = 0.0;
1068                            for k in 0..x.shape()[1] {
1069                                dist += (x[[i, k]] - x[[j, k]]).abs();
1070                            }
1071                            distances[[i, j]] = dist;
1072                            distances[[j, i]] = dist;
1073                        }
1074                    }
1075                } else {
1076                    // Multi-core Manhattan distance computation
1077                    let upper_triangle_indices: Vec<(usize, usize)> = (0..n_samples)
1078                        .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
1079                        .collect();
1080
1081                    let n_features = x.shape()[1];
1082                    let manhattan_distances: Vec<f64> = upper_triangle_indices
1083                        .par_iter()
1084                        .map(|&(i, j)| {
1085                            let mut dist = 0.0;
1086                            for k in 0..n_features {
1087                                dist += (x[[i, k]] - x[[j, k]]).abs();
1088                            }
1089                            dist
1090                        })
1091                        .collect();
1092
1093                    // Fill the distance matrix
1094                    for (idx, &(i, j)) in upper_triangle_indices.iter().enumerate() {
1095                        distances[[i, j]] = manhattan_distances[idx];
1096                        distances[[j, i]] = manhattan_distances[idx];
1097                    }
1098                }
1099            }
1100            "cosine" => {
1101                // First normalize all vectors for cosine distance computation
1102                let mut normalized_x = Array2::zeros((n_samples, x.shape()[1]));
1103                for i in 0..n_samples {
1104                    let row = x.row(i);
1105                    let norm = row.iter().map(|v| v * v).sum::<f64>().sqrt();
1106                    if norm > EPSILON {
1107                        for j in 0..x.shape()[1] {
1108                            normalized_x[[i, j]] = x[[i, j]] / norm;
1109                        }
1110                    } else {
1111                        // Handle zero vectors
1112                        for j in 0..x.shape()[1] {
1113                            normalized_x[[i, j]] = 0.0;
1114                        }
1115                    }
1116                }
1117
1118                if self.n_jobs == 1 {
1119                    // Single-core cosine distance computation
1120                    for i in 0..n_samples {
1121                        for j in i + 1..n_samples {
1122                            let mut dot_product = 0.0;
1123                            for k in 0..x.shape()[1] {
1124                                dot_product += normalized_x[[i, k]] * normalized_x[[j, k]];
1125                            }
1126                            // Cosine distance = 1 - cosine similarity
1127                            let cosine_dist = 1.0 - dot_product.clamp(-1.0, 1.0);
1128                            distances[[i, j]] = cosine_dist;
1129                            distances[[j, i]] = cosine_dist;
1130                        }
1131                    }
1132                } else {
1133                    // Multi-core cosine distance computation
1134                    let upper_triangle_indices: Vec<(usize, usize)> = (0..n_samples)
1135                        .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
1136                        .collect();
1137
1138                    let n_features = x.shape()[1];
1139                    let cosine_distances: Vec<f64> = upper_triangle_indices
1140                        .par_iter()
1141                        .map(|&(i, j)| {
1142                            let mut dot_product = 0.0;
1143                            for k in 0..n_features {
1144                                dot_product += normalized_x[[i, k]] * normalized_x[[j, k]];
1145                            }
1146                            // Cosine distance = 1 - cosine similarity
1147                            1.0 - dot_product.clamp(-1.0, 1.0)
1148                        })
1149                        .collect();
1150
1151                    // Fill the distance matrix
1152                    for (idx, &(i, j)) in upper_triangle_indices.iter().enumerate() {
1153                        distances[[i, j]] = cosine_distances[idx];
1154                        distances[[j, i]] = cosine_distances[idx];
1155                    }
1156                }
1157            }
1158            "chebyshev" => {
1159                if self.n_jobs == 1 {
1160                    // Single-core Chebyshev distance computation
1161                    for i in 0..n_samples {
1162                        for j in i + 1..n_samples {
1163                            let mut max_dist = 0.0;
1164                            for k in 0..x.shape()[1] {
1165                                let diff = (x[[i, k]] - x[[j, k]]).abs();
1166                                max_dist = max_dist.max(diff);
1167                            }
1168                            distances[[i, j]] = max_dist;
1169                            distances[[j, i]] = max_dist;
1170                        }
1171                    }
1172                } else {
1173                    // Multi-core Chebyshev distance computation
1174                    let upper_triangle_indices: Vec<(usize, usize)> = (0..n_samples)
1175                        .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
1176                        .collect();
1177
1178                    let n_features = x.shape()[1];
1179                    let chebyshev_distances: Vec<f64> = upper_triangle_indices
1180                        .par_iter()
1181                        .map(|&(i, j)| {
1182                            let mut max_dist = 0.0;
1183                            for k in 0..n_features {
1184                                let diff = (x[[i, k]] - x[[j, k]]).abs();
1185                                max_dist = max_dist.max(diff);
1186                            }
1187                            max_dist
1188                        })
1189                        .collect();
1190
1191                    // Fill the distance matrix
1192                    for (idx, &(i, j)) in upper_triangle_indices.iter().enumerate() {
1193                        distances[[i, j]] = chebyshev_distances[idx];
1194                        distances[[j, i]] = chebyshev_distances[idx];
1195                    }
1196                }
1197            }
1198            _ => {
1199                return Err(TransformError::InvalidInput(format!(
1200                    "Metric '{}' not implemented. Supported metrics are: 'euclidean', 'manhattan', 'cosine', 'chebyshev'",
1201                    self.metric
1202                )));
1203            }
1204        }
1205
1206        Ok(distances)
1207    }
1208
1209    /// Convert distances to affinities using perplexity-based normalization with optional multicore support
1210    fn distances_to_affinities(&self, distances: &Array2<f64>) -> Result<Array2<f64>> {
1211        let n_samples = distances.shape()[0];
1212        let mut p = Array2::zeros((n_samples, n_samples));
1213        let target = (2.0f64).ln() * self.perplexity;
1214
1215        if self.n_jobs == 1 {
1216            // Single-core computation (original implementation)
1217            for i in 0..n_samples {
1218                let mut beta_min = -f64::INFINITY;
1219                let mut beta_max = f64::INFINITY;
1220                let mut beta = 1.0;
1221
1222                // Get all distances from point i except self-distance (which is 0)
1223                let distances_i = distances.row(i).to_owned();
1224
1225                // Binary search for beta
1226                for _ in 0..50 {
1227                    // Usually converges within 50 iterations
1228                    // Compute conditional probabilities with current beta
1229                    let mut sum_pi = 0.0;
1230                    let mut h = 0.0;
1231
1232                    for j in 0..n_samples {
1233                        if i == j {
1234                            p[[i, j]] = 0.0;
1235                            continue;
1236                        }
1237
1238                        let p_ij = (-beta * distances_i[j]).exp();
1239                        p[[i, j]] = p_ij;
1240                        sum_pi += p_ij;
1241                    }
1242
1243                    // Normalize probabilities and compute entropy
1244                    if sum_pi > 0.0 {
1245                        for j in 0..n_samples {
1246                            if i == j {
1247                                continue;
1248                            }
1249
1250                            p[[i, j]] /= sum_pi;
1251
1252                            // Compute entropy
1253                            if p[[i, j]] > MACHINE_EPSILON {
1254                                h -= p[[i, j]] * p[[i, j]].ln();
1255                            }
1256                        }
1257                    }
1258
1259                    // Adjust beta based on entropy difference from target
1260                    let h_diff = h - target;
1261
1262                    if h_diff.abs() < EPSILON {
1263                        break; // Converged
1264                    }
1265
1266                    // Update beta using binary search
1267                    if h_diff > 0.0 {
1268                        beta_min = beta;
1269                        if beta_max == f64::INFINITY {
1270                            beta *= 2.0;
1271                        } else {
1272                            beta = (beta + beta_max) / 2.0;
1273                        }
1274                    } else {
1275                        beta_max = beta;
1276                        if beta_min == -f64::INFINITY {
1277                            beta /= 2.0;
1278                        } else {
1279                            beta = (beta + beta_min) / 2.0;
1280                        }
1281                    }
1282                }
1283            }
1284        } else {
1285            // Multi-core computation of conditional probabilities for each point
1286            let prob_rows: Vec<Vec<f64>> = (0..n_samples)
1287                .into_par_iter()
1288                .map(|i| {
1289                    let mut beta_min = -f64::INFINITY;
1290                    let mut beta_max = f64::INFINITY;
1291                    let mut beta = 1.0;
1292
1293                    // Get all distances from point i except self-distance (which is 0)
1294                    let distances_i: Vec<f64> = (0..n_samples).map(|j| distances[[i, j]]).collect();
1295                    let mut p_row = vec![0.0; n_samples];
1296
1297                    // Binary search for beta
1298                    for _ in 0..50 {
1299                        // Usually converges within 50 iterations
1300                        // Compute conditional probabilities with current beta
1301                        let mut sum_pi = 0.0;
1302                        let mut h = 0.0;
1303
1304                        for j in 0..n_samples {
1305                            if i == j {
1306                                p_row[j] = 0.0;
1307                                continue;
1308                            }
1309
1310                            let p_ij = (-beta * distances_i[j]).exp();
1311                            p_row[j] = p_ij;
1312                            sum_pi += p_ij;
1313                        }
1314
1315                        // Normalize probabilities and compute entropy
1316                        if sum_pi > 0.0 {
1317                            for (j, prob) in p_row.iter_mut().enumerate().take(n_samples) {
1318                                if i == j {
1319                                    continue;
1320                                }
1321
1322                                *prob /= sum_pi;
1323
1324                                // Compute entropy
1325                                if *prob > MACHINE_EPSILON {
1326                                    h -= *prob * prob.ln();
1327                                }
1328                            }
1329                        }
1330
1331                        // Adjust beta based on entropy difference from target
1332                        let h_diff = h - target;
1333
1334                        if h_diff.abs() < EPSILON {
1335                            break; // Converged
1336                        }
1337
1338                        // Update beta using binary search
1339                        if h_diff > 0.0 {
1340                            beta_min = beta;
1341                            if beta_max == f64::INFINITY {
1342                                beta *= 2.0;
1343                            } else {
1344                                beta = (beta + beta_max) / 2.0;
1345                            }
1346                        } else {
1347                            beta_max = beta;
1348                            if beta_min == -f64::INFINITY {
1349                                beta /= 2.0;
1350                            } else {
1351                                beta = (beta + beta_min) / 2.0;
1352                            }
1353                        }
1354                    }
1355
1356                    p_row
1357                })
1358                .collect();
1359
1360            // Copy results back to the main matrix
1361            for (i, row) in prob_rows.iter().enumerate() {
1362                for (j, &val) in row.iter().enumerate() {
1363                    p[[i, j]] = val;
1364                }
1365            }
1366        }
1367
1368        Ok(p)
1369    }
1370
1371    /// Main t-SNE optimization loop using gradient descent
1372    #[allow(clippy::too_many_arguments)]
1373    fn tsne_optimization(
1374        &self,
1375        p: Array2<f64>,
1376        initial_embedding: Array2<f64>,
1377        n_samples: usize,
1378    ) -> Result<(Array2<f64>, f64, usize)> {
1379        let n_components = self.n_components;
1380        let degrees_of_freedom = (n_components - 1).max(1) as f64;
1381
1382        // Initialize variables for optimization
1383        let mut embedding = initial_embedding;
1384        let mut update = Array2::zeros((n_samples, n_components));
1385        let mut gains = Array2::ones((n_samples, n_components));
1386        let mut error = f64::INFINITY;
1387        let mut best_error = f64::INFINITY;
1388        let mut best_iter = 0;
1389        let mut iter = 0;
1390
1391        // Exploration phase with early exaggeration
1392        let exploration_n_iter = 250;
1393        let n_iter_check = 50;
1394
1395        // Apply early exaggeration
1396        let p_early = &p * self.early_exaggeration;
1397
1398        if self.verbose {
1399            println!("[t-SNE] Starting optimization with early exaggeration phase...");
1400        }
1401
1402        // Early exaggeration phase
1403        for i in 0..exploration_n_iter {
1404            // Compute gradient and error for early exaggeration phase
1405            let (curr_error, grad) = if self.method == "barnes_hut" {
1406                self.compute_gradient_barnes_hut(&embedding, &p_early, degrees_of_freedom)?
1407            } else {
1408                self.compute_gradient_exact(&embedding, &p_early, degrees_of_freedom)?
1409            };
1410
1411            // Perform gradient update with momentum and gains
1412            self.gradient_update(
1413                &mut embedding,
1414                &mut update,
1415                &mut gains,
1416                &grad,
1417                0.5,
1418                self.learning_rate_,
1419            )?;
1420
1421            // Check for convergence
1422            if (i + 1) % n_iter_check == 0 {
1423                if self.verbose {
1424                    println!("[t-SNE] Iteration {}: error = {:.7}", i + 1, curr_error);
1425                }
1426
1427                if curr_error < best_error {
1428                    best_error = curr_error;
1429                    best_iter = i;
1430                } else if i - best_iter > self.n_iter_without_progress {
1431                    if self.verbose {
1432                        println!("[t-SNE] Early convergence at iteration {}", i + 1);
1433                    }
1434                    break;
1435                }
1436
1437                // Check gradient norm
1438                let grad_norm = grad.mapv(|x| x * x).sum().sqrt();
1439                if grad_norm < self.min_grad_norm {
1440                    if self.verbose {
1441                        println!("[t-SNE] Gradient norm {} below threshold, stopping optimization at iteration {}", 
1442                                grad_norm, i + 1);
1443                    }
1444                    break;
1445                }
1446            }
1447
1448            iter = i;
1449        }
1450
1451        if self.verbose {
1452            println!("[t-SNE] Completed early exaggeration phase, starting final optimization...");
1453        }
1454
1455        // Final optimization phase without early exaggeration
1456        for i in iter + 1..self.max_iter {
1457            // Compute gradient and error for normal phase
1458            let (curr_error, grad) = if self.method == "barnes_hut" {
1459                self.compute_gradient_barnes_hut(&embedding, &p, degrees_of_freedom)?
1460            } else {
1461                self.compute_gradient_exact(&embedding, &p, degrees_of_freedom)?
1462            };
1463            error = curr_error;
1464
1465            // Perform gradient update with momentum and gains
1466            self.gradient_update(
1467                &mut embedding,
1468                &mut update,
1469                &mut gains,
1470                &grad,
1471                0.8,
1472                self.learning_rate_,
1473            )?;
1474
1475            // Check for convergence
1476            if (i + 1) % n_iter_check == 0 {
1477                if self.verbose {
1478                    println!("[t-SNE] Iteration {}: error = {:.7}", i + 1, curr_error);
1479                }
1480
1481                if curr_error < best_error {
1482                    best_error = curr_error;
1483                    best_iter = i;
1484                } else if i - best_iter > self.n_iter_without_progress {
1485                    if self.verbose {
1486                        println!("[t-SNE] Stopping optimization at iteration {}", i + 1);
1487                    }
1488                    break;
1489                }
1490
1491                // Check gradient norm
1492                let grad_norm = grad.mapv(|x| x * x).sum().sqrt();
1493                if grad_norm < self.min_grad_norm {
1494                    if self.verbose {
1495                        println!("[t-SNE] Gradient norm {} below threshold, stopping optimization at iteration {}", 
1496                                grad_norm, i + 1);
1497                    }
1498                    break;
1499                }
1500            }
1501
1502            iter = i;
1503        }
1504
1505        if self.verbose {
1506            println!(
1507                "[t-SNE] Optimization finished after {} iterations with error {:.7}",
1508                iter + 1,
1509                error
1510            );
1511        }
1512
1513        Ok((embedding, error, iter + 1))
1514    }
1515
1516    /// Compute gradient and error for exact t-SNE with optional multicore support
1517    #[allow(clippy::too_many_arguments)]
1518    fn compute_gradient_exact(
1519        &self,
1520        embedding: &Array2<f64>,
1521        p: &Array2<f64>,
1522        degrees_of_freedom: f64,
1523    ) -> Result<(f64, Array2<f64>)> {
1524        let n_samples = embedding.shape()[0];
1525        let n_components = embedding.shape()[1];
1526
1527        if self.n_jobs == 1 {
1528            // Single-core computation (original implementation)
1529            let mut dist = Array2::zeros((n_samples, n_samples));
1530            for i in 0..n_samples {
1531                for j in i + 1..n_samples {
1532                    let mut d_squared = 0.0;
1533                    for k in 0..n_components {
1534                        let diff = embedding[[i, k]] - embedding[[j, k]];
1535                        d_squared += diff * diff;
1536                    }
1537
1538                    // Convert squared distance to t-distribution's probability
1539                    let q_ij = (1.0 + d_squared / degrees_of_freedom)
1540                        .powf(-(degrees_of_freedom + 1.0) / 2.0);
1541                    dist[[i, j]] = q_ij;
1542                    dist[[j, i]] = q_ij;
1543                }
1544            }
1545
1546            // Set diagonal to zero (self-distance)
1547            for i in 0..n_samples {
1548                dist[[i, i]] = 0.0;
1549            }
1550
1551            // Normalize Q matrix
1552            let sum_q = dist.sum().max(MACHINE_EPSILON);
1553            let q = &dist / sum_q;
1554
1555            // Compute KL divergence
1556            let mut kl_divergence = 0.0;
1557            for i in 0..n_samples {
1558                for j in 0..n_samples {
1559                    if p[[i, j]] > MACHINE_EPSILON && q[[i, j]] > MACHINE_EPSILON {
1560                        kl_divergence += p[[i, j]] * (p[[i, j]] / q[[i, j]]).ln();
1561                    }
1562                }
1563            }
1564
1565            // Compute gradient
1566            let mut grad = Array2::zeros((n_samples, n_components));
1567            let factor =
1568                4.0 * (degrees_of_freedom + 1.0) / (degrees_of_freedom * (sum_q.powf(2.0)));
1569
1570            for i in 0..n_samples {
1571                for j in 0..n_samples {
1572                    if i != j {
1573                        let p_q_diff = p[[i, j]] - q[[i, j]];
1574                        for k in 0..n_components {
1575                            grad[[i, k]] += factor
1576                                * p_q_diff
1577                                * dist[[i, j]]
1578                                * (embedding[[i, k]] - embedding[[j, k]]);
1579                        }
1580                    }
1581                }
1582            }
1583
1584            Ok((kl_divergence, grad))
1585        } else {
1586            // Multi-core computation
1587            let upper_triangle_indices: Vec<(usize, usize)> = (0..n_samples)
1588                .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
1589                .collect();
1590
1591            let q_values: Vec<f64> = upper_triangle_indices
1592                .par_iter()
1593                .map(|&(i, j)| {
1594                    let mut d_squared = 0.0;
1595                    for k in 0..n_components {
1596                        let diff = embedding[[i, k]] - embedding[[j, k]];
1597                        d_squared += diff * diff;
1598                    }
1599
1600                    // Convert squared distance to t-distribution's probability
1601                    (1.0 + d_squared / degrees_of_freedom).powf(-(degrees_of_freedom + 1.0) / 2.0)
1602                })
1603                .collect();
1604
1605            // Fill the distance matrix
1606            let mut dist = Array2::zeros((n_samples, n_samples));
1607            for (idx, &(i, j)) in upper_triangle_indices.iter().enumerate() {
1608                let q_val = q_values[idx];
1609                dist[[i, j]] = q_val;
1610                dist[[j, i]] = q_val;
1611            }
1612
1613            // Set diagonal to zero (self-distance)
1614            for i in 0..n_samples {
1615                dist[[i, i]] = 0.0;
1616            }
1617
1618            // Normalize Q matrix
1619            let sum_q = dist.sum().max(MACHINE_EPSILON);
1620            let q = &dist / sum_q;
1621
1622            // Parallel computation of KL divergence
1623            let kl_divergence: f64 = (0..n_samples)
1624                .into_par_iter()
1625                .map(|i| {
1626                    let mut local_kl = 0.0;
1627                    for j in 0..n_samples {
1628                        if p[[i, j]] > MACHINE_EPSILON && q[[i, j]] > MACHINE_EPSILON {
1629                            local_kl += p[[i, j]] * (p[[i, j]] / q[[i, j]]).ln();
1630                        }
1631                    }
1632                    local_kl
1633                })
1634                .sum();
1635
1636            // Parallel computation of gradient
1637            let factor =
1638                4.0 * (degrees_of_freedom + 1.0) / (degrees_of_freedom * (sum_q.powf(2.0)));
1639
1640            let grad_rows: Vec<Vec<f64>> = (0..n_samples)
1641                .into_par_iter()
1642                .map(|i| {
1643                    let mut grad_row = vec![0.0; n_components];
1644                    for j in 0..n_samples {
1645                        if i != j {
1646                            let p_q_diff = p[[i, j]] - q[[i, j]];
1647                            for k in 0..n_components {
1648                                grad_row[k] += factor
1649                                    * p_q_diff
1650                                    * dist[[i, j]]
1651                                    * (embedding[[i, k]] - embedding[[j, k]]);
1652                            }
1653                        }
1654                    }
1655                    grad_row
1656                })
1657                .collect();
1658
1659            // Convert gradient rows back to array
1660            let mut grad = Array2::zeros((n_samples, n_components));
1661            for (i, row) in grad_rows.iter().enumerate() {
1662                for (k, &val) in row.iter().enumerate() {
1663                    grad[[i, k]] = val;
1664                }
1665            }
1666
1667            Ok((kl_divergence, grad))
1668        }
1669    }
1670
1671    /// Compute gradient and error using Barnes-Hut approximation
1672    #[allow(clippy::too_many_arguments)]
1673    fn compute_gradient_barnes_hut(
1674        &self,
1675        embedding: &Array2<f64>,
1676        p: &Array2<f64>,
1677        degrees_of_freedom: f64,
1678    ) -> Result<(f64, Array2<f64>)> {
1679        let n_samples = embedding.shape()[0];
1680        let n_components = embedding.shape()[1];
1681
1682        // Build spatial tree for Barnes-Hut approximation
1683        let tree = if n_components == 2 {
1684            SpatialTree::new_quadtree(embedding)?
1685        } else if n_components == 3 {
1686            SpatialTree::new_octree(embedding)?
1687        } else {
1688            return Err(TransformError::InvalidInput(
1689                "Barnes-Hut approximation only supports 2D and 3D embeddings".to_string(),
1690            ));
1691        };
1692
1693        // Compute Q matrix and gradient using Barnes-Hut
1694        let mut q = Array2::zeros((n_samples, n_samples));
1695        let mut grad = Array2::zeros((n_samples, n_components));
1696        let mut sum_q = 0.0;
1697
1698        // For each point, compute repulsive forces using Barnes-Hut
1699        for i in 0..n_samples {
1700            let point = embedding.row(i).to_owned();
1701            let (repulsive_force, q_sum) =
1702                tree.compute_forces(&point, i, self.angle, degrees_of_freedom)?;
1703
1704            sum_q += q_sum;
1705
1706            // Add repulsive forces to gradient
1707            for j in 0..n_components {
1708                grad[[i, j]] += repulsive_force[j];
1709            }
1710
1711            // Compute Q matrix for KL divergence calculation
1712            for j in 0..n_samples {
1713                if i != j {
1714                    let mut dist_squared = 0.0;
1715                    for k in 0..n_components {
1716                        let diff = embedding[[i, k]] - embedding[[j, k]];
1717                        dist_squared += diff * diff;
1718                    }
1719                    let q_ij = (1.0 + dist_squared / degrees_of_freedom)
1720                        .powf(-(degrees_of_freedom + 1.0) / 2.0);
1721                    q[[i, j]] = q_ij;
1722                }
1723            }
1724        }
1725
1726        // Normalize Q matrix
1727        sum_q = sum_q.max(MACHINE_EPSILON);
1728        q.mapv_inplace(|x| x / sum_q);
1729
1730        // Add attractive forces to gradient
1731        for i in 0..n_samples {
1732            for j in 0..n_samples {
1733                if i != j && p[[i, j]] > MACHINE_EPSILON {
1734                    let mut dist_squared = 0.0;
1735                    for k in 0..n_components {
1736                        let diff = embedding[[i, k]] - embedding[[j, k]];
1737                        dist_squared += diff * diff;
1738                    }
1739
1740                    let q_ij = (1.0 + dist_squared / degrees_of_freedom)
1741                        .powf(-(degrees_of_freedom + 1.0) / 2.0);
1742                    let factor = 4.0 * p[[i, j]] * q_ij;
1743
1744                    for k in 0..n_components {
1745                        grad[[i, k]] -= factor * (embedding[[i, k]] - embedding[[j, k]]);
1746                    }
1747                }
1748            }
1749        }
1750
1751        // Compute KL divergence
1752        let mut kl_divergence = 0.0;
1753        for i in 0..n_samples {
1754            for j in 0..n_samples {
1755                if p[[i, j]] > MACHINE_EPSILON && q[[i, j]] > MACHINE_EPSILON {
1756                    kl_divergence += p[[i, j]] * (p[[i, j]] / q[[i, j]]).ln();
1757                }
1758            }
1759        }
1760
1761        Ok((kl_divergence, grad))
1762    }
1763
1764    /// Update embedding using gradient descent with momentum and adaptive gains
1765    #[allow(clippy::too_many_arguments)]
1766    fn gradient_update(
1767        &self,
1768        embedding: &mut Array2<f64>,
1769        update: &mut Array2<f64>,
1770        gains: &mut Array2<f64>,
1771        grad: &Array2<f64>,
1772        momentum: f64,
1773        learning_rate: Option<f64>,
1774    ) -> Result<()> {
1775        let n_samples = embedding.shape()[0];
1776        let n_components = embedding.shape()[1];
1777        let eta = learning_rate.unwrap_or(self.learning_rate);
1778
1779        // Update gains and momentum
1780        for i in 0..n_samples {
1781            for j in 0..n_components {
1782                let same_sign = update[[i, j]] * grad[[i, j]] > 0.0;
1783
1784                if same_sign {
1785                    gains[[i, j]] *= 0.8;
1786                } else {
1787                    gains[[i, j]] += 0.2;
1788                }
1789
1790                // Ensure minimum gain
1791                gains[[i, j]] = gains[[i, j]].max(0.01);
1792
1793                // Update with momentum and adaptive learning _rate
1794                update[[i, j]] = momentum * update[[i, j]] - eta * gains[[i, j]] * grad[[i, j]];
1795                embedding[[i, j]] += update[[i, j]];
1796            }
1797        }
1798
1799        Ok(())
1800    }
1801
1802    /// Returns the embedding after fitting
1803    pub fn embedding(&self) -> Option<&Array2<f64>> {
1804        self.embedding_.as_ref()
1805    }
1806
1807    /// Returns the KL divergence after optimization
1808    pub fn kl_divergence(&self) -> Option<f64> {
1809        self.kl_divergence_
1810    }
1811
1812    /// Returns the number of iterations run
1813    pub fn n_iter(&self) -> Option<usize> {
1814        self.n_iter_
1815    }
1816}
1817
1818/// Calculate trustworthiness score for a dimensionality reduction
1819///
1820/// Trustworthiness measures to what extent the local structure is retained when
1821/// projecting data from the original space to the embedding space.
1822///
1823/// # Arguments
1824/// * `x` - Original data, shape (n_samples, n_features)
1825/// * `x_embedded` - Embedded data, shape (n_samples, n_components)
1826/// * `n_neighbors` - Number of neighbors to consider
1827/// * `metric` - Metric to use (currently only 'euclidean' is implemented)
1828///
1829/// # Returns
1830/// * `Result<f64>` - Trustworthiness score between 0.0 and 1.0
1831#[allow(dead_code)]
1832#[allow(clippy::too_many_arguments)]
1833pub fn trustworthiness<S1, S2>(
1834    x: &ArrayBase<S1, Ix2>,
1835    x_embedded: &ArrayBase<S2, Ix2>,
1836    n_neighbors: usize,
1837    metric: &str,
1838) -> Result<f64>
1839where
1840    S1: Data,
1841    S2: Data,
1842    S1::Elem: Float + NumCast,
1843    S2::Elem: Float + NumCast,
1844{
1845    let x_f64 = x.mapv(|x| num_traits::cast::<S1::Elem, f64>(x).unwrap_or(0.0));
1846    let x_embedded_f64 = x_embedded.mapv(|x| num_traits::cast::<S2::Elem, f64>(x).unwrap_or(0.0));
1847
1848    let n_samples = x_f64.shape()[0];
1849
1850    if n_neighbors >= n_samples / 2 {
1851        return Err(TransformError::InvalidInput(format!(
1852            "n_neighbors ({}) should be less than n_samples / 2 ({})",
1853            n_neighbors,
1854            n_samples / 2
1855        )));
1856    }
1857
1858    if metric != "euclidean" {
1859        return Err(TransformError::InvalidInput(format!(
1860            "Metric '{metric}' not implemented. Currently only 'euclidean' is supported"
1861        )));
1862    }
1863
1864    // Compute pairwise distances in original space
1865    let mut dist_x = Array2::zeros((n_samples, n_samples));
1866    for i in 0..n_samples {
1867        for j in 0..n_samples {
1868            if i == j {
1869                dist_x[[i, j]] = f64::INFINITY; // Set self-distance to infinity
1870                continue;
1871            }
1872
1873            let mut d_squared = 0.0;
1874            for k in 0..x_f64.shape()[1] {
1875                let diff = x_f64[[i, k]] - x_f64[[j, k]];
1876                d_squared += diff * diff;
1877            }
1878            dist_x[[i, j]] = d_squared.sqrt();
1879        }
1880    }
1881
1882    // Compute pairwise distances in _embedded space
1883    let mut dist_embedded = Array2::zeros((n_samples, n_samples));
1884    for i in 0..n_samples {
1885        for j in 0..n_samples {
1886            if i == j {
1887                dist_embedded[[i, j]] = f64::INFINITY; // Set self-distance to infinity
1888                continue;
1889            }
1890
1891            let mut d_squared = 0.0;
1892            for k in 0..x_embedded_f64.shape()[1] {
1893                let diff = x_embedded_f64[[i, k]] - x_embedded_f64[[j, k]];
1894                d_squared += diff * diff;
1895            }
1896            dist_embedded[[i, j]] = d_squared.sqrt();
1897        }
1898    }
1899
1900    // For each point, find the n_neighbors nearest _neighbors in the original space
1901    let mut nn_orig = Array2::<usize>::zeros((n_samples, n_neighbors));
1902    for i in 0..n_samples {
1903        // Get the indices of the sorted distances
1904        let row = dist_x.row(i).to_owned();
1905        let mut pairs: Vec<(usize, f64)> = row.iter().enumerate().map(|(j, &d)| (j, d)).collect();
1906        pairs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1907
1908        // The first element will be i itself (distance 0), so skip it
1909        for (j, &(idx_, _)) in pairs.iter().enumerate().take(n_neighbors) {
1910            nn_orig[[i, j]] = idx_;
1911        }
1912    }
1913
1914    // For each point, find the n_neighbors nearest _neighbors in the _embedded space
1915    let mut nn_embedded = Array2::<usize>::zeros((n_samples, n_neighbors));
1916    for i in 0..n_samples {
1917        // Get the indices of the sorted distances
1918        let row = dist_embedded.row(i).to_owned();
1919        let mut pairs: Vec<(usize, f64)> = row.iter().enumerate().map(|(j, &d)| (j, d)).collect();
1920        pairs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1921
1922        // The first element will be i itself (distance 0), so skip it
1923        for (j, &(idx, _)) in pairs.iter().skip(1).take(n_neighbors).enumerate() {
1924            nn_embedded[[i, j]] = idx;
1925        }
1926    }
1927
1928    // Calculate the trustworthiness score
1929    let mut t = 0.0;
1930    for i in 0..n_samples {
1931        for &j in nn_embedded.row(i).iter() {
1932            // Check if j is not in the n_neighbors nearest neighbors in the original space
1933            let is_not_neighbor = !nn_orig.row(i).iter().any(|&nn| nn == j);
1934
1935            if is_not_neighbor {
1936                // Find the rank of j in the original space
1937                let row = dist_x.row(i).to_owned();
1938                let mut pairs: Vec<(usize, f64)> =
1939                    row.iter().enumerate().map(|(idx, &d)| (idx, d)).collect();
1940                pairs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1941
1942                let rank = pairs.iter().position(|&(idx_, _)| idx_ == j).unwrap_or(0) - n_neighbors;
1943
1944                t += rank as f64;
1945            }
1946        }
1947    }
1948
1949    // Normalize the trustworthiness score
1950    let n = n_samples as f64;
1951    let k = n_neighbors as f64;
1952    let normalizer = 2.0 / (n * k * (2.0 * n - 3.0 * k - 1.0));
1953    let trustworthiness = 1.0 - normalizer * t;
1954
1955    Ok(trustworthiness)
1956}
1957
1958#[cfg(test)]
1959mod tests {
1960    use super::*;
1961    use approx::assert_abs_diff_eq;
1962    use ndarray::arr2;
1963
1964    #[test]
1965    fn test_tsne_simple() {
1966        // Create a simple dataset
1967        let x = arr2(&[
1968            [0.0, 0.0],
1969            [0.0, 1.0],
1970            [1.0, 0.0],
1971            [1.0, 1.0],
1972            [5.0, 5.0],
1973            [6.0, 5.0],
1974            [5.0, 6.0],
1975            [6.0, 6.0],
1976        ]);
1977
1978        // Initialize and fit t-SNE with exact method
1979        let mut tsne_exact = TSNE::new()
1980            .with_n_components(2)
1981            .with_perplexity(2.0)
1982            .with_method("exact")
1983            .with_random_state(42)
1984            .with_max_iter(250)
1985            .with_verbose(false);
1986
1987        let embedding_exact = tsne_exact.fit_transform(&x).unwrap();
1988
1989        // Check that the shape is correct
1990        assert_eq!(embedding_exact.shape(), &[8, 2]);
1991
1992        // Check that groups are separated in the embedding space
1993        // Compute the average distance within each group
1994        let dist_group1 = average_pairwise_distance(&embedding_exact.slice(ndarray::s![0..4, ..]));
1995        let dist_group2 = average_pairwise_distance(&embedding_exact.slice(ndarray::s![4..8, ..]));
1996
1997        // Compute the average distance between groups
1998        let dist_between = average_intergroup_distance(
1999            &embedding_exact.slice(ndarray::s![0..4, ..]),
2000            &embedding_exact.slice(ndarray::s![4..8, ..]),
2001        );
2002
2003        // The between-group distance should be larger than the within-group distances
2004        assert!(dist_between > dist_group1);
2005        assert!(dist_between > dist_group2);
2006    }
2007
2008    #[test]
2009    fn test_tsne_barnes_hut() {
2010        // Create a simple dataset
2011        let x = arr2(&[
2012            [0.0, 0.0],
2013            [0.0, 1.0],
2014            [1.0, 0.0],
2015            [1.0, 1.0],
2016            [5.0, 5.0],
2017            [6.0, 5.0],
2018            [5.0, 6.0],
2019            [6.0, 6.0],
2020        ]);
2021
2022        // Initialize and fit t-SNE with Barnes-Hut method
2023        let mut tsne_bh = TSNE::new()
2024            .with_n_components(2)
2025            .with_perplexity(2.0)
2026            .with_method("barnes_hut")
2027            .with_angle(0.5)
2028            .with_random_state(42)
2029            .with_max_iter(250)
2030            .with_verbose(false);
2031
2032        let embedding_bh = tsne_bh.fit_transform(&x).unwrap();
2033
2034        // Check that the shape is correct
2035        assert_eq!(embedding_bh.shape(), &[8, 2]);
2036
2037        // Test basic functionality - Barnes-Hut is approximate so just check for basic properties
2038        assert!(embedding_bh.iter().all(|&x| x.is_finite()));
2039
2040        // Check that the embedding has some spread (not all points collapsed to the same location)
2041        let min_val = embedding_bh.iter().cloned().fold(f64::INFINITY, f64::min);
2042        let max_val = embedding_bh
2043            .iter()
2044            .cloned()
2045            .fold(f64::NEG_INFINITY, f64::max);
2046        assert!(
2047            max_val - min_val > 1e-6,
2048            "Embedding should have some spread"
2049        );
2050
2051        // Check that KL divergence was computed (Barnes-Hut is approximate, so we're more lenient)
2052        assert!(tsne_bh.kl_divergence().is_some());
2053
2054        // For Barnes-Hut approximation, the KL divergence might not always be finite
2055        // due to the approximation nature, so we just check that it's a number
2056        let kl_div = tsne_bh.kl_divergence().unwrap();
2057        if !kl_div.is_finite() {
2058            // This is acceptable for Barnes-Hut approximation
2059            println!(
2060                "Barnes-Hut KL divergence: {} (non-finite, which is acceptable for approximation)",
2061                kl_div
2062            );
2063        } else {
2064            println!("Barnes-Hut KL divergence: {} (finite)", kl_div);
2065        }
2066    }
2067
2068    #[test]
2069    fn test_tsne_multicore() {
2070        // Create a simple dataset
2071        let x = arr2(&[
2072            [0.0, 0.0],
2073            [0.0, 1.0],
2074            [1.0, 0.0],
2075            [1.0, 1.0],
2076            [5.0, 5.0],
2077            [6.0, 5.0],
2078            [5.0, 6.0],
2079            [6.0, 6.0],
2080        ]);
2081
2082        // Initialize and fit t-SNE with multicore enabled
2083        let mut tsne_multicore = TSNE::new()
2084            .with_n_components(2)
2085            .with_perplexity(2.0)
2086            .with_method("exact")
2087            .with_n_jobs(-1) // Use all cores
2088            .with_random_state(42)
2089            .with_max_iter(100) // Shorter for testing
2090            .with_verbose(false);
2091
2092        let embedding_multicore = tsne_multicore.fit_transform(&x).unwrap();
2093
2094        // Check that the shape is correct
2095        assert_eq!(embedding_multicore.shape(), &[8, 2]);
2096
2097        // Test basic functionality - multicore should produce valid results
2098        assert!(embedding_multicore.iter().all(|&x| x.is_finite()));
2099
2100        // Check that the embedding has some spread (more lenient for short iterations)
2101        let min_val = embedding_multicore
2102            .iter()
2103            .cloned()
2104            .fold(f64::INFINITY, f64::min);
2105        let max_val = embedding_multicore
2106            .iter()
2107            .cloned()
2108            .fold(f64::NEG_INFINITY, f64::max);
2109        assert!(
2110            max_val - min_val > 1e-12,
2111            "Embedding should have some spread, got range: {}",
2112            max_val - min_val
2113        );
2114
2115        // Test single-core vs multicore consistency
2116        let mut tsne_singlecore = TSNE::new()
2117            .with_n_components(2)
2118            .with_perplexity(2.0)
2119            .with_method("exact")
2120            .with_n_jobs(1) // Single core
2121            .with_random_state(42)
2122            .with_max_iter(100)
2123            .with_verbose(false);
2124
2125        let embedding_singlecore = tsne_singlecore.fit_transform(&x).unwrap();
2126
2127        // Both should produce finite results (exact numerical match is not expected due to randomness)
2128        assert!(embedding_multicore.iter().all(|&x| x.is_finite()));
2129        assert!(embedding_singlecore.iter().all(|&x| x.is_finite()));
2130    }
2131
2132    #[test]
2133    fn test_tsne_3d_barnes_hut() {
2134        // Create a simple 3D dataset
2135        let x = arr2(&[
2136            [0.0, 0.0, 0.0],
2137            [0.0, 1.0, 0.0],
2138            [1.0, 0.0, 0.0],
2139            [1.0, 1.0, 0.0],
2140            [5.0, 5.0, 5.0],
2141            [6.0, 5.0, 5.0],
2142            [5.0, 6.0, 5.0],
2143            [6.0, 6.0, 5.0],
2144        ]);
2145
2146        // Initialize and fit t-SNE with Barnes-Hut method for 3D
2147        let mut tsne_3d = TSNE::new()
2148            .with_n_components(3)
2149            .with_perplexity(2.0)
2150            .with_method("barnes_hut")
2151            .with_angle(0.5)
2152            .with_random_state(42)
2153            .with_max_iter(250)
2154            .with_verbose(false);
2155
2156        let embedding_3d = tsne_3d.fit_transform(&x).unwrap();
2157
2158        // Check that the shape is correct
2159        assert_eq!(embedding_3d.shape(), &[8, 3]);
2160
2161        // Test basic functionality - should not panic
2162        assert!(embedding_3d.iter().all(|&x| x.is_finite()));
2163    }
2164
2165    // Helper function to compute average pairwise distance within a group
2166    fn average_pairwise_distance(points: &ArrayBase<ndarray::ViewRepr<&f64>, Ix2>) -> f64 {
2167        let n = points.shape()[0];
2168        let mut total_dist = 0.0;
2169        let mut count = 0;
2170
2171        for i in 0..n {
2172            for j in i + 1..n {
2173                let mut dist_squared = 0.0;
2174                for k in 0..points.shape()[1] {
2175                    let diff = points[[i, k]] - points[[j, k]];
2176                    dist_squared += diff * diff;
2177                }
2178                total_dist += dist_squared.sqrt();
2179                count += 1;
2180            }
2181        }
2182
2183        if count > 0 {
2184            total_dist / count as f64
2185        } else {
2186            0.0
2187        }
2188    }
2189
2190    // Helper function to compute average distance between two groups
2191    fn average_intergroup_distance(
2192        group1: &ArrayBase<ndarray::ViewRepr<&f64>, Ix2>,
2193        group2: &ArrayBase<ndarray::ViewRepr<&f64>, Ix2>,
2194    ) -> f64 {
2195        let n1 = group1.shape()[0];
2196        let n2 = group2.shape()[0];
2197        let mut total_dist = 0.0;
2198        let mut count = 0;
2199
2200        for i in 0..n1 {
2201            for j in 0..n2 {
2202                let mut dist_squared = 0.0;
2203                for k in 0..group1.shape()[1] {
2204                    let diff = group1[[i, k]] - group2[[j, k]];
2205                    dist_squared += diff * diff;
2206                }
2207                total_dist += dist_squared.sqrt();
2208                count += 1;
2209            }
2210        }
2211
2212        if count > 0 {
2213            total_dist / count as f64
2214        } else {
2215            0.0
2216        }
2217    }
2218
2219    #[test]
2220    fn test_trustworthiness() {
2221        // Create a simple dataset where we know the structure
2222        let x = arr2(&[
2223            [0.0, 0.0],
2224            [0.0, 1.0],
2225            [1.0, 0.0],
2226            [1.0, 1.0],
2227            [5.0, 5.0],
2228            [5.0, 6.0],
2229            [6.0, 5.0],
2230            [6.0, 6.0],
2231        ]);
2232
2233        // A perfect embedding would preserve all neighborhoods
2234        let perfect_embedding = x.clone();
2235        let t_perfect = trustworthiness(&x, &perfect_embedding, 3, "euclidean").unwrap();
2236        assert_abs_diff_eq!(t_perfect, 1.0, epsilon = 1e-10);
2237
2238        // A random embedding would have low trustworthiness
2239        let random_embedding = arr2(&[
2240            [0.9, 0.1],
2241            [0.8, 0.2],
2242            [0.7, 0.3],
2243            [0.6, 0.4],
2244            [0.5, 0.5],
2245            [0.4, 0.6],
2246            [0.3, 0.7],
2247            [0.2, 0.8],
2248        ]);
2249
2250        let t_random = trustworthiness(&x, &random_embedding, 3, "euclidean").unwrap();
2251        assert!(t_random < 1.0);
2252    }
2253}