scirs2_transform/reduction/
tsne.rs

1//! t-SNE (t-distributed Stochastic Neighbor Embedding) implementation
2//!
3//! This module provides an implementation of t-SNE, a technique for dimensionality
4//! reduction particularly well-suited for visualization of high-dimensional data.
5//!
6//! t-SNE converts similarities between data points to joint probabilities and tries
7//! to minimize the Kullback-Leibler divergence between the joint probabilities of
8//! the low-dimensional embedding and the high-dimensional data.
9
10use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
11use scirs2_core::numeric::{Float, NumCast};
12use scirs2_core::parallel_ops::*;
13use scirs2_core::random::Normal;
14use scirs2_core::random::RandomExt;
15
16use crate::error::{Result, TransformError};
17use crate::reduction::PCA;
18
19// Constants for numerical stability
20const MACHINE_EPSILON: f64 = 1e-14;
21const EPSILON: f64 = 1e-7;
22
23/// Spatial tree data structure for Barnes-Hut approximation
24#[derive(Debug, Clone)]
25enum SpatialTree {
26    QuadTree(QuadTreeNode),
27    OctTree(OctTreeNode),
28}
29
30/// Node in a quadtree (for 2D embeddings)
31#[derive(Debug, Clone)]
32struct QuadTreeNode {
33    /// Bounding box of this node
34    x_min: f64,
35    x_max: f64,
36    y_min: f64,
37    y_max: f64,
38    /// Center of mass
39    center_of_mass: Option<Array1<f64>>,
40    /// Total mass (number of points)
41    total_mass: f64,
42    /// Point indices in this node (for leaf nodes)
43    point_indices: Vec<usize>,
44    /// Children nodes (NW, NE, SW, SE)
45    children: Option<[Box<QuadTreeNode>; 4]>,
46    /// Whether this is a leaf node
47    is_leaf: bool,
48}
49
50/// Node in an octree (for 3D embeddings)
51#[derive(Debug, Clone)]
52struct OctTreeNode {
53    /// Bounding box of this node
54    x_min: f64,
55    x_max: f64,
56    y_min: f64,
57    y_max: f64,
58    z_min: f64,
59    z_max: f64,
60    /// Center of mass
61    center_of_mass: Option<Array1<f64>>,
62    /// Total mass (number of points)
63    total_mass: f64,
64    /// Point indices in this node (for leaf nodes)
65    point_indices: Vec<usize>,
66    /// Children nodes (8 octants)
67    children: Option<[Box<OctTreeNode>; 8]>,
68    /// Whether this is a leaf node
69    is_leaf: bool,
70}
71
72impl SpatialTree {
73    /// Create a new quadtree for 2D embeddings
74    fn new_quadtree(embedding: &Array2<f64>) -> Result<Self> {
75        let n_samples = embedding.shape()[0];
76
77        if embedding.shape()[1] != 2 {
78            return Err(TransformError::InvalidInput(
79                "QuadTree requires 2D _embedding".to_string(),
80            ));
81        }
82
83        // Find bounding box
84        let mut x_min = f64::INFINITY;
85        let mut x_max = f64::NEG_INFINITY;
86        let mut y_min = f64::INFINITY;
87        let mut y_max = f64::NEG_INFINITY;
88
89        for i in 0..n_samples {
90            let x = embedding[[i, 0]];
91            let y = embedding[[i, 1]];
92            x_min = x_min.min(x);
93            x_max = x_max.max(x);
94            y_min = y_min.min(y);
95            y_max = y_max.max(y);
96        }
97
98        // Add small margin to avoid edge cases
99        let margin = 0.01 * ((x_max - x_min) + (y_max - y_min));
100        x_min -= margin;
101        x_max += margin;
102        y_min -= margin;
103        y_max += margin;
104
105        // Collect all point indices
106        let point_indices: Vec<usize> = (0..n_samples).collect();
107
108        // Create root node
109        let mut root = QuadTreeNode {
110            x_min,
111            x_max,
112            y_min,
113            y_max,
114            center_of_mass: None,
115            total_mass: 0.0,
116            point_indices,
117            children: None,
118            is_leaf: true,
119        };
120
121        // Build the tree
122        root.build_tree(embedding)?;
123
124        Ok(SpatialTree::QuadTree(root))
125    }
126
127    /// Create a new octree for 3D embeddings
128    fn new_octree(embedding: &Array2<f64>) -> Result<Self> {
129        let n_samples = embedding.shape()[0];
130
131        if embedding.shape()[1] != 3 {
132            return Err(TransformError::InvalidInput(
133                "OctTree requires 3D _embedding".to_string(),
134            ));
135        }
136
137        // Find bounding box
138        let mut x_min = f64::INFINITY;
139        let mut x_max = f64::NEG_INFINITY;
140        let mut y_min = f64::INFINITY;
141        let mut y_max = f64::NEG_INFINITY;
142        let mut z_min = f64::INFINITY;
143        let mut z_max = f64::NEG_INFINITY;
144
145        for i in 0..n_samples {
146            let x = embedding[[i, 0]];
147            let y = embedding[[i, 1]];
148            let z = embedding[[i, 2]];
149            x_min = x_min.min(x);
150            x_max = x_max.max(x);
151            y_min = y_min.min(y);
152            y_max = y_max.max(y);
153            z_min = z_min.min(z);
154            z_max = z_max.max(z);
155        }
156
157        // Add small margin to avoid edge cases
158        let margin = 0.01 * ((x_max - x_min) + (y_max - y_min) + (z_max - z_min));
159        x_min -= margin;
160        x_max += margin;
161        y_min -= margin;
162        y_max += margin;
163        z_min -= margin;
164        z_max += margin;
165
166        // Collect all point indices
167        let point_indices: Vec<usize> = (0..n_samples).collect();
168
169        // Create root node
170        let mut root = OctTreeNode {
171            x_min,
172            x_max,
173            y_min,
174            y_max,
175            z_min,
176            z_max,
177            center_of_mass: None,
178            total_mass: 0.0,
179            point_indices,
180            children: None,
181            is_leaf: true,
182        };
183
184        // Build the tree
185        root.build_tree(embedding)?;
186
187        Ok(SpatialTree::OctTree(root))
188    }
189
190    /// Compute forces on a point using Barnes-Hut approximation
191    #[allow(clippy::too_many_arguments)]
192    fn compute_forces(
193        &self,
194        point: &Array1<f64>,
195        point_idx: usize,
196        angle: f64,
197        degrees_of_freedom: f64,
198    ) -> Result<(Array1<f64>, f64)> {
199        match self {
200            SpatialTree::QuadTree(root) => {
201                root.compute_forces_quad(point, point_idx, angle, degrees_of_freedom)
202            }
203            SpatialTree::OctTree(root) => {
204                root.compute_forces_oct(point, point_idx, angle, degrees_of_freedom)
205            }
206        }
207    }
208}
209
210impl QuadTreeNode {
211    /// Build the quadtree recursively
212    fn build_tree(&mut self, embedding: &Array2<f64>) -> Result<()> {
213        if self.point_indices.len() <= 1 {
214            // Leaf node with 0 or 1 points
215            self.update_center_of_mass(embedding)?;
216            return Ok(());
217        }
218
219        // Split into 4 quadrants
220        let x_mid = (self.x_min + self.x_max) / 2.0;
221        let y_mid = (self.y_min + self.y_max) / 2.0;
222
223        let mut quadrants: [Vec<usize>; 4] = [Vec::new(), Vec::new(), Vec::new(), Vec::new()];
224
225        // Distribute points to quadrants
226        for &idx in &self.point_indices {
227            let x = embedding[[idx, 0]];
228            let y = embedding[[idx, 1]];
229
230            let quadrant = match (x >= x_mid, y >= y_mid) {
231                (false, false) => 0, // SW
232                (true, false) => 1,  // SE
233                (false, true) => 2,  // NW
234                (true, true) => 3,   // NE
235            };
236
237            quadrants[quadrant].push(idx);
238        }
239
240        // Create child nodes
241        let mut children = [
242            Box::new(QuadTreeNode {
243                x_min: self.x_min,
244                x_max: x_mid,
245                y_min: self.y_min,
246                y_max: y_mid,
247                center_of_mass: None,
248                total_mass: 0.0,
249                point_indices: quadrants[0].clone(),
250                children: None,
251                is_leaf: true,
252            }),
253            Box::new(QuadTreeNode {
254                x_min: x_mid,
255                x_max: self.x_max,
256                y_min: self.y_min,
257                y_max: y_mid,
258                center_of_mass: None,
259                total_mass: 0.0,
260                point_indices: quadrants[1].clone(),
261                children: None,
262                is_leaf: true,
263            }),
264            Box::new(QuadTreeNode {
265                x_min: self.x_min,
266                x_max: x_mid,
267                y_min: y_mid,
268                y_max: self.y_max,
269                center_of_mass: None,
270                total_mass: 0.0,
271                point_indices: quadrants[2].clone(),
272                children: None,
273                is_leaf: true,
274            }),
275            Box::new(QuadTreeNode {
276                x_min: x_mid,
277                x_max: self.x_max,
278                y_min: y_mid,
279                y_max: self.y_max,
280                center_of_mass: None,
281                total_mass: 0.0,
282                point_indices: quadrants[3].clone(),
283                children: None,
284                is_leaf: true,
285            }),
286        ];
287
288        // Recursively build children
289        for child in &mut children {
290            child.build_tree(embedding)?;
291        }
292
293        self.children = Some(children);
294        self.is_leaf = false;
295        self.point_indices.clear(); // Clear points as they are now in children
296        self.update_center_of_mass(embedding)?;
297
298        Ok(())
299    }
300
301    /// Update center of mass for this node
302    fn update_center_of_mass(&mut self, embedding: &Array2<f64>) -> Result<()> {
303        if self.is_leaf {
304            // Leaf node: compute center of mass from points
305            if self.point_indices.is_empty() {
306                self.total_mass = 0.0;
307                self.center_of_mass = None;
308                return Ok(());
309            }
310
311            let mut com = Array1::zeros(2);
312            for &idx in &self.point_indices {
313                com[0] += embedding[[idx, 0]];
314                com[1] += embedding[[idx, 1]];
315            }
316
317            self.total_mass = self.point_indices.len() as f64;
318            com.mapv_inplace(|x| x / self.total_mass);
319            self.center_of_mass = Some(com);
320        } else {
321            // Internal node: compute center of mass from children
322            if let Some(ref children) = self.children {
323                let mut com = Array1::zeros(2);
324                let mut total_mass = 0.0;
325
326                for child in children.iter() {
327                    if let Some(ref child_com) = child.center_of_mass {
328                        total_mass += child.total_mass;
329                        for i in 0..2 {
330                            com[i] += child_com[i] * child.total_mass;
331                        }
332                    }
333                }
334
335                if total_mass > 0.0 {
336                    com.mapv_inplace(|x| x / total_mass);
337                    self.center_of_mass = Some(com);
338                    self.total_mass = total_mass;
339                } else {
340                    self.center_of_mass = None;
341                    self.total_mass = 0.0;
342                }
343            }
344        }
345
346        Ok(())
347    }
348
349    /// Compute forces using Barnes-Hut approximation for quadtree
350    #[allow(clippy::too_many_arguments)]
351    fn compute_forces_quad(
352        &self,
353        point: &Array1<f64>,
354        point_idx: usize,
355        angle: f64,
356        degrees_of_freedom: f64,
357    ) -> Result<(Array1<f64>, f64)> {
358        let mut force = Array1::zeros(2);
359        let mut sum_q = 0.0;
360
361        self.compute_forces_recursive_quad(
362            point,
363            point_idx,
364            angle,
365            degrees_of_freedom,
366            &mut force,
367            &mut sum_q,
368        )?;
369
370        Ok((force, sum_q))
371    }
372
373    /// Recursive force computation for quadtree
374    #[allow(clippy::too_many_arguments)]
375    fn compute_forces_recursive_quad(
376        &self,
377        point: &Array1<f64>,
378        point_idx: usize,
379        angle: f64,
380        degrees_of_freedom: f64,
381        force: &mut Array1<f64>,
382        sum_q: &mut f64,
383    ) -> Result<()> {
384        if let Some(ref com) = self.center_of_mass {
385            if self.total_mass == 0.0 {
386                return Ok(());
387            }
388
389            // Compute distance to center of mass
390            let dx = point[0] - com[0];
391            let dy = point[1] - com[1];
392            let dist_squared = dx * dx + dy * dy;
393
394            if dist_squared < MACHINE_EPSILON {
395                return Ok(());
396            }
397
398            // Check if we can use this node's center of mass (Barnes-Hut criterion)
399            let node_size = (self.x_max - self.x_min).max(self.y_max - self.y_min);
400            let distance = dist_squared.sqrt();
401
402            if self.is_leaf || (node_size / distance) < angle {
403                // Use center of mass approximation
404                let q_factor = (1.0 + dist_squared / degrees_of_freedom)
405                    .powf(-(degrees_of_freedom + 1.0) / 2.0);
406
407                *sum_q += self.total_mass * q_factor;
408
409                let force_factor =
410                    (degrees_of_freedom + 1.0) * self.total_mass * q_factor / degrees_of_freedom;
411                force[0] += force_factor * dx;
412                force[1] += force_factor * dy;
413            } else {
414                // Recursively compute forces from children
415                if let Some(ref children) = self.children {
416                    for child in children.iter() {
417                        child.compute_forces_recursive_quad(
418                            point,
419                            point_idx,
420                            angle,
421                            degrees_of_freedom,
422                            force,
423                            sum_q,
424                        )?;
425                    }
426                }
427            }
428        } else if self.is_leaf {
429            // Leaf node without center of mass (empty node)
430            for &_idx in &self.point_indices {
431                if _idx != point_idx {
432                    // Compute exact force for this point
433                    // This will be handled by attractive forces in the main gradient computation
434                }
435            }
436        }
437
438        Ok(())
439    }
440}
441
442impl OctTreeNode {
443    /// Build the octree recursively
444    fn build_tree(&mut self, embedding: &Array2<f64>) -> Result<()> {
445        if self.point_indices.len() <= 1 {
446            // Leaf node with 0 or 1 points
447            self.update_center_of_mass(embedding)?;
448            return Ok(());
449        }
450
451        // Split into 8 octants
452        let x_mid = (self.x_min + self.x_max) / 2.0;
453        let y_mid = (self.y_min + self.y_max) / 2.0;
454        let z_mid = (self.z_min + self.z_max) / 2.0;
455
456        let mut octants: [Vec<usize>; 8] = [
457            Vec::new(),
458            Vec::new(),
459            Vec::new(),
460            Vec::new(),
461            Vec::new(),
462            Vec::new(),
463            Vec::new(),
464            Vec::new(),
465        ];
466
467        // Distribute points to octants
468        for &idx in &self.point_indices {
469            let x = embedding[[idx, 0]];
470            let y = embedding[[idx, 1]];
471            let z = embedding[[idx, 2]];
472
473            let octant = match (x >= x_mid, y >= y_mid, z >= z_mid) {
474                (false, false, false) => 0,
475                (true, false, false) => 1,
476                (false, true, false) => 2,
477                (true, true, false) => 3,
478                (false, false, true) => 4,
479                (true, false, true) => 5,
480                (false, true, true) => 6,
481                (true, true, true) => 7,
482            };
483
484            octants[octant].push(idx);
485        }
486
487        // Create child nodes
488        let mut children = [
489            Box::new(OctTreeNode {
490                x_min: self.x_min,
491                x_max: x_mid,
492                y_min: self.y_min,
493                y_max: y_mid,
494                z_min: self.z_min,
495                z_max: z_mid,
496                center_of_mass: None,
497                total_mass: 0.0,
498                point_indices: octants[0].clone(),
499                children: None,
500                is_leaf: true,
501            }),
502            Box::new(OctTreeNode {
503                x_min: x_mid,
504                x_max: self.x_max,
505                y_min: self.y_min,
506                y_max: y_mid,
507                z_min: self.z_min,
508                z_max: z_mid,
509                center_of_mass: None,
510                total_mass: 0.0,
511                point_indices: octants[1].clone(),
512                children: None,
513                is_leaf: true,
514            }),
515            Box::new(OctTreeNode {
516                x_min: self.x_min,
517                x_max: x_mid,
518                y_min: y_mid,
519                y_max: self.y_max,
520                z_min: self.z_min,
521                z_max: z_mid,
522                center_of_mass: None,
523                total_mass: 0.0,
524                point_indices: octants[2].clone(),
525                children: None,
526                is_leaf: true,
527            }),
528            Box::new(OctTreeNode {
529                x_min: x_mid,
530                x_max: self.x_max,
531                y_min: y_mid,
532                y_max: self.y_max,
533                z_min: self.z_min,
534                z_max: z_mid,
535                center_of_mass: None,
536                total_mass: 0.0,
537                point_indices: octants[3].clone(),
538                children: None,
539                is_leaf: true,
540            }),
541            Box::new(OctTreeNode {
542                x_min: self.x_min,
543                x_max: x_mid,
544                y_min: self.y_min,
545                y_max: y_mid,
546                z_min: z_mid,
547                z_max: self.z_max,
548                center_of_mass: None,
549                total_mass: 0.0,
550                point_indices: octants[4].clone(),
551                children: None,
552                is_leaf: true,
553            }),
554            Box::new(OctTreeNode {
555                x_min: x_mid,
556                x_max: self.x_max,
557                y_min: self.y_min,
558                y_max: y_mid,
559                z_min: z_mid,
560                z_max: self.z_max,
561                center_of_mass: None,
562                total_mass: 0.0,
563                point_indices: octants[5].clone(),
564                children: None,
565                is_leaf: true,
566            }),
567            Box::new(OctTreeNode {
568                x_min: self.x_min,
569                x_max: x_mid,
570                y_min: y_mid,
571                y_max: self.y_max,
572                z_min: z_mid,
573                z_max: self.z_max,
574                center_of_mass: None,
575                total_mass: 0.0,
576                point_indices: octants[6].clone(),
577                children: None,
578                is_leaf: true,
579            }),
580            Box::new(OctTreeNode {
581                x_min: x_mid,
582                x_max: self.x_max,
583                y_min: y_mid,
584                y_max: self.y_max,
585                z_min: z_mid,
586                z_max: self.z_max,
587                center_of_mass: None,
588                total_mass: 0.0,
589                point_indices: octants[7].clone(),
590                children: None,
591                is_leaf: true,
592            }),
593        ];
594
595        // Recursively build children
596        for child in &mut children {
597            child.build_tree(embedding)?;
598        }
599
600        self.children = Some(children);
601        self.is_leaf = false;
602        self.point_indices.clear();
603        self.update_center_of_mass(embedding)?;
604
605        Ok(())
606    }
607
608    /// Update center of mass for this octree node
609    fn update_center_of_mass(&mut self, embedding: &Array2<f64>) -> Result<()> {
610        if self.is_leaf {
611            if self.point_indices.is_empty() {
612                self.total_mass = 0.0;
613                self.center_of_mass = None;
614                return Ok(());
615            }
616
617            let mut com = Array1::zeros(3);
618            for &idx in &self.point_indices {
619                com[0] += embedding[[idx, 0]];
620                com[1] += embedding[[idx, 1]];
621                com[2] += embedding[[idx, 2]];
622            }
623
624            self.total_mass = self.point_indices.len() as f64;
625            com.mapv_inplace(|x| x / self.total_mass);
626            self.center_of_mass = Some(com);
627        } else if let Some(ref children) = self.children {
628            let mut com = Array1::zeros(3);
629            let mut total_mass = 0.0;
630
631            for child in children.iter() {
632                if let Some(ref child_com) = child.center_of_mass {
633                    total_mass += child.total_mass;
634                    for i in 0..3 {
635                        com[i] += child_com[i] * child.total_mass;
636                    }
637                }
638            }
639
640            if total_mass > 0.0 {
641                com.mapv_inplace(|x| x / total_mass);
642                self.center_of_mass = Some(com);
643                self.total_mass = total_mass;
644            } else {
645                self.center_of_mass = None;
646                self.total_mass = 0.0;
647            }
648        }
649
650        Ok(())
651    }
652
653    /// Compute forces using Barnes-Hut approximation for octree
654    #[allow(clippy::too_many_arguments)]
655    fn compute_forces_oct(
656        &self,
657        point: &Array1<f64>,
658        point_idx: usize,
659        angle: f64,
660        degrees_of_freedom: f64,
661    ) -> Result<(Array1<f64>, f64)> {
662        let mut force = Array1::zeros(3);
663        let mut sum_q = 0.0;
664
665        self.compute_forces_recursive_oct(
666            point,
667            point_idx,
668            angle,
669            degrees_of_freedom,
670            &mut force,
671            &mut sum_q,
672        )?;
673
674        Ok((force, sum_q))
675    }
676
677    /// Recursive force computation for octree
678    #[allow(clippy::too_many_arguments)]
679    fn compute_forces_recursive_oct(
680        &self,
681        point: &Array1<f64>,
682        _point_idx: usize,
683        angle: f64,
684        degrees_of_freedom: f64,
685        force: &mut Array1<f64>,
686        sum_q: &mut f64,
687    ) -> Result<()> {
688        if let Some(ref com) = self.center_of_mass {
689            if self.total_mass == 0.0 {
690                return Ok(());
691            }
692
693            let dx = point[0] - com[0];
694            let dy = point[1] - com[1];
695            let dz = point[2] - com[2];
696            let dist_squared = dx * dx + dy * dy + dz * dz;
697
698            if dist_squared < MACHINE_EPSILON {
699                return Ok(());
700            }
701
702            let node_size = (self.x_max - self.x_min)
703                .max(self.y_max - self.y_min)
704                .max(self.z_max - self.z_min);
705            let distance = dist_squared.sqrt();
706
707            if self.is_leaf || (node_size / distance) < angle {
708                let q_factor = (1.0 + dist_squared / degrees_of_freedom)
709                    .powf(-(degrees_of_freedom + 1.0) / 2.0);
710
711                *sum_q += self.total_mass * q_factor;
712
713                let force_factor =
714                    (degrees_of_freedom + 1.0) * self.total_mass * q_factor / degrees_of_freedom;
715                force[0] += force_factor * dx;
716                force[1] += force_factor * dy;
717                force[2] += force_factor * dz;
718            } else if let Some(ref children) = self.children {
719                for child in children.iter() {
720                    child.compute_forces_recursive_oct(
721                        point,
722                        _point_idx,
723                        angle,
724                        degrees_of_freedom,
725                        force,
726                        sum_q,
727                    )?;
728                }
729            }
730        }
731
732        Ok(())
733    }
734}
735
736/// t-SNE (t-distributed Stochastic Neighbor Embedding) for dimensionality reduction
737///
738/// t-SNE is a nonlinear dimensionality reduction technique well-suited for
739/// embedding high-dimensional data for visualization in a low-dimensional space
740/// (typically 2D or 3D). It models each high-dimensional object by a two- or
741/// three-dimensional point in such a way that similar objects are modeled by
742/// nearby points and dissimilar objects are modeled by distant points with
743/// high probability.
744pub struct TSNE {
745    /// Number of components in the embedded space
746    n_components: usize,
747    /// Perplexity parameter that balances attention between local and global structure
748    perplexity: f64,
749    /// Weight of early exaggeration phase
750    early_exaggeration: f64,
751    /// Learning rate for optimization
752    learning_rate: f64,
753    /// Maximum number of iterations
754    max_iter: usize,
755    /// Maximum iterations without progress before early stopping
756    n_iter_without_progress: usize,
757    /// Minimum gradient norm for convergence
758    min_grad_norm: f64,
759    /// Method to compute pairwise distances
760    metric: String,
761    /// Method to perform dimensionality reduction
762    method: String,
763    /// Initialization method
764    init: String,
765    /// Angle for Barnes-Hut approximation
766    angle: f64,
767    /// Whether to use multicore processing
768    n_jobs: i32,
769    /// Verbosity level
770    verbose: bool,
771    /// Random state for reproducibility
772    random_state: Option<u64>,
773    /// The embedding vectors
774    embedding_: Option<Array2<f64>>,
775    /// KL divergence after optimization
776    kl_divergence_: Option<f64>,
777    /// Total number of iterations run
778    n_iter_: Option<usize>,
779    /// Effective learning rate used
780    learning_rate_: Option<f64>,
781}
782
783impl Default for TSNE {
784    fn default() -> Self {
785        Self::new()
786    }
787}
788
789impl TSNE {
790    /// Creates a new t-SNE instance with default parameters
791    pub fn new() -> Self {
792        TSNE {
793            n_components: 2,
794            perplexity: 30.0,
795            early_exaggeration: 12.0,
796            learning_rate: 200.0,
797            max_iter: 1000,
798            n_iter_without_progress: 300,
799            min_grad_norm: 1e-7,
800            metric: "euclidean".to_string(),
801            method: "barnes_hut".to_string(),
802            init: "pca".to_string(),
803            angle: 0.5,
804            n_jobs: -1, // Use all available cores by default
805            verbose: false,
806            random_state: None,
807            embedding_: None,
808            kl_divergence_: None,
809            n_iter_: None,
810            learning_rate_: None,
811        }
812    }
813
814    /// Sets the number of components in the embedded space
815    pub fn with_n_components(mut self, ncomponents: usize) -> Self {
816        self.n_components = ncomponents;
817        self
818    }
819
820    /// Sets the perplexity parameter
821    pub fn with_perplexity(mut self, perplexity: f64) -> Self {
822        self.perplexity = perplexity;
823        self
824    }
825
826    /// Sets the early exaggeration factor
827    pub fn with_early_exaggeration(mut self, earlyexaggeration: f64) -> Self {
828        self.early_exaggeration = earlyexaggeration;
829        self
830    }
831
832    /// Sets the learning rate for gradient descent
833    pub fn with_learning_rate(mut self, learningrate: f64) -> Self {
834        self.learning_rate = learningrate;
835        self
836    }
837
838    /// Sets the maximum number of iterations
839    pub fn with_max_iter(mut self, maxiter: usize) -> Self {
840        self.max_iter = maxiter;
841        self
842    }
843
844    /// Sets the number of iterations without progress before early stopping
845    pub fn with_n_iter_without_progress(mut self, n_iter_withoutprogress: usize) -> Self {
846        self.n_iter_without_progress = n_iter_withoutprogress;
847        self
848    }
849
850    /// Sets the minimum gradient norm for convergence
851    pub fn with_min_grad_norm(mut self, min_gradnorm: f64) -> Self {
852        self.min_grad_norm = min_gradnorm;
853        self
854    }
855
856    /// Sets the metric for pairwise distance computation
857    ///
858    /// Supported metrics:
859    /// - "euclidean": Euclidean distance (L2 norm) - default
860    /// - "manhattan": Manhattan distance (L1 norm)
861    /// - "cosine": Cosine distance (1 - cosine similarity)
862    /// - "chebyshev": Chebyshev distance (maximum coordinate difference)
863    pub fn with_metric(mut self, metric: &str) -> Self {
864        self.metric = metric.to_string();
865        self
866    }
867
868    /// Sets the method for dimensionality reduction
869    pub fn with_method(mut self, method: &str) -> Self {
870        self.method = method.to_string();
871        self
872    }
873
874    /// Sets the initialization method
875    pub fn with_init(mut self, init: &str) -> Self {
876        self.init = init.to_string();
877        self
878    }
879
880    /// Sets the angle for Barnes-Hut approximation
881    pub fn with_angle(mut self, angle: f64) -> Self {
882        self.angle = angle;
883        self
884    }
885
886    /// Sets the number of parallel jobs to run
887    /// * n_jobs = -1: Use all available cores
888    /// * n_jobs = 1: Use single-core (disable multicore)
889    /// * n_jobs > 1: Use specific number of cores
890    pub fn with_n_jobs(mut self, njobs: i32) -> Self {
891        self.n_jobs = njobs;
892        self
893    }
894
895    /// Sets the verbosity level
896    pub fn with_verbose(mut self, verbose: bool) -> Self {
897        self.verbose = verbose;
898        self
899    }
900
901    /// Sets the random state for reproducibility
902    pub fn with_random_state(mut self, randomstate: u64) -> Self {
903        self.random_state = Some(randomstate);
904        self
905    }
906
907    /// Fit t-SNE to input data and transform it to the embedded space
908    ///
909    /// # Arguments
910    /// * `x` - Input data, shape (n_samples, n_features)
911    ///
912    /// # Returns
913    /// * `Result<Array2<f64>>` - Embedding of the training data, shape (n_samples, n_components)
914    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
915    where
916        S: Data,
917        S::Elem: Float + NumCast,
918    {
919        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
920
921        let n_samples = x_f64.shape()[0];
922        let n_features = x_f64.shape()[1];
923
924        // Input validation
925        if n_samples == 0 || n_features == 0 {
926            return Err(TransformError::InvalidInput("Empty input data".to_string()));
927        }
928
929        if self.perplexity >= n_samples as f64 {
930            return Err(TransformError::InvalidInput(format!(
931                "perplexity ({}) must be less than n_samples ({})",
932                self.perplexity, n_samples
933            )));
934        }
935
936        if self.method == "barnes_hut" && self.n_components > 3 {
937            return Err(TransformError::InvalidInput(
938                "'n_components' should be less than or equal to 3 for barnes_hut algorithm"
939                    .to_string(),
940            ));
941        }
942
943        // Set learning rate if auto
944        self.learning_rate_ = Some(self.learning_rate);
945
946        // Initialize embedding
947        let x_embedded = self.initialize_embedding(&x_f64)?;
948
949        // Compute pairwise affinities (P)
950        let p = self.compute_pairwise_affinities(&x_f64)?;
951
952        // Run t-SNE optimization
953        let (embedding, kl_divergence, n_iter) =
954            self.tsne_optimization(p, x_embedded, n_samples)?;
955
956        self.embedding_ = Some(embedding.clone());
957        self.kl_divergence_ = Some(kl_divergence);
958        self.n_iter_ = Some(n_iter);
959
960        Ok(embedding)
961    }
962
963    /// Initialize embedding either with PCA or random
964    fn initialize_embedding(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
965        let n_samples = x.shape()[0];
966
967        if self.init == "pca" {
968            let n_components = self.n_components.min(x.shape()[1]);
969            let mut pca = PCA::new(n_components, true, false);
970            let mut x_embedded = pca.fit_transform(x)?;
971
972            // Scale PCA initialization
973            let std_dev = (x_embedded.column(0).map(|&x| x * x).sum() / (n_samples as f64)).sqrt();
974            if std_dev > 0.0 {
975                x_embedded.mapv_inplace(|x| x / std_dev * 1e-4);
976            }
977
978            Ok(x_embedded)
979        } else if self.init == "random" {
980            // Random initialization from standard normal distribution
981            // Ignoring random_state as it's not needed for basic random functionality
982            use scirs2_core::random::{thread_rng, Distribution};
983            let normal = Normal::new(0.0, 1e-4).unwrap();
984            let mut rng = thread_rng();
985
986            // Use simple random initialization
987            let data: Vec<f64> = (0..(n_samples * self.n_components))
988                .map(|_| normal.sample(&mut rng))
989                .collect();
990            Ok(Array2::from_shape_vec((n_samples, self.n_components), data).unwrap())
991        } else {
992            Err(TransformError::InvalidInput(format!(
993                "Initialization method '{}' not recognized",
994                self.init
995            )))
996        }
997    }
998
999    /// Compute pairwise affinities with perplexity-based normalization
1000    fn compute_pairwise_affinities(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
1001        let _n_samples = x.shape()[0];
1002
1003        // Compute pairwise distances
1004        let distances = self.compute_pairwise_distances(x)?;
1005
1006        // Convert distances to affinities using binary search for sigma
1007        let p = self.distances_to_affinities(&distances)?;
1008
1009        // Symmetrize and normalize the affinity matrix
1010        let mut p_symmetric = &p + &p.t();
1011
1012        // Normalize
1013        let p_sum = p_symmetric.sum();
1014        if p_sum > 0.0 {
1015            p_symmetric.mapv_inplace(|x| x.max(MACHINE_EPSILON) / p_sum);
1016        }
1017
1018        Ok(p_symmetric)
1019    }
1020
1021    /// Compute pairwise distances with optional multicore support
1022    fn compute_pairwise_distances(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
1023        let n_samples = x.shape()[0];
1024        let mut distances = Array2::zeros((n_samples, n_samples));
1025
1026        match self.metric.as_str() {
1027            "euclidean" => {
1028                if self.n_jobs == 1 {
1029                    // Single-core computation
1030                    for i in 0..n_samples {
1031                        for j in i + 1..n_samples {
1032                            let mut dist_squared = 0.0;
1033                            for k in 0..x.shape()[1] {
1034                                let diff = x[[i, k]] - x[[j, k]];
1035                                dist_squared += diff * diff;
1036                            }
1037                            distances[[i, j]] = dist_squared;
1038                            distances[[j, i]] = dist_squared;
1039                        }
1040                    }
1041                } else {
1042                    // Multi-core computation
1043                    let upper_triangle_indices: Vec<(usize, usize)> = (0..n_samples)
1044                        .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
1045                        .collect();
1046
1047                    let n_features = x.shape()[1];
1048                    let squared_distances: Vec<f64> = upper_triangle_indices
1049                        .par_iter()
1050                        .map(|&(i, j)| {
1051                            let mut dist_squared = 0.0;
1052                            for k in 0..n_features {
1053                                let diff = x[[i, k]] - x[[j, k]];
1054                                dist_squared += diff * diff;
1055                            }
1056                            dist_squared
1057                        })
1058                        .collect();
1059
1060                    // Fill the distance matrix
1061                    for (idx, &(i, j)) in upper_triangle_indices.iter().enumerate() {
1062                        distances[[i, j]] = squared_distances[idx];
1063                        distances[[j, i]] = squared_distances[idx];
1064                    }
1065                }
1066            }
1067            "manhattan" => {
1068                if self.n_jobs == 1 {
1069                    // Single-core Manhattan distance computation
1070                    for i in 0..n_samples {
1071                        for j in i + 1..n_samples {
1072                            let mut dist = 0.0;
1073                            for k in 0..x.shape()[1] {
1074                                dist += (x[[i, k]] - x[[j, k]]).abs();
1075                            }
1076                            distances[[i, j]] = dist;
1077                            distances[[j, i]] = dist;
1078                        }
1079                    }
1080                } else {
1081                    // Multi-core Manhattan distance computation
1082                    let upper_triangle_indices: Vec<(usize, usize)> = (0..n_samples)
1083                        .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
1084                        .collect();
1085
1086                    let n_features = x.shape()[1];
1087                    let manhattan_distances: Vec<f64> = upper_triangle_indices
1088                        .par_iter()
1089                        .map(|&(i, j)| {
1090                            let mut dist = 0.0;
1091                            for k in 0..n_features {
1092                                dist += (x[[i, k]] - x[[j, k]]).abs();
1093                            }
1094                            dist
1095                        })
1096                        .collect();
1097
1098                    // Fill the distance matrix
1099                    for (idx, &(i, j)) in upper_triangle_indices.iter().enumerate() {
1100                        distances[[i, j]] = manhattan_distances[idx];
1101                        distances[[j, i]] = manhattan_distances[idx];
1102                    }
1103                }
1104            }
1105            "cosine" => {
1106                // First normalize all vectors for cosine distance computation
1107                let mut normalized_x = Array2::zeros((n_samples, x.shape()[1]));
1108                for i in 0..n_samples {
1109                    let row = x.row(i);
1110                    let norm = row.iter().map(|v| v * v).sum::<f64>().sqrt();
1111                    if norm > EPSILON {
1112                        for j in 0..x.shape()[1] {
1113                            normalized_x[[i, j]] = x[[i, j]] / norm;
1114                        }
1115                    } else {
1116                        // Handle zero vectors
1117                        for j in 0..x.shape()[1] {
1118                            normalized_x[[i, j]] = 0.0;
1119                        }
1120                    }
1121                }
1122
1123                if self.n_jobs == 1 {
1124                    // Single-core cosine distance computation
1125                    for i in 0..n_samples {
1126                        for j in i + 1..n_samples {
1127                            let mut dot_product = 0.0;
1128                            for k in 0..x.shape()[1] {
1129                                dot_product += normalized_x[[i, k]] * normalized_x[[j, k]];
1130                            }
1131                            // Cosine distance = 1 - cosine similarity
1132                            let cosine_dist = 1.0 - dot_product.clamp(-1.0, 1.0);
1133                            distances[[i, j]] = cosine_dist;
1134                            distances[[j, i]] = cosine_dist;
1135                        }
1136                    }
1137                } else {
1138                    // Multi-core cosine distance computation
1139                    let upper_triangle_indices: Vec<(usize, usize)> = (0..n_samples)
1140                        .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
1141                        .collect();
1142
1143                    let n_features = x.shape()[1];
1144                    let cosine_distances: Vec<f64> = upper_triangle_indices
1145                        .par_iter()
1146                        .map(|&(i, j)| {
1147                            let mut dot_product = 0.0;
1148                            for k in 0..n_features {
1149                                dot_product += normalized_x[[i, k]] * normalized_x[[j, k]];
1150                            }
1151                            // Cosine distance = 1 - cosine similarity
1152                            1.0 - dot_product.clamp(-1.0, 1.0)
1153                        })
1154                        .collect();
1155
1156                    // Fill the distance matrix
1157                    for (idx, &(i, j)) in upper_triangle_indices.iter().enumerate() {
1158                        distances[[i, j]] = cosine_distances[idx];
1159                        distances[[j, i]] = cosine_distances[idx];
1160                    }
1161                }
1162            }
1163            "chebyshev" => {
1164                if self.n_jobs == 1 {
1165                    // Single-core Chebyshev distance computation
1166                    for i in 0..n_samples {
1167                        for j in i + 1..n_samples {
1168                            let mut max_dist = 0.0;
1169                            for k in 0..x.shape()[1] {
1170                                let diff = (x[[i, k]] - x[[j, k]]).abs();
1171                                max_dist = max_dist.max(diff);
1172                            }
1173                            distances[[i, j]] = max_dist;
1174                            distances[[j, i]] = max_dist;
1175                        }
1176                    }
1177                } else {
1178                    // Multi-core Chebyshev distance computation
1179                    let upper_triangle_indices: Vec<(usize, usize)> = (0..n_samples)
1180                        .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
1181                        .collect();
1182
1183                    let n_features = x.shape()[1];
1184                    let chebyshev_distances: Vec<f64> = upper_triangle_indices
1185                        .par_iter()
1186                        .map(|&(i, j)| {
1187                            let mut max_dist = 0.0;
1188                            for k in 0..n_features {
1189                                let diff = (x[[i, k]] - x[[j, k]]).abs();
1190                                max_dist = max_dist.max(diff);
1191                            }
1192                            max_dist
1193                        })
1194                        .collect();
1195
1196                    // Fill the distance matrix
1197                    for (idx, &(i, j)) in upper_triangle_indices.iter().enumerate() {
1198                        distances[[i, j]] = chebyshev_distances[idx];
1199                        distances[[j, i]] = chebyshev_distances[idx];
1200                    }
1201                }
1202            }
1203            _ => {
1204                return Err(TransformError::InvalidInput(format!(
1205                    "Metric '{}' not implemented. Supported metrics are: 'euclidean', 'manhattan', 'cosine', 'chebyshev'",
1206                    self.metric
1207                )));
1208            }
1209        }
1210
1211        Ok(distances)
1212    }
1213
1214    /// Convert distances to affinities using perplexity-based normalization with optional multicore support
1215    fn distances_to_affinities(&self, distances: &Array2<f64>) -> Result<Array2<f64>> {
1216        let n_samples = distances.shape()[0];
1217        let mut p = Array2::zeros((n_samples, n_samples));
1218        let target = (2.0f64).ln() * self.perplexity;
1219
1220        if self.n_jobs == 1 {
1221            // Single-core computation (original implementation)
1222            for i in 0..n_samples {
1223                let mut beta_min = -f64::INFINITY;
1224                let mut beta_max = f64::INFINITY;
1225                let mut beta = 1.0;
1226
1227                // Get all distances from point i except self-distance (which is 0)
1228                let distances_i = distances.row(i).to_owned();
1229
1230                // Binary search for beta
1231                for _ in 0..50 {
1232                    // Usually converges within 50 iterations
1233                    // Compute conditional probabilities with current beta
1234                    let mut sum_pi = 0.0;
1235                    let mut h = 0.0;
1236
1237                    for j in 0..n_samples {
1238                        if i == j {
1239                            p[[i, j]] = 0.0;
1240                            continue;
1241                        }
1242
1243                        let p_ij = (-beta * distances_i[j]).exp();
1244                        p[[i, j]] = p_ij;
1245                        sum_pi += p_ij;
1246                    }
1247
1248                    // Normalize probabilities and compute entropy
1249                    if sum_pi > 0.0 {
1250                        for j in 0..n_samples {
1251                            if i == j {
1252                                continue;
1253                            }
1254
1255                            p[[i, j]] /= sum_pi;
1256
1257                            // Compute entropy
1258                            if p[[i, j]] > MACHINE_EPSILON {
1259                                h -= p[[i, j]] * p[[i, j]].ln();
1260                            }
1261                        }
1262                    }
1263
1264                    // Adjust beta based on entropy difference from target
1265                    let h_diff = h - target;
1266
1267                    if h_diff.abs() < EPSILON {
1268                        break; // Converged
1269                    }
1270
1271                    // Update beta using binary search
1272                    if h_diff > 0.0 {
1273                        beta_min = beta;
1274                        if beta_max == f64::INFINITY {
1275                            beta *= 2.0;
1276                        } else {
1277                            beta = (beta + beta_max) / 2.0;
1278                        }
1279                    } else {
1280                        beta_max = beta;
1281                        if beta_min == -f64::INFINITY {
1282                            beta /= 2.0;
1283                        } else {
1284                            beta = (beta + beta_min) / 2.0;
1285                        }
1286                    }
1287                }
1288            }
1289        } else {
1290            // Multi-core computation of conditional probabilities for each point
1291            let prob_rows: Vec<Vec<f64>> = (0..n_samples)
1292                .into_par_iter()
1293                .map(|i| {
1294                    let mut beta_min = -f64::INFINITY;
1295                    let mut beta_max = f64::INFINITY;
1296                    let mut beta = 1.0;
1297
1298                    // Get all distances from point i except self-distance (which is 0)
1299                    let distances_i: Vec<f64> = (0..n_samples).map(|j| distances[[i, j]]).collect();
1300                    let mut p_row = vec![0.0; n_samples];
1301
1302                    // Binary search for beta
1303                    for _ in 0..50 {
1304                        // Usually converges within 50 iterations
1305                        // Compute conditional probabilities with current beta
1306                        let mut sum_pi = 0.0;
1307                        let mut h = 0.0;
1308
1309                        for j in 0..n_samples {
1310                            if i == j {
1311                                p_row[j] = 0.0;
1312                                continue;
1313                            }
1314
1315                            let p_ij = (-beta * distances_i[j]).exp();
1316                            p_row[j] = p_ij;
1317                            sum_pi += p_ij;
1318                        }
1319
1320                        // Normalize probabilities and compute entropy
1321                        if sum_pi > 0.0 {
1322                            for (j, prob) in p_row.iter_mut().enumerate().take(n_samples) {
1323                                if i == j {
1324                                    continue;
1325                                }
1326
1327                                *prob /= sum_pi;
1328
1329                                // Compute entropy
1330                                if *prob > MACHINE_EPSILON {
1331                                    h -= *prob * prob.ln();
1332                                }
1333                            }
1334                        }
1335
1336                        // Adjust beta based on entropy difference from target
1337                        let h_diff = h - target;
1338
1339                        if h_diff.abs() < EPSILON {
1340                            break; // Converged
1341                        }
1342
1343                        // Update beta using binary search
1344                        if h_diff > 0.0 {
1345                            beta_min = beta;
1346                            if beta_max == f64::INFINITY {
1347                                beta *= 2.0;
1348                            } else {
1349                                beta = (beta + beta_max) / 2.0;
1350                            }
1351                        } else {
1352                            beta_max = beta;
1353                            if beta_min == -f64::INFINITY {
1354                                beta /= 2.0;
1355                            } else {
1356                                beta = (beta + beta_min) / 2.0;
1357                            }
1358                        }
1359                    }
1360
1361                    p_row
1362                })
1363                .collect();
1364
1365            // Copy results back to the main matrix
1366            for (i, row) in prob_rows.iter().enumerate() {
1367                for (j, &val) in row.iter().enumerate() {
1368                    p[[i, j]] = val;
1369                }
1370            }
1371        }
1372
1373        Ok(p)
1374    }
1375
1376    /// Main t-SNE optimization loop using gradient descent
1377    #[allow(clippy::too_many_arguments)]
1378    fn tsne_optimization(
1379        &self,
1380        p: Array2<f64>,
1381        initial_embedding: Array2<f64>,
1382        n_samples: usize,
1383    ) -> Result<(Array2<f64>, f64, usize)> {
1384        let n_components = self.n_components;
1385        let degrees_of_freedom = (n_components - 1).max(1) as f64;
1386
1387        // Initialize variables for optimization
1388        let mut embedding = initial_embedding;
1389        let mut update = Array2::zeros((n_samples, n_components));
1390        let mut gains = Array2::ones((n_samples, n_components));
1391        let mut error = f64::INFINITY;
1392        let mut best_error = f64::INFINITY;
1393        let mut best_iter = 0;
1394        let mut iter = 0;
1395
1396        // Exploration phase with early exaggeration
1397        let exploration_n_iter = 250;
1398        let n_iter_check = 50;
1399
1400        // Apply early exaggeration
1401        let p_early = &p * self.early_exaggeration;
1402
1403        if self.verbose {
1404            println!("[t-SNE] Starting optimization with early exaggeration phase...");
1405        }
1406
1407        // Early exaggeration phase
1408        for i in 0..exploration_n_iter {
1409            // Compute gradient and error for early exaggeration phase
1410            let (curr_error, grad) = if self.method == "barnes_hut" {
1411                self.compute_gradient_barnes_hut(&embedding, &p_early, degrees_of_freedom)?
1412            } else {
1413                self.compute_gradient_exact(&embedding, &p_early, degrees_of_freedom)?
1414            };
1415
1416            // Perform gradient update with momentum and gains
1417            self.gradient_update(
1418                &mut embedding,
1419                &mut update,
1420                &mut gains,
1421                &grad,
1422                0.5,
1423                self.learning_rate_,
1424            )?;
1425
1426            // Check for convergence
1427            if (i + 1) % n_iter_check == 0 {
1428                if self.verbose {
1429                    println!("[t-SNE] Iteration {}: error = {:.7}", i + 1, curr_error);
1430                }
1431
1432                if curr_error < best_error {
1433                    best_error = curr_error;
1434                    best_iter = i;
1435                } else if i - best_iter > self.n_iter_without_progress {
1436                    if self.verbose {
1437                        println!("[t-SNE] Early convergence at iteration {}", i + 1);
1438                    }
1439                    break;
1440                }
1441
1442                // Check gradient norm
1443                let grad_norm = grad.mapv(|x| x * x).sum().sqrt();
1444                if grad_norm < self.min_grad_norm {
1445                    if self.verbose {
1446                        println!("[t-SNE] Gradient norm {} below threshold, stopping optimization at iteration {}", 
1447                                grad_norm, i + 1);
1448                    }
1449                    break;
1450                }
1451            }
1452
1453            iter = i;
1454        }
1455
1456        if self.verbose {
1457            println!("[t-SNE] Completed early exaggeration phase, starting final optimization...");
1458        }
1459
1460        // Final optimization phase without early exaggeration
1461        for i in iter + 1..self.max_iter {
1462            // Compute gradient and error for normal phase
1463            let (curr_error, grad) = if self.method == "barnes_hut" {
1464                self.compute_gradient_barnes_hut(&embedding, &p, degrees_of_freedom)?
1465            } else {
1466                self.compute_gradient_exact(&embedding, &p, degrees_of_freedom)?
1467            };
1468            error = curr_error;
1469
1470            // Perform gradient update with momentum and gains
1471            self.gradient_update(
1472                &mut embedding,
1473                &mut update,
1474                &mut gains,
1475                &grad,
1476                0.8,
1477                self.learning_rate_,
1478            )?;
1479
1480            // Check for convergence
1481            if (i + 1) % n_iter_check == 0 {
1482                if self.verbose {
1483                    println!("[t-SNE] Iteration {}: error = {:.7}", i + 1, curr_error);
1484                }
1485
1486                if curr_error < best_error {
1487                    best_error = curr_error;
1488                    best_iter = i;
1489                } else if i - best_iter > self.n_iter_without_progress {
1490                    if self.verbose {
1491                        println!("[t-SNE] Stopping optimization at iteration {}", i + 1);
1492                    }
1493                    break;
1494                }
1495
1496                // Check gradient norm
1497                let grad_norm = grad.mapv(|x| x * x).sum().sqrt();
1498                if grad_norm < self.min_grad_norm {
1499                    if self.verbose {
1500                        println!("[t-SNE] Gradient norm {} below threshold, stopping optimization at iteration {}", 
1501                                grad_norm, i + 1);
1502                    }
1503                    break;
1504                }
1505            }
1506
1507            iter = i;
1508        }
1509
1510        if self.verbose {
1511            println!(
1512                "[t-SNE] Optimization finished after {} iterations with error {:.7}",
1513                iter + 1,
1514                error
1515            );
1516        }
1517
1518        Ok((embedding, error, iter + 1))
1519    }
1520
1521    /// Compute gradient and error for exact t-SNE with optional multicore support
1522    #[allow(clippy::too_many_arguments)]
1523    fn compute_gradient_exact(
1524        &self,
1525        embedding: &Array2<f64>,
1526        p: &Array2<f64>,
1527        degrees_of_freedom: f64,
1528    ) -> Result<(f64, Array2<f64>)> {
1529        let n_samples = embedding.shape()[0];
1530        let n_components = embedding.shape()[1];
1531
1532        if self.n_jobs == 1 {
1533            // Single-core computation (original implementation)
1534            let mut dist = Array2::zeros((n_samples, n_samples));
1535            for i in 0..n_samples {
1536                for j in i + 1..n_samples {
1537                    let mut d_squared = 0.0;
1538                    for k in 0..n_components {
1539                        let diff = embedding[[i, k]] - embedding[[j, k]];
1540                        d_squared += diff * diff;
1541                    }
1542
1543                    // Convert squared distance to t-distribution's probability
1544                    let q_ij = (1.0 + d_squared / degrees_of_freedom)
1545                        .powf(-(degrees_of_freedom + 1.0) / 2.0);
1546                    dist[[i, j]] = q_ij;
1547                    dist[[j, i]] = q_ij;
1548                }
1549            }
1550
1551            // Set diagonal to zero (self-distance)
1552            for i in 0..n_samples {
1553                dist[[i, i]] = 0.0;
1554            }
1555
1556            // Normalize Q matrix
1557            let sum_q = dist.sum().max(MACHINE_EPSILON);
1558            let q = &dist / sum_q;
1559
1560            // Compute KL divergence
1561            let mut kl_divergence = 0.0;
1562            for i in 0..n_samples {
1563                for j in 0..n_samples {
1564                    if p[[i, j]] > MACHINE_EPSILON && q[[i, j]] > MACHINE_EPSILON {
1565                        kl_divergence += p[[i, j]] * (p[[i, j]] / q[[i, j]]).ln();
1566                    }
1567                }
1568            }
1569
1570            // Compute gradient
1571            let mut grad = Array2::zeros((n_samples, n_components));
1572            let factor =
1573                4.0 * (degrees_of_freedom + 1.0) / (degrees_of_freedom * (sum_q.powf(2.0)));
1574
1575            for i in 0..n_samples {
1576                for j in 0..n_samples {
1577                    if i != j {
1578                        let p_q_diff = p[[i, j]] - q[[i, j]];
1579                        for k in 0..n_components {
1580                            grad[[i, k]] += factor
1581                                * p_q_diff
1582                                * dist[[i, j]]
1583                                * (embedding[[i, k]] - embedding[[j, k]]);
1584                        }
1585                    }
1586                }
1587            }
1588
1589            Ok((kl_divergence, grad))
1590        } else {
1591            // Multi-core computation
1592            let upper_triangle_indices: Vec<(usize, usize)> = (0..n_samples)
1593                .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
1594                .collect();
1595
1596            let q_values: Vec<f64> = upper_triangle_indices
1597                .par_iter()
1598                .map(|&(i, j)| {
1599                    let mut d_squared = 0.0;
1600                    for k in 0..n_components {
1601                        let diff = embedding[[i, k]] - embedding[[j, k]];
1602                        d_squared += diff * diff;
1603                    }
1604
1605                    // Convert squared distance to t-distribution's probability
1606                    (1.0 + d_squared / degrees_of_freedom).powf(-(degrees_of_freedom + 1.0) / 2.0)
1607                })
1608                .collect();
1609
1610            // Fill the distance matrix
1611            let mut dist = Array2::zeros((n_samples, n_samples));
1612            for (idx, &(i, j)) in upper_triangle_indices.iter().enumerate() {
1613                let q_val = q_values[idx];
1614                dist[[i, j]] = q_val;
1615                dist[[j, i]] = q_val;
1616            }
1617
1618            // Set diagonal to zero (self-distance)
1619            for i in 0..n_samples {
1620                dist[[i, i]] = 0.0;
1621            }
1622
1623            // Normalize Q matrix
1624            let sum_q = dist.sum().max(MACHINE_EPSILON);
1625            let q = &dist / sum_q;
1626
1627            // Parallel computation of KL divergence
1628            let kl_divergence: f64 = (0..n_samples)
1629                .into_par_iter()
1630                .map(|i| {
1631                    let mut local_kl = 0.0;
1632                    for j in 0..n_samples {
1633                        if p[[i, j]] > MACHINE_EPSILON && q[[i, j]] > MACHINE_EPSILON {
1634                            local_kl += p[[i, j]] * (p[[i, j]] / q[[i, j]]).ln();
1635                        }
1636                    }
1637                    local_kl
1638                })
1639                .sum();
1640
1641            // Parallel computation of gradient
1642            let factor =
1643                4.0 * (degrees_of_freedom + 1.0) / (degrees_of_freedom * (sum_q.powf(2.0)));
1644
1645            let grad_rows: Vec<Vec<f64>> = (0..n_samples)
1646                .into_par_iter()
1647                .map(|i| {
1648                    let mut grad_row = vec![0.0; n_components];
1649                    for j in 0..n_samples {
1650                        if i != j {
1651                            let p_q_diff = p[[i, j]] - q[[i, j]];
1652                            for k in 0..n_components {
1653                                grad_row[k] += factor
1654                                    * p_q_diff
1655                                    * dist[[i, j]]
1656                                    * (embedding[[i, k]] - embedding[[j, k]]);
1657                            }
1658                        }
1659                    }
1660                    grad_row
1661                })
1662                .collect();
1663
1664            // Convert gradient rows back to array
1665            let mut grad = Array2::zeros((n_samples, n_components));
1666            for (i, row) in grad_rows.iter().enumerate() {
1667                for (k, &val) in row.iter().enumerate() {
1668                    grad[[i, k]] = val;
1669                }
1670            }
1671
1672            Ok((kl_divergence, grad))
1673        }
1674    }
1675
1676    /// Compute gradient and error using Barnes-Hut approximation
1677    #[allow(clippy::too_many_arguments)]
1678    fn compute_gradient_barnes_hut(
1679        &self,
1680        embedding: &Array2<f64>,
1681        p: &Array2<f64>,
1682        degrees_of_freedom: f64,
1683    ) -> Result<(f64, Array2<f64>)> {
1684        let n_samples = embedding.shape()[0];
1685        let n_components = embedding.shape()[1];
1686
1687        // Build spatial tree for Barnes-Hut approximation
1688        let tree = if n_components == 2 {
1689            SpatialTree::new_quadtree(embedding)?
1690        } else if n_components == 3 {
1691            SpatialTree::new_octree(embedding)?
1692        } else {
1693            return Err(TransformError::InvalidInput(
1694                "Barnes-Hut approximation only supports 2D and 3D embeddings".to_string(),
1695            ));
1696        };
1697
1698        // Compute Q matrix and gradient using Barnes-Hut
1699        let mut q = Array2::zeros((n_samples, n_samples));
1700        let mut grad = Array2::zeros((n_samples, n_components));
1701        let mut sum_q = 0.0;
1702
1703        // For each point, compute repulsive forces using Barnes-Hut
1704        for i in 0..n_samples {
1705            let point = embedding.row(i).to_owned();
1706            let (repulsive_force, q_sum) =
1707                tree.compute_forces(&point, i, self.angle, degrees_of_freedom)?;
1708
1709            sum_q += q_sum;
1710
1711            // Add repulsive forces to gradient
1712            for j in 0..n_components {
1713                grad[[i, j]] += repulsive_force[j];
1714            }
1715
1716            // Compute Q matrix for KL divergence calculation
1717            for j in 0..n_samples {
1718                if i != j {
1719                    let mut dist_squared = 0.0;
1720                    for k in 0..n_components {
1721                        let diff = embedding[[i, k]] - embedding[[j, k]];
1722                        dist_squared += diff * diff;
1723                    }
1724                    let q_ij = (1.0 + dist_squared / degrees_of_freedom)
1725                        .powf(-(degrees_of_freedom + 1.0) / 2.0);
1726                    q[[i, j]] = q_ij;
1727                }
1728            }
1729        }
1730
1731        // Normalize Q matrix
1732        sum_q = sum_q.max(MACHINE_EPSILON);
1733        q.mapv_inplace(|x| x / sum_q);
1734
1735        // Add attractive forces to gradient
1736        for i in 0..n_samples {
1737            for j in 0..n_samples {
1738                if i != j && p[[i, j]] > MACHINE_EPSILON {
1739                    let mut dist_squared = 0.0;
1740                    for k in 0..n_components {
1741                        let diff = embedding[[i, k]] - embedding[[j, k]];
1742                        dist_squared += diff * diff;
1743                    }
1744
1745                    let q_ij = (1.0 + dist_squared / degrees_of_freedom)
1746                        .powf(-(degrees_of_freedom + 1.0) / 2.0);
1747                    let factor = 4.0 * p[[i, j]] * q_ij;
1748
1749                    for k in 0..n_components {
1750                        grad[[i, k]] -= factor * (embedding[[i, k]] - embedding[[j, k]]);
1751                    }
1752                }
1753            }
1754        }
1755
1756        // Compute KL divergence
1757        let mut kl_divergence = 0.0;
1758        for i in 0..n_samples {
1759            for j in 0..n_samples {
1760                if p[[i, j]] > MACHINE_EPSILON && q[[i, j]] > MACHINE_EPSILON {
1761                    kl_divergence += p[[i, j]] * (p[[i, j]] / q[[i, j]]).ln();
1762                }
1763            }
1764        }
1765
1766        Ok((kl_divergence, grad))
1767    }
1768
1769    /// Update embedding using gradient descent with momentum and adaptive gains
1770    #[allow(clippy::too_many_arguments)]
1771    fn gradient_update(
1772        &self,
1773        embedding: &mut Array2<f64>,
1774        update: &mut Array2<f64>,
1775        gains: &mut Array2<f64>,
1776        grad: &Array2<f64>,
1777        momentum: f64,
1778        learning_rate: Option<f64>,
1779    ) -> Result<()> {
1780        let n_samples = embedding.shape()[0];
1781        let n_components = embedding.shape()[1];
1782        let eta = learning_rate.unwrap_or(self.learning_rate);
1783
1784        // Update gains and momentum
1785        for i in 0..n_samples {
1786            for j in 0..n_components {
1787                let same_sign = update[[i, j]] * grad[[i, j]] > 0.0;
1788
1789                if same_sign {
1790                    gains[[i, j]] *= 0.8;
1791                } else {
1792                    gains[[i, j]] += 0.2;
1793                }
1794
1795                // Ensure minimum gain
1796                gains[[i, j]] = gains[[i, j]].max(0.01);
1797
1798                // Update with momentum and adaptive learning _rate
1799                update[[i, j]] = momentum * update[[i, j]] - eta * gains[[i, j]] * grad[[i, j]];
1800                embedding[[i, j]] += update[[i, j]];
1801            }
1802        }
1803
1804        Ok(())
1805    }
1806
1807    /// Returns the embedding after fitting
1808    pub fn embedding(&self) -> Option<&Array2<f64>> {
1809        self.embedding_.as_ref()
1810    }
1811
1812    /// Returns the KL divergence after optimization
1813    pub fn kl_divergence(&self) -> Option<f64> {
1814        self.kl_divergence_
1815    }
1816
1817    /// Returns the number of iterations run
1818    pub fn n_iter(&self) -> Option<usize> {
1819        self.n_iter_
1820    }
1821}
1822
1823/// Calculate trustworthiness score for a dimensionality reduction
1824///
1825/// Trustworthiness measures to what extent the local structure is retained when
1826/// projecting data from the original space to the embedding space.
1827///
1828/// # Arguments
1829/// * `x` - Original data, shape (n_samples, n_features)
1830/// * `x_embedded` - Embedded data, shape (n_samples, n_components)
1831/// * `n_neighbors` - Number of neighbors to consider
1832/// * `metric` - Metric to use (currently only 'euclidean' is implemented)
1833///
1834/// # Returns
1835/// * `Result<f64>` - Trustworthiness score between 0.0 and 1.0
1836#[allow(dead_code)]
1837#[allow(clippy::too_many_arguments)]
1838pub fn trustworthiness<S1, S2>(
1839    x: &ArrayBase<S1, Ix2>,
1840    x_embedded: &ArrayBase<S2, Ix2>,
1841    n_neighbors: usize,
1842    metric: &str,
1843) -> Result<f64>
1844where
1845    S1: Data,
1846    S2: Data,
1847    S1::Elem: Float + NumCast,
1848    S2::Elem: Float + NumCast,
1849{
1850    let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
1851    let x_embedded_f64 = x_embedded.mapv(|x| NumCast::from(x).unwrap_or(0.0));
1852
1853    let n_samples = x_f64.shape()[0];
1854
1855    if n_neighbors >= n_samples / 2 {
1856        return Err(TransformError::InvalidInput(format!(
1857            "n_neighbors ({}) should be less than n_samples / 2 ({})",
1858            n_neighbors,
1859            n_samples / 2
1860        )));
1861    }
1862
1863    if metric != "euclidean" {
1864        return Err(TransformError::InvalidInput(format!(
1865            "Metric '{metric}' not implemented. Currently only 'euclidean' is supported"
1866        )));
1867    }
1868
1869    // Compute pairwise distances in original space
1870    let mut dist_x = Array2::zeros((n_samples, n_samples));
1871    for i in 0..n_samples {
1872        for j in 0..n_samples {
1873            if i == j {
1874                dist_x[[i, j]] = f64::INFINITY; // Set self-distance to infinity
1875                continue;
1876            }
1877
1878            let mut d_squared = 0.0;
1879            for k in 0..x_f64.shape()[1] {
1880                let diff = x_f64[[i, k]] - x_f64[[j, k]];
1881                d_squared += diff * diff;
1882            }
1883            dist_x[[i, j]] = d_squared.sqrt();
1884        }
1885    }
1886
1887    // Compute pairwise distances in _embedded space
1888    let mut dist_embedded = Array2::zeros((n_samples, n_samples));
1889    for i in 0..n_samples {
1890        for j in 0..n_samples {
1891            if i == j {
1892                dist_embedded[[i, j]] = f64::INFINITY; // Set self-distance to infinity
1893                continue;
1894            }
1895
1896            let mut d_squared = 0.0;
1897            for k in 0..x_embedded_f64.shape()[1] {
1898                let diff = x_embedded_f64[[i, k]] - x_embedded_f64[[j, k]];
1899                d_squared += diff * diff;
1900            }
1901            dist_embedded[[i, j]] = d_squared.sqrt();
1902        }
1903    }
1904
1905    // For each point, find the n_neighbors nearest _neighbors in the original space
1906    let mut nn_orig = Array2::<usize>::zeros((n_samples, n_neighbors));
1907    for i in 0..n_samples {
1908        // Get the indices of the sorted distances
1909        let row = dist_x.row(i).to_owned();
1910        let mut pairs: Vec<(usize, f64)> = row.iter().enumerate().map(|(j, &d)| (j, d)).collect();
1911        pairs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1912
1913        // The first element will be i itself (distance 0), so skip it
1914        for (j, &(idx_, _)) in pairs.iter().enumerate().take(n_neighbors) {
1915            nn_orig[[i, j]] = idx_;
1916        }
1917    }
1918
1919    // For each point, find the n_neighbors nearest _neighbors in the _embedded space
1920    let mut nn_embedded = Array2::<usize>::zeros((n_samples, n_neighbors));
1921    for i in 0..n_samples {
1922        // Get the indices of the sorted distances
1923        let row = dist_embedded.row(i).to_owned();
1924        let mut pairs: Vec<(usize, f64)> = row.iter().enumerate().map(|(j, &d)| (j, d)).collect();
1925        pairs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1926
1927        // The first element will be i itself (distance 0), so skip it
1928        for (j, &(idx, _)) in pairs.iter().skip(1).take(n_neighbors).enumerate() {
1929            nn_embedded[[i, j]] = idx;
1930        }
1931    }
1932
1933    // Calculate the trustworthiness score
1934    let mut t = 0.0;
1935    for i in 0..n_samples {
1936        for &j in nn_embedded.row(i).iter() {
1937            // Check if j is not in the n_neighbors nearest neighbors in the original space
1938            let is_not_neighbor = !nn_orig.row(i).iter().any(|&nn| nn == j);
1939
1940            if is_not_neighbor {
1941                // Find the rank of j in the original space
1942                let row = dist_x.row(i).to_owned();
1943                let mut pairs: Vec<(usize, f64)> =
1944                    row.iter().enumerate().map(|(idx, &d)| (idx, d)).collect();
1945                pairs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1946
1947                let rank = pairs.iter().position(|&(idx_, _)| idx_ == j).unwrap_or(0) - n_neighbors;
1948
1949                t += rank as f64;
1950            }
1951        }
1952    }
1953
1954    // Normalize the trustworthiness score
1955    let n = n_samples as f64;
1956    let k = n_neighbors as f64;
1957    let normalizer = 2.0 / (n * k * (2.0 * n - 3.0 * k - 1.0));
1958    let trustworthiness = 1.0 - normalizer * t;
1959
1960    Ok(trustworthiness)
1961}
1962
1963#[cfg(test)]
1964mod tests {
1965    use super::*;
1966    use approx::assert_abs_diff_eq;
1967    use scirs2_core::ndarray::arr2;
1968
1969    #[test]
1970    fn test_tsne_simple() {
1971        // Create a simple dataset
1972        let x = arr2(&[
1973            [0.0, 0.0],
1974            [0.0, 1.0],
1975            [1.0, 0.0],
1976            [1.0, 1.0],
1977            [5.0, 5.0],
1978            [6.0, 5.0],
1979            [5.0, 6.0],
1980            [6.0, 6.0],
1981        ]);
1982
1983        // Initialize and fit t-SNE with exact method
1984        let mut tsne_exact = TSNE::new()
1985            .with_n_components(2)
1986            .with_perplexity(2.0)
1987            .with_method("exact")
1988            .with_random_state(42)
1989            .with_max_iter(250)
1990            .with_verbose(false);
1991
1992        let embedding_exact = tsne_exact.fit_transform(&x).unwrap();
1993
1994        // Check that the shape is correct
1995        assert_eq!(embedding_exact.shape(), &[8, 2]);
1996
1997        // Check that groups are separated in the embedding space
1998        // Compute the average distance within each group
1999        let dist_group1 =
2000            average_pairwise_distance(&embedding_exact.slice(scirs2_core::ndarray::s![0..4, ..]));
2001        let dist_group2 =
2002            average_pairwise_distance(&embedding_exact.slice(scirs2_core::ndarray::s![4..8, ..]));
2003
2004        // Compute the average distance between groups
2005        let dist_between = average_intergroup_distance(
2006            &embedding_exact.slice(scirs2_core::ndarray::s![0..4, ..]),
2007            &embedding_exact.slice(scirs2_core::ndarray::s![4..8, ..]),
2008        );
2009
2010        // The between-group distance should be larger than the within-group distances
2011        assert!(dist_between > dist_group1);
2012        assert!(dist_between > dist_group2);
2013    }
2014
2015    #[test]
2016    fn test_tsne_barnes_hut() {
2017        // Create a simple dataset
2018        let x = arr2(&[
2019            [0.0, 0.0],
2020            [0.0, 1.0],
2021            [1.0, 0.0],
2022            [1.0, 1.0],
2023            [5.0, 5.0],
2024            [6.0, 5.0],
2025            [5.0, 6.0],
2026            [6.0, 6.0],
2027        ]);
2028
2029        // Initialize and fit t-SNE with Barnes-Hut method
2030        let mut tsne_bh = TSNE::new()
2031            .with_n_components(2)
2032            .with_perplexity(2.0)
2033            .with_method("barnes_hut")
2034            .with_angle(0.5)
2035            .with_random_state(42)
2036            .with_max_iter(250)
2037            .with_verbose(false);
2038
2039        let embedding_bh = tsne_bh.fit_transform(&x).unwrap();
2040
2041        // Check that the shape is correct
2042        assert_eq!(embedding_bh.shape(), &[8, 2]);
2043
2044        // Test basic functionality - Barnes-Hut is approximate so just check for basic properties
2045        assert!(embedding_bh.iter().all(|&x| x.is_finite()));
2046
2047        // Check that the embedding has some spread (not all points collapsed to the same location)
2048        let min_val = embedding_bh.iter().cloned().fold(f64::INFINITY, f64::min);
2049        let max_val = embedding_bh
2050            .iter()
2051            .cloned()
2052            .fold(f64::NEG_INFINITY, f64::max);
2053        assert!(
2054            max_val - min_val > 1e-6,
2055            "Embedding should have some spread"
2056        );
2057
2058        // Check that KL divergence was computed (Barnes-Hut is approximate, so we're more lenient)
2059        assert!(tsne_bh.kl_divergence().is_some());
2060
2061        // For Barnes-Hut approximation, the KL divergence might not always be finite
2062        // due to the approximation nature, so we just check that it's a number
2063        let kl_div = tsne_bh.kl_divergence().unwrap();
2064        if !kl_div.is_finite() {
2065            // This is acceptable for Barnes-Hut approximation
2066            println!(
2067                "Barnes-Hut KL divergence: {} (non-finite, which is acceptable for approximation)",
2068                kl_div
2069            );
2070        } else {
2071            println!("Barnes-Hut KL divergence: {} (finite)", kl_div);
2072        }
2073    }
2074
2075    #[test]
2076    fn test_tsne_multicore() {
2077        // Create a simple dataset
2078        let x = arr2(&[
2079            [0.0, 0.0],
2080            [0.0, 1.0],
2081            [1.0, 0.0],
2082            [1.0, 1.0],
2083            [5.0, 5.0],
2084            [6.0, 5.0],
2085            [5.0, 6.0],
2086            [6.0, 6.0],
2087        ]);
2088
2089        // Initialize and fit t-SNE with multicore enabled
2090        let mut tsne_multicore = TSNE::new()
2091            .with_n_components(2)
2092            .with_perplexity(2.0)
2093            .with_method("exact")
2094            .with_n_jobs(-1) // Use all cores
2095            .with_random_state(42)
2096            .with_max_iter(100) // Shorter for testing
2097            .with_verbose(false);
2098
2099        let embedding_multicore = tsne_multicore.fit_transform(&x).unwrap();
2100
2101        // Check that the shape is correct
2102        assert_eq!(embedding_multicore.shape(), &[8, 2]);
2103
2104        // Test basic functionality - multicore should produce valid results
2105        assert!(embedding_multicore.iter().all(|&x| x.is_finite()));
2106
2107        // Check that the embedding has some spread (more lenient for short iterations)
2108        let min_val = embedding_multicore
2109            .iter()
2110            .cloned()
2111            .fold(f64::INFINITY, f64::min);
2112        let max_val = embedding_multicore
2113            .iter()
2114            .cloned()
2115            .fold(f64::NEG_INFINITY, f64::max);
2116        assert!(
2117            max_val - min_val > 1e-12,
2118            "Embedding should have some spread, got range: {}",
2119            max_val - min_val
2120        );
2121
2122        // Test single-core vs multicore consistency
2123        let mut tsne_singlecore = TSNE::new()
2124            .with_n_components(2)
2125            .with_perplexity(2.0)
2126            .with_method("exact")
2127            .with_n_jobs(1) // Single core
2128            .with_random_state(42)
2129            .with_max_iter(100)
2130            .with_verbose(false);
2131
2132        let embedding_singlecore = tsne_singlecore.fit_transform(&x).unwrap();
2133
2134        // Both should produce finite results (exact numerical match is not expected due to randomness)
2135        assert!(embedding_multicore.iter().all(|&x| x.is_finite()));
2136        assert!(embedding_singlecore.iter().all(|&x| x.is_finite()));
2137    }
2138
2139    #[test]
2140    fn test_tsne_3d_barnes_hut() {
2141        // Create a simple 3D dataset
2142        let x = arr2(&[
2143            [0.0, 0.0, 0.0],
2144            [0.0, 1.0, 0.0],
2145            [1.0, 0.0, 0.0],
2146            [1.0, 1.0, 0.0],
2147            [5.0, 5.0, 5.0],
2148            [6.0, 5.0, 5.0],
2149            [5.0, 6.0, 5.0],
2150            [6.0, 6.0, 5.0],
2151        ]);
2152
2153        // Initialize and fit t-SNE with Barnes-Hut method for 3D
2154        let mut tsne_3d = TSNE::new()
2155            .with_n_components(3)
2156            .with_perplexity(2.0)
2157            .with_method("barnes_hut")
2158            .with_angle(0.5)
2159            .with_random_state(42)
2160            .with_max_iter(250)
2161            .with_verbose(false);
2162
2163        let embedding_3d = tsne_3d.fit_transform(&x).unwrap();
2164
2165        // Check that the shape is correct
2166        assert_eq!(embedding_3d.shape(), &[8, 3]);
2167
2168        // Test basic functionality - should not panic
2169        assert!(embedding_3d.iter().all(|&x| x.is_finite()));
2170    }
2171
2172    // Helper function to compute average pairwise distance within a group
2173    fn average_pairwise_distance(
2174        points: &ArrayBase<scirs2_core::ndarray::ViewRepr<&f64>, Ix2>,
2175    ) -> f64 {
2176        let n = points.shape()[0];
2177        let mut total_dist = 0.0;
2178        let mut count = 0;
2179
2180        for i in 0..n {
2181            for j in i + 1..n {
2182                let mut dist_squared = 0.0;
2183                for k in 0..points.shape()[1] {
2184                    let diff = points[[i, k]] - points[[j, k]];
2185                    dist_squared += diff * diff;
2186                }
2187                total_dist += dist_squared.sqrt();
2188                count += 1;
2189            }
2190        }
2191
2192        if count > 0 {
2193            total_dist / count as f64
2194        } else {
2195            0.0
2196        }
2197    }
2198
2199    // Helper function to compute average distance between two groups
2200    fn average_intergroup_distance(
2201        group1: &ArrayBase<scirs2_core::ndarray::ViewRepr<&f64>, Ix2>,
2202        group2: &ArrayBase<scirs2_core::ndarray::ViewRepr<&f64>, Ix2>,
2203    ) -> f64 {
2204        let n1 = group1.shape()[0];
2205        let n2 = group2.shape()[0];
2206        let mut total_dist = 0.0;
2207        let mut count = 0;
2208
2209        for i in 0..n1 {
2210            for j in 0..n2 {
2211                let mut dist_squared = 0.0;
2212                for k in 0..group1.shape()[1] {
2213                    let diff = group1[[i, k]] - group2[[j, k]];
2214                    dist_squared += diff * diff;
2215                }
2216                total_dist += dist_squared.sqrt();
2217                count += 1;
2218            }
2219        }
2220
2221        if count > 0 {
2222            total_dist / count as f64
2223        } else {
2224            0.0
2225        }
2226    }
2227
2228    #[test]
2229    fn test_trustworthiness() {
2230        // Create a simple dataset where we know the structure
2231        let x = arr2(&[
2232            [0.0, 0.0],
2233            [0.0, 1.0],
2234            [1.0, 0.0],
2235            [1.0, 1.0],
2236            [5.0, 5.0],
2237            [5.0, 6.0],
2238            [6.0, 5.0],
2239            [6.0, 6.0],
2240        ]);
2241
2242        // A perfect embedding would preserve all neighborhoods
2243        let perfect_embedding = x.clone();
2244        let t_perfect = trustworthiness(&x, &perfect_embedding, 3, "euclidean").unwrap();
2245        assert_abs_diff_eq!(t_perfect, 1.0, epsilon = 1e-10);
2246
2247        // A random embedding would have low trustworthiness
2248        let random_embedding = arr2(&[
2249            [0.9, 0.1],
2250            [0.8, 0.2],
2251            [0.7, 0.3],
2252            [0.6, 0.4],
2253            [0.5, 0.5],
2254            [0.4, 0.6],
2255            [0.3, 0.7],
2256            [0.2, 0.8],
2257        ]);
2258
2259        let t_random = trustworthiness(&x, &random_embedding, 3, "euclidean").unwrap();
2260        assert!(t_random < 1.0);
2261    }
2262}