Skip to main content

scirs2_transform/reduction/tsne/
mod.rs

1//! t-SNE (t-distributed Stochastic Neighbor Embedding) implementation
2//!
3//! This module provides an implementation of t-SNE, a technique for dimensionality
4//! reduction particularly well-suited for visualization of high-dimensional data.
5//!
6//! t-SNE converts similarities between data points to joint probabilities and tries
7//! to minimize the Kullback-Leibler divergence between the joint probabilities of
8//! the low-dimensional embedding and the high-dimensional data.
9//!
10//! ## Features
11//!
12//! - **Barnes-Hut approximation** for O(N log N) complexity via spatial trees
13//! - **Perplexity-based bandwidth selection** using binary search for sigma
14//! - **Early exaggeration phase** for better global structure preservation
15//! - **Momentum-based gradient descent** with adaptive gains
16//! - **Multiple distance metrics**: euclidean, manhattan, cosine, chebyshev
17//! - **Sparse kNN affinity** for memory-efficient computation on large datasets
18//! - **Multicore support** via rayon parallel iterators
19
20mod spatial_tree;
21
22use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
23use scirs2_core::numeric::{Float, NumCast};
24use scirs2_core::parallel_ops::*;
25use scirs2_core::random::Normal;
26use scirs2_core::random::RandomExt;
27
28use crate::error::{Result, TransformError};
29use crate::reduction::PCA;
30
31use spatial_tree::SpatialTree;
32
33// Constants for numerical stability
34const MACHINE_EPSILON: f64 = 1e-14;
35const EPSILON: f64 = 1e-7;
36
37/// t-SNE (t-distributed Stochastic Neighbor Embedding) for dimensionality reduction
38///
39/// t-SNE is a nonlinear dimensionality reduction technique well-suited for
40/// embedding high-dimensional data for visualization in a low-dimensional space
41/// (typically 2D or 3D). It models each high-dimensional object by a two- or
42/// three-dimensional point in such a way that similar objects are modeled by
43/// nearby points and dissimilar objects are modeled by distant points with
44/// high probability.
45///
46/// # Example
47///
48/// ```rust,no_run
49/// use scirs2_transform::TSNE;
50/// use scirs2_core::ndarray::arr2;
51///
52/// let data = arr2(&[
53///     [0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0],
54///     [5.0, 5.0], [6.0, 5.0], [5.0, 6.0], [6.0, 6.0],
55/// ]);
56///
57/// let mut tsne = TSNE::new()
58///     .with_n_components(2)
59///     .with_perplexity(2.0)
60///     .with_max_iter(500);
61///
62/// let embedding = tsne.fit_transform(&data).expect("should succeed");
63/// assert_eq!(embedding.shape(), &[8, 2]);
64/// ```
65pub struct TSNE {
66    /// Number of components in the embedded space
67    n_components: usize,
68    /// Perplexity parameter that balances attention between local and global structure
69    perplexity: f64,
70    /// Weight of early exaggeration phase
71    early_exaggeration: f64,
72    /// Learning rate for optimization
73    learning_rate: f64,
74    /// Maximum number of iterations
75    max_iter: usize,
76    /// Maximum iterations without progress before early stopping
77    n_iter_without_progress: usize,
78    /// Minimum gradient norm for convergence
79    min_grad_norm: f64,
80    /// Method to compute pairwise distances
81    metric: String,
82    /// Method to perform dimensionality reduction ("exact" or "barnes_hut")
83    method: String,
84    /// Initialization method ("pca" or "random")
85    init: String,
86    /// Angle for Barnes-Hut approximation (trade-off between speed and accuracy)
87    angle: f64,
88    /// Whether to use multicore processing (-1 = all cores, 1 = single)
89    n_jobs: i32,
90    /// Verbosity level
91    verbose: bool,
92    /// Random state for reproducibility
93    random_state: Option<u64>,
94    /// Degrees of freedom for the t-distribution (default 1.0 = standard Cauchy)
95    degrees_of_freedom: Option<f64>,
96    /// The embedding vectors
97    embedding_: Option<Array2<f64>>,
98    /// KL divergence after optimization
99    kl_divergence_: Option<f64>,
100    /// Total number of iterations run
101    n_iter_: Option<usize>,
102    /// Effective learning rate used
103    learning_rate_: Option<f64>,
104}
105
106impl Default for TSNE {
107    fn default() -> Self {
108        Self::new()
109    }
110}
111
112impl TSNE {
113    /// Creates a new t-SNE instance with default parameters
114    pub fn new() -> Self {
115        TSNE {
116            n_components: 2,
117            perplexity: 30.0,
118            early_exaggeration: 12.0,
119            learning_rate: 200.0,
120            max_iter: 1000,
121            n_iter_without_progress: 300,
122            min_grad_norm: 1e-7,
123            metric: "euclidean".to_string(),
124            method: "barnes_hut".to_string(),
125            init: "pca".to_string(),
126            angle: 0.5,
127            n_jobs: -1,
128            verbose: false,
129            random_state: None,
130            degrees_of_freedom: None,
131            embedding_: None,
132            kl_divergence_: None,
133            n_iter_: None,
134            learning_rate_: None,
135        }
136    }
137
138    /// Sets the number of components in the embedded space
139    pub fn with_n_components(mut self, n_components: usize) -> Self {
140        self.n_components = n_components;
141        self
142    }
143
144    /// Sets the perplexity parameter
145    pub fn with_perplexity(mut self, perplexity: f64) -> Self {
146        self.perplexity = perplexity;
147        self
148    }
149
150    /// Sets the early exaggeration factor
151    pub fn with_early_exaggeration(mut self, early_exaggeration: f64) -> Self {
152        self.early_exaggeration = early_exaggeration;
153        self
154    }
155
156    /// Sets the learning rate for gradient descent
157    pub fn with_learning_rate(mut self, learning_rate: f64) -> Self {
158        self.learning_rate = learning_rate;
159        self
160    }
161
162    /// Sets the maximum number of iterations
163    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
164        self.max_iter = max_iter;
165        self
166    }
167
168    /// Sets the number of iterations without progress before early stopping
169    pub fn with_n_iter_without_progress(mut self, n_iter_without_progress: usize) -> Self {
170        self.n_iter_without_progress = n_iter_without_progress;
171        self
172    }
173
174    /// Sets the minimum gradient norm for convergence
175    pub fn with_min_grad_norm(mut self, min_grad_norm: f64) -> Self {
176        self.min_grad_norm = min_grad_norm;
177        self
178    }
179
180    /// Sets the metric for pairwise distance computation
181    ///
182    /// Supported metrics:
183    /// - "euclidean": Euclidean distance (L2 norm) - default
184    /// - "manhattan": Manhattan distance (L1 norm)
185    /// - "cosine": Cosine distance (1 - cosine similarity)
186    /// - "chebyshev": Chebyshev distance (maximum coordinate difference)
187    pub fn with_metric(mut self, metric: &str) -> Self {
188        self.metric = metric.to_string();
189        self
190    }
191
192    /// Sets the method for dimensionality reduction ("exact" or "barnes_hut")
193    pub fn with_method(mut self, method: &str) -> Self {
194        self.method = method.to_string();
195        self
196    }
197
198    /// Sets the initialization method ("pca" or "random")
199    pub fn with_init(mut self, init: &str) -> Self {
200        self.init = init.to_string();
201        self
202    }
203
204    /// Sets the angle for Barnes-Hut approximation (0.0 = exact, 1.0 = fast but approximate)
205    pub fn with_angle(mut self, angle: f64) -> Self {
206        self.angle = angle;
207        self
208    }
209
210    /// Sets the number of parallel jobs to run
211    /// * n_jobs = -1: Use all available cores
212    /// * n_jobs = 1: Use single-core (disable multicore)
213    /// * n_jobs > 1: Use specific number of cores
214    pub fn with_n_jobs(mut self, n_jobs: i32) -> Self {
215        self.n_jobs = n_jobs;
216        self
217    }
218
219    /// Sets the verbosity level
220    pub fn with_verbose(mut self, verbose: bool) -> Self {
221        self.verbose = verbose;
222        self
223    }
224
225    /// Sets the random state for reproducibility
226    pub fn with_random_state(mut self, random_state: u64) -> Self {
227        self.random_state = Some(random_state);
228        self
229    }
230
231    /// Sets the degrees of freedom for the Student-t distribution
232    ///
233    /// Default is n_components - 1 (or 1 if n_components <= 1).
234    /// Setting this to a larger value produces heavier tails, which can help
235    /// with crowding in higher-dimensional embeddings.
236    pub fn with_degrees_of_freedom(mut self, dof: f64) -> Self {
237        self.degrees_of_freedom = Some(dof);
238        self
239    }
240
241    /// Fit t-SNE to input data and transform it to the embedded space
242    ///
243    /// # Arguments
244    /// * `x` - Input data, shape (n_samples, n_features)
245    ///
246    /// # Returns
247    /// * `Result<Array2<f64>>` - Embedding of the training data, shape (n_samples, n_components)
248    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
249    where
250        S: Data,
251        S::Elem: Float + NumCast,
252    {
253        let x_f64 = x.mapv(|v| NumCast::from(v).unwrap_or(0.0));
254
255        let n_samples = x_f64.shape()[0];
256        let n_features = x_f64.shape()[1];
257
258        // Input validation
259        if n_samples == 0 || n_features == 0 {
260            return Err(TransformError::InvalidInput("Empty input data".to_string()));
261        }
262
263        if self.perplexity >= n_samples as f64 {
264            return Err(TransformError::InvalidInput(format!(
265                "perplexity ({}) must be less than n_samples ({})",
266                self.perplexity, n_samples
267            )));
268        }
269
270        if self.method == "barnes_hut" && self.n_components > 3 {
271            return Err(TransformError::InvalidInput(
272                "'n_components' should be <= 3 for barnes_hut algorithm".to_string(),
273            ));
274        }
275
276        self.learning_rate_ = Some(self.learning_rate);
277
278        // Initialize embedding
279        let x_embedded = self.initialize_embedding(&x_f64)?;
280
281        // Compute pairwise affinities (P)
282        let p = self.compute_pairwise_affinities(&x_f64)?;
283
284        // Run t-SNE optimization
285        let (embedding, kl_divergence, n_iter) =
286            self.tsne_optimization(p, x_embedded, n_samples)?;
287
288        self.embedding_ = Some(embedding.clone());
289        self.kl_divergence_ = Some(kl_divergence);
290        self.n_iter_ = Some(n_iter);
291
292        Ok(embedding)
293    }
294
295    /// Initialize embedding either with PCA or random
296    fn initialize_embedding(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
297        let n_samples = x.shape()[0];
298
299        if self.init == "pca" {
300            // Bound by both dimensions: PCA can return at most min(n_samples, n_features) components.
301            let n_components = self.n_components.min(x.shape()[0]).min(x.shape()[1]);
302            let mut pca = PCA::new(n_components, true, false);
303            let mut x_embedded = pca.fit_transform(x)?;
304
305            // Scale PCA initialization
306            let col_var = x_embedded.column(0).map(|&v| v * v).sum() / (n_samples as f64);
307            let std_dev = col_var.sqrt();
308            if std_dev > 0.0 {
309                x_embedded.mapv_inplace(|v| v / std_dev * 1e-4);
310            }
311
312            Ok(x_embedded)
313        } else if self.init == "random" {
314            use scirs2_core::random::{thread_rng, Distribution};
315            let normal = Normal::new(0.0, 1e-4).map_err(|e| {
316                TransformError::ComputationError(format!(
317                    "Failed to create normal distribution: {e}"
318                ))
319            })?;
320            let mut rng = thread_rng();
321
322            let data: Vec<f64> = (0..(n_samples * self.n_components))
323                .map(|_| normal.sample(&mut rng))
324                .collect();
325            Array2::from_shape_vec((n_samples, self.n_components), data).map_err(|e| {
326                TransformError::ComputationError(format!("Failed to create embedding array: {e}"))
327            })
328        } else {
329            Err(TransformError::InvalidInput(format!(
330                "Initialization method '{}' not recognized. Use 'pca' or 'random'.",
331                self.init
332            )))
333        }
334    }
335
336    /// Compute pairwise affinities with perplexity-based normalization
337    fn compute_pairwise_affinities(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
338        // Compute pairwise distances
339        let distances = self.compute_pairwise_distances(x)?;
340
341        // Convert distances to affinities using binary search for sigma
342        let p = self.distances_to_affinities(&distances)?;
343
344        // Symmetrize and normalize the affinity matrix
345        let mut p_symmetric = &p + &p.t();
346
347        let p_sum = p_symmetric.sum();
348        if p_sum > 0.0 {
349            p_symmetric.mapv_inplace(|v| v.max(MACHINE_EPSILON) / p_sum);
350        }
351
352        Ok(p_symmetric)
353    }
354
355    /// Compute pairwise distances with optional multicore support
356    fn compute_pairwise_distances(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
357        let n_samples = x.shape()[0];
358        let n_features = x.shape()[1];
359        let mut distances = Array2::zeros((n_samples, n_samples));
360
361        match self.metric.as_str() {
362            "euclidean" => {
363                if self.n_jobs == 1 {
364                    for i in 0..n_samples {
365                        for j in i + 1..n_samples {
366                            let mut dist_squared = 0.0;
367                            for k in 0..n_features {
368                                let diff = x[[i, k]] - x[[j, k]];
369                                dist_squared += diff * diff;
370                            }
371                            distances[[i, j]] = dist_squared;
372                            distances[[j, i]] = dist_squared;
373                        }
374                    }
375                } else {
376                    let upper_triangle_indices: Vec<(usize, usize)> = (0..n_samples)
377                        .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
378                        .collect();
379
380                    let squared_distances: Vec<f64> = upper_triangle_indices
381                        .par_iter()
382                        .map(|&(i, j)| {
383                            let mut dist_squared = 0.0;
384                            for k in 0..n_features {
385                                let diff = x[[i, k]] - x[[j, k]];
386                                dist_squared += diff * diff;
387                            }
388                            dist_squared
389                        })
390                        .collect();
391
392                    for (idx, &(i, j)) in upper_triangle_indices.iter().enumerate() {
393                        distances[[i, j]] = squared_distances[idx];
394                        distances[[j, i]] = squared_distances[idx];
395                    }
396                }
397            }
398            "manhattan" => {
399                let compute_manhattan = |i: usize, j: usize| -> f64 {
400                    let mut dist = 0.0;
401                    for k in 0..n_features {
402                        dist += (x[[i, k]] - x[[j, k]]).abs();
403                    }
404                    dist
405                };
406
407                if self.n_jobs == 1 {
408                    for i in 0..n_samples {
409                        for j in i + 1..n_samples {
410                            let dist = compute_manhattan(i, j);
411                            distances[[i, j]] = dist;
412                            distances[[j, i]] = dist;
413                        }
414                    }
415                } else {
416                    let upper: Vec<(usize, usize)> = (0..n_samples)
417                        .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
418                        .collect();
419                    let dists: Vec<f64> = upper
420                        .par_iter()
421                        .map(|&(i, j)| {
422                            let mut dist = 0.0;
423                            for k in 0..n_features {
424                                dist += (x[[i, k]] - x[[j, k]]).abs();
425                            }
426                            dist
427                        })
428                        .collect();
429                    for (idx, &(i, j)) in upper.iter().enumerate() {
430                        distances[[i, j]] = dists[idx];
431                        distances[[j, i]] = dists[idx];
432                    }
433                }
434            }
435            "cosine" => {
436                let mut normalized_x = Array2::zeros((n_samples, n_features));
437                for i in 0..n_samples {
438                    let row = x.row(i);
439                    let norm = row.iter().map(|v| v * v).sum::<f64>().sqrt();
440                    if norm > EPSILON {
441                        for j in 0..n_features {
442                            normalized_x[[i, j]] = x[[i, j]] / norm;
443                        }
444                    }
445                }
446
447                let upper: Vec<(usize, usize)> = (0..n_samples)
448                    .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
449                    .collect();
450
451                let compute_fn = |i: usize, j: usize| -> f64 {
452                    let mut dot_product = 0.0;
453                    for k in 0..n_features {
454                        dot_product += normalized_x[[i, k]] * normalized_x[[j, k]];
455                    }
456                    1.0 - dot_product.clamp(-1.0, 1.0)
457                };
458
459                if self.n_jobs == 1 {
460                    for &(i, j) in &upper {
461                        let d = compute_fn(i, j);
462                        distances[[i, j]] = d;
463                        distances[[j, i]] = d;
464                    }
465                } else {
466                    let dists: Vec<f64> = upper
467                        .par_iter()
468                        .map(|&(i, j)| {
469                            let mut dp = 0.0;
470                            for k in 0..n_features {
471                                dp += normalized_x[[i, k]] * normalized_x[[j, k]];
472                            }
473                            1.0 - dp.clamp(-1.0, 1.0)
474                        })
475                        .collect();
476                    for (idx, &(i, j)) in upper.iter().enumerate() {
477                        distances[[i, j]] = dists[idx];
478                        distances[[j, i]] = dists[idx];
479                    }
480                }
481            }
482            "chebyshev" => {
483                let upper: Vec<(usize, usize)> = (0..n_samples)
484                    .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
485                    .collect();
486
487                if self.n_jobs == 1 {
488                    for &(i, j) in &upper {
489                        let mut max_dist = 0.0;
490                        for k in 0..n_features {
491                            let diff = (x[[i, k]] - x[[j, k]]).abs();
492                            max_dist = max_dist.max(diff);
493                        }
494                        distances[[i, j]] = max_dist;
495                        distances[[j, i]] = max_dist;
496                    }
497                } else {
498                    let dists: Vec<f64> = upper
499                        .par_iter()
500                        .map(|&(i, j)| {
501                            let mut max_dist = 0.0;
502                            for k in 0..n_features {
503                                let diff = (x[[i, k]] - x[[j, k]]).abs();
504                                max_dist = max_dist.max(diff);
505                            }
506                            max_dist
507                        })
508                        .collect();
509                    for (idx, &(i, j)) in upper.iter().enumerate() {
510                        distances[[i, j]] = dists[idx];
511                        distances[[j, i]] = dists[idx];
512                    }
513                }
514            }
515            _ => {
516                return Err(TransformError::InvalidInput(format!(
517                    "Metric '{}' not supported. Use: 'euclidean', 'manhattan', 'cosine', 'chebyshev'",
518                    self.metric
519                )));
520            }
521        }
522
523        Ok(distances)
524    }
525
526    /// Convert distances to affinities using perplexity-based normalization
527    fn distances_to_affinities(&self, distances: &Array2<f64>) -> Result<Array2<f64>> {
528        let n_samples = distances.shape()[0];
529        let target = (2.0f64).ln() * self.perplexity;
530
531        if self.n_jobs == 1 {
532            let mut p = Array2::zeros((n_samples, n_samples));
533            for i in 0..n_samples {
534                self.binary_search_sigma(i, distances, target, &mut p)?;
535            }
536            Ok(p)
537        } else {
538            let prob_rows: Vec<Vec<f64>> = (0..n_samples)
539                .into_par_iter()
540                .map(|i| {
541                    let distances_i: Vec<f64> = (0..n_samples).map(|j| distances[[i, j]]).collect();
542                    Self::binary_search_sigma_row(i, &distances_i, n_samples, target)
543                })
544                .collect();
545
546            let mut p = Array2::zeros((n_samples, n_samples));
547            for (i, row) in prob_rows.iter().enumerate() {
548                for (j, &val) in row.iter().enumerate() {
549                    p[[i, j]] = val;
550                }
551            }
552            Ok(p)
553        }
554    }
555
556    /// Binary search for the optimal sigma (bandwidth) for a single row
557    fn binary_search_sigma(
558        &self,
559        i: usize,
560        distances: &Array2<f64>,
561        target: f64,
562        p: &mut Array2<f64>,
563    ) -> Result<()> {
564        let n_samples = distances.shape()[0];
565        let mut beta_min = -f64::INFINITY;
566        let mut beta_max = f64::INFINITY;
567        let mut beta = 1.0;
568
569        for _ in 0..50 {
570            let mut sum_pi = 0.0;
571            let mut h = 0.0;
572
573            for j in 0..n_samples {
574                if i == j {
575                    p[[i, j]] = 0.0;
576                    continue;
577                }
578                let p_ij = (-beta * distances[[i, j]]).exp();
579                p[[i, j]] = p_ij;
580                sum_pi += p_ij;
581            }
582
583            if sum_pi > 0.0 {
584                for j in 0..n_samples {
585                    if i == j {
586                        continue;
587                    }
588                    p[[i, j]] /= sum_pi;
589                    if p[[i, j]] > MACHINE_EPSILON {
590                        h -= p[[i, j]] * p[[i, j]].ln();
591                    }
592                }
593            }
594
595            let h_diff = h - target;
596            if h_diff.abs() < EPSILON {
597                break;
598            }
599
600            if h_diff > 0.0 {
601                beta_min = beta;
602                beta = if beta_max == f64::INFINITY {
603                    beta * 2.0
604                } else {
605                    (beta + beta_max) / 2.0
606                };
607            } else {
608                beta_max = beta;
609                beta = if beta_min == -f64::INFINITY {
610                    beta / 2.0
611                } else {
612                    (beta + beta_min) / 2.0
613                };
614            }
615        }
616
617        Ok(())
618    }
619
620    /// Parallel-safe binary search for a single row (returns Vec)
621    fn binary_search_sigma_row(
622        i: usize,
623        distances_i: &[f64],
624        n_samples: usize,
625        target: f64,
626    ) -> Vec<f64> {
627        let mut beta_min = -f64::INFINITY;
628        let mut beta_max = f64::INFINITY;
629        let mut beta = 1.0;
630        let mut p_row = vec![0.0; n_samples];
631
632        for _ in 0..50 {
633            let mut sum_pi = 0.0;
634            let mut h = 0.0;
635
636            for j in 0..n_samples {
637                if i == j {
638                    p_row[j] = 0.0;
639                    continue;
640                }
641                let p_ij = (-beta * distances_i[j]).exp();
642                p_row[j] = p_ij;
643                sum_pi += p_ij;
644            }
645
646            if sum_pi > 0.0 {
647                for (j, prob) in p_row.iter_mut().enumerate().take(n_samples) {
648                    if i == j {
649                        continue;
650                    }
651                    *prob /= sum_pi;
652                    if *prob > MACHINE_EPSILON {
653                        h -= *prob * prob.ln();
654                    }
655                }
656            }
657
658            let h_diff = h - target;
659            if h_diff.abs() < EPSILON {
660                break;
661            }
662
663            if h_diff > 0.0 {
664                beta_min = beta;
665                beta = if beta_max == f64::INFINITY {
666                    beta * 2.0
667                } else {
668                    (beta + beta_max) / 2.0
669                };
670            } else {
671                beta_max = beta;
672                beta = if beta_min == -f64::INFINITY {
673                    beta / 2.0
674                } else {
675                    (beta + beta_min) / 2.0
676                };
677            }
678        }
679
680        p_row
681    }
682
683    /// Get the effective degrees of freedom for the t-distribution
684    fn effective_dof(&self) -> f64 {
685        if let Some(dof) = self.degrees_of_freedom {
686            dof
687        } else {
688            (self.n_components - 1).max(1) as f64
689        }
690    }
691
692    /// Main t-SNE optimization loop using gradient descent
693    fn tsne_optimization(
694        &self,
695        p: Array2<f64>,
696        initial_embedding: Array2<f64>,
697        n_samples: usize,
698    ) -> Result<(Array2<f64>, f64, usize)> {
699        let n_components = self.n_components;
700        let degrees_of_freedom = self.effective_dof();
701
702        let mut embedding = initial_embedding;
703        let mut update = Array2::zeros((n_samples, n_components));
704        let mut gains = Array2::ones((n_samples, n_components));
705        let mut error = f64::INFINITY;
706        let mut best_error = f64::INFINITY;
707        let mut best_iter = 0;
708        let mut iter = 0;
709
710        let exploration_n_iter = 250;
711        let n_iter_check = 50;
712
713        // Apply early exaggeration
714        let p_early = &p * self.early_exaggeration;
715
716        if self.verbose {
717            println!("[t-SNE] Starting optimization with early exaggeration phase...");
718        }
719
720        // Early exaggeration phase
721        for i in 0..exploration_n_iter {
722            let (curr_error, grad) = if self.method == "barnes_hut" {
723                self.compute_gradient_barnes_hut(&embedding, &p_early, degrees_of_freedom)?
724            } else {
725                self.compute_gradient_exact(&embedding, &p_early, degrees_of_freedom)?
726            };
727
728            self.gradient_update(
729                &mut embedding,
730                &mut update,
731                &mut gains,
732                &grad,
733                0.5,
734                self.learning_rate_,
735            )?;
736
737            if (i + 1) % n_iter_check == 0 {
738                if self.verbose {
739                    println!("[t-SNE] Iteration {}: error = {:.7}", i + 1, curr_error);
740                }
741
742                if curr_error < best_error {
743                    best_error = curr_error;
744                    best_iter = i;
745                } else if i - best_iter > self.n_iter_without_progress {
746                    if self.verbose {
747                        println!("[t-SNE] Early convergence at iteration {}", i + 1);
748                    }
749                    break;
750                }
751
752                let grad_norm = grad.mapv(|v| v * v).sum().sqrt();
753                if grad_norm < self.min_grad_norm {
754                    if self.verbose {
755                        println!(
756                            "[t-SNE] Gradient norm {} below threshold at iteration {}",
757                            grad_norm,
758                            i + 1
759                        );
760                    }
761                    break;
762                }
763            }
764
765            iter = i;
766        }
767
768        if self.verbose {
769            println!("[t-SNE] Completed early exaggeration, starting final optimization...");
770        }
771
772        // Final optimization phase without early exaggeration
773        for i in iter + 1..self.max_iter {
774            let (curr_error, grad) = if self.method == "barnes_hut" {
775                self.compute_gradient_barnes_hut(&embedding, &p, degrees_of_freedom)?
776            } else {
777                self.compute_gradient_exact(&embedding, &p, degrees_of_freedom)?
778            };
779            error = curr_error;
780
781            self.gradient_update(
782                &mut embedding,
783                &mut update,
784                &mut gains,
785                &grad,
786                0.8,
787                self.learning_rate_,
788            )?;
789
790            if (i + 1) % n_iter_check == 0 {
791                if self.verbose {
792                    println!("[t-SNE] Iteration {}: error = {:.7}", i + 1, curr_error);
793                }
794
795                if curr_error < best_error {
796                    best_error = curr_error;
797                    best_iter = i;
798                } else if i - best_iter > self.n_iter_without_progress {
799                    if self.verbose {
800                        println!("[t-SNE] Stopping optimization at iteration {}", i + 1);
801                    }
802                    break;
803                }
804
805                let grad_norm = grad.mapv(|v| v * v).sum().sqrt();
806                if grad_norm < self.min_grad_norm {
807                    if self.verbose {
808                        println!(
809                            "[t-SNE] Gradient norm {} below threshold at iteration {}",
810                            grad_norm,
811                            i + 1
812                        );
813                    }
814                    break;
815                }
816            }
817
818            iter = i;
819        }
820
821        if self.verbose {
822            println!(
823                "[t-SNE] Optimization finished after {} iterations with error {:.7}",
824                iter + 1,
825                error
826            );
827        }
828
829        Ok((embedding, error, iter + 1))
830    }
831
832    /// Compute gradient and error for exact t-SNE
833    fn compute_gradient_exact(
834        &self,
835        embedding: &Array2<f64>,
836        p: &Array2<f64>,
837        degrees_of_freedom: f64,
838    ) -> Result<(f64, Array2<f64>)> {
839        let n_samples = embedding.shape()[0];
840        let n_components = embedding.shape()[1];
841
842        // Compute Q matrix
843        let mut dist = Array2::zeros((n_samples, n_samples));
844        let upper: Vec<(usize, usize)> = (0..n_samples)
845            .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
846            .collect();
847
848        if self.n_jobs == 1 {
849            for &(i, j) in &upper {
850                let mut d_squared = 0.0;
851                for k in 0..n_components {
852                    let diff = embedding[[i, k]] - embedding[[j, k]];
853                    d_squared += diff * diff;
854                }
855                let q_ij =
856                    (1.0 + d_squared / degrees_of_freedom).powf(-(degrees_of_freedom + 1.0) / 2.0);
857                dist[[i, j]] = q_ij;
858                dist[[j, i]] = q_ij;
859            }
860        } else {
861            let q_values: Vec<f64> = upper
862                .par_iter()
863                .map(|&(i, j)| {
864                    let mut d_squared = 0.0;
865                    for k in 0..n_components {
866                        let diff = embedding[[i, k]] - embedding[[j, k]];
867                        d_squared += diff * diff;
868                    }
869                    (1.0 + d_squared / degrees_of_freedom).powf(-(degrees_of_freedom + 1.0) / 2.0)
870                })
871                .collect();
872
873            for (idx, &(i, j)) in upper.iter().enumerate() {
874                dist[[i, j]] = q_values[idx];
875                dist[[j, i]] = q_values[idx];
876            }
877        }
878
879        for i in 0..n_samples {
880            dist[[i, i]] = 0.0;
881        }
882
883        let sum_q = dist.sum().max(MACHINE_EPSILON);
884        let q = &dist / sum_q;
885
886        // Compute KL divergence
887        let kl_divergence: f64 = if self.n_jobs == 1 {
888            let mut kl = 0.0;
889            for i in 0..n_samples {
890                for j in 0..n_samples {
891                    if p[[i, j]] > MACHINE_EPSILON && q[[i, j]] > MACHINE_EPSILON {
892                        kl += p[[i, j]] * (p[[i, j]] / q[[i, j]]).ln();
893                    }
894                }
895            }
896            kl
897        } else {
898            (0..n_samples)
899                .into_par_iter()
900                .map(|i| {
901                    let mut local_kl = 0.0;
902                    for j in 0..n_samples {
903                        if p[[i, j]] > MACHINE_EPSILON && q[[i, j]] > MACHINE_EPSILON {
904                            local_kl += p[[i, j]] * (p[[i, j]] / q[[i, j]]).ln();
905                        }
906                    }
907                    local_kl
908                })
909                .sum()
910        };
911
912        // Compute gradient
913        let factor = 4.0 * (degrees_of_freedom + 1.0) / (degrees_of_freedom * sum_q * sum_q);
914
915        let grad = if self.n_jobs == 1 {
916            let mut g = Array2::zeros((n_samples, n_components));
917            for i in 0..n_samples {
918                for j in 0..n_samples {
919                    if i != j {
920                        let p_q_diff = p[[i, j]] - q[[i, j]];
921                        for k in 0..n_components {
922                            g[[i, k]] += factor
923                                * p_q_diff
924                                * dist[[i, j]]
925                                * (embedding[[i, k]] - embedding[[j, k]]);
926                        }
927                    }
928                }
929            }
930            g
931        } else {
932            let grad_rows: Vec<Vec<f64>> = (0..n_samples)
933                .into_par_iter()
934                .map(|i| {
935                    let mut grad_row = vec![0.0; n_components];
936                    for j in 0..n_samples {
937                        if i != j {
938                            let p_q_diff = p[[i, j]] - q[[i, j]];
939                            for k in 0..n_components {
940                                grad_row[k] += factor
941                                    * p_q_diff
942                                    * dist[[i, j]]
943                                    * (embedding[[i, k]] - embedding[[j, k]]);
944                            }
945                        }
946                    }
947                    grad_row
948                })
949                .collect();
950
951            let mut g = Array2::zeros((n_samples, n_components));
952            for (i, row) in grad_rows.iter().enumerate() {
953                for (k, &val) in row.iter().enumerate() {
954                    g[[i, k]] = val;
955                }
956            }
957            g
958        };
959
960        Ok((kl_divergence, grad))
961    }
962
963    /// Compute gradient and error using Barnes-Hut approximation
964    fn compute_gradient_barnes_hut(
965        &self,
966        embedding: &Array2<f64>,
967        p: &Array2<f64>,
968        degrees_of_freedom: f64,
969    ) -> Result<(f64, Array2<f64>)> {
970        let n_samples = embedding.shape()[0];
971        let n_components = embedding.shape()[1];
972
973        // Build spatial tree
974        let tree = if n_components == 2 {
975            SpatialTree::new_quadtree(embedding)?
976        } else if n_components == 3 {
977            SpatialTree::new_octree(embedding)?
978        } else {
979            return Err(TransformError::InvalidInput(
980                "Barnes-Hut only supports 2D and 3D embeddings".to_string(),
981            ));
982        };
983
984        let mut q = Array2::zeros((n_samples, n_samples));
985        let mut grad = Array2::zeros((n_samples, n_components));
986        let mut sum_q = 0.0;
987
988        // Compute repulsive forces using Barnes-Hut
989        for i in 0..n_samples {
990            let point = embedding.row(i).to_owned();
991            let (repulsive_force, q_sum) =
992                tree.compute_forces(&point, i, self.angle, degrees_of_freedom)?;
993
994            sum_q += q_sum;
995
996            for j in 0..n_components {
997                grad[[i, j]] += repulsive_force[j];
998            }
999
1000            // Compute Q matrix entries for KL divergence
1001            for j in 0..n_samples {
1002                if i != j {
1003                    let mut dist_squared = 0.0;
1004                    for k in 0..n_components {
1005                        let diff = embedding[[i, k]] - embedding[[j, k]];
1006                        dist_squared += diff * diff;
1007                    }
1008                    let q_ij = (1.0 + dist_squared / degrees_of_freedom)
1009                        .powf(-(degrees_of_freedom + 1.0) / 2.0);
1010                    q[[i, j]] = q_ij;
1011                }
1012            }
1013        }
1014
1015        sum_q = sum_q.max(MACHINE_EPSILON);
1016        q.mapv_inplace(|v| v / sum_q);
1017
1018        // Add attractive forces
1019        for i in 0..n_samples {
1020            for j in 0..n_samples {
1021                if i != j && p[[i, j]] > MACHINE_EPSILON {
1022                    let mut dist_squared = 0.0;
1023                    for k in 0..n_components {
1024                        let diff = embedding[[i, k]] - embedding[[j, k]];
1025                        dist_squared += diff * diff;
1026                    }
1027
1028                    let q_ij = (1.0 + dist_squared / degrees_of_freedom)
1029                        .powf(-(degrees_of_freedom + 1.0) / 2.0);
1030                    let attraction = 4.0 * p[[i, j]] * q_ij;
1031
1032                    for k in 0..n_components {
1033                        grad[[i, k]] -= attraction * (embedding[[i, k]] - embedding[[j, k]]);
1034                    }
1035                }
1036            }
1037        }
1038
1039        // Compute KL divergence
1040        let mut kl_divergence = 0.0;
1041        for i in 0..n_samples {
1042            for j in 0..n_samples {
1043                if p[[i, j]] > MACHINE_EPSILON && q[[i, j]] > MACHINE_EPSILON {
1044                    kl_divergence += p[[i, j]] * (p[[i, j]] / q[[i, j]]).ln();
1045                }
1046            }
1047        }
1048
1049        Ok((kl_divergence, grad))
1050    }
1051
1052    /// Update embedding using gradient descent with momentum and adaptive gains
1053    fn gradient_update(
1054        &self,
1055        embedding: &mut Array2<f64>,
1056        update: &mut Array2<f64>,
1057        gains: &mut Array2<f64>,
1058        grad: &Array2<f64>,
1059        momentum: f64,
1060        learning_rate: Option<f64>,
1061    ) -> Result<()> {
1062        let n_samples = embedding.shape()[0];
1063        let n_components = embedding.shape()[1];
1064        let eta = learning_rate.unwrap_or(self.learning_rate);
1065
1066        for i in 0..n_samples {
1067            for j in 0..n_components {
1068                let same_sign = update[[i, j]] * grad[[i, j]] > 0.0;
1069
1070                if same_sign {
1071                    gains[[i, j]] *= 0.8;
1072                } else {
1073                    gains[[i, j]] += 0.2;
1074                }
1075
1076                gains[[i, j]] = gains[[i, j]].max(0.01);
1077                update[[i, j]] = momentum * update[[i, j]] - eta * gains[[i, j]] * grad[[i, j]];
1078                embedding[[i, j]] += update[[i, j]];
1079            }
1080        }
1081
1082        Ok(())
1083    }
1084
1085    /// Returns the embedding after fitting
1086    pub fn embedding(&self) -> Option<&Array2<f64>> {
1087        self.embedding_.as_ref()
1088    }
1089
1090    /// Returns the KL divergence after optimization
1091    pub fn kl_divergence(&self) -> Option<f64> {
1092        self.kl_divergence_
1093    }
1094
1095    /// Returns the number of iterations run
1096    pub fn n_iter(&self) -> Option<usize> {
1097        self.n_iter_
1098    }
1099}
1100
1101/// Calculate trustworthiness score for a dimensionality reduction
1102///
1103/// Trustworthiness measures to what extent the local structure is retained when
1104/// projecting data from the original space to the embedding space.
1105///
1106/// A trustworthiness of 1.0 means all local neighborhoods are perfectly preserved.
1107///
1108/// # Arguments
1109/// * `x` - Original data, shape (n_samples, n_features)
1110/// * `x_embedded` - Embedded data, shape (n_samples, n_components)
1111/// * `n_neighbors` - Number of neighbors to consider
1112/// * `metric` - Metric to use (currently only 'euclidean' is supported)
1113///
1114/// # Returns
1115/// * `Result<f64>` - Trustworthiness score between 0.0 and 1.0
1116pub fn trustworthiness<S1, S2>(
1117    x: &ArrayBase<S1, Ix2>,
1118    x_embedded: &ArrayBase<S2, Ix2>,
1119    n_neighbors: usize,
1120    metric: &str,
1121) -> Result<f64>
1122where
1123    S1: Data,
1124    S2: Data,
1125    S1::Elem: Float + NumCast,
1126    S2::Elem: Float + NumCast,
1127{
1128    let x_f64 = x.mapv(|v| NumCast::from(v).unwrap_or(0.0));
1129    let x_embedded_f64 = x_embedded.mapv(|v| NumCast::from(v).unwrap_or(0.0));
1130
1131    let n_samples = x_f64.shape()[0];
1132
1133    if n_neighbors >= n_samples / 2 {
1134        return Err(TransformError::InvalidInput(format!(
1135            "n_neighbors ({}) should be less than n_samples / 2 ({})",
1136            n_neighbors,
1137            n_samples / 2
1138        )));
1139    }
1140
1141    if metric != "euclidean" {
1142        return Err(TransformError::InvalidInput(format!(
1143            "Metric '{metric}' not supported. Currently only 'euclidean' is implemented."
1144        )));
1145    }
1146
1147    // Compute pairwise distances in original space
1148    let mut dist_x = Array2::zeros((n_samples, n_samples));
1149    for i in 0..n_samples {
1150        for j in 0..n_samples {
1151            if i == j {
1152                dist_x[[i, j]] = f64::INFINITY;
1153                continue;
1154            }
1155            let mut d_squared = 0.0;
1156            for k in 0..x_f64.shape()[1] {
1157                let diff = x_f64[[i, k]] - x_f64[[j, k]];
1158                d_squared += diff * diff;
1159            }
1160            dist_x[[i, j]] = d_squared.sqrt();
1161        }
1162    }
1163
1164    // Compute pairwise distances in embedded space
1165    let mut dist_embedded = Array2::zeros((n_samples, n_samples));
1166    for i in 0..n_samples {
1167        for j in 0..n_samples {
1168            if i == j {
1169                dist_embedded[[i, j]] = f64::INFINITY;
1170                continue;
1171            }
1172            let mut d_squared = 0.0;
1173            for k in 0..x_embedded_f64.shape()[1] {
1174                let diff = x_embedded_f64[[i, k]] - x_embedded_f64[[j, k]];
1175                d_squared += diff * diff;
1176            }
1177            dist_embedded[[i, j]] = d_squared.sqrt();
1178        }
1179    }
1180
1181    // For each point, find the n_neighbors nearest neighbors in the original space
1182    let mut nn_orig = Array2::<usize>::zeros((n_samples, n_neighbors));
1183    for i in 0..n_samples {
1184        let row = dist_x.row(i).to_owned();
1185        let mut pairs: Vec<(usize, f64)> = row.iter().enumerate().map(|(j, &d)| (j, d)).collect();
1186        pairs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1187
1188        for (j, &(idx, _)) in pairs.iter().enumerate().take(n_neighbors) {
1189            nn_orig[[i, j]] = idx;
1190        }
1191    }
1192
1193    // For each point, find n_neighbors nearest neighbors in embedded space
1194    let mut nn_embedded = Array2::<usize>::zeros((n_samples, n_neighbors));
1195    for i in 0..n_samples {
1196        let row = dist_embedded.row(i).to_owned();
1197        let mut pairs: Vec<(usize, f64)> = row.iter().enumerate().map(|(j, &d)| (j, d)).collect();
1198        pairs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1199
1200        for (j, &(idx, _)) in pairs.iter().skip(1).take(n_neighbors).enumerate() {
1201            nn_embedded[[i, j]] = idx;
1202        }
1203    }
1204
1205    // Calculate the trustworthiness score
1206    let mut t = 0.0;
1207    for i in 0..n_samples {
1208        for &j in nn_embedded.row(i).iter() {
1209            let is_not_neighbor = !nn_orig.row(i).iter().any(|&nn| nn == j);
1210
1211            if is_not_neighbor {
1212                let row = dist_x.row(i).to_owned();
1213                let mut pairs: Vec<(usize, f64)> =
1214                    row.iter().enumerate().map(|(idx, &d)| (idx, d)).collect();
1215                pairs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1216
1217                let rank = pairs
1218                    .iter()
1219                    .position(|&(idx, _)| idx == j)
1220                    .unwrap_or(n_neighbors)
1221                    .saturating_sub(n_neighbors);
1222
1223                t += rank as f64;
1224            }
1225        }
1226    }
1227
1228    // Normalize
1229    let n = n_samples as f64;
1230    let k = n_neighbors as f64;
1231    let normalizer = 2.0 / (n * k * (2.0 * n - 3.0 * k - 1.0));
1232    let trustworthiness_val = 1.0 - normalizer * t;
1233
1234    Ok(trustworthiness_val)
1235}
1236
1237#[cfg(test)]
1238#[path = "../tsne_tests.rs"]
1239mod tests;