Skip to main content

scirs2_transform/reduction/tsne/
mod.rs

1//! t-SNE (t-distributed Stochastic Neighbor Embedding) implementation
2//!
3//! This module provides an implementation of t-SNE, a technique for dimensionality
4//! reduction particularly well-suited for visualization of high-dimensional data.
5//!
6//! t-SNE converts similarities between data points to joint probabilities and tries
7//! to minimize the Kullback-Leibler divergence between the joint probabilities of
8//! the low-dimensional embedding and the high-dimensional data.
9//!
10//! ## Features
11//!
12//! - **Barnes-Hut approximation** for O(N log N) complexity via spatial trees
13//! - **Perplexity-based bandwidth selection** using binary search for sigma
14//! - **Early exaggeration phase** for better global structure preservation
15//! - **Momentum-based gradient descent** with adaptive gains
16//! - **Multiple distance metrics**: euclidean, manhattan, cosine, chebyshev
17//! - **Sparse kNN affinity** for memory-efficient computation on large datasets
18//! - **Multicore support** via rayon parallel iterators
19
20mod spatial_tree;
21
22use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
23use scirs2_core::numeric::{Float, NumCast};
24use scirs2_core::parallel_ops::*;
25use scirs2_core::random::Normal;
26use scirs2_core::random::RandomExt;
27
28use crate::error::{Result, TransformError};
29use crate::reduction::PCA;
30
31use spatial_tree::SpatialTree;
32
33// Constants for numerical stability
34const MACHINE_EPSILON: f64 = 1e-14;
35const EPSILON: f64 = 1e-7;
36
37/// t-SNE (t-distributed Stochastic Neighbor Embedding) for dimensionality reduction
38///
39/// t-SNE is a nonlinear dimensionality reduction technique well-suited for
40/// embedding high-dimensional data for visualization in a low-dimensional space
41/// (typically 2D or 3D). It models each high-dimensional object by a two- or
42/// three-dimensional point in such a way that similar objects are modeled by
43/// nearby points and dissimilar objects are modeled by distant points with
44/// high probability.
45///
46/// # Example
47///
48/// ```rust,no_run
49/// use scirs2_transform::TSNE;
50/// use scirs2_core::ndarray::arr2;
51///
52/// let data = arr2(&[
53///     [0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0],
54///     [5.0, 5.0], [6.0, 5.0], [5.0, 6.0], [6.0, 6.0],
55/// ]);
56///
57/// let mut tsne = TSNE::new()
58///     .with_n_components(2)
59///     .with_perplexity(2.0)
60///     .with_max_iter(500);
61///
62/// let embedding = tsne.fit_transform(&data).expect("should succeed");
63/// assert_eq!(embedding.shape(), &[8, 2]);
64/// ```
65pub struct TSNE {
66    /// Number of components in the embedded space
67    n_components: usize,
68    /// Perplexity parameter that balances attention between local and global structure
69    perplexity: f64,
70    /// Weight of early exaggeration phase
71    early_exaggeration: f64,
72    /// Learning rate for optimization
73    learning_rate: f64,
74    /// Maximum number of iterations
75    max_iter: usize,
76    /// Maximum iterations without progress before early stopping
77    n_iter_without_progress: usize,
78    /// Minimum gradient norm for convergence
79    min_grad_norm: f64,
80    /// Method to compute pairwise distances
81    metric: String,
82    /// Method to perform dimensionality reduction ("exact" or "barnes_hut")
83    method: String,
84    /// Initialization method ("pca" or "random")
85    init: String,
86    /// Angle for Barnes-Hut approximation (trade-off between speed and accuracy)
87    angle: f64,
88    /// Whether to use multicore processing (-1 = all cores, 1 = single)
89    n_jobs: i32,
90    /// Verbosity level
91    verbose: bool,
92    /// Random state for reproducibility
93    random_state: Option<u64>,
94    /// Degrees of freedom for the t-distribution (default 1.0 = standard Cauchy)
95    degrees_of_freedom: Option<f64>,
96    /// The embedding vectors
97    embedding_: Option<Array2<f64>>,
98    /// KL divergence after optimization
99    kl_divergence_: Option<f64>,
100    /// Total number of iterations run
101    n_iter_: Option<usize>,
102    /// Effective learning rate used
103    learning_rate_: Option<f64>,
104}
105
106impl Default for TSNE {
107    fn default() -> Self {
108        Self::new()
109    }
110}
111
112impl TSNE {
113    /// Creates a new t-SNE instance with default parameters
114    pub fn new() -> Self {
115        TSNE {
116            n_components: 2,
117            perplexity: 30.0,
118            early_exaggeration: 12.0,
119            learning_rate: 200.0,
120            max_iter: 1000,
121            n_iter_without_progress: 300,
122            min_grad_norm: 1e-7,
123            metric: "euclidean".to_string(),
124            method: "barnes_hut".to_string(),
125            init: "pca".to_string(),
126            angle: 0.5,
127            n_jobs: -1,
128            verbose: false,
129            random_state: None,
130            degrees_of_freedom: None,
131            embedding_: None,
132            kl_divergence_: None,
133            n_iter_: None,
134            learning_rate_: None,
135        }
136    }
137
138    /// Sets the number of components in the embedded space
139    pub fn with_n_components(mut self, n_components: usize) -> Self {
140        self.n_components = n_components;
141        self
142    }
143
144    /// Sets the perplexity parameter
145    pub fn with_perplexity(mut self, perplexity: f64) -> Self {
146        self.perplexity = perplexity;
147        self
148    }
149
150    /// Sets the early exaggeration factor
151    pub fn with_early_exaggeration(mut self, early_exaggeration: f64) -> Self {
152        self.early_exaggeration = early_exaggeration;
153        self
154    }
155
156    /// Sets the learning rate for gradient descent
157    pub fn with_learning_rate(mut self, learning_rate: f64) -> Self {
158        self.learning_rate = learning_rate;
159        self
160    }
161
162    /// Sets the maximum number of iterations
163    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
164        self.max_iter = max_iter;
165        self
166    }
167
168    /// Sets the number of iterations without progress before early stopping
169    pub fn with_n_iter_without_progress(mut self, n_iter_without_progress: usize) -> Self {
170        self.n_iter_without_progress = n_iter_without_progress;
171        self
172    }
173
174    /// Sets the minimum gradient norm for convergence
175    pub fn with_min_grad_norm(mut self, min_grad_norm: f64) -> Self {
176        self.min_grad_norm = min_grad_norm;
177        self
178    }
179
180    /// Sets the metric for pairwise distance computation
181    ///
182    /// Supported metrics:
183    /// - "euclidean": Euclidean distance (L2 norm) - default
184    /// - "manhattan": Manhattan distance (L1 norm)
185    /// - "cosine": Cosine distance (1 - cosine similarity)
186    /// - "chebyshev": Chebyshev distance (maximum coordinate difference)
187    pub fn with_metric(mut self, metric: &str) -> Self {
188        self.metric = metric.to_string();
189        self
190    }
191
192    /// Sets the method for dimensionality reduction ("exact" or "barnes_hut")
193    pub fn with_method(mut self, method: &str) -> Self {
194        self.method = method.to_string();
195        self
196    }
197
198    /// Sets the initialization method ("pca" or "random")
199    pub fn with_init(mut self, init: &str) -> Self {
200        self.init = init.to_string();
201        self
202    }
203
204    /// Sets the angle for Barnes-Hut approximation (0.0 = exact, 1.0 = fast but approximate)
205    pub fn with_angle(mut self, angle: f64) -> Self {
206        self.angle = angle;
207        self
208    }
209
210    /// Sets the number of parallel jobs to run
211    /// * n_jobs = -1: Use all available cores
212    /// * n_jobs = 1: Use single-core (disable multicore)
213    /// * n_jobs > 1: Use specific number of cores
214    pub fn with_n_jobs(mut self, n_jobs: i32) -> Self {
215        self.n_jobs = n_jobs;
216        self
217    }
218
219    /// Sets the verbosity level
220    pub fn with_verbose(mut self, verbose: bool) -> Self {
221        self.verbose = verbose;
222        self
223    }
224
225    /// Sets the random state for reproducibility
226    pub fn with_random_state(mut self, random_state: u64) -> Self {
227        self.random_state = Some(random_state);
228        self
229    }
230
231    /// Sets the degrees of freedom for the Student-t distribution
232    ///
233    /// Default is n_components - 1 (or 1 if n_components <= 1).
234    /// Setting this to a larger value produces heavier tails, which can help
235    /// with crowding in higher-dimensional embeddings.
236    pub fn with_degrees_of_freedom(mut self, dof: f64) -> Self {
237        self.degrees_of_freedom = Some(dof);
238        self
239    }
240
241    /// Fit t-SNE to input data and transform it to the embedded space
242    ///
243    /// # Arguments
244    /// * `x` - Input data, shape (n_samples, n_features)
245    ///
246    /// # Returns
247    /// * `Result<Array2<f64>>` - Embedding of the training data, shape (n_samples, n_components)
248    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
249    where
250        S: Data,
251        S::Elem: Float + NumCast,
252    {
253        let x_f64 = x.mapv(|v| NumCast::from(v).unwrap_or(0.0));
254
255        let n_samples = x_f64.shape()[0];
256        let n_features = x_f64.shape()[1];
257
258        // Input validation
259        if n_samples == 0 || n_features == 0 {
260            return Err(TransformError::InvalidInput("Empty input data".to_string()));
261        }
262
263        if self.perplexity >= n_samples as f64 {
264            return Err(TransformError::InvalidInput(format!(
265                "perplexity ({}) must be less than n_samples ({})",
266                self.perplexity, n_samples
267            )));
268        }
269
270        if self.method == "barnes_hut" && self.n_components > 3 {
271            return Err(TransformError::InvalidInput(
272                "'n_components' should be <= 3 for barnes_hut algorithm".to_string(),
273            ));
274        }
275
276        self.learning_rate_ = Some(self.learning_rate);
277
278        // Initialize embedding
279        let x_embedded = self.initialize_embedding(&x_f64)?;
280
281        // Compute pairwise affinities (P)
282        let p = self.compute_pairwise_affinities(&x_f64)?;
283
284        // Run t-SNE optimization
285        let (embedding, kl_divergence, n_iter) =
286            self.tsne_optimization(p, x_embedded, n_samples)?;
287
288        self.embedding_ = Some(embedding.clone());
289        self.kl_divergence_ = Some(kl_divergence);
290        self.n_iter_ = Some(n_iter);
291
292        Ok(embedding)
293    }
294
295    /// Initialize embedding either with PCA or random
296    fn initialize_embedding(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
297        let n_samples = x.shape()[0];
298
299        if self.init == "pca" {
300            let n_components = self.n_components.min(x.shape()[1]);
301            let mut pca = PCA::new(n_components, true, false);
302            let mut x_embedded = pca.fit_transform(x)?;
303
304            // Scale PCA initialization
305            let col_var = x_embedded.column(0).map(|&v| v * v).sum() / (n_samples as f64);
306            let std_dev = col_var.sqrt();
307            if std_dev > 0.0 {
308                x_embedded.mapv_inplace(|v| v / std_dev * 1e-4);
309            }
310
311            Ok(x_embedded)
312        } else if self.init == "random" {
313            use scirs2_core::random::{thread_rng, Distribution};
314            let normal = Normal::new(0.0, 1e-4).map_err(|e| {
315                TransformError::ComputationError(format!(
316                    "Failed to create normal distribution: {e}"
317                ))
318            })?;
319            let mut rng = thread_rng();
320
321            let data: Vec<f64> = (0..(n_samples * self.n_components))
322                .map(|_| normal.sample(&mut rng))
323                .collect();
324            Array2::from_shape_vec((n_samples, self.n_components), data).map_err(|e| {
325                TransformError::ComputationError(format!("Failed to create embedding array: {e}"))
326            })
327        } else {
328            Err(TransformError::InvalidInput(format!(
329                "Initialization method '{}' not recognized. Use 'pca' or 'random'.",
330                self.init
331            )))
332        }
333    }
334
335    /// Compute pairwise affinities with perplexity-based normalization
336    fn compute_pairwise_affinities(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
337        // Compute pairwise distances
338        let distances = self.compute_pairwise_distances(x)?;
339
340        // Convert distances to affinities using binary search for sigma
341        let p = self.distances_to_affinities(&distances)?;
342
343        // Symmetrize and normalize the affinity matrix
344        let mut p_symmetric = &p + &p.t();
345
346        let p_sum = p_symmetric.sum();
347        if p_sum > 0.0 {
348            p_symmetric.mapv_inplace(|v| v.max(MACHINE_EPSILON) / p_sum);
349        }
350
351        Ok(p_symmetric)
352    }
353
354    /// Compute pairwise distances with optional multicore support
355    fn compute_pairwise_distances(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
356        let n_samples = x.shape()[0];
357        let n_features = x.shape()[1];
358        let mut distances = Array2::zeros((n_samples, n_samples));
359
360        match self.metric.as_str() {
361            "euclidean" => {
362                if self.n_jobs == 1 {
363                    for i in 0..n_samples {
364                        for j in i + 1..n_samples {
365                            let mut dist_squared = 0.0;
366                            for k in 0..n_features {
367                                let diff = x[[i, k]] - x[[j, k]];
368                                dist_squared += diff * diff;
369                            }
370                            distances[[i, j]] = dist_squared;
371                            distances[[j, i]] = dist_squared;
372                        }
373                    }
374                } else {
375                    let upper_triangle_indices: Vec<(usize, usize)> = (0..n_samples)
376                        .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
377                        .collect();
378
379                    let squared_distances: Vec<f64> = upper_triangle_indices
380                        .par_iter()
381                        .map(|&(i, j)| {
382                            let mut dist_squared = 0.0;
383                            for k in 0..n_features {
384                                let diff = x[[i, k]] - x[[j, k]];
385                                dist_squared += diff * diff;
386                            }
387                            dist_squared
388                        })
389                        .collect();
390
391                    for (idx, &(i, j)) in upper_triangle_indices.iter().enumerate() {
392                        distances[[i, j]] = squared_distances[idx];
393                        distances[[j, i]] = squared_distances[idx];
394                    }
395                }
396            }
397            "manhattan" => {
398                let compute_manhattan = |i: usize, j: usize| -> f64 {
399                    let mut dist = 0.0;
400                    for k in 0..n_features {
401                        dist += (x[[i, k]] - x[[j, k]]).abs();
402                    }
403                    dist
404                };
405
406                if self.n_jobs == 1 {
407                    for i in 0..n_samples {
408                        for j in i + 1..n_samples {
409                            let dist = compute_manhattan(i, j);
410                            distances[[i, j]] = dist;
411                            distances[[j, i]] = dist;
412                        }
413                    }
414                } else {
415                    let upper: Vec<(usize, usize)> = (0..n_samples)
416                        .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
417                        .collect();
418                    let dists: Vec<f64> = upper
419                        .par_iter()
420                        .map(|&(i, j)| {
421                            let mut dist = 0.0;
422                            for k in 0..n_features {
423                                dist += (x[[i, k]] - x[[j, k]]).abs();
424                            }
425                            dist
426                        })
427                        .collect();
428                    for (idx, &(i, j)) in upper.iter().enumerate() {
429                        distances[[i, j]] = dists[idx];
430                        distances[[j, i]] = dists[idx];
431                    }
432                }
433            }
434            "cosine" => {
435                let mut normalized_x = Array2::zeros((n_samples, n_features));
436                for i in 0..n_samples {
437                    let row = x.row(i);
438                    let norm = row.iter().map(|v| v * v).sum::<f64>().sqrt();
439                    if norm > EPSILON {
440                        for j in 0..n_features {
441                            normalized_x[[i, j]] = x[[i, j]] / norm;
442                        }
443                    }
444                }
445
446                let upper: Vec<(usize, usize)> = (0..n_samples)
447                    .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
448                    .collect();
449
450                let compute_fn = |i: usize, j: usize| -> f64 {
451                    let mut dot_product = 0.0;
452                    for k in 0..n_features {
453                        dot_product += normalized_x[[i, k]] * normalized_x[[j, k]];
454                    }
455                    1.0 - dot_product.clamp(-1.0, 1.0)
456                };
457
458                if self.n_jobs == 1 {
459                    for &(i, j) in &upper {
460                        let d = compute_fn(i, j);
461                        distances[[i, j]] = d;
462                        distances[[j, i]] = d;
463                    }
464                } else {
465                    let dists: Vec<f64> = upper
466                        .par_iter()
467                        .map(|&(i, j)| {
468                            let mut dp = 0.0;
469                            for k in 0..n_features {
470                                dp += normalized_x[[i, k]] * normalized_x[[j, k]];
471                            }
472                            1.0 - dp.clamp(-1.0, 1.0)
473                        })
474                        .collect();
475                    for (idx, &(i, j)) in upper.iter().enumerate() {
476                        distances[[i, j]] = dists[idx];
477                        distances[[j, i]] = dists[idx];
478                    }
479                }
480            }
481            "chebyshev" => {
482                let upper: Vec<(usize, usize)> = (0..n_samples)
483                    .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
484                    .collect();
485
486                if self.n_jobs == 1 {
487                    for &(i, j) in &upper {
488                        let mut max_dist = 0.0;
489                        for k in 0..n_features {
490                            let diff = (x[[i, k]] - x[[j, k]]).abs();
491                            max_dist = max_dist.max(diff);
492                        }
493                        distances[[i, j]] = max_dist;
494                        distances[[j, i]] = max_dist;
495                    }
496                } else {
497                    let dists: Vec<f64> = upper
498                        .par_iter()
499                        .map(|&(i, j)| {
500                            let mut max_dist = 0.0;
501                            for k in 0..n_features {
502                                let diff = (x[[i, k]] - x[[j, k]]).abs();
503                                max_dist = max_dist.max(diff);
504                            }
505                            max_dist
506                        })
507                        .collect();
508                    for (idx, &(i, j)) in upper.iter().enumerate() {
509                        distances[[i, j]] = dists[idx];
510                        distances[[j, i]] = dists[idx];
511                    }
512                }
513            }
514            _ => {
515                return Err(TransformError::InvalidInput(format!(
516                    "Metric '{}' not supported. Use: 'euclidean', 'manhattan', 'cosine', 'chebyshev'",
517                    self.metric
518                )));
519            }
520        }
521
522        Ok(distances)
523    }
524
525    /// Convert distances to affinities using perplexity-based normalization
526    fn distances_to_affinities(&self, distances: &Array2<f64>) -> Result<Array2<f64>> {
527        let n_samples = distances.shape()[0];
528        let target = (2.0f64).ln() * self.perplexity;
529
530        if self.n_jobs == 1 {
531            let mut p = Array2::zeros((n_samples, n_samples));
532            for i in 0..n_samples {
533                self.binary_search_sigma(i, distances, target, &mut p)?;
534            }
535            Ok(p)
536        } else {
537            let prob_rows: Vec<Vec<f64>> = (0..n_samples)
538                .into_par_iter()
539                .map(|i| {
540                    let distances_i: Vec<f64> = (0..n_samples).map(|j| distances[[i, j]]).collect();
541                    Self::binary_search_sigma_row(i, &distances_i, n_samples, target)
542                })
543                .collect();
544
545            let mut p = Array2::zeros((n_samples, n_samples));
546            for (i, row) in prob_rows.iter().enumerate() {
547                for (j, &val) in row.iter().enumerate() {
548                    p[[i, j]] = val;
549                }
550            }
551            Ok(p)
552        }
553    }
554
555    /// Binary search for the optimal sigma (bandwidth) for a single row
556    fn binary_search_sigma(
557        &self,
558        i: usize,
559        distances: &Array2<f64>,
560        target: f64,
561        p: &mut Array2<f64>,
562    ) -> Result<()> {
563        let n_samples = distances.shape()[0];
564        let mut beta_min = -f64::INFINITY;
565        let mut beta_max = f64::INFINITY;
566        let mut beta = 1.0;
567
568        for _ in 0..50 {
569            let mut sum_pi = 0.0;
570            let mut h = 0.0;
571
572            for j in 0..n_samples {
573                if i == j {
574                    p[[i, j]] = 0.0;
575                    continue;
576                }
577                let p_ij = (-beta * distances[[i, j]]).exp();
578                p[[i, j]] = p_ij;
579                sum_pi += p_ij;
580            }
581
582            if sum_pi > 0.0 {
583                for j in 0..n_samples {
584                    if i == j {
585                        continue;
586                    }
587                    p[[i, j]] /= sum_pi;
588                    if p[[i, j]] > MACHINE_EPSILON {
589                        h -= p[[i, j]] * p[[i, j]].ln();
590                    }
591                }
592            }
593
594            let h_diff = h - target;
595            if h_diff.abs() < EPSILON {
596                break;
597            }
598
599            if h_diff > 0.0 {
600                beta_min = beta;
601                beta = if beta_max == f64::INFINITY {
602                    beta * 2.0
603                } else {
604                    (beta + beta_max) / 2.0
605                };
606            } else {
607                beta_max = beta;
608                beta = if beta_min == -f64::INFINITY {
609                    beta / 2.0
610                } else {
611                    (beta + beta_min) / 2.0
612                };
613            }
614        }
615
616        Ok(())
617    }
618
619    /// Parallel-safe binary search for a single row (returns Vec)
620    fn binary_search_sigma_row(
621        i: usize,
622        distances_i: &[f64],
623        n_samples: usize,
624        target: f64,
625    ) -> Vec<f64> {
626        let mut beta_min = -f64::INFINITY;
627        let mut beta_max = f64::INFINITY;
628        let mut beta = 1.0;
629        let mut p_row = vec![0.0; n_samples];
630
631        for _ in 0..50 {
632            let mut sum_pi = 0.0;
633            let mut h = 0.0;
634
635            for j in 0..n_samples {
636                if i == j {
637                    p_row[j] = 0.0;
638                    continue;
639                }
640                let p_ij = (-beta * distances_i[j]).exp();
641                p_row[j] = p_ij;
642                sum_pi += p_ij;
643            }
644
645            if sum_pi > 0.0 {
646                for (j, prob) in p_row.iter_mut().enumerate().take(n_samples) {
647                    if i == j {
648                        continue;
649                    }
650                    *prob /= sum_pi;
651                    if *prob > MACHINE_EPSILON {
652                        h -= *prob * prob.ln();
653                    }
654                }
655            }
656
657            let h_diff = h - target;
658            if h_diff.abs() < EPSILON {
659                break;
660            }
661
662            if h_diff > 0.0 {
663                beta_min = beta;
664                beta = if beta_max == f64::INFINITY {
665                    beta * 2.0
666                } else {
667                    (beta + beta_max) / 2.0
668                };
669            } else {
670                beta_max = beta;
671                beta = if beta_min == -f64::INFINITY {
672                    beta / 2.0
673                } else {
674                    (beta + beta_min) / 2.0
675                };
676            }
677        }
678
679        p_row
680    }
681
682    /// Get the effective degrees of freedom for the t-distribution
683    fn effective_dof(&self) -> f64 {
684        if let Some(dof) = self.degrees_of_freedom {
685            dof
686        } else {
687            (self.n_components - 1).max(1) as f64
688        }
689    }
690
691    /// Main t-SNE optimization loop using gradient descent
692    fn tsne_optimization(
693        &self,
694        p: Array2<f64>,
695        initial_embedding: Array2<f64>,
696        n_samples: usize,
697    ) -> Result<(Array2<f64>, f64, usize)> {
698        let n_components = self.n_components;
699        let degrees_of_freedom = self.effective_dof();
700
701        let mut embedding = initial_embedding;
702        let mut update = Array2::zeros((n_samples, n_components));
703        let mut gains = Array2::ones((n_samples, n_components));
704        let mut error = f64::INFINITY;
705        let mut best_error = f64::INFINITY;
706        let mut best_iter = 0;
707        let mut iter = 0;
708
709        let exploration_n_iter = 250;
710        let n_iter_check = 50;
711
712        // Apply early exaggeration
713        let p_early = &p * self.early_exaggeration;
714
715        if self.verbose {
716            println!("[t-SNE] Starting optimization with early exaggeration phase...");
717        }
718
719        // Early exaggeration phase
720        for i in 0..exploration_n_iter {
721            let (curr_error, grad) = if self.method == "barnes_hut" {
722                self.compute_gradient_barnes_hut(&embedding, &p_early, degrees_of_freedom)?
723            } else {
724                self.compute_gradient_exact(&embedding, &p_early, degrees_of_freedom)?
725            };
726
727            self.gradient_update(
728                &mut embedding,
729                &mut update,
730                &mut gains,
731                &grad,
732                0.5,
733                self.learning_rate_,
734            )?;
735
736            if (i + 1) % n_iter_check == 0 {
737                if self.verbose {
738                    println!("[t-SNE] Iteration {}: error = {:.7}", i + 1, curr_error);
739                }
740
741                if curr_error < best_error {
742                    best_error = curr_error;
743                    best_iter = i;
744                } else if i - best_iter > self.n_iter_without_progress {
745                    if self.verbose {
746                        println!("[t-SNE] Early convergence at iteration {}", i + 1);
747                    }
748                    break;
749                }
750
751                let grad_norm = grad.mapv(|v| v * v).sum().sqrt();
752                if grad_norm < self.min_grad_norm {
753                    if self.verbose {
754                        println!(
755                            "[t-SNE] Gradient norm {} below threshold at iteration {}",
756                            grad_norm,
757                            i + 1
758                        );
759                    }
760                    break;
761                }
762            }
763
764            iter = i;
765        }
766
767        if self.verbose {
768            println!("[t-SNE] Completed early exaggeration, starting final optimization...");
769        }
770
771        // Final optimization phase without early exaggeration
772        for i in iter + 1..self.max_iter {
773            let (curr_error, grad) = if self.method == "barnes_hut" {
774                self.compute_gradient_barnes_hut(&embedding, &p, degrees_of_freedom)?
775            } else {
776                self.compute_gradient_exact(&embedding, &p, degrees_of_freedom)?
777            };
778            error = curr_error;
779
780            self.gradient_update(
781                &mut embedding,
782                &mut update,
783                &mut gains,
784                &grad,
785                0.8,
786                self.learning_rate_,
787            )?;
788
789            if (i + 1) % n_iter_check == 0 {
790                if self.verbose {
791                    println!("[t-SNE] Iteration {}: error = {:.7}", i + 1, curr_error);
792                }
793
794                if curr_error < best_error {
795                    best_error = curr_error;
796                    best_iter = i;
797                } else if i - best_iter > self.n_iter_without_progress {
798                    if self.verbose {
799                        println!("[t-SNE] Stopping optimization at iteration {}", i + 1);
800                    }
801                    break;
802                }
803
804                let grad_norm = grad.mapv(|v| v * v).sum().sqrt();
805                if grad_norm < self.min_grad_norm {
806                    if self.verbose {
807                        println!(
808                            "[t-SNE] Gradient norm {} below threshold at iteration {}",
809                            grad_norm,
810                            i + 1
811                        );
812                    }
813                    break;
814                }
815            }
816
817            iter = i;
818        }
819
820        if self.verbose {
821            println!(
822                "[t-SNE] Optimization finished after {} iterations with error {:.7}",
823                iter + 1,
824                error
825            );
826        }
827
828        Ok((embedding, error, iter + 1))
829    }
830
831    /// Compute gradient and error for exact t-SNE
832    fn compute_gradient_exact(
833        &self,
834        embedding: &Array2<f64>,
835        p: &Array2<f64>,
836        degrees_of_freedom: f64,
837    ) -> Result<(f64, Array2<f64>)> {
838        let n_samples = embedding.shape()[0];
839        let n_components = embedding.shape()[1];
840
841        // Compute Q matrix
842        let mut dist = Array2::zeros((n_samples, n_samples));
843        let upper: Vec<(usize, usize)> = (0..n_samples)
844            .flat_map(|i| ((i + 1)..n_samples).map(move |j| (i, j)))
845            .collect();
846
847        if self.n_jobs == 1 {
848            for &(i, j) in &upper {
849                let mut d_squared = 0.0;
850                for k in 0..n_components {
851                    let diff = embedding[[i, k]] - embedding[[j, k]];
852                    d_squared += diff * diff;
853                }
854                let q_ij =
855                    (1.0 + d_squared / degrees_of_freedom).powf(-(degrees_of_freedom + 1.0) / 2.0);
856                dist[[i, j]] = q_ij;
857                dist[[j, i]] = q_ij;
858            }
859        } else {
860            let q_values: Vec<f64> = upper
861                .par_iter()
862                .map(|&(i, j)| {
863                    let mut d_squared = 0.0;
864                    for k in 0..n_components {
865                        let diff = embedding[[i, k]] - embedding[[j, k]];
866                        d_squared += diff * diff;
867                    }
868                    (1.0 + d_squared / degrees_of_freedom).powf(-(degrees_of_freedom + 1.0) / 2.0)
869                })
870                .collect();
871
872            for (idx, &(i, j)) in upper.iter().enumerate() {
873                dist[[i, j]] = q_values[idx];
874                dist[[j, i]] = q_values[idx];
875            }
876        }
877
878        for i in 0..n_samples {
879            dist[[i, i]] = 0.0;
880        }
881
882        let sum_q = dist.sum().max(MACHINE_EPSILON);
883        let q = &dist / sum_q;
884
885        // Compute KL divergence
886        let kl_divergence: f64 = if self.n_jobs == 1 {
887            let mut kl = 0.0;
888            for i in 0..n_samples {
889                for j in 0..n_samples {
890                    if p[[i, j]] > MACHINE_EPSILON && q[[i, j]] > MACHINE_EPSILON {
891                        kl += p[[i, j]] * (p[[i, j]] / q[[i, j]]).ln();
892                    }
893                }
894            }
895            kl
896        } else {
897            (0..n_samples)
898                .into_par_iter()
899                .map(|i| {
900                    let mut local_kl = 0.0;
901                    for j in 0..n_samples {
902                        if p[[i, j]] > MACHINE_EPSILON && q[[i, j]] > MACHINE_EPSILON {
903                            local_kl += p[[i, j]] * (p[[i, j]] / q[[i, j]]).ln();
904                        }
905                    }
906                    local_kl
907                })
908                .sum()
909        };
910
911        // Compute gradient
912        let factor = 4.0 * (degrees_of_freedom + 1.0) / (degrees_of_freedom * sum_q * sum_q);
913
914        let grad = if self.n_jobs == 1 {
915            let mut g = Array2::zeros((n_samples, n_components));
916            for i in 0..n_samples {
917                for j in 0..n_samples {
918                    if i != j {
919                        let p_q_diff = p[[i, j]] - q[[i, j]];
920                        for k in 0..n_components {
921                            g[[i, k]] += factor
922                                * p_q_diff
923                                * dist[[i, j]]
924                                * (embedding[[i, k]] - embedding[[j, k]]);
925                        }
926                    }
927                }
928            }
929            g
930        } else {
931            let grad_rows: Vec<Vec<f64>> = (0..n_samples)
932                .into_par_iter()
933                .map(|i| {
934                    let mut grad_row = vec![0.0; n_components];
935                    for j in 0..n_samples {
936                        if i != j {
937                            let p_q_diff = p[[i, j]] - q[[i, j]];
938                            for k in 0..n_components {
939                                grad_row[k] += factor
940                                    * p_q_diff
941                                    * dist[[i, j]]
942                                    * (embedding[[i, k]] - embedding[[j, k]]);
943                            }
944                        }
945                    }
946                    grad_row
947                })
948                .collect();
949
950            let mut g = Array2::zeros((n_samples, n_components));
951            for (i, row) in grad_rows.iter().enumerate() {
952                for (k, &val) in row.iter().enumerate() {
953                    g[[i, k]] = val;
954                }
955            }
956            g
957        };
958
959        Ok((kl_divergence, grad))
960    }
961
962    /// Compute gradient and error using Barnes-Hut approximation
963    fn compute_gradient_barnes_hut(
964        &self,
965        embedding: &Array2<f64>,
966        p: &Array2<f64>,
967        degrees_of_freedom: f64,
968    ) -> Result<(f64, Array2<f64>)> {
969        let n_samples = embedding.shape()[0];
970        let n_components = embedding.shape()[1];
971
972        // Build spatial tree
973        let tree = if n_components == 2 {
974            SpatialTree::new_quadtree(embedding)?
975        } else if n_components == 3 {
976            SpatialTree::new_octree(embedding)?
977        } else {
978            return Err(TransformError::InvalidInput(
979                "Barnes-Hut only supports 2D and 3D embeddings".to_string(),
980            ));
981        };
982
983        let mut q = Array2::zeros((n_samples, n_samples));
984        let mut grad = Array2::zeros((n_samples, n_components));
985        let mut sum_q = 0.0;
986
987        // Compute repulsive forces using Barnes-Hut
988        for i in 0..n_samples {
989            let point = embedding.row(i).to_owned();
990            let (repulsive_force, q_sum) =
991                tree.compute_forces(&point, i, self.angle, degrees_of_freedom)?;
992
993            sum_q += q_sum;
994
995            for j in 0..n_components {
996                grad[[i, j]] += repulsive_force[j];
997            }
998
999            // Compute Q matrix entries for KL divergence
1000            for j in 0..n_samples {
1001                if i != j {
1002                    let mut dist_squared = 0.0;
1003                    for k in 0..n_components {
1004                        let diff = embedding[[i, k]] - embedding[[j, k]];
1005                        dist_squared += diff * diff;
1006                    }
1007                    let q_ij = (1.0 + dist_squared / degrees_of_freedom)
1008                        .powf(-(degrees_of_freedom + 1.0) / 2.0);
1009                    q[[i, j]] = q_ij;
1010                }
1011            }
1012        }
1013
1014        sum_q = sum_q.max(MACHINE_EPSILON);
1015        q.mapv_inplace(|v| v / sum_q);
1016
1017        // Add attractive forces
1018        for i in 0..n_samples {
1019            for j in 0..n_samples {
1020                if i != j && p[[i, j]] > MACHINE_EPSILON {
1021                    let mut dist_squared = 0.0;
1022                    for k in 0..n_components {
1023                        let diff = embedding[[i, k]] - embedding[[j, k]];
1024                        dist_squared += diff * diff;
1025                    }
1026
1027                    let q_ij = (1.0 + dist_squared / degrees_of_freedom)
1028                        .powf(-(degrees_of_freedom + 1.0) / 2.0);
1029                    let attraction = 4.0 * p[[i, j]] * q_ij;
1030
1031                    for k in 0..n_components {
1032                        grad[[i, k]] -= attraction * (embedding[[i, k]] - embedding[[j, k]]);
1033                    }
1034                }
1035            }
1036        }
1037
1038        // Compute KL divergence
1039        let mut kl_divergence = 0.0;
1040        for i in 0..n_samples {
1041            for j in 0..n_samples {
1042                if p[[i, j]] > MACHINE_EPSILON && q[[i, j]] > MACHINE_EPSILON {
1043                    kl_divergence += p[[i, j]] * (p[[i, j]] / q[[i, j]]).ln();
1044                }
1045            }
1046        }
1047
1048        Ok((kl_divergence, grad))
1049    }
1050
1051    /// Update embedding using gradient descent with momentum and adaptive gains
1052    fn gradient_update(
1053        &self,
1054        embedding: &mut Array2<f64>,
1055        update: &mut Array2<f64>,
1056        gains: &mut Array2<f64>,
1057        grad: &Array2<f64>,
1058        momentum: f64,
1059        learning_rate: Option<f64>,
1060    ) -> Result<()> {
1061        let n_samples = embedding.shape()[0];
1062        let n_components = embedding.shape()[1];
1063        let eta = learning_rate.unwrap_or(self.learning_rate);
1064
1065        for i in 0..n_samples {
1066            for j in 0..n_components {
1067                let same_sign = update[[i, j]] * grad[[i, j]] > 0.0;
1068
1069                if same_sign {
1070                    gains[[i, j]] *= 0.8;
1071                } else {
1072                    gains[[i, j]] += 0.2;
1073                }
1074
1075                gains[[i, j]] = gains[[i, j]].max(0.01);
1076                update[[i, j]] = momentum * update[[i, j]] - eta * gains[[i, j]] * grad[[i, j]];
1077                embedding[[i, j]] += update[[i, j]];
1078            }
1079        }
1080
1081        Ok(())
1082    }
1083
1084    /// Returns the embedding after fitting
1085    pub fn embedding(&self) -> Option<&Array2<f64>> {
1086        self.embedding_.as_ref()
1087    }
1088
1089    /// Returns the KL divergence after optimization
1090    pub fn kl_divergence(&self) -> Option<f64> {
1091        self.kl_divergence_
1092    }
1093
1094    /// Returns the number of iterations run
1095    pub fn n_iter(&self) -> Option<usize> {
1096        self.n_iter_
1097    }
1098}
1099
1100/// Calculate trustworthiness score for a dimensionality reduction
1101///
1102/// Trustworthiness measures to what extent the local structure is retained when
1103/// projecting data from the original space to the embedding space.
1104///
1105/// A trustworthiness of 1.0 means all local neighborhoods are perfectly preserved.
1106///
1107/// # Arguments
1108/// * `x` - Original data, shape (n_samples, n_features)
1109/// * `x_embedded` - Embedded data, shape (n_samples, n_components)
1110/// * `n_neighbors` - Number of neighbors to consider
1111/// * `metric` - Metric to use (currently only 'euclidean' is supported)
1112///
1113/// # Returns
1114/// * `Result<f64>` - Trustworthiness score between 0.0 and 1.0
1115pub fn trustworthiness<S1, S2>(
1116    x: &ArrayBase<S1, Ix2>,
1117    x_embedded: &ArrayBase<S2, Ix2>,
1118    n_neighbors: usize,
1119    metric: &str,
1120) -> Result<f64>
1121where
1122    S1: Data,
1123    S2: Data,
1124    S1::Elem: Float + NumCast,
1125    S2::Elem: Float + NumCast,
1126{
1127    let x_f64 = x.mapv(|v| NumCast::from(v).unwrap_or(0.0));
1128    let x_embedded_f64 = x_embedded.mapv(|v| NumCast::from(v).unwrap_or(0.0));
1129
1130    let n_samples = x_f64.shape()[0];
1131
1132    if n_neighbors >= n_samples / 2 {
1133        return Err(TransformError::InvalidInput(format!(
1134            "n_neighbors ({}) should be less than n_samples / 2 ({})",
1135            n_neighbors,
1136            n_samples / 2
1137        )));
1138    }
1139
1140    if metric != "euclidean" {
1141        return Err(TransformError::InvalidInput(format!(
1142            "Metric '{metric}' not supported. Currently only 'euclidean' is implemented."
1143        )));
1144    }
1145
1146    // Compute pairwise distances in original space
1147    let mut dist_x = Array2::zeros((n_samples, n_samples));
1148    for i in 0..n_samples {
1149        for j in 0..n_samples {
1150            if i == j {
1151                dist_x[[i, j]] = f64::INFINITY;
1152                continue;
1153            }
1154            let mut d_squared = 0.0;
1155            for k in 0..x_f64.shape()[1] {
1156                let diff = x_f64[[i, k]] - x_f64[[j, k]];
1157                d_squared += diff * diff;
1158            }
1159            dist_x[[i, j]] = d_squared.sqrt();
1160        }
1161    }
1162
1163    // Compute pairwise distances in embedded space
1164    let mut dist_embedded = Array2::zeros((n_samples, n_samples));
1165    for i in 0..n_samples {
1166        for j in 0..n_samples {
1167            if i == j {
1168                dist_embedded[[i, j]] = f64::INFINITY;
1169                continue;
1170            }
1171            let mut d_squared = 0.0;
1172            for k in 0..x_embedded_f64.shape()[1] {
1173                let diff = x_embedded_f64[[i, k]] - x_embedded_f64[[j, k]];
1174                d_squared += diff * diff;
1175            }
1176            dist_embedded[[i, j]] = d_squared.sqrt();
1177        }
1178    }
1179
1180    // For each point, find the n_neighbors nearest neighbors in the original space
1181    let mut nn_orig = Array2::<usize>::zeros((n_samples, n_neighbors));
1182    for i in 0..n_samples {
1183        let row = dist_x.row(i).to_owned();
1184        let mut pairs: Vec<(usize, f64)> = row.iter().enumerate().map(|(j, &d)| (j, d)).collect();
1185        pairs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1186
1187        for (j, &(idx, _)) in pairs.iter().enumerate().take(n_neighbors) {
1188            nn_orig[[i, j]] = idx;
1189        }
1190    }
1191
1192    // For each point, find n_neighbors nearest neighbors in embedded space
1193    let mut nn_embedded = Array2::<usize>::zeros((n_samples, n_neighbors));
1194    for i in 0..n_samples {
1195        let row = dist_embedded.row(i).to_owned();
1196        let mut pairs: Vec<(usize, f64)> = row.iter().enumerate().map(|(j, &d)| (j, d)).collect();
1197        pairs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1198
1199        for (j, &(idx, _)) in pairs.iter().skip(1).take(n_neighbors).enumerate() {
1200            nn_embedded[[i, j]] = idx;
1201        }
1202    }
1203
1204    // Calculate the trustworthiness score
1205    let mut t = 0.0;
1206    for i in 0..n_samples {
1207        for &j in nn_embedded.row(i).iter() {
1208            let is_not_neighbor = !nn_orig.row(i).iter().any(|&nn| nn == j);
1209
1210            if is_not_neighbor {
1211                let row = dist_x.row(i).to_owned();
1212                let mut pairs: Vec<(usize, f64)> =
1213                    row.iter().enumerate().map(|(idx, &d)| (idx, d)).collect();
1214                pairs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1215
1216                let rank = pairs
1217                    .iter()
1218                    .position(|&(idx, _)| idx == j)
1219                    .unwrap_or(n_neighbors)
1220                    .saturating_sub(n_neighbors);
1221
1222                t += rank as f64;
1223            }
1224        }
1225    }
1226
1227    // Normalize
1228    let n = n_samples as f64;
1229    let k = n_neighbors as f64;
1230    let normalizer = 2.0 / (n * k * (2.0 * n - 3.0 * k - 1.0));
1231    let trustworthiness_val = 1.0 - normalizer * t;
1232
1233    Ok(trustworthiness_val)
1234}
1235
1236#[cfg(test)]
1237#[path = "../tsne_tests.rs"]
1238mod tests;