1use std::f64::consts::PI;
2use std::sync::Arc;
3
4use sphereql_core::{CartesianPoint, SphericalPoint, cartesian_to_spherical};
5
6use crate::types::{Embedding, ProjectedPoint, RadialStrategy};
7
8#[derive(Debug, Clone, PartialEq, thiserror::Error)]
15pub enum ProjectionError {
16 #[error("need at least one embedding to fit a projection")]
18 EmptyCorpus,
19
20 #[error("embedding dimension {got} is below the minimum {required} for this projection")]
23 DimensionTooLow { got: usize, required: usize },
24
25 #[error("embedding {index} has dimension {got}, expected {expected}")]
28 InconsistentDimension {
29 index: usize,
30 expected: usize,
31 got: usize,
32 },
33
34 #[error("need at least {required} embeddings, got {got}")]
37 TooFewEmbeddings { got: usize, required: usize },
38
39 #[error("kernel bandwidth σ must be positive, got {got}")]
41 InvalidSigma { got: f64 },
42}
43
44pub trait Projection: Send + Sync {
50 fn project(&self, embedding: &Embedding) -> SphericalPoint;
51
52 fn project_rich(&self, embedding: &Embedding) -> ProjectedPoint {
54 let position = self.project(embedding);
55 ProjectedPoint::from_position(position, embedding.magnitude())
56 }
57
58 fn dimensionality(&self) -> usize;
59}
60
61impl<P: Projection> Projection for Arc<P> {
62 fn project(&self, embedding: &Embedding) -> SphericalPoint {
63 (**self).project(embedding)
64 }
65 fn project_rich(&self, embedding: &Embedding) -> ProjectedPoint {
66 (**self).project_rich(embedding)
67 }
68 fn dimensionality(&self) -> usize {
69 (**self).dimensionality()
70 }
71}
72
73#[derive(Clone)]
83pub struct PcaProjection {
84 components: [Vec<f64>; 3],
85 mean: Vec<f64>,
86 dim: usize,
87 radial: RadialStrategy,
88 volumetric: bool,
89 eigenvalues: [f64; 3],
91 total_variance: f64,
94}
95
96impl PcaProjection {
97 pub fn fit(embeddings: &[Embedding], radial: RadialStrategy) -> Result<Self, ProjectionError> {
106 if embeddings.is_empty() {
107 return Err(ProjectionError::EmptyCorpus);
108 }
109 let dim = embeddings[0].dimension();
110 if dim < 3 {
111 return Err(ProjectionError::DimensionTooLow {
112 got: dim,
113 required: 3,
114 });
115 }
116 for (i, e) in embeddings.iter().enumerate() {
117 if e.dimension() != dim {
118 return Err(ProjectionError::InconsistentDimension {
119 index: i,
120 expected: dim,
121 got: e.dimension(),
122 });
123 }
124 }
125
126 let normalized: Vec<Vec<f64>> = embeddings.iter().map(|e| e.normalized()).collect();
127 let n = normalized.len();
128
129 let mut mean = vec![0.0; dim];
130 for v in &normalized {
131 for (i, &val) in v.iter().enumerate() {
132 mean[i] += val;
133 }
134 }
135 for m in &mut mean {
136 *m /= n as f64;
137 }
138
139 let centered: Vec<Vec<f64>> = normalized
140 .iter()
141 .map(|v| {
142 v.iter()
143 .zip(mean.iter())
144 .map(|(&val, &m)| val - m)
145 .collect()
146 })
147 .collect();
148
149 let (components, eigenvalues) = top_k_eigenvectors(¢ered, 3, dim);
150
151 let total_variance: f64 = centered
153 .iter()
154 .map(|row| row.iter().map(|x| x * x).sum::<f64>())
155 .sum::<f64>()
156 / centered.len() as f64;
157
158 Ok(Self {
159 components: [
160 components[0].clone(),
161 components[1].clone(),
162 components[2].clone(),
163 ],
164 mean,
165 dim,
166 radial,
167 volumetric: false,
168 eigenvalues: [
169 eigenvalues.first().copied().unwrap_or(0.0),
170 eigenvalues.get(1).copied().unwrap_or(0.0),
171 eigenvalues.get(2).copied().unwrap_or(0.0),
172 ],
173 total_variance,
174 })
175 }
176
177 pub fn fit_default(embeddings: &[Embedding]) -> Result<Self, ProjectionError> {
178 Self::fit(embeddings, RadialStrategy::default())
179 }
180
181 pub fn with_volumetric(mut self, enabled: bool) -> Self {
185 self.volumetric = enabled;
186 self
187 }
188
189 pub fn explained_variance_ratio(&self) -> f64 {
192 if self.total_variance < f64::EPSILON {
193 return 1.0;
194 }
195 let explained: f64 = self.eigenvalues.iter().sum();
196 (explained / self.total_variance).clamp(0.0, 1.0)
197 }
198
199 fn project_xyz_residual(&self, embedding: &Embedding) -> (f64, f64, f64, f64) {
211 let values = &embedding.values;
212 debug_assert_eq!(values.len(), self.dim);
213
214 let mag = embedding.magnitude();
215 let inv_mag = if mag < f64::EPSILON { 0.0 } else { 1.0 / mag };
216
217 let mut x = 0.0f64;
218 let mut y = 0.0f64;
219 let mut z = 0.0f64;
220 let mut total_sq = 0.0f64;
221 let c0 = &self.components[0];
222 let c1 = &self.components[1];
223 let c2 = &self.components[2];
224 for i in 0..self.dim {
225 let n = values[i] * inv_mag;
226 let c = n - self.mean[i];
227 x += c * c0[i];
228 y += c * c1[i];
229 z += c * c2[i];
230 total_sq += c * c;
231 }
232 let projected_sq = x * x + y * y + z * z;
233 let residual_sq = (total_sq - projected_sq).max(0.0);
234 (x, y, z, residual_sq)
235 }
236}
237
238impl Projection for PcaProjection {
239 fn project(&self, embedding: &Embedding) -> SphericalPoint {
240 assert_eq!(
241 embedding.dimension(),
242 self.dim,
243 "expected dimension {}, got {}",
244 self.dim,
245 embedding.dimension()
246 );
247
248 let (x, y, z, _) = self.project_xyz_residual(embedding);
249
250 if self.volumetric {
251 let sp = cartesian_to_spherical(&CartesianPoint::new(x, y, z));
252 if sp.r < f64::EPSILON {
253 return SphericalPoint::new_unchecked(0.0, 0.0, 0.0);
254 }
255 SphericalPoint::new_unchecked(sp.r, sp.theta, sp.phi)
256 } else {
257 let r = self.radial.compute(embedding.magnitude());
258 project_xyz_to_spherical(x, y, z, r)
259 }
260 }
261
262 fn project_rich(&self, embedding: &Embedding) -> ProjectedPoint {
263 assert_eq!(
264 embedding.dimension(),
265 self.dim,
266 "expected dimension {}, got {}",
267 self.dim,
268 embedding.dimension()
269 );
270
271 let intensity = embedding.magnitude();
272 let (x, y, z, residual_sq) = self.project_xyz_residual(embedding);
273 let projection_magnitude = (x * x + y * y + z * z).sqrt();
274
275 let inv_mag = if intensity < f64::EPSILON {
279 0.0
280 } else {
281 1.0 / intensity
282 };
283 let total_sq: f64 = (0..self.dim)
284 .map(|i| {
285 let c = embedding.values[i] * inv_mag - self.mean[i];
286 c * c
287 })
288 .sum();
289 let certainty = if total_sq < f64::EPSILON {
290 0.0
291 } else {
292 (1.0 - residual_sq / total_sq).clamp(0.0, 1.0)
293 };
294
295 let position = if self.volumetric {
296 let sp = cartesian_to_spherical(&CartesianPoint::new(x, y, z));
297 if sp.r < f64::EPSILON {
298 SphericalPoint::new_unchecked(0.0, 0.0, 0.0)
299 } else {
300 SphericalPoint::new_unchecked(sp.r, sp.theta, sp.phi)
301 }
302 } else {
303 let r = self.radial.compute(intensity);
304 project_xyz_to_spherical(x, y, z, r)
305 };
306
307 ProjectedPoint::new(position, certainty, intensity, projection_magnitude)
308 }
309
310 fn dimensionality(&self) -> usize {
311 self.dim
312 }
313}
314
315#[derive(Clone)]
324pub struct RandomProjection {
325 matrix: [Vec<f64>; 3],
326 dim: usize,
327 radial: RadialStrategy,
328}
329
330impl RandomProjection {
331 pub fn new(dim: usize, radial: RadialStrategy, seed: u64) -> Self {
332 assert!(dim >= 3, "embedding dimension must be >= 3");
333 let mut rng = SplitMix64::new(seed);
334 let matrix = std::array::from_fn(|_| (0..dim).map(|_| rng.normal()).collect());
335 Self {
336 matrix,
337 dim,
338 radial,
339 }
340 }
341
342 pub fn new_default(dim: usize) -> Self {
343 Self::new(dim, RadialStrategy::default(), 42)
344 }
345}
346
347impl Projection for RandomProjection {
348 fn project(&self, embedding: &Embedding) -> SphericalPoint {
349 assert_eq!(
350 embedding.dimension(),
351 self.dim,
352 "expected dimension {}, got {}",
353 self.dim,
354 embedding.dimension()
355 );
356
357 let magnitude = embedding.magnitude();
358 let r = self.radial.compute(magnitude);
359 let normalized = embedding.normalized();
360
361 let x = dot(&normalized, &self.matrix[0]);
362 let y = dot(&normalized, &self.matrix[1]);
363 let z = dot(&normalized, &self.matrix[2]);
364
365 project_xyz_to_spherical(x, y, z, r)
366 }
367
368 fn dimensionality(&self) -> usize {
369 self.dim
370 }
371}
372
373pub(crate) fn project_xyz_to_spherical(x: f64, y: f64, z: f64, r: f64) -> SphericalPoint {
376 let cart = CartesianPoint::new(x, y, z).normalize();
377 if cart.magnitude() < f64::EPSILON {
378 return SphericalPoint::new_unchecked(r, 0.0, 0.0);
379 }
380 let sp = cartesian_to_spherical(&cart);
381 SphericalPoint::new_unchecked(r, sp.theta, sp.phi)
382}
383
384pub(crate) fn dot(a: &[f64], b: &[f64]) -> f64 {
387 a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
388}
389
390pub(crate) fn normalize_vec(v: &mut [f64]) -> f64 {
391 let mag = v.iter().map(|x| x * x).sum::<f64>().sqrt();
392 if mag > f64::EPSILON {
393 for x in v.iter_mut() {
394 *x /= mag;
395 }
396 }
397 mag
398}
399
400fn top_k_eigenvectors(data: &[Vec<f64>], k: usize, dim: usize) -> (Vec<Vec<f64>>, Vec<f64>) {
407 let max_iters = 200;
408 let tol = 1e-10;
409 let mut vectors: Vec<Vec<f64>> = Vec::with_capacity(k);
410 let mut values: Vec<f64> = Vec::with_capacity(k);
411 let mut rng = SplitMix64::new(0xDEAD_BEEF);
412 let n = data.len() as f64;
413
414 for _ in 0..k {
415 let mut v: Vec<f64> = (0..dim).map(|_| rng.normal()).collect();
416 normalize_vec(&mut v);
417 let mut eigenvalue = 0.0;
418
419 for _ in 0..max_iters {
420 let w: Vec<f64> = data.iter().map(|row| dot(row, &v)).collect();
422
423 let mut u = vec![0.0; dim];
425 for (row, &wi) in data.iter().zip(w.iter()) {
426 for (uj, &rj) in u.iter_mut().zip(row.iter()) {
427 *uj += wi * rj;
428 }
429 }
430
431 for prev in &vectors {
433 let proj = dot(&u, prev);
434 for (uj, &pj) in u.iter_mut().zip(prev.iter()) {
435 *uj -= proj * pj;
436 }
437 }
438
439 let mag = normalize_vec(&mut u);
440 if mag < f64::EPSILON {
441 break;
442 }
443
444 eigenvalue = mag / n;
446
447 let change = (1.0 - dot(&u, &v).abs()).max(0.0);
450 v = u;
451
452 if change < tol {
453 break;
454 }
455 }
456
457 vectors.push(v);
458 values.push(eigenvalue);
459 }
460
461 while vectors.len() < k {
463 let mut v: Vec<f64> = (0..dim).map(|_| rng.normal()).collect();
464 for prev in &vectors {
465 let proj = dot(&v, prev);
466 for (vj, &pj) in v.iter_mut().zip(prev.iter()) {
467 *vj -= proj * pj;
468 }
469 }
470 normalize_vec(&mut v);
471 vectors.push(v);
472 values.push(0.0);
473 }
474
475 (vectors, values)
476}
477
478pub(crate) struct SplitMix64 {
482 state: u64,
483}
484
485impl SplitMix64 {
486 pub(crate) fn new(seed: u64) -> Self {
487 Self { state: seed }
488 }
489
490 pub(crate) fn next_u64(&mut self) -> u64 {
491 self.state = self.state.wrapping_add(0x9e3779b97f4a7c15);
492 let mut z = self.state;
493 z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
494 z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb);
495 z ^ (z >> 31)
496 }
497
498 pub(crate) fn next_f64(&mut self) -> f64 {
499 (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
500 }
501
502 pub(crate) fn normal(&mut self) -> f64 {
503 let u1 = self.next_f64().max(f64::MIN_POSITIVE);
504 let u2 = self.next_f64();
505 (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos()
506 }
507}
508
509#[cfg(test)]
510mod tests {
511 use super::*;
512 use sphereql_core::angular_distance;
513 use std::f64::consts::TAU;
514
515 fn emb(vals: &[f64]) -> Embedding {
516 Embedding::new(vals.to_vec())
517 }
518
519 fn corpus_10d() -> Vec<Embedding> {
520 vec![
521 emb(&[1.0, 0.0, 0.0, 0.1, 0.05, -0.02, 0.03, -0.01, 0.04, 0.02]),
522 emb(&[0.0, 1.0, 0.0, -0.05, 0.1, 0.03, -0.02, 0.01, -0.03, 0.04]),
523 emb(&[0.0, 0.0, 1.0, 0.02, -0.03, 0.1, 0.05, 0.02, -0.01, -0.04]),
524 emb(&[1.0, 1.0, 0.0, 0.05, 0.08, 0.01, 0.01, -0.02, 0.02, 0.03]),
525 emb(&[0.0, 1.0, 1.0, -0.02, 0.07, 0.07, 0.01, 0.02, -0.02, 0.01]),
526 emb(&[1.0, 0.0, 1.0, 0.06, 0.01, 0.05, -0.03, -0.01, 0.03, -0.02]),
527 emb(&[-1.0, 0.0, 0.0, -0.08, 0.02, 0.01, 0.02, 0.03, -0.02, 0.01]),
528 emb(&[0.0, -1.0, 0.0, 0.03, -0.09, -0.02, 0.01, -0.01, 0.02, -0.03]),
529 ]
530 }
531
532 fn assert_valid_spherical(sp: &SphericalPoint) {
533 assert!(sp.r >= 0.0, "r must be >= 0, got {}", sp.r);
534 assert!(
535 sp.theta >= 0.0 && sp.theta < TAU,
536 "theta must be in [0, 2π), got {}",
537 sp.theta
538 );
539 assert!(
540 sp.phi >= 0.0 && sp.phi <= PI,
541 "phi must be in [0, π], got {}",
542 sp.phi
543 );
544 }
545
546 #[test]
549 fn pca_produces_valid_spherical_points() {
550 let corpus = corpus_10d();
551 let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
552 for e in &corpus {
553 assert_valid_spherical(&pca.project(e));
554 }
555 }
556
557 #[test]
558 fn pca_preserves_angular_ordering() {
559 let corpus = corpus_10d();
560 let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
561
562 let a = emb(&[1.0, 0.1, 0.0, 0.05, 0.02, -0.01, 0.01, 0.0, 0.02, 0.01]);
564 let b = emb(&[0.9, 0.2, 0.1, 0.04, 0.03, 0.0, 0.02, -0.01, 0.01, 0.02]);
565 let c = emb(&[-1.0, -0.1, 0.0, -0.04, 0.01, 0.02, 0.01, 0.02, -0.01, 0.01]);
566
567 let pa = pca.project(&a);
568 let pb = pca.project(&b);
569 let pc = pca.project(&c);
570
571 let d_ab = angular_distance(&pa, &pb);
572 let d_ac = angular_distance(&pa, &pc);
573
574 assert!(
575 d_ab < d_ac,
576 "similar items should be closer: d(a,b)={d_ab:.4} should be < d(a,c)={d_ac:.4}"
577 );
578 }
579
580 #[test]
581 fn pca_magnitude_radial() {
582 let corpus = corpus_10d();
583 let pca = PcaProjection::fit(&corpus, RadialStrategy::Magnitude).unwrap();
584
585 let short = emb(&[0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
586 let long = emb(&[10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
587
588 let ps = pca.project(&short);
589 let pl = pca.project(&long);
590
591 assert!(ps.r < pl.r, "longer vector should have larger radius");
592 assert!((ps.r - 0.1).abs() < 1e-10);
593 assert!((pl.r - 10.0).abs() < 1e-10);
594 }
595
596 #[test]
597 fn pca_transform_radial() {
598 let corpus = corpus_10d();
599 let pca = PcaProjection::fit(
600 &corpus,
601 RadialStrategy::MagnitudeTransform(Arc::new(|mag| mag.ln_1p())),
602 )
603 .unwrap();
604
605 let e = emb(&[3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
606 let sp = pca.project(&e);
607 assert!((sp.r - 5.0_f64.ln_1p()).abs() < 1e-10);
608 }
609
610 #[test]
611 fn pca_single_embedding() {
612 let corpus = vec![emb(&[1.0, 0.0, 0.0, 0.0, 0.0])];
613 let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
614 let sp = pca.project(&corpus[0]);
615 assert!((sp.r - 1.0).abs() < 1e-12);
616 assert_valid_spherical(&sp);
617 }
618
619 #[test]
620 fn pca_dimensionality() {
621 let corpus = corpus_10d();
622 let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
623 assert_eq!(pca.dimensionality(), 10);
624 }
625
626 #[test]
627 fn pca_empty_corpus_returns_err() {
628 assert!(matches!(
629 PcaProjection::fit(&[], RadialStrategy::Fixed(1.0)),
630 Err(ProjectionError::EmptyCorpus)
631 ));
632 }
633
634 #[test]
635 fn pca_too_few_dimensions_returns_err() {
636 assert!(matches!(
637 PcaProjection::fit(&[emb(&[1.0, 2.0])], RadialStrategy::Fixed(1.0)),
638 Err(ProjectionError::DimensionTooLow {
639 got: 2,
640 required: 3
641 })
642 ));
643 }
644
645 #[test]
646 #[should_panic(expected = "expected dimension 10")]
647 fn pca_dimension_mismatch_panics() {
648 let corpus = corpus_10d();
649 let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
650 let _ = pca.project(&emb(&[1.0, 2.0, 3.0]));
651 }
652
653 #[test]
656 fn random_produces_valid_spherical_points() {
657 let rp = RandomProjection::new(10, RadialStrategy::Fixed(1.0), 42);
658 for i in 0..20 {
659 let e = emb(&[i as f64 * 0.1 + 0.01; 10]);
660 assert_valid_spherical(&rp.project(&e));
661 }
662 }
663
664 #[test]
665 fn random_deterministic_with_same_seed() {
666 let rp1 = RandomProjection::new(10, RadialStrategy::Fixed(1.0), 42);
667 let rp2 = RandomProjection::new(10, RadialStrategy::Fixed(1.0), 42);
668 let e = emb(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
669 let sp1 = rp1.project(&e);
670 let sp2 = rp2.project(&e);
671 assert!((sp1.theta - sp2.theta).abs() < 1e-12);
672 assert!((sp1.phi - sp2.phi).abs() < 1e-12);
673 }
674
675 #[test]
676 fn random_different_seeds_differ() {
677 let rp1 = RandomProjection::new(10, RadialStrategy::Fixed(1.0), 42);
678 let rp2 = RandomProjection::new(10, RadialStrategy::Fixed(1.0), 999);
679 let e = emb(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
680 let d = angular_distance(&rp1.project(&e), &rp2.project(&e));
681 assert!(
682 d > 1e-6,
683 "different seeds should produce different projections"
684 );
685 }
686
687 #[test]
688 fn random_dimensionality() {
689 let rp = RandomProjection::new(768, RadialStrategy::Fixed(1.0), 0);
690 assert_eq!(rp.dimensionality(), 768);
691 }
692
693 #[test]
694 #[should_panic(expected = "embedding dimension must be >= 3")]
695 fn random_too_few_dimensions_panics() {
696 RandomProjection::new(2, RadialStrategy::Fixed(1.0), 0);
697 }
698
699 #[test]
702 fn arc_projection_delegates() {
703 let rp = Arc::new(RandomProjection::new_default(10));
704 let e = emb(&[1.0; 10]);
705 let sp = rp.project(&e);
706 assert!(sp.r > 0.0);
707 assert_eq!(rp.dimensionality(), 10);
708 }
709
710 #[test]
713 fn prng_produces_distinct_values() {
714 let mut rng = SplitMix64::new(42);
715 let vals: Vec<f64> = (0..100).map(|_| rng.next_f64()).collect();
716 for i in 0..vals.len() {
717 for j in (i + 1)..vals.len() {
718 assert_ne!(vals[i].to_bits(), vals[j].to_bits());
719 }
720 }
721 }
722
723 #[test]
724 fn prng_normal_distribution_reasonable() {
725 let mut rng = SplitMix64::new(12345);
726 let samples: Vec<f64> = (0..10_000).map(|_| rng.normal()).collect();
727
728 let mean = samples.iter().sum::<f64>() / samples.len() as f64;
729 let variance =
730 samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / samples.len() as f64;
731
732 assert!(mean.abs() < 0.05, "mean should be near 0, got {mean}");
733 assert!(
734 (variance - 1.0).abs() < 0.1,
735 "variance should be near 1, got {variance}"
736 );
737 }
738}