Skip to main content

sphereql_embed/
projection.rs

1use std::f64::consts::PI;
2use std::sync::Arc;
3
4use sphereql_core::{CartesianPoint, SphericalPoint, cartesian_to_spherical};
5
6use crate::types::{Embedding, ProjectedPoint, RadialStrategy};
7
8/// Reasons a projection fit can fail.
9///
10/// Every concrete projection's `fit` used to panic via `assert!` on
11/// invalid input, which turned typos in Python / WASM bindings into
12/// `PanicException`s. These variants classify the same preconditions
13/// so callers can surface typed errors instead.
14#[derive(Debug, Clone, PartialEq, thiserror::Error)]
15pub enum ProjectionError {
16    /// The input slice was empty. Fitting needs at least one embedding.
17    #[error("need at least one embedding to fit a projection")]
18    EmptyCorpus,
19
20    /// Embedding dimensionality below the projection's requirement.
21    /// PCA and kernel PCA need `dim >= 3`; Laplacian requires `dim > 0`.
22    #[error("embedding dimension {got} is below the minimum {required} for this projection")]
23    DimensionTooLow { got: usize, required: usize },
24
25    /// Embeddings disagreed on dimensionality. Every row must match the
26    /// first one; the mismatch is reported with the offending index.
27    #[error("embedding {index} has dimension {got}, expected {expected}")]
28    InconsistentDimension {
29        index: usize,
30        expected: usize,
31        got: usize,
32    },
33
34    /// Projection needs more embeddings than were provided. Laplacian
35    /// eigenmap's graph construction requires `n >= 4`.
36    #[error("need at least {required} embeddings, got {got}")]
37    TooFewEmbeddings { got: usize, required: usize },
38
39    /// `fit_with_sigma` was given a non-positive Gaussian bandwidth.
40    #[error("kernel bandwidth σ must be positive, got {got}")]
41    InvalidSigma { got: f64 },
42}
43
44/// Maps high-dimensional embeddings to spherical coordinates.
45///
46/// The angular coordinates (theta, phi) encode semantic direction via
47/// dimensionality reduction from S^{n-1} to S^2. The radial coordinate
48/// is controlled by the projection's [`RadialStrategy`].
49pub trait Projection: Send + Sync {
50    fn project(&self, embedding: &Embedding) -> SphericalPoint;
51
52    /// Project with rich metadata: certainty, intensity, projection magnitude.
53    fn project_rich(&self, embedding: &Embedding) -> ProjectedPoint {
54        let position = self.project(embedding);
55        ProjectedPoint::from_position(position, embedding.magnitude())
56    }
57
58    fn dimensionality(&self) -> usize;
59}
60
61impl<P: Projection> Projection for Arc<P> {
62    fn project(&self, embedding: &Embedding) -> SphericalPoint {
63        (**self).project(embedding)
64    }
65    fn project_rich(&self, embedding: &Embedding) -> ProjectedPoint {
66        (**self).project_rich(embedding)
67    }
68    fn dimensionality(&self) -> usize {
69        (**self).dimensionality()
70    }
71}
72
73/// Corpus-fitted projection via spherical PCA.
74///
75/// Finds the 3 principal directions of maximum angular variance in the
76/// embedding space, then projects new embeddings onto them. This preserves
77/// angular (cosine similarity) relationships as faithfully as possible
78/// in 3 dimensions.
79///
80/// Fitting: O(N·n·k·iters) where N=corpus size, n=dimension, k=3.
81/// Projection: O(n) per embedding.
82#[derive(Clone)]
83pub struct PcaProjection {
84    components: [Vec<f64>; 3],
85    mean: Vec<f64>,
86    dim: usize,
87    radial: RadialStrategy,
88    volumetric: bool,
89    /// Top-3 eigenvalues from PCA (descending). Used to compute per-point certainty.
90    eigenvalues: [f64; 3],
91    /// Total variance across all dimensions. eigenvalues[0..3].sum() / total_variance
92    /// gives the global explained variance ratio.
93    total_variance: f64,
94}
95
96impl PcaProjection {
97    /// Fit the top-3 principal components on `embeddings`.
98    ///
99    /// Returns [`ProjectionError::EmptyCorpus`] if the slice is empty,
100    /// [`ProjectionError::DimensionTooLow`] if `dim < 3`, and
101    /// [`ProjectionError::InconsistentDimension`] if any row's
102    /// dimensionality disagrees with the first. Previously these paths
103    /// panicked via `assert!`, which surfaced as a `PanicException` in
104    /// Python / WASM bindings.
105    pub fn fit(embeddings: &[Embedding], radial: RadialStrategy) -> Result<Self, ProjectionError> {
106        if embeddings.is_empty() {
107            return Err(ProjectionError::EmptyCorpus);
108        }
109        let dim = embeddings[0].dimension();
110        if dim < 3 {
111            return Err(ProjectionError::DimensionTooLow {
112                got: dim,
113                required: 3,
114            });
115        }
116        for (i, e) in embeddings.iter().enumerate() {
117            if e.dimension() != dim {
118                return Err(ProjectionError::InconsistentDimension {
119                    index: i,
120                    expected: dim,
121                    got: e.dimension(),
122                });
123            }
124        }
125
126        let normalized: Vec<Vec<f64>> = embeddings.iter().map(|e| e.normalized()).collect();
127        let n = normalized.len();
128
129        let mut mean = vec![0.0; dim];
130        for v in &normalized {
131            for (i, &val) in v.iter().enumerate() {
132                mean[i] += val;
133            }
134        }
135        for m in &mut mean {
136            *m /= n as f64;
137        }
138
139        let centered: Vec<Vec<f64>> = normalized
140            .iter()
141            .map(|v| {
142                v.iter()
143                    .zip(mean.iter())
144                    .map(|(&val, &m)| val - m)
145                    .collect()
146            })
147            .collect();
148
149        let (components, eigenvalues) = top_k_eigenvectors(&centered, 3, dim);
150
151        // Total variance = sum of all eigenvalues = trace of covariance = sum of squared norms
152        let total_variance: f64 = centered
153            .iter()
154            .map(|row| row.iter().map(|x| x * x).sum::<f64>())
155            .sum::<f64>()
156            / centered.len() as f64;
157
158        Ok(Self {
159            components: [
160                components[0].clone(),
161                components[1].clone(),
162                components[2].clone(),
163            ],
164            mean,
165            dim,
166            radial,
167            volumetric: false,
168            eigenvalues: [
169                eigenvalues.first().copied().unwrap_or(0.0),
170                eigenvalues.get(1).copied().unwrap_or(0.0),
171                eigenvalues.get(2).copied().unwrap_or(0.0),
172            ],
173            total_variance,
174        })
175    }
176
177    pub fn fit_default(embeddings: &[Embedding]) -> Result<Self, ProjectionError> {
178        Self::fit(embeddings, RadialStrategy::default())
179    }
180
181    /// Enable volumetric mode: r comes from the PCA projection magnitude
182    /// instead of the embedding magnitude. Points distribute through the
183    /// full 3D volume rather than clustering on the sphere surface.
184    pub fn with_volumetric(mut self, enabled: bool) -> Self {
185        self.volumetric = enabled;
186        self
187    }
188
189    /// The fraction of total variance captured by the top-3 PCA components.
190    /// A global quality metric for the projection — higher means less information lost.
191    pub fn explained_variance_ratio(&self) -> f64 {
192        if self.total_variance < f64::EPSILON {
193            return 1.0;
194        }
195        let explained: f64 = self.eigenvalues.iter().sum();
196        (explained / self.total_variance).clamp(0.0, 1.0)
197    }
198
199    /// Allocation-free projection kernel: folds
200    /// `normalize(embedding) − mean` into the per-axis dot product
201    /// without materializing the intermediate `Vec<f64>`s that the
202    /// previous implementation allocated per call.
203    ///
204    /// Matches the numerics of `project_centered(&centered)` exactly:
205    /// each axis sums `(v_i/|v| − mean_i) · component_j[i]` over i,
206    /// plus a total-squared accumulator for the residual.
207    ///
208    /// Called by [`Self::project`] and [`Self::project_rich`]; callers
209    /// that want `SphericalPoint` or `ProjectedPoint` should use those.
210    fn project_xyz_residual(&self, embedding: &Embedding) -> (f64, f64, f64, f64) {
211        let values = &embedding.values;
212        debug_assert_eq!(values.len(), self.dim);
213
214        let mag = embedding.magnitude();
215        let inv_mag = if mag < f64::EPSILON { 0.0 } else { 1.0 / mag };
216
217        let mut x = 0.0f64;
218        let mut y = 0.0f64;
219        let mut z = 0.0f64;
220        let mut total_sq = 0.0f64;
221        let c0 = &self.components[0];
222        let c1 = &self.components[1];
223        let c2 = &self.components[2];
224        for i in 0..self.dim {
225            let n = values[i] * inv_mag;
226            let c = n - self.mean[i];
227            x += c * c0[i];
228            y += c * c1[i];
229            z += c * c2[i];
230            total_sq += c * c;
231        }
232        let projected_sq = x * x + y * y + z * z;
233        let residual_sq = (total_sq - projected_sq).max(0.0);
234        (x, y, z, residual_sq)
235    }
236}
237
238impl Projection for PcaProjection {
239    fn project(&self, embedding: &Embedding) -> SphericalPoint {
240        assert_eq!(
241            embedding.dimension(),
242            self.dim,
243            "expected dimension {}, got {}",
244            self.dim,
245            embedding.dimension()
246        );
247
248        let (x, y, z, _) = self.project_xyz_residual(embedding);
249
250        if self.volumetric {
251            let sp = cartesian_to_spherical(&CartesianPoint::new(x, y, z));
252            if sp.r < f64::EPSILON {
253                return SphericalPoint::new_unchecked(0.0, 0.0, 0.0);
254            }
255            SphericalPoint::new_unchecked(sp.r, sp.theta, sp.phi)
256        } else {
257            let r = self.radial.compute(embedding.magnitude());
258            project_xyz_to_spherical(x, y, z, r)
259        }
260    }
261
262    fn project_rich(&self, embedding: &Embedding) -> ProjectedPoint {
263        assert_eq!(
264            embedding.dimension(),
265            self.dim,
266            "expected dimension {}, got {}",
267            self.dim,
268            embedding.dimension()
269        );
270
271        let intensity = embedding.magnitude();
272        let (x, y, z, residual_sq) = self.project_xyz_residual(embedding);
273        let projection_magnitude = (x * x + y * y + z * z).sqrt();
274
275        // Per-point certainty: fraction of this point's variance captured
276        // by the 3 components. The fold below also drops the separate
277        // centered-vec allocation the old version materialized.
278        let inv_mag = if intensity < f64::EPSILON {
279            0.0
280        } else {
281            1.0 / intensity
282        };
283        let total_sq: f64 = (0..self.dim)
284            .map(|i| {
285                let c = embedding.values[i] * inv_mag - self.mean[i];
286                c * c
287            })
288            .sum();
289        let certainty = if total_sq < f64::EPSILON {
290            0.0
291        } else {
292            (1.0 - residual_sq / total_sq).clamp(0.0, 1.0)
293        };
294
295        let position = if self.volumetric {
296            let sp = cartesian_to_spherical(&CartesianPoint::new(x, y, z));
297            if sp.r < f64::EPSILON {
298                SphericalPoint::new_unchecked(0.0, 0.0, 0.0)
299            } else {
300                SphericalPoint::new_unchecked(sp.r, sp.theta, sp.phi)
301            }
302        } else {
303            let r = self.radial.compute(intensity);
304            project_xyz_to_spherical(x, y, z, r)
305        };
306
307        ProjectedPoint::new(position, certainty, intensity, projection_magnitude)
308    }
309
310    fn dimensionality(&self) -> usize {
311        self.dim
312    }
313}
314
315/// Fit-free projection via random matrix (Johnson-Lindenstrauss).
316///
317/// Generates a fixed 3×n random matrix at construction time. Preserves
318/// pairwise distances probabilistically without needing a training corpus.
319/// Less accurate than PCA for any specific dataset, but useful when
320/// no corpus is available or for quick prototyping.
321///
322/// Deterministic for a given seed.
323#[derive(Clone)]
324pub struct RandomProjection {
325    matrix: [Vec<f64>; 3],
326    dim: usize,
327    radial: RadialStrategy,
328}
329
330impl RandomProjection {
331    pub fn new(dim: usize, radial: RadialStrategy, seed: u64) -> Self {
332        assert!(dim >= 3, "embedding dimension must be >= 3");
333        let mut rng = SplitMix64::new(seed);
334        let matrix = std::array::from_fn(|_| (0..dim).map(|_| rng.normal()).collect());
335        Self {
336            matrix,
337            dim,
338            radial,
339        }
340    }
341
342    pub fn new_default(dim: usize) -> Self {
343        Self::new(dim, RadialStrategy::default(), 42)
344    }
345}
346
347impl Projection for RandomProjection {
348    fn project(&self, embedding: &Embedding) -> SphericalPoint {
349        assert_eq!(
350            embedding.dimension(),
351            self.dim,
352            "expected dimension {}, got {}",
353            self.dim,
354            embedding.dimension()
355        );
356
357        let magnitude = embedding.magnitude();
358        let r = self.radial.compute(magnitude);
359        let normalized = embedding.normalized();
360
361        let x = dot(&normalized, &self.matrix[0]);
362        let y = dot(&normalized, &self.matrix[1]);
363        let z = dot(&normalized, &self.matrix[2]);
364
365        project_xyz_to_spherical(x, y, z, r)
366    }
367
368    fn dimensionality(&self) -> usize {
369        self.dim
370    }
371}
372
373// --- Shared projection math (pub(crate) for reuse by kernel_pca) ---
374
375pub(crate) fn project_xyz_to_spherical(x: f64, y: f64, z: f64, r: f64) -> SphericalPoint {
376    let cart = CartesianPoint::new(x, y, z).normalize();
377    if cart.magnitude() < f64::EPSILON {
378        return SphericalPoint::new_unchecked(r, 0.0, 0.0);
379    }
380    let sp = cartesian_to_spherical(&cart);
381    SphericalPoint::new_unchecked(r, sp.theta, sp.phi)
382}
383
384// --- Linear algebra primitives (pub(crate) for reuse by kernel_pca) ---
385
386pub(crate) fn dot(a: &[f64], b: &[f64]) -> f64 {
387    a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
388}
389
390pub(crate) fn normalize_vec(v: &mut [f64]) -> f64 {
391    let mag = v.iter().map(|x| x * x).sum::<f64>().sqrt();
392    if mag > f64::EPSILON {
393        for x in v.iter_mut() {
394            *x /= mag;
395        }
396    }
397    mag
398}
399
400/// Power iteration with deflation for the top-k eigenvectors of XᵀX.
401///
402/// Computes XᵀX·v as Xᵀ(Xv) to avoid forming the n×n matrix,
403/// keeping each iteration at O(N·n) instead of O(n²).
404///
405/// Returns (eigenvectors, eigenvalues) both sorted by decreasing eigenvalue.
406fn top_k_eigenvectors(data: &[Vec<f64>], k: usize, dim: usize) -> (Vec<Vec<f64>>, Vec<f64>) {
407    let max_iters = 200;
408    let tol = 1e-10;
409    let mut vectors: Vec<Vec<f64>> = Vec::with_capacity(k);
410    let mut values: Vec<f64> = Vec::with_capacity(k);
411    let mut rng = SplitMix64::new(0xDEAD_BEEF);
412    let n = data.len() as f64;
413
414    for _ in 0..k {
415        let mut v: Vec<f64> = (0..dim).map(|_| rng.normal()).collect();
416        normalize_vec(&mut v);
417        let mut eigenvalue = 0.0;
418
419        for _ in 0..max_iters {
420            // w = Xv ∈ ℝᴺ
421            let w: Vec<f64> = data.iter().map(|row| dot(row, &v)).collect();
422
423            // u = Xᵀw ∈ ℝⁿ
424            let mut u = vec![0.0; dim];
425            for (row, &wi) in data.iter().zip(w.iter()) {
426                for (uj, &rj) in u.iter_mut().zip(row.iter()) {
427                    *uj += wi * rj;
428                }
429            }
430
431            // Deflate: remove components along previously found eigenvectors
432            for prev in &vectors {
433                let proj = dot(&u, prev);
434                for (uj, &pj) in u.iter_mut().zip(prev.iter()) {
435                    *uj -= proj * pj;
436                }
437            }
438
439            let mag = normalize_vec(&mut u);
440            if mag < f64::EPSILON {
441                break;
442            }
443
444            // The eigenvalue is vᵀ(XᵀX)v / N = mag / N (before normalization)
445            eigenvalue = mag / n;
446
447            // `.max(0.0)` clamps the FP noise that can briefly push
448            // `1 - |⟨u,v⟩|` slightly negative near convergence.
449            let change = (1.0 - dot(&u, &v).abs()).max(0.0);
450            v = u;
451
452            if change < tol {
453                break;
454            }
455        }
456
457        vectors.push(v);
458        values.push(eigenvalue);
459    }
460
461    // If some components had zero variance, fill with orthogonal random directions
462    while vectors.len() < k {
463        let mut v: Vec<f64> = (0..dim).map(|_| rng.normal()).collect();
464        for prev in &vectors {
465            let proj = dot(&v, prev);
466            for (vj, &pj) in v.iter_mut().zip(prev.iter()) {
467                *vj -= proj * pj;
468            }
469        }
470        normalize_vec(&mut v);
471        vectors.push(v);
472        values.push(0.0);
473    }
474
475    (vectors, values)
476}
477
478// --- Deterministic PRNG (SplitMix64 + Box-Muller) ---
479// pub(crate) for reuse by kernel_pca module.
480
481pub(crate) struct SplitMix64 {
482    state: u64,
483}
484
485impl SplitMix64 {
486    pub(crate) fn new(seed: u64) -> Self {
487        Self { state: seed }
488    }
489
490    pub(crate) fn next_u64(&mut self) -> u64 {
491        self.state = self.state.wrapping_add(0x9e3779b97f4a7c15);
492        let mut z = self.state;
493        z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
494        z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb);
495        z ^ (z >> 31)
496    }
497
498    pub(crate) fn next_f64(&mut self) -> f64 {
499        (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
500    }
501
502    pub(crate) fn normal(&mut self) -> f64 {
503        let u1 = self.next_f64().max(f64::MIN_POSITIVE);
504        let u2 = self.next_f64();
505        (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos()
506    }
507}
508
509#[cfg(test)]
510mod tests {
511    use super::*;
512    use sphereql_core::angular_distance;
513    use std::f64::consts::TAU;
514
515    fn emb(vals: &[f64]) -> Embedding {
516        Embedding::new(vals.to_vec())
517    }
518
519    fn corpus_10d() -> Vec<Embedding> {
520        vec![
521            emb(&[1.0, 0.0, 0.0, 0.1, 0.05, -0.02, 0.03, -0.01, 0.04, 0.02]),
522            emb(&[0.0, 1.0, 0.0, -0.05, 0.1, 0.03, -0.02, 0.01, -0.03, 0.04]),
523            emb(&[0.0, 0.0, 1.0, 0.02, -0.03, 0.1, 0.05, 0.02, -0.01, -0.04]),
524            emb(&[1.0, 1.0, 0.0, 0.05, 0.08, 0.01, 0.01, -0.02, 0.02, 0.03]),
525            emb(&[0.0, 1.0, 1.0, -0.02, 0.07, 0.07, 0.01, 0.02, -0.02, 0.01]),
526            emb(&[1.0, 0.0, 1.0, 0.06, 0.01, 0.05, -0.03, -0.01, 0.03, -0.02]),
527            emb(&[-1.0, 0.0, 0.0, -0.08, 0.02, 0.01, 0.02, 0.03, -0.02, 0.01]),
528            emb(&[0.0, -1.0, 0.0, 0.03, -0.09, -0.02, 0.01, -0.01, 0.02, -0.03]),
529        ]
530    }
531
532    fn assert_valid_spherical(sp: &SphericalPoint) {
533        assert!(sp.r >= 0.0, "r must be >= 0, got {}", sp.r);
534        assert!(
535            sp.theta >= 0.0 && sp.theta < TAU,
536            "theta must be in [0, 2π), got {}",
537            sp.theta
538        );
539        assert!(
540            sp.phi >= 0.0 && sp.phi <= PI,
541            "phi must be in [0, π], got {}",
542            sp.phi
543        );
544    }
545
546    // --- PCA tests ---
547
548    #[test]
549    fn pca_produces_valid_spherical_points() {
550        let corpus = corpus_10d();
551        let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
552        for e in &corpus {
553            assert_valid_spherical(&pca.project(e));
554        }
555    }
556
557    #[test]
558    fn pca_preserves_angular_ordering() {
559        let corpus = corpus_10d();
560        let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
561
562        // a and b are both +x-ish, c is -x: a should be closer to b than to c
563        let a = emb(&[1.0, 0.1, 0.0, 0.05, 0.02, -0.01, 0.01, 0.0, 0.02, 0.01]);
564        let b = emb(&[0.9, 0.2, 0.1, 0.04, 0.03, 0.0, 0.02, -0.01, 0.01, 0.02]);
565        let c = emb(&[-1.0, -0.1, 0.0, -0.04, 0.01, 0.02, 0.01, 0.02, -0.01, 0.01]);
566
567        let pa = pca.project(&a);
568        let pb = pca.project(&b);
569        let pc = pca.project(&c);
570
571        let d_ab = angular_distance(&pa, &pb);
572        let d_ac = angular_distance(&pa, &pc);
573
574        assert!(
575            d_ab < d_ac,
576            "similar items should be closer: d(a,b)={d_ab:.4} should be < d(a,c)={d_ac:.4}"
577        );
578    }
579
580    #[test]
581    fn pca_magnitude_radial() {
582        let corpus = corpus_10d();
583        let pca = PcaProjection::fit(&corpus, RadialStrategy::Magnitude).unwrap();
584
585        let short = emb(&[0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
586        let long = emb(&[10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
587
588        let ps = pca.project(&short);
589        let pl = pca.project(&long);
590
591        assert!(ps.r < pl.r, "longer vector should have larger radius");
592        assert!((ps.r - 0.1).abs() < 1e-10);
593        assert!((pl.r - 10.0).abs() < 1e-10);
594    }
595
596    #[test]
597    fn pca_transform_radial() {
598        let corpus = corpus_10d();
599        let pca = PcaProjection::fit(
600            &corpus,
601            RadialStrategy::MagnitudeTransform(Arc::new(|mag| mag.ln_1p())),
602        )
603        .unwrap();
604
605        let e = emb(&[3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
606        let sp = pca.project(&e);
607        assert!((sp.r - 5.0_f64.ln_1p()).abs() < 1e-10);
608    }
609
610    #[test]
611    fn pca_single_embedding() {
612        let corpus = vec![emb(&[1.0, 0.0, 0.0, 0.0, 0.0])];
613        let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
614        let sp = pca.project(&corpus[0]);
615        assert!((sp.r - 1.0).abs() < 1e-12);
616        assert_valid_spherical(&sp);
617    }
618
619    #[test]
620    fn pca_dimensionality() {
621        let corpus = corpus_10d();
622        let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
623        assert_eq!(pca.dimensionality(), 10);
624    }
625
626    #[test]
627    fn pca_empty_corpus_returns_err() {
628        assert!(matches!(
629            PcaProjection::fit(&[], RadialStrategy::Fixed(1.0)),
630            Err(ProjectionError::EmptyCorpus)
631        ));
632    }
633
634    #[test]
635    fn pca_too_few_dimensions_returns_err() {
636        assert!(matches!(
637            PcaProjection::fit(&[emb(&[1.0, 2.0])], RadialStrategy::Fixed(1.0)),
638            Err(ProjectionError::DimensionTooLow {
639                got: 2,
640                required: 3
641            })
642        ));
643    }
644
645    #[test]
646    #[should_panic(expected = "expected dimension 10")]
647    fn pca_dimension_mismatch_panics() {
648        let corpus = corpus_10d();
649        let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
650        let _ = pca.project(&emb(&[1.0, 2.0, 3.0]));
651    }
652
653    // --- Random projection tests ---
654
655    #[test]
656    fn random_produces_valid_spherical_points() {
657        let rp = RandomProjection::new(10, RadialStrategy::Fixed(1.0), 42);
658        for i in 0..20 {
659            let e = emb(&[i as f64 * 0.1 + 0.01; 10]);
660            assert_valid_spherical(&rp.project(&e));
661        }
662    }
663
664    #[test]
665    fn random_deterministic_with_same_seed() {
666        let rp1 = RandomProjection::new(10, RadialStrategy::Fixed(1.0), 42);
667        let rp2 = RandomProjection::new(10, RadialStrategy::Fixed(1.0), 42);
668        let e = emb(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
669        let sp1 = rp1.project(&e);
670        let sp2 = rp2.project(&e);
671        assert!((sp1.theta - sp2.theta).abs() < 1e-12);
672        assert!((sp1.phi - sp2.phi).abs() < 1e-12);
673    }
674
675    #[test]
676    fn random_different_seeds_differ() {
677        let rp1 = RandomProjection::new(10, RadialStrategy::Fixed(1.0), 42);
678        let rp2 = RandomProjection::new(10, RadialStrategy::Fixed(1.0), 999);
679        let e = emb(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
680        let d = angular_distance(&rp1.project(&e), &rp2.project(&e));
681        assert!(
682            d > 1e-6,
683            "different seeds should produce different projections"
684        );
685    }
686
687    #[test]
688    fn random_dimensionality() {
689        let rp = RandomProjection::new(768, RadialStrategy::Fixed(1.0), 0);
690        assert_eq!(rp.dimensionality(), 768);
691    }
692
693    #[test]
694    #[should_panic(expected = "embedding dimension must be >= 3")]
695    fn random_too_few_dimensions_panics() {
696        RandomProjection::new(2, RadialStrategy::Fixed(1.0), 0);
697    }
698
699    // --- Arc delegation ---
700
701    #[test]
702    fn arc_projection_delegates() {
703        let rp = Arc::new(RandomProjection::new_default(10));
704        let e = emb(&[1.0; 10]);
705        let sp = rp.project(&e);
706        assert!(sp.r > 0.0);
707        assert_eq!(rp.dimensionality(), 10);
708    }
709
710    // --- SplitMix64 sanity ---
711
712    #[test]
713    fn prng_produces_distinct_values() {
714        let mut rng = SplitMix64::new(42);
715        let vals: Vec<f64> = (0..100).map(|_| rng.next_f64()).collect();
716        for i in 0..vals.len() {
717            for j in (i + 1)..vals.len() {
718                assert_ne!(vals[i].to_bits(), vals[j].to_bits());
719            }
720        }
721    }
722
723    #[test]
724    fn prng_normal_distribution_reasonable() {
725        let mut rng = SplitMix64::new(12345);
726        let samples: Vec<f64> = (0..10_000).map(|_| rng.normal()).collect();
727
728        let mean = samples.iter().sum::<f64>() / samples.len() as f64;
729        let variance =
730            samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / samples.len() as f64;
731
732        assert!(mean.abs() < 0.05, "mean should be near 0, got {mean}");
733        assert!(
734            (variance - 1.0).abs() < 0.1,
735            "variance should be near 1, got {variance}"
736        );
737    }
738}