Skip to main content

sphereql_embed/
umap.rs

1//! UMAP-on-sphere via Adam in the tangent bundle of S².
2//!
3//! Standard UMAP optimizes 2D embeddings in Euclidean space. Here every
4//! fitted point lives on the unit 2-sphere, so each Adam step happens in
5//! the local tangent space `T_x S² = { v : x·v = 0 }` and the iterate is
6//! retracted back to the sphere via normalization. PCA provides the warm
7//! start; the kNN graph supplies the attractive term and uniformly
8//! sampled negatives supply the repulsive term. Optional per-point
9//! categories add a third term that pulls same-category points together
10//! and pushes different-category points apart.
11//!
12//! Attractive edges carry the canonical fuzzy simplicial set weights:
13//! per-point local distance scaling (`rho_i` = nearest-neighbor distance,
14//! `sigma_i` solved so membership strengths sum to `log2 k`) followed by
15//! fuzzy-union symmetrization, so a dense cluster's 15th neighbor pulls
16//! hard while a sparse region's 15th neighbor barely pulls at all.
17//! The `epochs_per_sample` machinery is intentionally not implemented:
18//! negatives stay uniformly drawn at `negative_sample_rate` per
19//! attractive edge, and both the attractive gradient and the edge's
20//! negative-draw gradients are scaled by the edge weight — equivalent in
21//! expectation to canonical UMAP, where an edge fires (and draws its
22//! negatives) proportionally to its weight.
23//!
24//! `project()` on a fitted training embedding returns its exact
25//! optimized position (the Adam output, not an interpolation). For
26//! genuinely unseen embeddings it uses a kNN-weighted slerp-ish
27//! average over the fitted positions — UMAP itself is non-parametric, so
28//! transforms hand new points to their nearest fitted neighbors and
29//! interpolate on the sphere.
30
31use std::collections::HashMap;
32use std::sync::Arc;
33
34use sphereql_core::SphericalPoint;
35
36use crate::ann::{AnnConfig, AnnIndex};
37use crate::projection::{
38    Projection, ProjectionError, SplitMix64, dot, normalize_vec, project_xyz_to_spherical,
39};
40use crate::types::{Embedding, ProjectedPoint, RadialContext, RadialStrategy};
41
42/// Knobs for [`UmapSphereProjection::fit`]. All defaults match the
43/// canonical UMAP paper unless noted.
44#[derive(Debug, Clone)]
45pub struct UmapConfig {
46    /// Neighbors per point in the kNN graph that supplies the
47    /// attractive term. Higher = preserve global structure, lower =
48    /// preserve local clusters. UMAP default 15.
49    pub n_neighbors: usize,
50    /// Optimizer iterations. ~200 is enough for n<=2000; scale up
51    /// roughly logarithmically.
52    pub n_epochs: usize,
53    /// Adam base learning rate. Tangent-space gradients are bounded so
54    /// 0.05 is safe even for tiny corpora.
55    pub learning_rate: f64,
56    /// Negative samples drawn per attractive edge per epoch.
57    pub negative_sample_rate: usize,
58    /// Weight on the supervised category term (0.0 = disabled). When
59    /// active, every epoch samples for each point one same-category
60    /// partner (cohesion, attractive) and one different-category
61    /// partner (separation, repulsive) — stratified so the cohesion
62    /// half fires regardless of how many categories the corpus has.
63    /// Only meaningful when `categories` is supplied to `fit`.
64    pub category_weight: f64,
65    /// How tightly neighbors may pack in the layout — the canonical
66    /// UMAP knob. At fit time the kernel parameters `(a, b)` of
67    /// `Phi(d) = 1/(1 + a·d^(2b))` are least-squares fitted from it
68    /// (spread fixed at 1.0). On S², whose total area is fixed, this
69    /// directly sets how much territory a cluster occupies; 0.0
70    /// reproduces near-maximal clumping (close to the historical
71    /// hardcoded `a = b = 1`). UMAP default 0.1.
72    pub min_dist: f64,
73    /// Weight on an attractive pull toward each point's PCA
74    /// warm-start position (0.0 = disabled). Use small values
75    /// (~0.01–0.1) on sparse corpora whose kNN graphs fragment into
76    /// disconnected components: the warm start places the components
77    /// sensibly relative to each other, but with zero attraction
78    /// between them, uniform negative sampling drifts them into an
79    /// arbitrary arrangement over many epochs. The anchor keeps that
80    /// global arrangement from drifting under unopposed repulsion.
81    pub warm_start_anchor: f64,
82    /// PRNG seed for kNN tie-breaking, negative sampling, and
83    /// fallback random init when PCA warm-start is degenerate.
84    pub seed: u64,
85}
86
87impl Default for UmapConfig {
88    fn default() -> Self {
89        Self {
90            n_neighbors: 15,
91            n_epochs: 200,
92            learning_rate: 0.05,
93            negative_sample_rate: 5,
94            category_weight: 0.0,
95            min_dist: 0.1,
96            warm_start_anchor: 0.0,
97            seed: 0xA1B2_C3D4,
98        }
99    }
100}
101
102/// Precomputed kNN graph for UMAP. Cacheable across configs that
103/// share the same `n_neighbors` but differ in `n_epochs`,
104/// `category_weight`, or `min_dist` — everything here is computed
105/// before the optimizer, which is where those knobs act.
106#[derive(Clone)]
107pub struct UmapGraph {
108    /// kNN adjacency list: `knn[i]` = indices of k nearest neighbors of item i.
109    pub(crate) knn: Vec<Vec<usize>>,
110    /// Fuzzy simplicial set edge weights aligned with `knn`:
111    /// `weights[i][idx]` is the symmetrized (fuzzy-union) membership
112    /// strength of edge `i → knn[i][idx]`, in `[0, 1]`. Like `knn`, this
113    /// depends only on the data and `n_neighbors`, so the tuner's
114    /// per-`n_neighbors` graph cache contract is unchanged.
115    pub(crate) weights: Vec<Vec<f64>>,
116    /// L2-normalized embeddings used for graph construction.
117    /// Retained for the Adam optimizer's similarity lookups.
118    pub(crate) normalized: Vec<Vec<f64>>,
119    /// PCA warm-start positions on S² (unit vectors in ℝ³).
120    pub(crate) warm_start: Vec<[f64; 3]>,
121    /// Embedding dimensionality.
122    pub(crate) dim: usize,
123    /// Number of neighbors.
124    pub(crate) k: usize,
125    /// ANN index retained from kNN-graph construction (only built when
126    /// `n >= ANN_BRUTE_FORCE_THRESHOLD`). Depends solely on `normalized`
127    /// and the default [`AnnConfig`] — never on `n_neighbors` — so the
128    /// tuner's per-`n_neighbors` graph cache stays sound. Carried
129    /// through to the projection for transform-time neighbor queries.
130    pub(crate) ann: Option<Arc<AnnIndex>>,
131}
132
133impl UmapGraph {
134    /// Build the kNN graph and PCA warm-start from embeddings.
135    ///
136    /// This is the expensive part of UMAP fit — O(N·log N·d) for the
137    /// ANN-backed graph + O(N·d) for PCA warm-start. The result is
138    /// reusable across all UMAP configs that share `n_neighbors`.
139    pub fn build(embeddings: &[Embedding], n_neighbors: usize) -> Result<Self, ProjectionError> {
140        if embeddings.is_empty() {
141            return Err(ProjectionError::EmptyCorpus);
142        }
143        let dim = embeddings[0].dimension();
144        if dim < 3 {
145            return Err(ProjectionError::DimensionTooLow {
146                got: dim,
147                required: 3,
148            });
149        }
150        for (i, e) in embeddings.iter().enumerate() {
151            if e.dimension() != dim {
152                return Err(ProjectionError::InconsistentDimension {
153                    index: i,
154                    expected: dim,
155                    got: e.dimension(),
156                });
157            }
158        }
159        let n = embeddings.len();
160        if n < 4 {
161            return Err(ProjectionError::TooFewEmbeddings {
162                got: n,
163                required: 4,
164            });
165        }
166
167        let normalized: Vec<Vec<f64>> = embeddings.iter().map(|e| e.normalized()).collect();
168        let k = n_neighbors.min(n - 1).max(1);
169        let (knn, dists, ann) = build_knn_graph(&normalized, k);
170        let weights = fuzzy_simplicial_weights(&knn, &dists);
171        let warm_start = pca_warm_start(embeddings, &normalized)?;
172
173        Ok(Self {
174            knn,
175            weights,
176            normalized,
177            warm_start,
178            dim,
179            k,
180            ann,
181        })
182    }
183}
184
185/// UMAP-style projection that lives on S² and transforms new points by
186/// kNN-weighted averaging over the fitted positions.
187#[derive(Clone)]
188pub struct UmapSphereProjection {
189    /// Unit vectors in ℝ³ — one per fitted embedding.
190    fitted_points: Vec<[f64; 3]>,
191    /// L2-normalized copies of the original embeddings, kept for
192    /// kNN lookup at transform time (UMAP is non-parametric).
193    fitted_normalized: Vec<Vec<f64>>,
194    /// Exact-match lookup: hash of the normalized embedding's bit
195    /// pattern → fitted indices with that hash (verified by full vector
196    /// comparison on hit, so hash collisions are safe). Projecting a
197    /// training embedding returns its exact fitted position with
198    /// certainty 1.0 — the optimizer computed that position, so it is
199    /// known, not interpolated. Exact duplicates map to the first
200    /// fitted index: their post-optimization positions may differ
201    /// slightly, and first-index keeps the choice deterministic.
202    exact_lookup: HashMap<u64, Vec<usize>>,
203    /// ANN index over `fitted_normalized` for transform-time neighbor
204    /// queries when the corpus is at or above
205    /// `ANN_BRUTE_FORCE_THRESHOLD`; `None` means brute force.
206    ann: Option<Arc<AnnIndex>>,
207    dim: usize,
208    radial: RadialStrategy,
209    n_neighbors: usize,
210    /// Post-fit quality in `[0, 1]`: trustworthiness-style kNN recall —
211    /// mean overlap between each point's neighborhood among the fitted
212    /// 3D positions and its original-space kNN set. 1.0 means the
213    /// sphere preserves every original neighborhood.
214    quality: f64,
215}
216
217impl UmapSphereProjection {
218    /// Fit with default config and no categories.
219    pub fn fit_default(embeddings: &[Embedding]) -> Result<Self, ProjectionError> {
220        Self::fit(
221            embeddings,
222            None,
223            RadialStrategy::default(),
224            UmapConfig::default(),
225        )
226    }
227
228    /// Optimize from a prebuilt kNN graph. This is the cheap part of UMAP
229    /// fit — O(N·k·epochs) for the Adam optimizer. The graph is not rebuilt.
230    ///
231    /// Use this when the tuner has already built the graph via
232    /// [`UmapGraph::build`] and is sweeping `n_epochs` / `category_weight`.
233    pub fn fit_from_graph(
234        graph: &UmapGraph,
235        categories: Option<&[u32]>,
236        radial: RadialStrategy,
237        config: UmapConfig,
238    ) -> Result<Self, ProjectionError> {
239        let n = graph.normalized.len();
240
241        if let Some(cats) = categories
242            && cats.len() != n
243        {
244            return Err(ProjectionError::SliceLengthMismatch {
245                expected: n,
246                got: cats.len(),
247            });
248        }
249
250        let mut points = graph.warm_start.clone();
251        let mut rng = SplitMix64::new(config.seed);
252        // Kernel parameters from min_dist — fitted once per fit, then
253        // closed over by every gradient call below.
254        let (ka, kb) = find_ab_params(config.min_dist);
255        // NaN category_weight compares false, exactly like the old
256        // partial_cmp().map(is_gt).unwrap_or(false) chain.
257        let cat_active = config.category_weight > 0.0 && categories.is_some();
258        // Same NaN-safe comparison for the warm-start anchor.
259        let anchor_active = config.warm_start_anchor > 0.0;
260
261        // Per-category index buckets for stratified sampling in the
262        // supervised term. Category ids may be sparse, so each id is
263        // compacted to a dense bucket index up front; `bucket_of[i]`
264        // is point i's bucket.
265        let cat_buckets: Option<(Vec<Vec<usize>>, Vec<usize>)> = cat_active.then(|| {
266            // cat_active is only true when categories.is_some().
267            let cats = categories.unwrap();
268            let mut id_to_bucket: HashMap<u32, usize> = HashMap::new();
269            let mut buckets: Vec<Vec<usize>> = Vec::new();
270            let mut bucket_of = Vec::with_capacity(n);
271            for (i, &c) in cats.iter().enumerate() {
272                let b = *id_to_bucket.entry(c).or_insert_with(|| {
273                    buckets.push(Vec::new());
274                    buckets.len() - 1
275                });
276                buckets[b].push(i);
277                bucket_of.push(b);
278            }
279            (buckets, bucket_of)
280        });
281
282        // Effective weight per directed edge. The loop below walks
283        // directed edges, so a mutual kNN pair — present in both
284        // adjacency rows with the same symmetrized weight — would fire
285        // twice; halving those edges keeps every pair's total
286        // contribution at w_sym.
287        let attract: Vec<Vec<f64>> = graph
288            .knn
289            .iter()
290            .enumerate()
291            .map(|(i, neighbors)| {
292                neighbors
293                    .iter()
294                    .zip(&graph.weights[i])
295                    .map(|(&j, &w)| {
296                        if graph.knn[j].contains(&i) {
297                            0.5 * w
298                        } else {
299                            w
300                        }
301                    })
302                    .collect()
303            })
304            .collect();
305
306        // Adam state, three components per point.
307        let mut m = vec![[0.0f64; 3]; n];
308        let mut v = vec![[0.0f64; 3]; n];
309        let beta1 = 0.9;
310        let beta2 = 0.999;
311        let eps = 1e-8;
312
313        for epoch in 1..=config.n_epochs {
314            // Anneal the learning rate UMAP-style: linear decay to ~0.
315            let lr = config.learning_rate * (1.0 - (epoch as f64 / config.n_epochs as f64));
316            let mut grads = vec![[0.0f64; 3]; n];
317
318            // Attractive + repulsive in one pass: each kNN edge
319            // contributes its own attractive force AND draws
320            // `negative_sample_rate` repulsive samples for the source
321            // endpoint (UMAP's standard per-edge negative sampling, not
322            // per-point). Both gradients carry the edge's fuzzy weight:
323            // canonical UMAP fires an edge — and that edge's negative
324            // draws — proportionally to its weight, so scaling only the
325            // attraction would tilt the global balance toward repulsion
326            // by ~1/mean(w) (measured: kNN recall halved on the
327            // 775-concept benchmark corpus).
328            for (i, neighbors) in graph.knn.iter().enumerate() {
329                for (idx, &j) in neighbors.iter().enumerate() {
330                    let w = attract[i][idx];
331                    let (gi, gj) = attractive_grad(&points[i], &points[j], ka, kb);
332                    add3_scaled(&mut grads[i], &gi, w);
333                    add3_scaled(&mut grads[j], &gj, w);
334
335                    for _ in 0..config.negative_sample_rate {
336                        let nj = (rng.next_u64() as usize) % n;
337                        if nj == i {
338                            continue;
339                        }
340                        let (gi_r, gj_r) = repulsive_grad(&points[i], &points[nj], ka, kb);
341                        add3_scaled(&mut grads[i], &gi_r, w);
342                        add3_scaled(&mut grads[nj], &gj_r, w);
343                    }
344                }
345            }
346
347            // Optional category term, stratified per point: one
348            // same-category partner (cohesion) and one
349            // different-category partner (separation) per epoch. A
350            // uniform partner draw would be ~(C-1)/C repulsion at C
351            // categories, starving the cohesion half exactly where
352            // territorial scores need it.
353            if let Some((buckets, bucket_of)) = &cat_buckets {
354                let w = config.category_weight;
355                for i in 0..n {
356                    let bucket = &buckets[bucket_of[i]];
357                    if bucket.len() > 1 {
358                        // Uniform over the bucket minus i: draw from the
359                        // first len-1 slots and remap a self-draw to the
360                        // last slot.
361                        let idx = (rng.next_u64() as usize) % (bucket.len() - 1);
362                        let j = if bucket[idx] == i {
363                            bucket[bucket.len() - 1]
364                        } else {
365                            bucket[idx]
366                        };
367                        let (gi, gj) = attractive_grad(&points[i], &points[j], ka, kb);
368                        add3_scaled(&mut grads[i], &gi, w);
369                        add3_scaled(&mut grads[j], &gj, w);
370                    }
371                    // Rejection-sample the cross-category partner with
372                    // a bounded retry count so a (near-)single-category
373                    // corpus can't spin forever — on exhaustion, skip.
374                    for _ in 0..MAX_CROSS_CATEGORY_DRAWS {
375                        let j = (rng.next_u64() as usize) % n;
376                        if bucket_of[j] != bucket_of[i] {
377                            let (gi, gj) = repulsive_grad(&points[i], &points[j], ka, kb);
378                            add3_scaled(&mut grads[i], &gi, w);
379                            add3_scaled(&mut grads[j], &gj, w);
380                            break;
381                        }
382                    }
383                }
384            }
385
386            // Optional warm-start anchor: a weak pull from each point
387            // toward its PCA warm-start position. The anchor end is
388            // fixed — `points` starts as a clone of `warm_start` and
389            // then moves, so the pull reads `graph.warm_start` and
390            // only `grads[i]` receives a gradient (no equal-and-
391            // opposite term). Consumes no RNG, so toggling it never
392            // shifts the negative-sampling stream.
393            if anchor_active {
394                let w = config.warm_start_anchor;
395                for i in 0..n {
396                    let (gi, _) = attractive_grad(&points[i], &graph.warm_start[i], ka, kb);
397                    add3_scaled(&mut grads[i], &gi, w);
398                }
399            }
400
401            // Adam step in tangent space, retract to S².
402            for i in 0..n {
403                let g_tan = project_to_tangent(&points[i], &grads[i]);
404                for d in 0..3 {
405                    m[i][d] = beta1 * m[i][d] + (1.0 - beta1) * g_tan[d];
406                    v[i][d] = beta2 * v[i][d] + (1.0 - beta2) * g_tan[d] * g_tan[d];
407                }
408                let t = epoch as f64;
409                let bc1 = 1.0 - beta1.powf(t);
410                let bc2 = 1.0 - beta2.powf(t);
411                let mut step = [0.0f64; 3];
412                for d in 0..3 {
413                    let m_hat = m[i][d] / bc1;
414                    let v_hat = v[i][d] / bc2;
415                    step[d] = lr * m_hat / (v_hat.sqrt() + eps);
416                }
417                // Retraction: x_new = (x - step) / |x - step|.
418                // Sign: gradient descent ⇒ subtract.
419                let mut next = [
420                    points[i][0] - step[0],
421                    points[i][1] - step[1],
422                    points[i][2] - step[2],
423                ];
424                let mag = (next[0] * next[0] + next[1] * next[1] + next[2] * next[2]).sqrt();
425                if mag > f64::EPSILON {
426                    next[0] /= mag;
427                    next[1] /= mag;
428                    next[2] /= mag;
429                    points[i] = next;
430                }
431            }
432        }
433
434        let quality = knn_recall_score(&points, &graph.knn);
435
436        let mut exact_lookup: HashMap<u64, Vec<usize>> = HashMap::new();
437        for (i, vec) in graph.normalized.iter().enumerate() {
438            let bucket = exact_lookup.entry(hash_normalized(vec)).or_default();
439            if !bucket.iter().any(|&j| graph.normalized[j] == *vec) {
440                bucket.push(i);
441            }
442        }
443
444        Ok(Self {
445            fitted_points: points,
446            fitted_normalized: graph.normalized.clone(),
447            exact_lookup,
448            ann: graph.ann.clone(),
449            dim: graph.dim,
450            radial,
451            n_neighbors: graph.k,
452            quality,
453        })
454    }
455
456    /// Fit with custom config. `categories` is parallel to `embeddings`
457    /// when supplied; pass `None` to disable the supervised term even
458    /// if `config.category_weight > 0`.
459    ///
460    /// Equivalent to [`UmapGraph::build`] followed by
461    /// [`Self::fit_from_graph`] — the tuner calls those two halves
462    /// directly so it can reuse graphs across configs that share
463    /// `n_neighbors`; this entry point serves every other caller.
464    pub fn fit(
465        embeddings: &[Embedding],
466        categories: Option<&[u32]>,
467        radial: RadialStrategy,
468        config: UmapConfig,
469    ) -> Result<Self, ProjectionError> {
470        let graph = UmapGraph::build(embeddings, config.n_neighbors)?;
471        Self::fit_from_graph(&graph, categories, radial, config)
472    }
473
474    /// Post-fit quality: trustworthiness-style kNN recall. For each
475    /// point, its `k` nearest neighbors among the fitted 3D positions
476    /// (`k` = the graph's `n_neighbors`) are intersected with its
477    /// original-space kNN set; the score is the mean overlap fraction.
478    /// Bounded `[0, 1]`, where 1.0 means every original neighborhood
479    /// survives the projection.
480    ///
481    /// Intentionally exposed under the EVR name so the auto-tuner's
482    /// `MetaModel` consumers can compare projection kinds on one
483    /// scalar. Note the semantics changed: this used to be the fraction
484    /// of kNN edges shorter than the median random pairwise distance, a
485    /// bar that random spherical pairs (≈90° apart) made trivially
486    /// clearable, so scores saturated high and barely discriminated.
487    /// Recall is rank-meaningful across corpora and more honest when
488    /// compared against other projection kinds.
489    pub fn explained_variance_ratio(&self) -> f64 {
490        self.quality
491    }
492
493    /// Locate the `n_neighbors` fitted points closest to `embedding`
494    /// (cosine similarity in the original space) and return their
495    /// indices with similarity weights.
496    fn nearest_fitted(&self, normalized: &[f64]) -> Vec<(usize, f64)> {
497        if let Some(ann) = &self.ann {
498            return ann.query(normalized, self.n_neighbors);
499        }
500        let mut sims: Vec<(usize, f64)> = self
501            .fitted_normalized
502            .iter()
503            .enumerate()
504            .map(|(i, v)| (i, dot(normalized, v)))
505            .collect();
506        sims.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
507        sims.truncate(self.n_neighbors);
508        sims
509    }
510
511    fn exact_fitted(&self, normalized: &[f64]) -> Option<usize> {
512        self.exact_lookup
513            .get(&hash_normalized(normalized))?
514            .iter()
515            .copied()
516            .find(|&i| self.fitted_normalized[i] == normalized)
517    }
518
519    fn project_xyz(&self, embedding: &Embedding) -> ([f64; 3], f64) {
520        let normalized = embedding.normalized();
521        if let Some(idx) = self.exact_fitted(&normalized) {
522            return (self.fitted_points[idx], 1.0);
523        }
524        let neighbors = self.nearest_fitted(&normalized);
525
526        // Softmax over similarities to get a stable weighted average.
527        let max_sim = neighbors
528            .iter()
529            .map(|(_, s)| *s)
530            .fold(f64::NEG_INFINITY, f64::max);
531        let mut weights: Vec<f64> = neighbors
532            .iter()
533            .map(|(_, s)| ((s - max_sim) * 8.0).exp())
534            .collect();
535        let total: f64 = weights.iter().sum();
536        if total > f64::EPSILON {
537            for w in &mut weights {
538                *w /= total;
539            }
540        } else {
541            let n = weights.len() as f64;
542            for w in &mut weights {
543                *w = 1.0 / n;
544            }
545        }
546
547        let mut acc = [0.0f64; 3];
548        for ((idx, _), w) in neighbors.iter().zip(weights.iter()) {
549            let p = self.fitted_points[*idx];
550            acc[0] += w * p[0];
551            acc[1] += w * p[1];
552            acc[2] += w * p[2];
553        }
554        let mag = (acc[0] * acc[0] + acc[1] * acc[1] + acc[2] * acc[2]).sqrt();
555        let certainty = mag.clamp(0.0, 1.0);
556        (acc, certainty)
557    }
558}
559
560impl Projection for UmapSphereProjection {
561    fn project(&self, embedding: &Embedding) -> SphericalPoint {
562        // Caller contract: dimension must match the fitted projection.
563        assert_eq!(
564            embedding.dimension(),
565            self.dim,
566            "expected dimension {}, got {}",
567            self.dim,
568            embedding.dimension()
569        );
570        let (xyz, certainty) = self.project_xyz(embedding);
571        let projection_magnitude = (xyz[0] * xyz[0] + xyz[1] * xyz[1] + xyz[2] * xyz[2]).sqrt();
572        let intensity = embedding.magnitude();
573        let r = self.radial.compute_rich(&RadialContext::full(
574            intensity,
575            projection_magnitude,
576            certainty,
577        ));
578        project_xyz_to_spherical(xyz[0], xyz[1], xyz[2], r)
579    }
580
581    fn project_rich(&self, embedding: &Embedding) -> ProjectedPoint {
582        // Caller contract: dimension must match the fitted projection.
583        assert_eq!(
584            embedding.dimension(),
585            self.dim,
586            "expected dimension {}, got {}",
587            self.dim,
588            embedding.dimension()
589        );
590        let (xyz, certainty) = self.project_xyz(embedding);
591        let projection_magnitude = (xyz[0] * xyz[0] + xyz[1] * xyz[1] + xyz[2] * xyz[2]).sqrt();
592        let intensity = embedding.magnitude();
593        let r = self.radial.compute_rich(&RadialContext::full(
594            intensity,
595            projection_magnitude,
596            certainty,
597        ));
598        let position = project_xyz_to_spherical(xyz[0], xyz[1], xyz[2], r);
599        ProjectedPoint::new(position, certainty, intensity, projection_magnitude)
600    }
601
602    fn dimensionality(&self) -> usize {
603        self.dim
604    }
605}
606
607// ── Gradients ───────────────────────────────────────────────────────
608//
609// Loss decomposition mirrors the standard UMAP form with the rational
610// kernel Phi(d) = 1/(1 + a·d^(2b)), evaluated in ℝ³ on the embedded
611// points and projected to the tangent at step time. The closed-form
612// gradients below are the Euclidean gradients; the caller projects
613// them to the tangent before stepping. (a, b) come from
614// `find_ab_params(min_dist)`; at a = b = 1 both reduce exactly to the
615// historical hardcoded forms.
616
617fn attractive_grad(xi: &[f64; 3], xj: &[f64; 3], a: f64, b: f64) -> ([f64; 3], [f64; 3]) {
618    // L_attr = log(1 + a·d^(2b)) where d = |xi - xj|.
619    // ∂L/∂xi = 2ab·d^(2b-2)·(xi - xj) / (1 + a·d^(2b)); ∂L/∂xj = -∂L/∂xi.
620    // At a = b = 1: 2(xi - xj) / (1 + d²).
621    let dx = [xi[0] - xj[0], xi[1] - xj[1], xi[2] - xj[2]];
622    let d2 = dx[0] * dx[0] + dx[1] * dx[1] + dx[2] * dx[2];
623    // d^(2b-2) = (d²)^(b-1), floored so b < 1 stays finite at d → 0.
624    // The floor is a no-op at b = 1 (exponent 0).
625    let coef = 2.0 * a * b * d2.max(1e-6).powf(b - 1.0) / (1.0 + a * d2.powf(b));
626    let g = [coef * dx[0], coef * dx[1], coef * dx[2]];
627    (g, [-g[0], -g[1], -g[2]])
628}
629
630fn repulsive_grad(xi: &[f64; 3], xj: &[f64; 3], a: f64, b: f64) -> ([f64; 3], [f64; 3]) {
631    // L_rep = -log(1 - Phi(d)) where Phi(d) = 1/(1 + a·d^(2b)).
632    // ∂L/∂xi = -2b·(xi - xj) / (d²·(1 + a·d^(2b))); pushes them apart.
633    // At a = b = 1: -2(xi - xj) / (d²(1 + d²)).
634    let dx = [xi[0] - xj[0], xi[1] - xj[1], xi[2] - xj[2]];
635    let d2 = (dx[0] * dx[0] + dx[1] * dx[1] + dx[2] * dx[2]).max(1e-6);
636    let coef = -2.0 * b / (d2 * (1.0 + a * d2.powf(b)));
637    let g = [coef * dx[0], coef * dx[1], coef * dx[2]];
638    (g, [-g[0], -g[1], -g[2]])
639}
640
641/// Fit the kernel parameters `(a, b)` of `Phi(d) = 1/(1 + a·d^(2b))`
642/// to the target curve `f(d) = 1 if d <= min_dist, exp(-(d - min_dist))
643/// otherwise` — canonical UMAP's `find_ab_params` with spread fixed at
644/// 1.0, except deterministic: least squares over 300 samples of
645/// `d ∈ [0, 3]`, minimized by a coarse grid over `log₁₀ a ∈ [-3, 1] ×
646/// b ∈ [0.1, 2.5]` followed by grid-shrink refinement around the
647/// incumbent. At `min_dist = 0.1` this lands near the canonical
648/// `(a, b) ≈ (1.577, 0.895)`.
649fn find_ab_params(min_dist: f64) -> (f64, f64) {
650    const SAMPLES: usize = 300;
651    const D_MAX: f64 = 3.0;
652    let targets: Vec<(f64, f64)> = (0..SAMPLES)
653        .map(|i| {
654            let d = D_MAX * i as f64 / (SAMPLES - 1) as f64;
655            let f = if d <= min_dist {
656                1.0
657            } else {
658                (-(d - min_dist)).exp()
659            };
660            (d, f)
661        })
662        .collect();
663    let sse = |a: f64, b: f64| -> f64 {
664        targets
665            .iter()
666            .map(|&(d, f)| {
667                let phi = 1.0 / (1.0 + a * d.powf(2.0 * b));
668                (phi - f) * (phi - f)
669            })
670            .sum()
671    };
672
673    const LA_MIN: f64 = -3.0;
674    const LA_MAX: f64 = 1.0;
675    const B_MIN: f64 = 0.1;
676    const B_MAX: f64 = 2.5;
677    let mut la_lo = LA_MIN;
678    let mut la_hi = LA_MAX;
679    let mut b_lo = B_MIN;
680    let mut b_hi = B_MAX;
681    let mut best = (1.0f64, 1.0f64);
682    let mut best_err = f64::INFINITY;
683    for round in 0..6 {
684        let steps = if round == 0 { 24 } else { 8 };
685        for i in 0..=steps {
686            let la = la_lo + (la_hi - la_lo) * i as f64 / steps as f64;
687            let a = 10f64.powf(la);
688            for j in 0..=steps {
689                let b = b_lo + (b_hi - b_lo) * j as f64 / steps as f64;
690                let err = sse(a, b);
691                if err < best_err {
692                    best_err = err;
693                    best = (a, b);
694                }
695            }
696        }
697        // Halve each window around the incumbent, clamped to the
698        // original bounds.
699        let la_half = (la_hi - la_lo) / 4.0;
700        let la_best = best.0.log10();
701        la_lo = (la_best - la_half).max(LA_MIN);
702        la_hi = (la_best + la_half).min(LA_MAX);
703        let b_half = (b_hi - b_lo) / 4.0;
704        b_lo = (best.1 - b_half).max(B_MIN);
705        b_hi = (best.1 + b_half).min(B_MAX);
706    }
707    best
708}
709
710fn project_to_tangent(x: &[f64; 3], g: &[f64; 3]) -> [f64; 3] {
711    // T_x S² = { v : x·v = 0 }; project g by removing the radial part.
712    let radial = x[0] * g[0] + x[1] * g[1] + x[2] * g[2];
713    [
714        g[0] - radial * x[0],
715        g[1] - radial * x[1],
716        g[2] - radial * x[2],
717    ]
718}
719
720fn add3_scaled(a: &mut [f64; 3], b: &[f64; 3], s: f64) {
721    a[0] += s * b[0];
722    a[1] += s * b[1];
723    a[2] += s * b[2];
724}
725
726// ── Helpers ────────────────────────────────────────────────────────
727
728/// Retry bound for the cross-category rejection sampler in the
729/// supervised term. With C ≥ 2 roughly balanced categories the miss
730/// probability per draw is ≤ 1/2, so eight draws fail with probability
731/// ≤ 1/256; only a near-single-category corpus exhausts this.
732const MAX_CROSS_CATEGORY_DRAWS: usize = 8;
733
734/// Corpus size at which the ANN index amortizes its build cost.
735/// Below this, brute-force is faster and gives exact answers; above it,
736/// the all-pairs O(N²) cost dominates.
737const ANN_BRUTE_FORCE_THRESHOLD: usize = 2000;
738
739/// Iterations for the per-point sigma binary search. 64 halvings from
740/// the initial bracket pin sigma to f64 resolution.
741const SIGMA_SEARCH_ITERS: usize = 64;
742
743/// Early-exit tolerance on `|sum - log2(k)|` in the sigma search
744/// (umap-learn's SMOOTH_K_TOLERANCE).
745const SMOOTH_K_TOLERANCE: f64 = 1e-5;
746
747/// Sigma floor as a fraction of the point's mean neighbor distance
748/// (umap-learn's MIN_K_DIST_SCALE), with [`SIGMA_ABS_FLOOR`] as the
749/// absolute backstop for duplicate-heavy corpora where the mean is 0.
750const MIN_K_DIST_SCALE: f64 = 1e-3;
751const SIGMA_ABS_FLOOR: f64 = 1e-8;
752
753/// FNV-1a over the bit patterns of the components. Exact bit equality
754/// is the right key: training embeddings re-projected through the
755/// pipeline pass through the same deterministic `Embedding::normalized`,
756/// so they reproduce identical bits.
757fn hash_normalized(v: &[f64]) -> u64 {
758    let mut h = 0xcbf2_9ce4_8422_2325u64;
759    for &x in v {
760        h ^= x.to_bits();
761        h = h.wrapping_mul(0x0000_0100_0000_01b3);
762    }
763    h
764}
765
766/// Cosine distance from a similarity on L2-normalized vectors. The
767/// `max(0.0)` clamps fp overshoot past sim = 1 for (near-)duplicates.
768fn cosine_distance(sim: f64) -> f64 {
769    (1.0 - sim).max(0.0)
770}
771
772/// Split `(index, similarity)` rows into the adjacency list plus a
773/// parallel cosine-distance list for the fuzzy weight calibration.
774fn split_knn_rows(rows: Vec<Vec<(usize, f64)>>) -> (Vec<Vec<usize>>, Vec<Vec<f64>>) {
775    let mut knn = Vec::with_capacity(rows.len());
776    let mut dists = Vec::with_capacity(rows.len());
777    for row in rows {
778        knn.push(row.iter().map(|&(j, _)| j).collect());
779        dists.push(row.iter().map(|&(_, s)| cosine_distance(s)).collect());
780    }
781    (knn, dists)
782}
783
784#[allow(clippy::type_complexity)]
785fn build_knn_graph(
786    normalized: &[Vec<f64>],
787    k: usize,
788) -> (Vec<Vec<usize>>, Vec<Vec<f64>>, Option<Arc<AnnIndex>>) {
789    let n = normalized.len();
790    if n < ANN_BRUTE_FORCE_THRESHOLD {
791        let rows: Vec<Vec<(usize, f64)>> = (0..n)
792            .map(|i| {
793                let mut sims: Vec<(usize, f64)> = (0..n)
794                    .filter(|&j| j != i)
795                    .map(|j| (j, dot(&normalized[i], &normalized[j])))
796                    .collect();
797                sims.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
798                sims.truncate(k);
799                sims
800            })
801            .collect();
802        let (knn, dists) = split_knn_rows(rows);
803        return (knn, dists, None);
804    }
805
806    // AnnConfig defaults (n_trees=8, max_leaf_size=40) give >95% recall
807    // at N=500k for cosine kNN — the regime that drives this branch.
808    let index = Arc::new(AnnIndex::build_normalized(
809        normalized.to_vec(),
810        &AnnConfig::default(),
811    ));
812    let (knn, dists) = split_knn_rows(index.knn_graph_with_sims(k));
813    (knn, dists, Some(index))
814}
815
816/// Smooth-kNN calibration for one point's neighbor distances. Returns
817/// `(rho, sigma)`: `rho` is the nearest-neighbor distance
818/// (local_connectivity = 1, so each point's closest edge always gets
819/// weight 1.0 — the connectivity floor) and `sigma` is the bandwidth
820/// found by binary search so that
821/// `sum_j exp(-max(0, d_j - rho) / sigma) = log2(k)`.
822///
823/// On (near-)tie distances — duplicate-heavy corpora — the sum sticks at
824/// k for every sigma and the search collapses toward zero; the
825/// MIN_K_DIST_SCALE-style floor keeps sigma positive so the weights stay
826/// finite (they all come out ≈ 1.0, the right answer for duplicates).
827fn smooth_knn_calibrate(dists: &[f64]) -> (f64, f64) {
828    let rho = dists.iter().copied().fold(f64::INFINITY, f64::min);
829    let target = (dists.len() as f64).log2();
830    let mut lo = 0.0f64;
831    let mut hi = f64::INFINITY;
832    let mut sigma = 1.0f64;
833    for _ in 0..SIGMA_SEARCH_ITERS {
834        let sum: f64 = dists
835            .iter()
836            .map(|&d| (-((d - rho).max(0.0)) / sigma).exp())
837            .sum();
838        if (sum - target).abs() < SMOOTH_K_TOLERANCE {
839            break;
840        }
841        if sum > target {
842            hi = sigma;
843        } else {
844            lo = sigma;
845        }
846        sigma = if hi.is_finite() {
847            (lo + hi) / 2.0
848        } else {
849            sigma * 2.0
850        };
851    }
852    let mean = dists.iter().sum::<f64>() / dists.len() as f64;
853    (
854        rho,
855        sigma.max((MIN_K_DIST_SCALE * mean).max(SIGMA_ABS_FLOOR)),
856    )
857}
858
859/// Directed membership strengths per point, then fuzzy-union
860/// symmetrization `w = a + b - a·b` (reverse weight 0 when `j` does not
861/// list `i`). Each mutual pair appears in both adjacency rows carrying
862/// the same symmetrized weight; the optimizer halves those edges so a
863/// pair's total contribution is `w` regardless of mutuality.
864fn fuzzy_simplicial_weights(knn: &[Vec<usize>], dists: &[Vec<f64>]) -> Vec<Vec<f64>> {
865    let directed: Vec<Vec<f64>> = dists
866        .iter()
867        .map(|d| {
868            if d.is_empty() {
869                return Vec::new();
870            }
871            let (rho, sigma) = smooth_knn_calibrate(d);
872            d.iter()
873                .map(|&x| (-((x - rho).max(0.0)) / sigma).exp())
874                .collect()
875        })
876        .collect();
877
878    knn.iter()
879        .enumerate()
880        .map(|(i, neighbors)| {
881            neighbors
882                .iter()
883                .enumerate()
884                .map(|(idx, &j)| {
885                    let a = directed[i][idx];
886                    let b = knn[j]
887                        .iter()
888                        .position(|&x| x == i)
889                        .map_or(0.0, |p| directed[j][p]);
890                    // min() guards fp overshoot past 1 in the union.
891                    (a + b - a * b).min(1.0)
892                })
893                .collect()
894        })
895        .collect()
896}
897
898fn pca_warm_start(
899    embeddings: &[Embedding],
900    normalized: &[Vec<f64>],
901) -> Result<Vec<[f64; 3]>, ProjectionError> {
902    use crate::projection::PcaProjection;
903    use sphereql_core::spherical_to_cartesian;
904
905    let pca = PcaProjection::fit(embeddings, RadialStrategy::Fixed(1.0))?;
906    let mut out: Vec<[f64; 3]> = Vec::with_capacity(embeddings.len());
907    for (i, e) in embeddings.iter().enumerate() {
908        // `project_rich` exposes `projection_magnitude` — the raw 3D
909        // magnitude before the radial-strategy override. With
910        // `Fixed(1.0)` the SphericalPoint always has r=1, so checking
911        // the *spherical* point's cartesian magnitude is meaningless
912        // (it's always 1). The pre-radial magnitude is the real signal
913        // for "input near corpus mean → degenerate placement."
914        let pp = pca.project_rich(e);
915        if pp.projection_magnitude > f64::EPSILON {
916            let cart = spherical_to_cartesian(&pp.position);
917            out.push([cart.x, cart.y, cart.z]);
918            continue;
919        }
920        // Degenerate PCA position (input near corpus mean). Fall back
921        // to the first three normalized coords as a direction —
922        // stable, deterministic, and independent of any noise that
923        // pushed the PCA coordinate to zero.
924        let row = &normalized[i];
925        let mut v = [row[0], row[1], row[2]];
926        normalize_vec(&mut v);
927        if v[0] == 0.0 && v[1] == 0.0 && v[2] == 0.0 {
928            v = [1.0, 0.0, 0.0];
929        }
930        out.push(v);
931    }
932    Ok(out)
933}
934
935/// Trustworthiness-style kNN recall: for each point, compute its kNN
936/// set among the fitted 3D positions and intersect it with the
937/// original-space kNN set from the graph; the score is the mean overlap
938/// fraction. Random placements score near k/n; 1.0 means every original
939/// neighborhood survives the projection.
940///
941/// The fitted points are unit vectors, so cosine order equals angular
942/// order — the same kNN machinery used in the original space applies.
943/// Above `ANN_BRUTE_FORCE_THRESHOLD` a fresh ANN index is built over
944/// the 3D positions (deterministic via the default seed); this is
945/// distinct from the retained high-dimensional index in `UmapGraph`.
946fn knn_recall_score(points: &[[f64; 3]], knn: &[Vec<usize>]) -> f64 {
947    let n = points.len();
948    if n < 2 {
949        return 1.0;
950    }
951
952    let ann = (n >= ANN_BRUTE_FORCE_THRESHOLD).then(|| {
953        let coords: Vec<Vec<f64>> = points.iter().map(|p| p.to_vec()).collect();
954        AnnIndex::build_normalized(coords, &AnnConfig::default())
955    });
956
957    let mut total = 0.0;
958    let mut counted = 0usize;
959    for (i, original) in knn.iter().enumerate() {
960        let k = original.len();
961        if k == 0 {
962            continue;
963        }
964        let spherical: Vec<usize> = match &ann {
965            Some(index) => index
966                .query_by_index(i, k)
967                .into_iter()
968                .map(|(j, _)| j)
969                .collect(),
970            None => {
971                let mut sims: Vec<(usize, f64)> = (0..n)
972                    .filter(|&j| j != i)
973                    .map(|j| (j, dot(&points[i], &points[j])))
974                    .collect();
975                sims.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
976                sims.into_iter().take(k).map(|(j, _)| j).collect()
977            }
978        };
979        let hits = spherical.iter().filter(|j| original.contains(j)).count();
980        total += hits as f64 / k as f64;
981        counted += 1;
982    }
983    if counted == 0 {
984        1.0
985    } else {
986        total / counted as f64
987    }
988}
989
990#[cfg(test)]
991mod tests {
992    use super::*;
993    use sphereql_core::angular_distance;
994
995    fn emb(vals: &[f64]) -> Embedding {
996        Embedding::new(vals.to_vec())
997    }
998
999    fn cluster_corpus() -> Vec<Embedding> {
1000        // Two clear clusters in 6D so neighbor preservation is testable.
1001        let mut out = Vec::new();
1002        for i in 0..8 {
1003            let t = i as f64 * 0.01;
1004            out.push(emb(&[1.0 + t, 0.5 + t, 0.0, 0.0, 0.0, 0.0]));
1005        }
1006        for i in 0..8 {
1007            let t = i as f64 * 0.01;
1008            out.push(emb(&[0.0, 0.0, 0.0, 1.0 + t, 0.5 + t, 0.0]));
1009        }
1010        out
1011    }
1012
1013    #[test]
1014    fn fit_default_runs_and_produces_valid_points() {
1015        let corpus = cluster_corpus();
1016        let proj = UmapSphereProjection::fit_default(&corpus).unwrap();
1017        for e in &corpus {
1018            let sp = proj.project(e);
1019            assert!(sp.r >= 0.0);
1020            assert!(sp.theta.is_finite());
1021            assert!(sp.phi.is_finite());
1022        }
1023    }
1024
1025    #[test]
1026    fn quality_score_in_unit_interval() {
1027        let corpus = cluster_corpus();
1028        let proj = UmapSphereProjection::fit_default(&corpus).unwrap();
1029        let q = proj.explained_variance_ratio();
1030        assert!((0.0..=1.0).contains(&q), "got {q}");
1031    }
1032
1033    #[test]
1034    fn well_separated_clusters_score_high_recall() {
1035        let corpus = cluster_corpus();
1036        let proj = UmapSphereProjection::fit(
1037            &corpus,
1038            None,
1039            RadialStrategy::Fixed(1.0),
1040            UmapConfig {
1041                n_neighbors: 5,
1042                ..UmapConfig::default()
1043            },
1044        )
1045        .unwrap();
1046        let q = proj.explained_variance_ratio();
1047        assert!(
1048            q > 0.5,
1049            "expected high recall for separated clusters, got {q}"
1050        );
1051    }
1052
1053    #[test]
1054    fn shuffled_positions_score_lower_recall() {
1055        let corpus = cluster_corpus();
1056        let config = UmapConfig {
1057            n_neighbors: 5,
1058            ..UmapConfig::default()
1059        };
1060        let graph = UmapGraph::build(&corpus, config.n_neighbors).unwrap();
1061        let proj =
1062            UmapSphereProjection::fit_from_graph(&graph, None, RadialStrategy::Fixed(1.0), config)
1063                .unwrap();
1064        let fitted = proj.explained_variance_ratio();
1065
1066        // Permuting the fitted positions breaks every neighborhood the
1067        // optimizer built, so the same scorer must rank them below the
1068        // real layout.
1069        let mut shuffled = proj.fitted_points.clone();
1070        let mut rng = SplitMix64::new(0xD15C);
1071        for i in (1..shuffled.len()).rev() {
1072            let j = (rng.next_u64() as usize) % (i + 1);
1073            shuffled.swap(i, j);
1074        }
1075        let broken = knn_recall_score(&shuffled, &graph.knn);
1076        assert!(broken < fitted, "shuffled={broken}, fitted={fitted}");
1077    }
1078
1079    #[test]
1080    fn empty_corpus_errors() {
1081        assert!(matches!(
1082            UmapSphereProjection::fit_default(&[]),
1083            Err(ProjectionError::EmptyCorpus)
1084        ));
1085    }
1086
1087    #[test]
1088    fn dimension_too_low_errors() {
1089        let bad = vec![emb(&[1.0, 2.0]); 8];
1090        assert!(matches!(
1091            UmapSphereProjection::fit_default(&bad),
1092            Err(ProjectionError::DimensionTooLow { .. })
1093        ));
1094    }
1095
1096    #[test]
1097    fn too_few_embeddings_errors() {
1098        let small = vec![emb(&[1.0, 2.0, 3.0, 4.0]); 3];
1099        assert!(matches!(
1100            UmapSphereProjection::fit_default(&small),
1101            Err(ProjectionError::TooFewEmbeddings {
1102                got: 3,
1103                required: 4
1104            })
1105        ));
1106    }
1107
1108    #[test]
1109    fn ann_backed_knn_routes_to_correct_cluster() {
1110        // The 2000-item threshold in build_knn_graph routes this corpus
1111        // through brute force; here we directly exercise the ANN module
1112        // to confirm its k-NN graph respects cluster structure.
1113        //
1114        // cluster_corpus() builds two orthogonal clusters of 8 items.
1115        // The leaf size must be >= cluster size so every query routes
1116        // into a leaf containing all of its true neighbors — at this
1117        // tiny N, smaller leaves cause sub-leaf fragmentation that
1118        // would surface as recall holes only at this scale.
1119        use crate::ann::{AnnConfig, AnnIndex};
1120
1121        let corpus = cluster_corpus();
1122        let normalized: Vec<Vec<f64>> = corpus.iter().map(|e| e.normalized()).collect();
1123
1124        let config = AnnConfig {
1125            n_trees: 8,
1126            max_leaf_size: 8,
1127            seed: 42,
1128        };
1129        let index = AnnIndex::build_normalized(normalized.clone(), &config);
1130        let ann: Vec<Vec<usize>> = index.knn_graph(5);
1131
1132        for (i, neighbors) in ann.iter().enumerate() {
1133            let own_cluster = if i < 8 { 0..8 } else { 8..16 };
1134            for &n in neighbors {
1135                assert!(
1136                    own_cluster.contains(&n),
1137                    "item {i} got neighbor {n} from the wrong cluster"
1138                );
1139            }
1140        }
1141    }
1142
1143    #[test]
1144    fn category_term_pulls_same_class_together() {
1145        let corpus = cluster_corpus();
1146        let cats: Vec<u32> = (0..corpus.len())
1147            .map(|i| if i < 8 { 0 } else { 1 })
1148            .collect();
1149
1150        let unsupervised = UmapSphereProjection::fit(
1151            &corpus,
1152            None,
1153            RadialStrategy::Fixed(1.0),
1154            UmapConfig {
1155                n_epochs: 100,
1156                category_weight: 0.0,
1157                ..UmapConfig::default()
1158            },
1159        )
1160        .unwrap();
1161
1162        let supervised = UmapSphereProjection::fit(
1163            &corpus,
1164            Some(&cats),
1165            RadialStrategy::Fixed(1.0),
1166            UmapConfig {
1167                n_epochs: 100,
1168                category_weight: 2.0,
1169                ..UmapConfig::default()
1170            },
1171        )
1172        .unwrap();
1173
1174        // Mean within-class distance, supervised vs unsupervised.
1175        let within_unsup = mean_within_class(&unsupervised.fitted_points, &cats);
1176        let within_sup = mean_within_class(&supervised.fitted_points, &cats);
1177        assert!(
1178            within_sup <= within_unsup + 1e-6,
1179            "supervised within-class={within_sup}, unsupervised={within_unsup}"
1180        );
1181    }
1182
1183    #[test]
1184    fn category_term_tightens_classes_at_many_categories() {
1185        // 8 categories x 4 points, each category in its own basis
1186        // direction of an 8D space. At C=8 a uniform partner draw is
1187        // same-category only ~3/31 of the time, so the old sampling
1188        // was almost pure repulsion here — the cohesion half of the
1189        // term never fired. Stratified sampling must beat the
1190        // unsupervised baseline on within-class spread.
1191        let mut corpus = Vec::new();
1192        let mut cats: Vec<u32> = Vec::new();
1193        for c in 0..8u32 {
1194            for i in 0..4 {
1195                let mut v = vec![0.0; 8];
1196                v[c as usize] = 1.0 + i as f64 * 0.05;
1197                v[(c as usize + 1) % 8] = 0.1 + i as f64 * 0.02;
1198                corpus.push(emb(&v));
1199                cats.push(c);
1200            }
1201        }
1202
1203        let config = |category_weight: f64| UmapConfig {
1204            n_neighbors: 3,
1205            n_epochs: 100,
1206            category_weight,
1207            ..UmapConfig::default()
1208        };
1209
1210        let unsupervised =
1211            UmapSphereProjection::fit(&corpus, None, RadialStrategy::Fixed(1.0), config(0.0))
1212                .unwrap();
1213        let supervised = UmapSphereProjection::fit(
1214            &corpus,
1215            Some(&cats),
1216            RadialStrategy::Fixed(1.0),
1217            config(2.0),
1218        )
1219        .unwrap();
1220
1221        let within_unsup = mean_within_class(&unsupervised.fitted_points, &cats);
1222        let within_sup = mean_within_class(&supervised.fitted_points, &cats);
1223        assert!(
1224            within_sup < within_unsup,
1225            "supervised within-class={within_sup}, unsupervised={within_unsup}"
1226        );
1227    }
1228
1229    fn mean_within_class(points: &[[f64; 3]], cats: &[u32]) -> f64 {
1230        let mut total = 0.0;
1231        let mut count = 0;
1232        for i in 0..points.len() {
1233            for j in (i + 1)..points.len() {
1234                if cats[i] == cats[j] {
1235                    let pi = SphericalPoint::new_unchecked(
1236                        1.0,
1237                        points[i][1]
1238                            .atan2(points[i][0])
1239                            .rem_euclid(std::f64::consts::TAU),
1240                        points[i][2].clamp(-1.0, 1.0).acos(),
1241                    );
1242                    let pj = SphericalPoint::new_unchecked(
1243                        1.0,
1244                        points[j][1]
1245                            .atan2(points[j][0])
1246                            .rem_euclid(std::f64::consts::TAU),
1247                        points[j][2].clamp(-1.0, 1.0).acos(),
1248                    );
1249                    total += angular_distance(&pi, &pj);
1250                    count += 1;
1251                }
1252            }
1253        }
1254        if count == 0 {
1255            0.0
1256        } else {
1257            total / count as f64
1258        }
1259    }
1260
1261    #[test]
1262    fn sigma_calibration_hits_log2_k() {
1263        let dists = [0.1, 0.2, 0.3, 0.4, 0.5];
1264        let (rho, sigma) = smooth_knn_calibrate(&dists);
1265        assert_eq!(rho, 0.1);
1266        let sum: f64 = dists
1267            .iter()
1268            .map(|&d| (-((d - rho).max(0.0)) / sigma).exp())
1269            .sum();
1270        let target = 5.0f64.log2();
1271        assert!((sum - target).abs() < 1e-4, "sum={sum}, target={target}");
1272    }
1273
1274    #[test]
1275    fn nearest_neighbor_weight_is_one() {
1276        let corpus = cluster_corpus();
1277        let graph = UmapGraph::build(&corpus, 5).unwrap();
1278        for (i, neighbors) in graph.knn.iter().enumerate() {
1279            assert!(!neighbors.is_empty());
1280            // Rows are sorted by descending similarity, so index 0 is
1281            // the nearest neighbor — its distance equals rho, and the
1282            // fuzzy union preserves a directed weight of 1.
1283            let w = graph.weights[i][0];
1284            assert!((w - 1.0).abs() < 1e-9, "point {i}: nearest weight {w}");
1285        }
1286    }
1287
1288    #[test]
1289    fn duplicate_heavy_corpus_fits_with_finite_weights() {
1290        // Six exact copies of the first cluster point on top of the
1291        // normal corpus — rho = 0 and all-tie neighbor distances push
1292        // the sigma search to its floor.
1293        let mut corpus = vec![emb(&[1.0, 0.5, 0.0, 0.0, 0.0, 0.0]); 6];
1294        corpus.extend(cluster_corpus());
1295
1296        let graph = UmapGraph::build(&corpus, 5).unwrap();
1297        for row in &graph.weights {
1298            for &w in row {
1299                assert!(w.is_finite() && (0.0..=1.0).contains(&w), "weight {w}");
1300            }
1301        }
1302
1303        let proj = UmapSphereProjection::fit_from_graph(
1304            &graph,
1305            None,
1306            RadialStrategy::Fixed(1.0),
1307            UmapConfig {
1308                n_neighbors: 5,
1309                n_epochs: 30,
1310                ..UmapConfig::default()
1311            },
1312        )
1313        .unwrap();
1314        for p in &proj.fitted_points {
1315            assert!(p.iter().all(|c| c.is_finite()));
1316        }
1317        assert!(proj.explained_variance_ratio().is_finite());
1318    }
1319
1320    #[test]
1321    fn dense_cluster_edges_outweigh_diffuse_cluster_edges() {
1322        // A duplicate-tight cluster is the density extreme: every
1323        // neighbor sits at rho, so calibration leaves every intra edge
1324        // at weight ~1. The diffuse cluster's spread distances calibrate
1325        // sigma instead, so its directed weights sum to log2(k) and the
1326        // non-nearest edges decay below 1 — the local scaling an
1327        // unweighted graph lacked.
1328        let mut corpus = vec![emb(&[1.0, 0.2, 0.0, 0.0, 0.0, 0.0]); 8];
1329        for i in 0..8 {
1330            let mut v = vec![0.0; 6];
1331            v[3] = 1.0;
1332            v[4] = 0.15 * i as f64;
1333            corpus.push(emb(&v));
1334        }
1335        let graph = UmapGraph::build(&corpus, 5).unwrap();
1336
1337        let mean_intra = |range: std::ops::Range<usize>| {
1338            let mut total = 0.0;
1339            let mut count = 0usize;
1340            for i in range.clone() {
1341                for (idx, &j) in graph.knn[i].iter().enumerate() {
1342                    if range.contains(&j) {
1343                        total += graph.weights[i][idx];
1344                        count += 1;
1345                    }
1346                }
1347            }
1348            assert!(count > 0, "no intra-cluster edges in {range:?}");
1349            total / count as f64
1350        };
1351
1352        let tight = mean_intra(0..8);
1353        let diffuse = mean_intra(8..16);
1354        assert!(
1355            tight >= diffuse,
1356            "tight mean weight {tight} < diffuse mean weight {diffuse}"
1357        );
1358        assert!(
1359            tight > 0.99,
1360            "duplicate-cluster edges should be ~1, got {tight}"
1361        );
1362        assert!(diffuse < 0.95, "diffuse edges should decay, got {diffuse}");
1363    }
1364
1365    #[test]
1366    fn dimensionality_reports_input_dim() {
1367        let corpus = cluster_corpus();
1368        let proj = UmapSphereProjection::fit_default(&corpus).unwrap();
1369        assert_eq!(proj.dimensionality(), 6);
1370    }
1371
1372    #[test]
1373    fn fit_from_graph_matches_full_fit() {
1374        let corpus = cluster_corpus();
1375        let config = UmapConfig {
1376            n_epochs: 50,
1377            category_weight: 0.0,
1378            seed: 42,
1379            ..UmapConfig::default()
1380        };
1381
1382        let full =
1383            UmapSphereProjection::fit(&corpus, None, RadialStrategy::Fixed(1.0), config.clone())
1384                .unwrap();
1385
1386        let graph = UmapGraph::build(&corpus, config.n_neighbors).unwrap();
1387        let split =
1388            UmapSphereProjection::fit_from_graph(&graph, None, RadialStrategy::Fixed(1.0), config)
1389                .unwrap();
1390
1391        assert!(
1392            (full.explained_variance_ratio() - split.explained_variance_ratio()).abs() < 1e-6,
1393            "full={}, split={}",
1394            full.explained_variance_ratio(),
1395            split.explained_variance_ratio()
1396        );
1397    }
1398
1399    #[test]
1400    fn graph_reusable_across_configs() {
1401        let corpus = cluster_corpus();
1402        let graph = UmapGraph::build(&corpus, 5).unwrap();
1403
1404        let config1 = UmapConfig {
1405            n_epochs: 30,
1406            category_weight: 0.0,
1407            seed: 1,
1408            ..UmapConfig::default()
1409        };
1410        let config2 = UmapConfig {
1411            n_epochs: 60,
1412            category_weight: 1.0,
1413            seed: 2,
1414            ..UmapConfig::default()
1415        };
1416
1417        let p1 =
1418            UmapSphereProjection::fit_from_graph(&graph, None, RadialStrategy::Fixed(1.0), config1)
1419                .unwrap();
1420        let p2 =
1421            UmapSphereProjection::fit_from_graph(&graph, None, RadialStrategy::Fixed(1.0), config2)
1422                .unwrap();
1423
1424        assert!((0.0..=1.0).contains(&p1.explained_variance_ratio()));
1425        assert!((0.0..=1.0).contains(&p2.explained_variance_ratio()));
1426    }
1427
1428    fn assert_projects_to_fitted(proj: &UmapSphereProjection, e: &Embedding, idx: usize) {
1429        use sphereql_core::spherical_to_cartesian;
1430
1431        let pp = proj.project_rich(e);
1432        assert_eq!(pp.certainty, 1.0, "exact match must report certainty 1.0");
1433        let cart = spherical_to_cartesian(&pp.position);
1434        let expected = proj.fitted_points[idx];
1435        assert!((cart.x - expected[0]).abs() < 1e-12);
1436        assert!((cart.y - expected[1]).abs() < 1e-12);
1437        assert!((cart.z - expected[2]).abs() < 1e-12);
1438    }
1439
1440    #[test]
1441    fn projecting_training_embedding_returns_exact_fitted_position() {
1442        let corpus = cluster_corpus();
1443        let proj = UmapSphereProjection::fit(
1444            &corpus,
1445            None,
1446            RadialStrategy::Fixed(1.0),
1447            UmapConfig::default(),
1448        )
1449        .unwrap();
1450
1451        // Default n_neighbors=15 on a 16-point corpus means the old
1452        // interpolation path averaged nearly every fitted point, so a
1453        // smeared result would be visibly off the per-point positions
1454        // checked here.
1455        for (i, e) in corpus.iter().enumerate() {
1456            assert_projects_to_fitted(&proj, e, i);
1457        }
1458    }
1459
1460    #[test]
1461    fn duplicate_training_embedding_maps_to_first_fitted_index() {
1462        let mut corpus = cluster_corpus();
1463        corpus.push(corpus[0].clone());
1464        let proj = UmapSphereProjection::fit(
1465            &corpus,
1466            None,
1467            RadialStrategy::Fixed(1.0),
1468            UmapConfig::default(),
1469        )
1470        .unwrap();
1471
1472        assert_projects_to_fitted(&proj, &corpus[0], 0);
1473        assert_projects_to_fitted(&proj, &corpus[16], 0);
1474
1475        let a = proj.project_rich(&corpus[0]);
1476        let b = proj.project_rich(&corpus[16]);
1477        assert_eq!(a.position.theta, b.position.theta);
1478        assert_eq!(a.position.phi, b.position.phi);
1479        assert_eq!(a.position.r, b.position.r);
1480    }
1481
1482    #[test]
1483    fn unseen_embedding_interpolates_on_sphere() {
1484        let corpus = cluster_corpus();
1485        let proj = UmapSphereProjection::fit(
1486            &corpus,
1487            None,
1488            RadialStrategy::Fixed(1.0),
1489            UmapConfig::default(),
1490        )
1491        .unwrap();
1492
1493        let unseen = emb(&[1.0, 0.55, 0.02, 0.0, 0.0, 0.0]);
1494        let pp = proj.project_rich(&unseen);
1495        assert!(pp.certainty > 0.0 && pp.certainty < 1.0);
1496        assert!(pp.position.theta.is_finite());
1497        assert!(pp.position.phi.is_finite());
1498        assert!((pp.position.r - 1.0).abs() < 1e-12);
1499        assert!(pp.projection_magnitude > 0.0 && pp.projection_magnitude <= 1.0 + 1e-12);
1500    }
1501
1502    #[test]
1503    fn ann_backed_transform_above_threshold() {
1504        let mut rng = SplitMix64::new(0x5EED);
1505        let mut random_emb = |dim: usize| {
1506            let vals: Vec<f64> = (0..dim).map(|_| rng.normal()).collect();
1507            emb(&vals)
1508        };
1509        let corpus: Vec<Embedding> = (0..ANN_BRUTE_FORCE_THRESHOLD)
1510            .map(|_| random_emb(8))
1511            .collect();
1512        let config = UmapConfig {
1513            n_neighbors: 5,
1514            n_epochs: 2,
1515            negative_sample_rate: 1,
1516            ..UmapConfig::default()
1517        };
1518        let proj =
1519            UmapSphereProjection::fit(&corpus, None, RadialStrategy::Fixed(1.0), config).unwrap();
1520        assert!(proj.ann.is_some(), "expected ANN index above threshold");
1521
1522        assert_projects_to_fitted(&proj, &corpus[1234], 1234);
1523
1524        let unseen = random_emb(8);
1525        let pp = proj.project_rich(&unseen);
1526        assert!(pp.certainty > 0.0 && pp.certainty <= 1.0);
1527        assert!(pp.position.theta.is_finite());
1528        assert!(pp.position.phi.is_finite());
1529        assert!(pp.position.r.is_finite());
1530    }
1531
1532    #[test]
1533    fn gradients_at_unit_ab_match_old_hardcoded_forms() {
1534        // The generalized kernel must reduce exactly to the pre-min_dist
1535        // forms at a = b = 1: attractive 2(xi-xj)/(1+d²), repulsive
1536        // -2(xi-xj)/(d²(1+d²)) with the same 1e-6 floor.
1537        let pairs: [([f64; 3], [f64; 3]); 4] = [
1538            ([1.0, 0.0, 0.0], [0.6, 0.8, 0.0]),
1539            ([0.0, 1.0, 0.0], [0.0, 0.0, 1.0]),
1540            ([1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]),
1541            ([0.36, 0.48, 0.8], [0.48, 0.36, 0.8]),
1542        ];
1543        for (xi, xj) in pairs {
1544            let dx = [xi[0] - xj[0], xi[1] - xj[1], xi[2] - xj[2]];
1545            let d2 = dx[0] * dx[0] + dx[1] * dx[1] + dx[2] * dx[2];
1546
1547            let (gi, gj) = attractive_grad(&xi, &xj, 1.0, 1.0);
1548            let coef = 2.0 / (1.0 + d2);
1549            for d in 0..3 {
1550                assert!((gi[d] - coef * dx[d]).abs() < 1e-12, "attractive gi[{d}]");
1551                assert!((gj[d] + coef * dx[d]).abs() < 1e-12, "attractive gj[{d}]");
1552            }
1553
1554            let (ri, rj) = repulsive_grad(&xi, &xj, 1.0, 1.0);
1555            let d2f = d2.max(1e-6);
1556            let rcoef = -2.0 / (d2f * (1.0 + d2f));
1557            for d in 0..3 {
1558                assert!((ri[d] - rcoef * dx[d]).abs() < 1e-12, "repulsive ri[{d}]");
1559                assert!((rj[d] + rcoef * dx[d]).abs() < 1e-12, "repulsive rj[{d}]");
1560            }
1561        }
1562    }
1563
1564    #[test]
1565    fn ab_fit_matches_canonical_anchor_at_default_min_dist() {
1566        // umap-learn's scipy curve_fit gives (1.577, 0.895) for
1567        // min_dist = 0.1, spread = 1.0. The grid fit should land in the
1568        // same neighborhood.
1569        let (a, b) = find_ab_params(0.1);
1570        assert!((1.3..=1.9).contains(&a), "a = {a}");
1571        assert!((0.78..=1.0).contains(&b), "b = {b}");
1572    }
1573
1574    #[test]
1575    fn larger_min_dist_flattens_kernel_near_origin() {
1576        // A bigger min_dist must fit a kernel that holds Phi higher
1577        // (weaker attraction decay) at short range — that flatter
1578        // plateau is what buys territory on the sphere.
1579        let (a0, b0) = find_ab_params(0.0);
1580        let (a5, b5) = find_ab_params(0.5);
1581        let phi = |a: f64, b: f64, d: f64| 1.0 / (1.0 + a * d.powf(2.0 * b));
1582        for d in [0.1, 0.25, 0.5, 0.75, 1.0] {
1583            assert!(
1584                phi(a5, b5, d) > phi(a0, b0, d),
1585                "Phi at d={d}: min_dist=0.5 gives {}, min_dist=0.0 gives {}",
1586                phi(a5, b5, d),
1587                phi(a0, b0, d)
1588            );
1589        }
1590    }
1591
1592    #[test]
1593    fn larger_min_dist_spreads_fitted_points() {
1594        // Same corpus, same seed: the only difference is the kernel, so
1595        // intra-cluster packing must loosen with min_dist. Fits are
1596        // deterministic, so this is a fixed outcome, not a flaky one.
1597        let corpus = cluster_corpus();
1598        let cats: Vec<u32> = (0..corpus.len())
1599            .map(|i| if i < 8 { 0 } else { 1 })
1600            .collect();
1601        let fit_at = |min_dist: f64| {
1602            UmapSphereProjection::fit(
1603                &corpus,
1604                None,
1605                RadialStrategy::Fixed(1.0),
1606                UmapConfig {
1607                    n_neighbors: 5,
1608                    min_dist,
1609                    ..UmapConfig::default()
1610                },
1611            )
1612            .unwrap()
1613        };
1614        let tight = mean_within_class(&fit_at(0.0).fitted_points, &cats);
1615        let spread = mean_within_class(&fit_at(0.5).fitted_points, &cats);
1616        assert!(
1617            spread > tight,
1618            "min_dist=0.5 within-class spread {spread} should exceed min_dist=0.0's {tight}"
1619        );
1620    }
1621
1622    fn mean_warm_start_displacement(graph: &UmapGraph, points: &[[f64; 3]]) -> f64 {
1623        points
1624            .iter()
1625            .zip(&graph.warm_start)
1626            .map(|(p, w)| {
1627                let cos = (p[0] * w[0] + p[1] * w[1] + p[2] * w[2]).clamp(-1.0, 1.0);
1628                cos.acos()
1629            })
1630            .sum::<f64>()
1631            / points.len() as f64
1632    }
1633
1634    #[test]
1635    fn zero_anchor_is_bit_identical_and_rng_neutral() {
1636        // warm_start_anchor = 0.0 must be a strict no-op: the anchor
1637        // block consumes no RNG and adds no gradient, so two split
1638        // fits — and the split fit vs the full fit — agree to the bit,
1639        // exactly as before the knob existed.
1640        let corpus = cluster_corpus();
1641        let config = UmapConfig {
1642            n_neighbors: 5,
1643            n_epochs: 30,
1644            warm_start_anchor: 0.0,
1645            seed: 42,
1646            ..UmapConfig::default()
1647        };
1648
1649        let graph = UmapGraph::build(&corpus, config.n_neighbors).unwrap();
1650        let a = UmapSphereProjection::fit_from_graph(
1651            &graph,
1652            None,
1653            RadialStrategy::Fixed(1.0),
1654            config.clone(),
1655        )
1656        .unwrap();
1657        let b = UmapSphereProjection::fit_from_graph(
1658            &graph,
1659            None,
1660            RadialStrategy::Fixed(1.0),
1661            config.clone(),
1662        )
1663        .unwrap();
1664        let full =
1665            UmapSphereProjection::fit(&corpus, None, RadialStrategy::Fixed(1.0), config).unwrap();
1666
1667        assert_eq!(a.fitted_points, b.fitted_points);
1668        assert_eq!(a.fitted_points, full.fitted_points);
1669    }
1670
1671    #[test]
1672    fn warm_start_anchor_limits_component_drift() {
1673        // cluster_corpus()'s two clusters are orthogonal, so at
1674        // n_neighbors = 5 every kNN edge stays inside its own cluster
1675        // and the graph splits into two disconnected components — the
1676        // regime the anchor exists for. With zero cross-component
1677        // attraction, repulsion alone drags the layout away from the
1678        // warm-start arrangement; the anchored fit must end closer to
1679        // it. Both fits share one RNG stream (the anchor draws
1680        // nothing), so this is a paired comparison, not a flaky one.
1681        let corpus = cluster_corpus();
1682        let graph = UmapGraph::build(&corpus, 5).unwrap();
1683        for (i, neighbors) in graph.knn.iter().enumerate() {
1684            let own_cluster = if i < 8 { 0..8 } else { 8..16 };
1685            for &j in neighbors {
1686                assert!(
1687                    own_cluster.contains(&j),
1688                    "expected disconnected components, but {i} links to {j}"
1689                );
1690            }
1691        }
1692
1693        let fit_at = |anchor: f64| {
1694            UmapSphereProjection::fit_from_graph(
1695                &graph,
1696                None,
1697                RadialStrategy::Fixed(1.0),
1698                UmapConfig {
1699                    n_neighbors: 5,
1700                    n_epochs: 100,
1701                    warm_start_anchor: anchor,
1702                    seed: 42,
1703                    ..UmapConfig::default()
1704                },
1705            )
1706            .unwrap()
1707        };
1708
1709        let free = mean_warm_start_displacement(&graph, &fit_at(0.0).fitted_points);
1710        let anchored = mean_warm_start_displacement(&graph, &fit_at(0.05).fitted_points);
1711        assert!(
1712            anchored < free,
1713            "anchored displacement {anchored} should be below unanchored {free}"
1714        );
1715    }
1716
1717    #[test]
1718    fn anchored_fit_stays_on_sphere_with_valid_quality() {
1719        let corpus = cluster_corpus();
1720        let proj = UmapSphereProjection::fit(
1721            &corpus,
1722            None,
1723            RadialStrategy::Fixed(1.0),
1724            UmapConfig {
1725                n_neighbors: 5,
1726                n_epochs: 50,
1727                warm_start_anchor: 0.05,
1728                ..UmapConfig::default()
1729            },
1730        )
1731        .unwrap();
1732        for p in &proj.fitted_points {
1733            assert!(p.iter().all(|c| c.is_finite()));
1734            let mag = (p[0] * p[0] + p[1] * p[1] + p[2] * p[2]).sqrt();
1735            assert!((mag - 1.0).abs() < 1e-9, "off-sphere magnitude {mag}");
1736        }
1737        let q = proj.explained_variance_ratio();
1738        assert!((0.0..=1.0).contains(&q), "quality {q}");
1739    }
1740}