Skip to main content

sphereql_embed/
projection.rs

1use std::f64::consts::PI;
2use std::sync::Arc;
3
4use sphereql_core::{CartesianPoint, SphericalPoint, cartesian_to_spherical};
5
6use crate::types::{Embedding, ProjectedPoint, RadialStrategy};
7
8/// Reasons a projection fit can fail.
9///
10/// Every concrete projection's `fit` used to panic via `assert!` on
11/// invalid input, which turned typos in Python / WASM bindings into
12/// `PanicException`s. These variants classify the same preconditions
13/// so callers can surface typed errors instead.
14#[derive(Debug, Clone, PartialEq, thiserror::Error)]
15pub enum ProjectionError {
16    /// The input slice was empty. Fitting needs at least one embedding.
17    #[error("need at least one embedding to fit a projection")]
18    EmptyCorpus,
19
20    /// Embedding dimensionality below the projection's requirement.
21    /// PCA and kernel PCA need `dim >= 3`; Laplacian requires `dim > 0`.
22    #[error("embedding dimension {got} is below the minimum {required} for this projection")]
23    DimensionTooLow { got: usize, required: usize },
24
25    /// Embeddings disagreed on dimensionality. Every row must match the
26    /// first one; the mismatch is reported with the offending index.
27    #[error("embedding {index} has dimension {got}, expected {expected}")]
28    InconsistentDimension {
29        index: usize,
30        expected: usize,
31        got: usize,
32    },
33
34    /// Projection needs more embeddings than were provided. Laplacian
35    /// eigenmap's graph construction requires `n >= 4`.
36    #[error("need at least {required} embeddings, got {got}")]
37    TooFewEmbeddings { got: usize, required: usize },
38
39    /// `fit_with_sigma` was given a non-positive Gaussian bandwidth.
40    #[error("kernel bandwidth σ must be positive, got {got}")]
41    InvalidSigma { got: f64 },
42
43    /// A parallel slice (e.g. category labels) did not have the same length
44    /// as the embedding slice.
45    #[error("auxiliary slice has length {got}, expected {expected}")]
46    SliceLengthMismatch { expected: usize, got: usize },
47}
48
49/// Maps high-dimensional embeddings to spherical coordinates.
50///
51/// The angular coordinates (theta, phi) encode semantic direction via
52/// dimensionality reduction from S^{n-1} to S^2. The radial coordinate
53/// is controlled by the projection's [`RadialStrategy`].
54pub trait Projection: Send + Sync {
55    fn project(&self, embedding: &Embedding) -> SphericalPoint;
56
57    /// Project with rich metadata: certainty, intensity, projection magnitude.
58    fn project_rich(&self, embedding: &Embedding) -> ProjectedPoint {
59        let position = self.project(embedding);
60        ProjectedPoint::from_position(position, embedding.magnitude())
61    }
62
63    fn dimensionality(&self) -> usize;
64}
65
66impl<P: Projection> Projection for Arc<P> {
67    fn project(&self, embedding: &Embedding) -> SphericalPoint {
68        (**self).project(embedding)
69    }
70    fn project_rich(&self, embedding: &Embedding) -> ProjectedPoint {
71        (**self).project_rich(embedding)
72    }
73    fn dimensionality(&self) -> usize {
74        (**self).dimensionality()
75    }
76}
77
78/// Corpus-fitted projection via spherical PCA.
79///
80/// Finds the 3 principal directions of maximum angular variance in the
81/// embedding space, then projects new embeddings onto them. This preserves
82/// angular (cosine similarity) relationships as faithfully as possible
83/// in 3 dimensions.
84///
85/// Fitting: O(N·n·k·iters) where N=corpus size, n=dimension, k=3.
86/// Projection: O(n) per embedding.
87#[derive(Clone)]
88pub struct PcaProjection {
89    components: [Vec<f64>; 3],
90    mean: Vec<f64>,
91    dim: usize,
92    radial: RadialStrategy,
93    volumetric: bool,
94    /// Top-3 eigenvalues from PCA (descending). Summed against
95    /// `total_variance` in [`Self::explained_variance_ratio`].
96    eigenvalues: [f64; 3],
97    /// Total variance across all dimensions. eigenvalues[0..3].sum() / total_variance
98    /// gives the global explained variance ratio.
99    total_variance: f64,
100}
101
102/// Minimum embedding dimensionality required by PCA fits.
103const PCA_MIN_DIM: usize = 3;
104
105impl PcaProjection {
106    /// Validate that the corpus is non-empty, every row shares the same
107    /// dimensionality, and that dimensionality is at least
108    /// [`PCA_MIN_DIM`]. Returns the shared dimension on success.
109    fn validate_embeddings(embeddings: &[Embedding]) -> Result<usize, ProjectionError> {
110        if embeddings.is_empty() {
111            return Err(ProjectionError::EmptyCorpus);
112        }
113        let dim = embeddings[0].dimension();
114        if dim < PCA_MIN_DIM {
115            return Err(ProjectionError::DimensionTooLow {
116                got: dim,
117                required: PCA_MIN_DIM,
118            });
119        }
120        for (i, e) in embeddings.iter().enumerate() {
121            if e.dimension() != dim {
122                return Err(ProjectionError::InconsistentDimension {
123                    index: i,
124                    expected: dim,
125                    got: e.dimension(),
126                });
127            }
128        }
129        Ok(dim)
130    }
131
132    /// Assemble a `PcaProjection` from the eigendecomposition outputs.
133    /// Padding shorter eigenvalue/component lists with zeros keeps the
134    /// fixed-arity arrays well-defined when [`top_k_eigenvectors`]
135    /// returns fewer than 3 components.
136    fn from_eigendecomp(
137        components: Vec<Vec<f64>>,
138        eigenvalues: Vec<f64>,
139        mean: Vec<f64>,
140        dim: usize,
141        radial: RadialStrategy,
142        total_variance: f64,
143    ) -> Self {
144        Self {
145            components: [
146                components[0].clone(),
147                components[1].clone(),
148                components[2].clone(),
149            ],
150            mean,
151            dim,
152            radial,
153            volumetric: false,
154            eigenvalues: [
155                eigenvalues.first().copied().unwrap_or(0.0),
156                eigenvalues.get(1).copied().unwrap_or(0.0),
157                eigenvalues.get(2).copied().unwrap_or(0.0),
158            ],
159            total_variance,
160        }
161    }
162
163    /// Fit the top-3 principal components on `embeddings`.
164    ///
165    /// Returns [`ProjectionError::EmptyCorpus`] if the slice is empty,
166    /// [`ProjectionError::DimensionTooLow`] if `dim < 3`, and
167    /// [`ProjectionError::InconsistentDimension`] if any row's
168    /// dimensionality disagrees with the first. Previously these paths
169    /// panicked via `assert!`, which surfaced as a `PanicException` in
170    /// Python / WASM bindings.
171    pub fn fit(embeddings: &[Embedding], radial: RadialStrategy) -> Result<Self, ProjectionError> {
172        let dim = Self::validate_embeddings(embeddings)?;
173
174        let normalized: Vec<Vec<f64>> = embeddings.iter().map(|e| e.normalized()).collect();
175        let n = normalized.len();
176
177        let mut mean = vec![0.0; dim];
178        for v in &normalized {
179            for (i, &val) in v.iter().enumerate() {
180                mean[i] += val;
181            }
182        }
183        for m in &mut mean {
184            *m /= n as f64;
185        }
186
187        let centered: Vec<Vec<f64>> = normalized
188            .iter()
189            .map(|v| {
190                v.iter()
191                    .zip(mean.iter())
192                    .map(|(&val, &m)| val - m)
193                    .collect()
194            })
195            .collect();
196
197        let (components, eigenvalues) = top_k_eigenvectors(&centered, 3, dim);
198
199        // Total variance = sum of all eigenvalues = trace of covariance = sum of squared norms
200        let total_variance: f64 = centered
201            .iter()
202            .map(|row| row.iter().map(|x| x * x).sum::<f64>())
203            .sum::<f64>()
204            / centered.len() as f64;
205
206        Ok(Self::from_eigendecomp(
207            components,
208            eigenvalues,
209            mean,
210            dim,
211            radial,
212            total_variance,
213        ))
214    }
215
216    pub fn fit_default(embeddings: &[Embedding]) -> Result<Self, ProjectionError> {
217        Self::fit(embeddings, RadialStrategy::default())
218    }
219
220    /// Fit the top-3 principal components with per-sample weights.
221    ///
222    /// Weighted PCA finds the top eigenvectors of the weighted
223    /// covariance matrix `Σ wᵢ (xᵢ − μ_w)(xᵢ − μ_w)ᵀ / Σ wᵢ`, where
224    /// `μ_w = Σ wᵢ xᵢ / Σ wᵢ`. With uniform weights this collapses to
225    /// the same answer as [`Self::fit`].
226    ///
227    /// The intended use is rebalancing covariance estimates over
228    /// imbalanced corpora. Setting `wᵢ = 1 / sqrt(|category(i)|)` gives
229    /// a category of size `m` total covariance mass `m · (1/√m) = √m`,
230    /// compressing category influence from linear to square-root in its
231    /// size. For *exactly* equal per-category mass use
232    /// `wᵢ = 1 / |category(i)|`; the square-root compromise keeps large
233    /// categories' internal variance structure from being washed out
234    /// entirely while still letting small categories register.
235    ///
236    /// Returns the same error variants as [`Self::fit`], plus
237    /// [`ProjectionError::SliceLengthMismatch`] when `weights.len() !=
238    /// embeddings.len()`. Negative weights are treated as zero.
239    pub fn fit_weighted(
240        embeddings: &[Embedding],
241        weights: &[f64],
242        radial: RadialStrategy,
243    ) -> Result<Self, ProjectionError> {
244        // SliceLengthMismatch is the only error specific to the weighted
245        // path; the rest is shared with `fit`.
246        if weights.len() != embeddings.len() {
247            return Err(ProjectionError::SliceLengthMismatch {
248                expected: embeddings.len(),
249                got: weights.len(),
250            });
251        }
252        let dim = Self::validate_embeddings(embeddings)?;
253
254        let clamped: Vec<f64> = weights.iter().map(|&w| w.max(0.0)).collect();
255        let w_sum: f64 = clamped.iter().sum();
256        if w_sum < f64::EPSILON {
257            // All weights zero or negative — fall back to unweighted fit
258            // rather than producing a degenerate covariance.
259            return Self::fit(embeddings, radial);
260        }
261
262        let normalized: Vec<Vec<f64>> = embeddings.iter().map(|e| e.normalized()).collect();
263
264        // Weighted mean: μ_w = Σ wᵢ xᵢ / Σ wᵢ.
265        let mut mean = vec![0.0; dim];
266        for (v, &w) in normalized.iter().zip(clamped.iter()) {
267            for (i, &val) in v.iter().enumerate() {
268                mean[i] += w * val;
269            }
270        }
271        for m in &mut mean {
272            *m /= w_sum;
273        }
274
275        // Each row scaled by sqrt(wᵢ) so that XᵀX equals the weighted
276        // covariance (times Σwᵢ). Eigenvectors are invariant under the
277        // overall scalar, so power iteration on these scaled rows yields
278        // the weighted principal components.
279        let scaled: Vec<Vec<f64>> = normalized
280            .iter()
281            .zip(clamped.iter())
282            .map(|(v, &w)| {
283                let s = w.sqrt();
284                v.iter()
285                    .zip(mean.iter())
286                    .map(|(&val, &m)| s * (val - m))
287                    .collect()
288            })
289            .collect();
290
291        let (components, eigenvalues) = top_k_eigenvectors(&scaled, 3, dim);
292
293        // total_variance uses the same /N normalization as the
294        // eigenvalues coming back from top_k_eigenvectors, so the EVR
295        // ratio is well-defined regardless of weight scale.
296        let total_variance: f64 = scaled
297            .iter()
298            .map(|row| row.iter().map(|x| x * x).sum::<f64>())
299            .sum::<f64>()
300            / scaled.len() as f64;
301
302        Ok(Self::from_eigendecomp(
303            components,
304            eigenvalues,
305            mean,
306            dim,
307            radial,
308            total_variance,
309        ))
310    }
311
312    /// Enable volumetric mode: r comes from the PCA projection magnitude
313    /// instead of the embedding magnitude. Points distribute through the
314    /// full 3D volume rather than clustering on the sphere surface.
315    pub fn with_volumetric(mut self, enabled: bool) -> Self {
316        self.volumetric = enabled;
317        self
318    }
319
320    /// The fraction of total variance captured by the top-3 PCA components.
321    /// A global quality metric for the projection — higher means less information lost.
322    pub fn explained_variance_ratio(&self) -> f64 {
323        if self.total_variance < f64::EPSILON {
324            return 1.0;
325        }
326        let explained: f64 = self.eigenvalues.iter().sum();
327        (explained / self.total_variance).clamp(0.0, 1.0)
328    }
329
330    /// Allocation-free projection kernel: folds
331    /// `normalize(embedding) − mean` into the per-axis dot product
332    /// without materializing the intermediate `Vec<f64>`s that the
333    /// previous implementation allocated per call.
334    ///
335    /// Matches the numerics of `project_centered(&centered)` exactly:
336    /// each axis sums `(v_i/|v| − mean_i) · component_j[i]` over i,
337    /// plus a total-squared accumulator for the residual.
338    ///
339    /// Called by [`Self::project`] and [`Self::project_rich`]; callers
340    /// that want `SphericalPoint` or `ProjectedPoint` should use those.
341    fn project_xyz_residual(&self, embedding: &Embedding) -> (f64, f64, f64, f64) {
342        let values = &embedding.values;
343        debug_assert_eq!(values.len(), self.dim);
344
345        let mag = embedding.magnitude();
346        let inv_mag = if mag < f64::EPSILON { 0.0 } else { 1.0 / mag };
347
348        let mut x = 0.0f64;
349        let mut y = 0.0f64;
350        let mut z = 0.0f64;
351        let mut total_sq = 0.0f64;
352        let c0 = &self.components[0];
353        let c1 = &self.components[1];
354        let c2 = &self.components[2];
355        for i in 0..self.dim {
356            let n = values[i] * inv_mag;
357            let c = n - self.mean[i];
358            x += c * c0[i];
359            y += c * c1[i];
360            z += c * c2[i];
361            total_sq += c * c;
362        }
363        let projected_sq = x * x + y * y + z * z;
364        let residual_sq = (total_sq - projected_sq).max(0.0);
365        (x, y, z, residual_sq)
366    }
367}
368
369impl Projection for PcaProjection {
370    fn project(&self, embedding: &Embedding) -> SphericalPoint {
371        // Caller contract: embedding must have the same dimensionality as the
372        // fitted projection. Violated only by mixing projections with corpora —
373        // a programming error, not a runtime condition.
374        assert_eq!(
375            embedding.dimension(),
376            self.dim,
377            "expected dimension {}, got {}",
378            self.dim,
379            embedding.dimension()
380        );
381
382        let (x, y, z, residual_sq) = self.project_xyz_residual(embedding);
383
384        if self.volumetric {
385            let sp = cartesian_to_spherical(&CartesianPoint::new(x, y, z));
386            if sp.r < f64::EPSILON {
387                return SphericalPoint::new_unchecked(0.0, 0.0, 0.0);
388            }
389            SphericalPoint::new_unchecked(sp.r, sp.theta, sp.phi)
390        } else {
391            let projection_magnitude = (x * x + y * y + z * z).sqrt();
392            let intensity = embedding.magnitude();
393            let certainty = pca_certainty(embedding, &self.mean, intensity, residual_sq);
394            let r = self.radial.compute_rich(&crate::types::RadialContext::full(
395                intensity,
396                projection_magnitude,
397                certainty,
398            ));
399            project_xyz_to_spherical(x, y, z, r)
400        }
401    }
402
403    fn project_rich(&self, embedding: &Embedding) -> ProjectedPoint {
404        // Same caller contract as `project`: dimension must match the fitted projection.
405        assert_eq!(
406            embedding.dimension(),
407            self.dim,
408            "expected dimension {}, got {}",
409            self.dim,
410            embedding.dimension()
411        );
412
413        let intensity = embedding.magnitude();
414        let (x, y, z, residual_sq) = self.project_xyz_residual(embedding);
415        let projection_magnitude = (x * x + y * y + z * z).sqrt();
416        let certainty = pca_certainty(embedding, &self.mean, intensity, residual_sq);
417
418        let position = if self.volumetric {
419            let sp = cartesian_to_spherical(&CartesianPoint::new(x, y, z));
420            if sp.r < f64::EPSILON {
421                SphericalPoint::new_unchecked(0.0, 0.0, 0.0)
422            } else {
423                SphericalPoint::new_unchecked(sp.r, sp.theta, sp.phi)
424            }
425        } else {
426            let r = self.radial.compute_rich(&crate::types::RadialContext::full(
427                intensity,
428                projection_magnitude,
429                certainty,
430            ));
431            project_xyz_to_spherical(x, y, z, r)
432        };
433
434        ProjectedPoint::new(position, certainty, intensity, projection_magnitude)
435    }
436
437    fn dimensionality(&self) -> usize {
438        self.dim
439    }
440}
441
442/// Fit-free projection via random matrix (Johnson-Lindenstrauss).
443///
444/// Generates a fixed 3×n random matrix at construction time. Preserves
445/// pairwise distances probabilistically without needing a training corpus.
446/// Less accurate than PCA for any specific dataset, but useful when
447/// no corpus is available or for quick prototyping.
448///
449/// Deterministic for a given seed.
450#[derive(Clone)]
451pub struct RandomProjection {
452    matrix: [Vec<f64>; 3],
453    dim: usize,
454    radial: RadialStrategy,
455}
456
457impl RandomProjection {
458    pub fn new(dim: usize, radial: RadialStrategy, seed: u64) -> Self {
459        // Caller contract: random projection needs ≥3 dimensions to produce
460        // a non-degenerate 3×n matrix (all three rows would be identical
461        // zero-padded otherwise). This parallels PcaProjection::fit's check.
462        assert!(dim >= 3, "embedding dimension must be >= 3");
463        let mut rng = SplitMix64::new(seed);
464        let matrix = std::array::from_fn(|_| (0..dim).map(|_| rng.normal()).collect());
465        Self {
466            matrix,
467            dim,
468            radial,
469        }
470    }
471
472    pub fn new_default(dim: usize) -> Self {
473        Self::new(dim, RadialStrategy::default(), 42)
474    }
475}
476
477impl Projection for RandomProjection {
478    fn project(&self, embedding: &Embedding) -> SphericalPoint {
479        assert_eq!(
480            embedding.dimension(),
481            self.dim,
482            "expected dimension {}, got {}",
483            self.dim,
484            embedding.dimension()
485        );
486
487        let magnitude = embedding.magnitude();
488        let normalized = embedding.normalized();
489
490        let x = dot(&normalized, &self.matrix[0]);
491        let y = dot(&normalized, &self.matrix[1]);
492        let z = dot(&normalized, &self.matrix[2]);
493
494        // Random projection has no fidelity signal; report certainty = 1.0
495        // so `Certainty { scale }` reduces to a constant for callers who
496        // pick this strategy here (a meaningful warning, not a silent zero).
497        let projection_magnitude = (x * x + y * y + z * z).sqrt();
498        let r = self.radial.compute_rich(&crate::types::RadialContext::full(
499            magnitude,
500            projection_magnitude,
501            1.0,
502        ));
503
504        project_xyz_to_spherical(x, y, z, r)
505    }
506
507    fn dimensionality(&self) -> usize {
508        self.dim
509    }
510}
511
512// --- Shared projection math (pub(crate) for reuse by kernel_pca) ---
513
514/// Per-point variance-captured ratio for PCA: `1 − residual_sq / total_sq`.
515/// Returns 0.0 for inputs whose centered norm is below `f64::EPSILON`
516/// (otherwise we'd divide by zero) and clamps to `[0, 1]`.
517fn pca_certainty(embedding: &Embedding, mean: &[f64], intensity: f64, residual_sq: f64) -> f64 {
518    // Mirror Embedding::normalized()'s fallback: when the input has no
519    // magnitude, the rest of the pipeline treats it as [1, 0, 0, ...].
520    // Computing total_sq off the same vector keeps certainty consistent
521    // with the projection coordinates the caller will actually see.
522    let zero_intensity = intensity < f64::EPSILON;
523    let inv_mag = if zero_intensity { 0.0 } else { 1.0 / intensity };
524    let total_sq: f64 = (0..mean.len())
525        .map(|i| {
526            let normalized_i = if zero_intensity {
527                if i == 0 { 1.0 } else { 0.0 }
528            } else {
529                embedding.values[i] * inv_mag
530            };
531            let c = normalized_i - mean[i];
532            c * c
533        })
534        .sum();
535    if total_sq < f64::EPSILON {
536        0.0
537    } else {
538        (1.0 - residual_sq / total_sq).clamp(0.0, 1.0)
539    }
540}
541
542pub(crate) fn project_xyz_to_spherical(x: f64, y: f64, z: f64, r: f64) -> SphericalPoint {
543    let cart = CartesianPoint::new(x, y, z).normalize();
544    if cart.magnitude() < f64::EPSILON {
545        return SphericalPoint::new_unchecked(r, 0.0, 0.0);
546    }
547    let sp = cartesian_to_spherical(&cart);
548    SphericalPoint::new_unchecked(r, sp.theta, sp.phi)
549}
550
551// --- Linear algebra primitives (pub(crate) for reuse by kernel_pca) ---
552
553pub(crate) fn dot(a: &[f64], b: &[f64]) -> f64 {
554    a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
555}
556
557pub(crate) fn normalize_vec(v: &mut [f64]) -> f64 {
558    let mag = v.iter().map(|x| x * x).sum::<f64>().sqrt();
559    if mag > f64::EPSILON {
560        for x in v.iter_mut() {
561            *x /= mag;
562        }
563    }
564    mag
565}
566
567/// Power iteration with deflation for the top-k eigenvectors of XᵀX.
568///
569/// Computes XᵀX·v as Xᵀ(Xv) to avoid forming the n×n matrix,
570/// keeping each iteration at O(N·n) instead of O(n²).
571///
572/// Returns (eigenvectors, eigenvalues) both sorted by decreasing eigenvalue.
573fn top_k_eigenvectors(data: &[Vec<f64>], k: usize, dim: usize) -> (Vec<Vec<f64>>, Vec<f64>) {
574    let max_iters = 200;
575    let tol = 1e-10;
576    let mut vectors: Vec<Vec<f64>> = Vec::with_capacity(k);
577    let mut values: Vec<f64> = Vec::with_capacity(k);
578    let mut rng = SplitMix64::new(0xDEAD_BEEF);
579    let n = data.len() as f64;
580
581    for _ in 0..k {
582        let mut v: Vec<f64> = (0..dim).map(|_| rng.normal()).collect();
583        normalize_vec(&mut v);
584        let mut eigenvalue = 0.0;
585
586        for _ in 0..max_iters {
587            // w = Xv ∈ ℝᴺ
588            let w: Vec<f64> = data.iter().map(|row| dot(row, &v)).collect();
589
590            // u = Xᵀw ∈ ℝⁿ
591            let mut u = vec![0.0; dim];
592            for (row, &wi) in data.iter().zip(w.iter()) {
593                for (uj, &rj) in u.iter_mut().zip(row.iter()) {
594                    *uj += wi * rj;
595                }
596            }
597
598            // Deflate: remove components along previously found eigenvectors
599            for prev in &vectors {
600                let proj = dot(&u, prev);
601                for (uj, &pj) in u.iter_mut().zip(prev.iter()) {
602                    *uj -= proj * pj;
603                }
604            }
605
606            let mag = normalize_vec(&mut u);
607            if mag < f64::EPSILON {
608                break;
609            }
610
611            // The eigenvalue is vᵀ(XᵀX)v / N = mag / N (before normalization)
612            eigenvalue = mag / n;
613
614            // `.max(0.0)` clamps the FP noise that can briefly push
615            // `1 - |⟨u,v⟩|` slightly negative near convergence.
616            let change = (1.0 - dot(&u, &v).abs()).max(0.0);
617            v = u;
618
619            if change < tol {
620                break;
621            }
622        }
623
624        vectors.push(v);
625        values.push(eigenvalue);
626    }
627
628    // If some components had zero variance, fill with orthogonal random directions
629    while vectors.len() < k {
630        let mut v: Vec<f64> = (0..dim).map(|_| rng.normal()).collect();
631        for prev in &vectors {
632            let proj = dot(&v, prev);
633            for (vj, &pj) in v.iter_mut().zip(prev.iter()) {
634                *vj -= proj * pj;
635            }
636        }
637        normalize_vec(&mut v);
638        vectors.push(v);
639        values.push(0.0);
640    }
641
642    (vectors, values)
643}
644
645// --- Deterministic PRNG (SplitMix64 + Box-Muller) ---
646// pub(crate) for reuse by kernel_pca module.
647
648pub(crate) struct SplitMix64 {
649    state: u64,
650}
651
652impl SplitMix64 {
653    pub(crate) fn new(seed: u64) -> Self {
654        Self { state: seed }
655    }
656
657    pub(crate) fn next_u64(&mut self) -> u64 {
658        self.state = self.state.wrapping_add(0x9e3779b97f4a7c15);
659        let mut z = self.state;
660        z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
661        z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb);
662        z ^ (z >> 31)
663    }
664
665    pub(crate) fn next_f64(&mut self) -> f64 {
666        (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
667    }
668
669    pub(crate) fn normal(&mut self) -> f64 {
670        let u1 = self.next_f64().max(f64::MIN_POSITIVE);
671        let u2 = self.next_f64();
672        (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos()
673    }
674}
675
676#[cfg(test)]
677mod tests {
678    use super::*;
679    use sphereql_core::angular_distance;
680    use std::f64::consts::TAU;
681
682    fn emb(vals: &[f64]) -> Embedding {
683        Embedding::new(vals.to_vec())
684    }
685
686    fn corpus_10d() -> Vec<Embedding> {
687        vec![
688            emb(&[1.0, 0.0, 0.0, 0.1, 0.05, -0.02, 0.03, -0.01, 0.04, 0.02]),
689            emb(&[0.0, 1.0, 0.0, -0.05, 0.1, 0.03, -0.02, 0.01, -0.03, 0.04]),
690            emb(&[0.0, 0.0, 1.0, 0.02, -0.03, 0.1, 0.05, 0.02, -0.01, -0.04]),
691            emb(&[1.0, 1.0, 0.0, 0.05, 0.08, 0.01, 0.01, -0.02, 0.02, 0.03]),
692            emb(&[0.0, 1.0, 1.0, -0.02, 0.07, 0.07, 0.01, 0.02, -0.02, 0.01]),
693            emb(&[1.0, 0.0, 1.0, 0.06, 0.01, 0.05, -0.03, -0.01, 0.03, -0.02]),
694            emb(&[-1.0, 0.0, 0.0, -0.08, 0.02, 0.01, 0.02, 0.03, -0.02, 0.01]),
695            emb(&[0.0, -1.0, 0.0, 0.03, -0.09, -0.02, 0.01, -0.01, 0.02, -0.03]),
696        ]
697    }
698
699    fn assert_valid_spherical(sp: &SphericalPoint) {
700        assert!(sp.r >= 0.0, "r must be >= 0, got {}", sp.r);
701        assert!(
702            sp.theta >= 0.0 && sp.theta < TAU,
703            "theta must be in [0, 2π), got {}",
704            sp.theta
705        );
706        assert!(
707            sp.phi >= 0.0 && sp.phi <= PI,
708            "phi must be in [0, π], got {}",
709            sp.phi
710        );
711    }
712
713    // --- PCA tests ---
714
715    #[test]
716    fn pca_produces_valid_spherical_points() {
717        let corpus = corpus_10d();
718        let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
719        for e in &corpus {
720            assert_valid_spherical(&pca.project(e));
721        }
722    }
723
724    #[test]
725    fn pca_preserves_angular_ordering() {
726        let corpus = corpus_10d();
727        let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
728
729        // a and b are both +x-ish, c is -x: a should be closer to b than to c
730        let a = emb(&[1.0, 0.1, 0.0, 0.05, 0.02, -0.01, 0.01, 0.0, 0.02, 0.01]);
731        let b = emb(&[0.9, 0.2, 0.1, 0.04, 0.03, 0.0, 0.02, -0.01, 0.01, 0.02]);
732        let c = emb(&[-1.0, -0.1, 0.0, -0.04, 0.01, 0.02, 0.01, 0.02, -0.01, 0.01]);
733
734        let pa = pca.project(&a);
735        let pb = pca.project(&b);
736        let pc = pca.project(&c);
737
738        let d_ab = angular_distance(&pa, &pb);
739        let d_ac = angular_distance(&pa, &pc);
740
741        assert!(
742            d_ab < d_ac,
743            "similar items should be closer: d(a,b)={d_ab:.4} should be < d(a,c)={d_ac:.4}"
744        );
745    }
746
747    #[test]
748    fn pca_magnitude_radial() {
749        let corpus = corpus_10d();
750        let pca = PcaProjection::fit(&corpus, RadialStrategy::Magnitude).unwrap();
751
752        let short = emb(&[0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
753        let long = emb(&[10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
754
755        let ps = pca.project(&short);
756        let pl = pca.project(&long);
757
758        assert!(ps.r < pl.r, "longer vector should have larger radius");
759        assert!((ps.r - 0.1).abs() < 1e-10);
760        assert!((pl.r - 10.0).abs() < 1e-10);
761    }
762
763    #[test]
764    fn pca_transform_radial() {
765        let corpus = corpus_10d();
766        let pca = PcaProjection::fit(
767            &corpus,
768            RadialStrategy::MagnitudeTransform(Arc::new(|mag| mag.ln_1p())),
769        )
770        .unwrap();
771
772        let e = emb(&[3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
773        let sp = pca.project(&e);
774        assert!((sp.r - 5.0_f64.ln_1p()).abs() < 1e-10);
775    }
776
777    #[test]
778    fn pca_single_embedding() {
779        let corpus = vec![emb(&[1.0, 0.0, 0.0, 0.0, 0.0])];
780        let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
781        let sp = pca.project(&corpus[0]);
782        assert!((sp.r - 1.0).abs() < 1e-12);
783        assert_valid_spherical(&sp);
784    }
785
786    #[test]
787    fn pca_dimensionality() {
788        let corpus = corpus_10d();
789        let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
790        assert_eq!(pca.dimensionality(), 10);
791    }
792
793    #[test]
794    fn pca_empty_corpus_returns_err() {
795        assert!(matches!(
796            PcaProjection::fit(&[], RadialStrategy::Fixed(1.0)),
797            Err(ProjectionError::EmptyCorpus)
798        ));
799    }
800
801    #[test]
802    fn pca_too_few_dimensions_returns_err() {
803        assert!(matches!(
804            PcaProjection::fit(&[emb(&[1.0, 2.0])], RadialStrategy::Fixed(1.0)),
805            Err(ProjectionError::DimensionTooLow {
806                got: 2,
807                required: 3
808            })
809        ));
810    }
811
812    #[test]
813    #[should_panic(expected = "expected dimension 10")]
814    fn pca_dimension_mismatch_panics() {
815        let corpus = corpus_10d();
816        let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
817        let _ = pca.project(&emb(&[1.0, 2.0, 3.0]));
818    }
819
820    // --- Weighted PCA tests ---
821
822    #[test]
823    fn fit_weighted_uniform_weights_matches_naive_fit() {
824        let corpus = corpus_10d();
825        let uniform: Vec<f64> = vec![1.0; corpus.len()];
826
827        let plain = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
828        let weighted =
829            PcaProjection::fit_weighted(&corpus, &uniform, RadialStrategy::Fixed(1.0)).unwrap();
830
831        // Power iteration converges to ±eigenvector with the same sign
832        // structure for identical input, so we compare projection
833        // outputs (sign-invariant via angular distance) rather than the
834        // raw components.
835        for e in &corpus {
836            let a = plain.project(e);
837            let b = weighted.project(e);
838            assert!(
839                angular_distance(&a, &b) < 1e-9,
840                "uniform-weight fit should match naive fit"
841            );
842        }
843        assert!(
844            (plain.explained_variance_ratio() - weighted.explained_variance_ratio()).abs() < 1e-9
845        );
846    }
847
848    #[test]
849    fn fit_weighted_balances_imbalanced_corpus() {
850        // 20 copies of an x-axis pattern + 1 sample on the y axis. With
851        // unit weights the y sample is washed out; with weight
852        // 1/sqrt(count) per category, the singleton y is amplified
853        // enough that the second component picks up its direction.
854        let mut corpus: Vec<Embedding> = Vec::new();
855        let mut weights: Vec<f64> = Vec::new();
856        for i in 0..20 {
857            let mut v = vec![0.0; 8];
858            v[0] = 1.0 + (i as f64) * 0.001;
859            v[1] = 0.01;
860            corpus.push(emb(&v));
861            weights.push(1.0 / (20f64).sqrt());
862        }
863        // Singleton: y-axis
864        corpus.push(emb(&[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]));
865        weights.push(1.0);
866
867        let weighted =
868            PcaProjection::fit_weighted(&corpus, &weights, RadialStrategy::Fixed(1.0)).unwrap();
869        // EVR should be meaningful (well above zero); the singleton's
870        // direction is preserved in the principal subspace.
871        assert!(
872            weighted.explained_variance_ratio() > 0.5,
873            "weighted EVR should be > 0.5, got {}",
874            weighted.explained_variance_ratio()
875        );
876    }
877
878    #[test]
879    fn fit_weighted_rejects_length_mismatch() {
880        let corpus = corpus_10d();
881        let bad_weights = vec![1.0; corpus.len() - 1];
882        let result = PcaProjection::fit_weighted(&corpus, &bad_weights, RadialStrategy::Fixed(1.0));
883        assert!(matches!(
884            result,
885            Err(ProjectionError::SliceLengthMismatch { .. })
886        ));
887    }
888
889    #[test]
890    fn fit_weighted_zero_weights_falls_back_to_unweighted() {
891        let corpus = corpus_10d();
892        let zeros = vec![0.0; corpus.len()];
893        let weighted =
894            PcaProjection::fit_weighted(&corpus, &zeros, RadialStrategy::Fixed(1.0)).unwrap();
895        let plain = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
896        for e in &corpus {
897            let a = plain.project(e);
898            let b = weighted.project(e);
899            assert!(angular_distance(&a, &b) < 1e-9);
900        }
901    }
902
903    // --- Random projection tests ---
904
905    #[test]
906    fn random_produces_valid_spherical_points() {
907        let rp = RandomProjection::new(10, RadialStrategy::Fixed(1.0), 42);
908        for i in 0..20 {
909            let e = emb(&[i as f64 * 0.1 + 0.01; 10]);
910            assert_valid_spherical(&rp.project(&e));
911        }
912    }
913
914    #[test]
915    fn random_deterministic_with_same_seed() {
916        let rp1 = RandomProjection::new(10, RadialStrategy::Fixed(1.0), 42);
917        let rp2 = RandomProjection::new(10, RadialStrategy::Fixed(1.0), 42);
918        let e = emb(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
919        let sp1 = rp1.project(&e);
920        let sp2 = rp2.project(&e);
921        assert!((sp1.theta - sp2.theta).abs() < 1e-12);
922        assert!((sp1.phi - sp2.phi).abs() < 1e-12);
923    }
924
925    #[test]
926    fn random_different_seeds_differ() {
927        let rp1 = RandomProjection::new(10, RadialStrategy::Fixed(1.0), 42);
928        let rp2 = RandomProjection::new(10, RadialStrategy::Fixed(1.0), 999);
929        let e = emb(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
930        let d = angular_distance(&rp1.project(&e), &rp2.project(&e));
931        assert!(
932            d > 1e-6,
933            "different seeds should produce different projections"
934        );
935    }
936
937    #[test]
938    fn random_dimensionality() {
939        let rp = RandomProjection::new(768, RadialStrategy::Fixed(1.0), 0);
940        assert_eq!(rp.dimensionality(), 768);
941    }
942
943    #[test]
944    #[should_panic(expected = "embedding dimension must be >= 3")]
945    fn random_too_few_dimensions_panics() {
946        RandomProjection::new(2, RadialStrategy::Fixed(1.0), 0);
947    }
948
949    // --- Arc delegation ---
950
951    #[test]
952    fn arc_projection_delegates() {
953        let rp = Arc::new(RandomProjection::new_default(10));
954        let e = emb(&[1.0; 10]);
955        let sp = rp.project(&e);
956        assert!(sp.r > 0.0);
957        assert_eq!(rp.dimensionality(), 10);
958    }
959
960    // --- SplitMix64 sanity ---
961
962    #[test]
963    fn prng_produces_distinct_values() {
964        let mut rng = SplitMix64::new(42);
965        let vals: Vec<f64> = (0..100).map(|_| rng.next_f64()).collect();
966        for i in 0..vals.len() {
967            for j in (i + 1)..vals.len() {
968                assert_ne!(vals[i].to_bits(), vals[j].to_bits());
969            }
970        }
971    }
972
973    #[test]
974    fn prng_normal_distribution_reasonable() {
975        let mut rng = SplitMix64::new(12345);
976        let samples: Vec<f64> = (0..10_000).map(|_| rng.normal()).collect();
977
978        let mean = samples.iter().sum::<f64>() / samples.len() as f64;
979        let variance =
980            samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / samples.len() as f64;
981
982        assert!(mean.abs() < 0.05, "mean should be near 0, got {mean}");
983        assert!(
984            (variance - 1.0).abs() < 0.1,
985            "variance should be near 1, got {variance}"
986        );
987    }
988}