1use std::f64::consts::PI;
2use std::sync::Arc;
3
4use sphereql_core::{CartesianPoint, SphericalPoint, cartesian_to_spherical};
5
6use crate::types::{Embedding, ProjectedPoint, RadialStrategy};
7
8pub trait Projection: Send + Sync {
14 fn project(&self, embedding: &Embedding) -> SphericalPoint;
15
16 fn project_rich(&self, embedding: &Embedding) -> ProjectedPoint {
18 let position = self.project(embedding);
19 ProjectedPoint::from_position(position, embedding.magnitude())
20 }
21
22 fn dimensionality(&self) -> usize;
23}
24
25impl<P: Projection> Projection for Arc<P> {
26 fn project(&self, embedding: &Embedding) -> SphericalPoint {
27 (**self).project(embedding)
28 }
29 fn project_rich(&self, embedding: &Embedding) -> ProjectedPoint {
30 (**self).project_rich(embedding)
31 }
32 fn dimensionality(&self) -> usize {
33 (**self).dimensionality()
34 }
35}
36
37#[derive(Clone)]
47pub struct PcaProjection {
48 components: [Vec<f64>; 3],
49 mean: Vec<f64>,
50 dim: usize,
51 radial: RadialStrategy,
52 volumetric: bool,
53 eigenvalues: [f64; 3],
55 total_variance: f64,
58}
59
60impl PcaProjection {
61 pub fn fit(embeddings: &[Embedding], radial: RadialStrategy) -> Self {
62 assert!(
63 !embeddings.is_empty(),
64 "need at least one embedding to fit PCA"
65 );
66 let dim = embeddings[0].dimension();
67 assert!(dim >= 3, "embedding dimension must be >= 3");
68 for (i, e) in embeddings.iter().enumerate() {
69 assert_eq!(
70 e.dimension(),
71 dim,
72 "embedding {i} has dimension {}, expected {dim}",
73 e.dimension()
74 );
75 }
76
77 let normalized: Vec<Vec<f64>> = embeddings.iter().map(|e| e.normalized()).collect();
78 let n = normalized.len();
79
80 let mut mean = vec![0.0; dim];
81 for v in &normalized {
82 for (i, &val) in v.iter().enumerate() {
83 mean[i] += val;
84 }
85 }
86 for m in &mut mean {
87 *m /= n as f64;
88 }
89
90 let centered: Vec<Vec<f64>> = normalized
91 .iter()
92 .map(|v| {
93 v.iter()
94 .zip(mean.iter())
95 .map(|(&val, &m)| val - m)
96 .collect()
97 })
98 .collect();
99
100 let (components, eigenvalues) = top_k_eigenvectors(¢ered, 3, dim);
101
102 let total_variance: f64 = centered
104 .iter()
105 .map(|row| row.iter().map(|x| x * x).sum::<f64>())
106 .sum::<f64>()
107 / centered.len() as f64;
108
109 Self {
110 components: [
111 components[0].clone(),
112 components[1].clone(),
113 components[2].clone(),
114 ],
115 mean,
116 dim,
117 radial,
118 volumetric: false,
119 eigenvalues: [
120 eigenvalues.first().copied().unwrap_or(0.0),
121 eigenvalues.get(1).copied().unwrap_or(0.0),
122 eigenvalues.get(2).copied().unwrap_or(0.0),
123 ],
124 total_variance,
125 }
126 }
127
128 pub fn fit_default(embeddings: &[Embedding]) -> Self {
129 Self::fit(embeddings, RadialStrategy::default())
130 }
131
132 pub fn with_volumetric(mut self, enabled: bool) -> Self {
136 self.volumetric = enabled;
137 self
138 }
139
140 pub fn explained_variance_ratio(&self) -> f64 {
143 if self.total_variance < f64::EPSILON {
144 return 1.0;
145 }
146 let explained: f64 = self.eigenvalues.iter().sum();
147 (explained / self.total_variance).clamp(0.0, 1.0)
148 }
149
150 fn project_centered(&self, centered: &[f64]) -> (f64, f64, f64, f64) {
152 let x = dot(centered, &self.components[0]);
153 let y = dot(centered, &self.components[1]);
154 let z = dot(centered, &self.components[2]);
155
156 let projected_sq = x * x + y * y + z * z;
157 let total_sq: f64 = centered.iter().map(|v| v * v).sum();
158 let residual_sq = (total_sq - projected_sq).max(0.0);
159
160 (x, y, z, residual_sq)
161 }
162}
163
164impl Projection for PcaProjection {
165 fn project(&self, embedding: &Embedding) -> SphericalPoint {
166 assert_eq!(
167 embedding.dimension(),
168 self.dim,
169 "expected dimension {}, got {}",
170 self.dim,
171 embedding.dimension()
172 );
173
174 let normalized = embedding.normalized();
175
176 let centered: Vec<f64> = normalized
177 .iter()
178 .zip(self.mean.iter())
179 .map(|(&v, &m)| v - m)
180 .collect();
181
182 let (x, y, z, _) = self.project_centered(¢ered);
183
184 if self.volumetric {
185 let sp = cartesian_to_spherical(&CartesianPoint::new(x, y, z));
186 if sp.r < f64::EPSILON {
187 return SphericalPoint::new_unchecked(0.0, 0.0, 0.0);
188 }
189 SphericalPoint::new_unchecked(sp.r, sp.theta, sp.phi)
190 } else {
191 let r = self.radial.compute(embedding.magnitude());
192 project_xyz_to_spherical(x, y, z, r)
193 }
194 }
195
196 fn project_rich(&self, embedding: &Embedding) -> ProjectedPoint {
197 assert_eq!(
198 embedding.dimension(),
199 self.dim,
200 "expected dimension {}, got {}",
201 self.dim,
202 embedding.dimension()
203 );
204
205 let intensity = embedding.magnitude();
206 let normalized = embedding.normalized();
207
208 let centered: Vec<f64> = normalized
209 .iter()
210 .zip(self.mean.iter())
211 .map(|(&v, &m)| v - m)
212 .collect();
213
214 let (x, y, z, residual_sq) = self.project_centered(¢ered);
215 let projection_magnitude = (x * x + y * y + z * z).sqrt();
216
217 let total_sq: f64 = centered.iter().map(|v| v * v).sum();
219 let certainty = if total_sq < f64::EPSILON {
220 0.0
221 } else {
222 (1.0 - residual_sq / total_sq).clamp(0.0, 1.0)
223 };
224
225 let position = if self.volumetric {
226 let sp = cartesian_to_spherical(&CartesianPoint::new(x, y, z));
227 if sp.r < f64::EPSILON {
228 SphericalPoint::new_unchecked(0.0, 0.0, 0.0)
229 } else {
230 SphericalPoint::new_unchecked(sp.r, sp.theta, sp.phi)
231 }
232 } else {
233 let r = self.radial.compute(intensity);
234 project_xyz_to_spherical(x, y, z, r)
235 };
236
237 ProjectedPoint::new(position, certainty, intensity, projection_magnitude)
238 }
239
240 fn dimensionality(&self) -> usize {
241 self.dim
242 }
243}
244
245#[derive(Clone)]
254pub struct RandomProjection {
255 matrix: [Vec<f64>; 3],
256 dim: usize,
257 radial: RadialStrategy,
258}
259
260impl RandomProjection {
261 pub fn new(dim: usize, radial: RadialStrategy, seed: u64) -> Self {
262 assert!(dim >= 3, "embedding dimension must be >= 3");
263 let mut rng = SplitMix64::new(seed);
264 let matrix = std::array::from_fn(|_| (0..dim).map(|_| rng.normal()).collect());
265 Self {
266 matrix,
267 dim,
268 radial,
269 }
270 }
271
272 pub fn new_default(dim: usize) -> Self {
273 Self::new(dim, RadialStrategy::default(), 42)
274 }
275}
276
277impl Projection for RandomProjection {
278 fn project(&self, embedding: &Embedding) -> SphericalPoint {
279 assert_eq!(
280 embedding.dimension(),
281 self.dim,
282 "expected dimension {}, got {}",
283 self.dim,
284 embedding.dimension()
285 );
286
287 let magnitude = embedding.magnitude();
288 let r = self.radial.compute(magnitude);
289 let normalized = embedding.normalized();
290
291 let x = dot(&normalized, &self.matrix[0]);
292 let y = dot(&normalized, &self.matrix[1]);
293 let z = dot(&normalized, &self.matrix[2]);
294
295 project_xyz_to_spherical(x, y, z, r)
296 }
297
298 fn dimensionality(&self) -> usize {
299 self.dim
300 }
301}
302
303pub(crate) fn project_xyz_to_spherical(x: f64, y: f64, z: f64, r: f64) -> SphericalPoint {
306 let cart = CartesianPoint::new(x, y, z).normalize();
307 if cart.magnitude() < f64::EPSILON {
308 return SphericalPoint::new_unchecked(r, 0.0, 0.0);
309 }
310 let sp = cartesian_to_spherical(&cart);
311 SphericalPoint::new_unchecked(r, sp.theta, sp.phi)
312}
313
314pub(crate) fn dot(a: &[f64], b: &[f64]) -> f64 {
317 a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
318}
319
320pub(crate) fn normalize_vec(v: &mut [f64]) -> f64 {
321 let mag = v.iter().map(|x| x * x).sum::<f64>().sqrt();
322 if mag > f64::EPSILON {
323 for x in v.iter_mut() {
324 *x /= mag;
325 }
326 }
327 mag
328}
329
330fn top_k_eigenvectors(data: &[Vec<f64>], k: usize, dim: usize) -> (Vec<Vec<f64>>, Vec<f64>) {
337 let max_iters = 200;
338 let tol = 1e-10;
339 let mut vectors: Vec<Vec<f64>> = Vec::with_capacity(k);
340 let mut values: Vec<f64> = Vec::with_capacity(k);
341 let mut rng = SplitMix64::new(0xDEAD_BEEF);
342 let n = data.len() as f64;
343
344 for _ in 0..k {
345 let mut v: Vec<f64> = (0..dim).map(|_| rng.normal()).collect();
346 normalize_vec(&mut v);
347 let mut eigenvalue = 0.0;
348
349 for _ in 0..max_iters {
350 let w: Vec<f64> = data.iter().map(|row| dot(row, &v)).collect();
352
353 let mut u = vec![0.0; dim];
355 for (row, &wi) in data.iter().zip(w.iter()) {
356 for (uj, &rj) in u.iter_mut().zip(row.iter()) {
357 *uj += wi * rj;
358 }
359 }
360
361 for prev in &vectors {
363 let proj = dot(&u, prev);
364 for (uj, &pj) in u.iter_mut().zip(prev.iter()) {
365 *uj -= proj * pj;
366 }
367 }
368
369 let mag = normalize_vec(&mut u);
370 if mag < f64::EPSILON {
371 break;
372 }
373
374 eigenvalue = mag / n;
376
377 let change = 1.0 - dot(&u, &v).abs();
378 v = u;
379
380 if change < tol {
381 break;
382 }
383 }
384
385 vectors.push(v);
386 values.push(eigenvalue);
387 }
388
389 while vectors.len() < k {
391 let mut v: Vec<f64> = (0..dim).map(|_| rng.normal()).collect();
392 for prev in &vectors {
393 let proj = dot(&v, prev);
394 for (vj, &pj) in v.iter_mut().zip(prev.iter()) {
395 *vj -= proj * pj;
396 }
397 }
398 normalize_vec(&mut v);
399 vectors.push(v);
400 values.push(0.0);
401 }
402
403 (vectors, values)
404}
405
406pub(crate) struct SplitMix64 {
410 state: u64,
411}
412
413impl SplitMix64 {
414 pub(crate) fn new(seed: u64) -> Self {
415 Self { state: seed }
416 }
417
418 pub(crate) fn next_u64(&mut self) -> u64 {
419 self.state = self.state.wrapping_add(0x9e3779b97f4a7c15);
420 let mut z = self.state;
421 z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
422 z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb);
423 z ^ (z >> 31)
424 }
425
426 pub(crate) fn next_f64(&mut self) -> f64 {
427 (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
428 }
429
430 pub(crate) fn normal(&mut self) -> f64 {
431 let u1 = self.next_f64().max(f64::MIN_POSITIVE);
432 let u2 = self.next_f64();
433 (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos()
434 }
435}
436
437#[cfg(test)]
438mod tests {
439 use super::*;
440 use sphereql_core::angular_distance;
441 use std::f64::consts::TAU;
442
443 fn emb(vals: &[f64]) -> Embedding {
444 Embedding::new(vals.to_vec())
445 }
446
447 fn corpus_10d() -> Vec<Embedding> {
448 vec![
449 emb(&[1.0, 0.0, 0.0, 0.1, 0.05, -0.02, 0.03, -0.01, 0.04, 0.02]),
450 emb(&[0.0, 1.0, 0.0, -0.05, 0.1, 0.03, -0.02, 0.01, -0.03, 0.04]),
451 emb(&[0.0, 0.0, 1.0, 0.02, -0.03, 0.1, 0.05, 0.02, -0.01, -0.04]),
452 emb(&[1.0, 1.0, 0.0, 0.05, 0.08, 0.01, 0.01, -0.02, 0.02, 0.03]),
453 emb(&[0.0, 1.0, 1.0, -0.02, 0.07, 0.07, 0.01, 0.02, -0.02, 0.01]),
454 emb(&[1.0, 0.0, 1.0, 0.06, 0.01, 0.05, -0.03, -0.01, 0.03, -0.02]),
455 emb(&[-1.0, 0.0, 0.0, -0.08, 0.02, 0.01, 0.02, 0.03, -0.02, 0.01]),
456 emb(&[0.0, -1.0, 0.0, 0.03, -0.09, -0.02, 0.01, -0.01, 0.02, -0.03]),
457 ]
458 }
459
460 fn assert_valid_spherical(sp: &SphericalPoint) {
461 assert!(sp.r >= 0.0, "r must be >= 0, got {}", sp.r);
462 assert!(
463 sp.theta >= 0.0 && sp.theta < TAU,
464 "theta must be in [0, 2π), got {}",
465 sp.theta
466 );
467 assert!(
468 sp.phi >= 0.0 && sp.phi <= PI,
469 "phi must be in [0, π], got {}",
470 sp.phi
471 );
472 }
473
474 #[test]
477 fn pca_produces_valid_spherical_points() {
478 let corpus = corpus_10d();
479 let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0));
480 for e in &corpus {
481 assert_valid_spherical(&pca.project(e));
482 }
483 }
484
485 #[test]
486 fn pca_preserves_angular_ordering() {
487 let corpus = corpus_10d();
488 let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0));
489
490 let a = emb(&[1.0, 0.1, 0.0, 0.05, 0.02, -0.01, 0.01, 0.0, 0.02, 0.01]);
492 let b = emb(&[0.9, 0.2, 0.1, 0.04, 0.03, 0.0, 0.02, -0.01, 0.01, 0.02]);
493 let c = emb(&[-1.0, -0.1, 0.0, -0.04, 0.01, 0.02, 0.01, 0.02, -0.01, 0.01]);
494
495 let pa = pca.project(&a);
496 let pb = pca.project(&b);
497 let pc = pca.project(&c);
498
499 let d_ab = angular_distance(&pa, &pb);
500 let d_ac = angular_distance(&pa, &pc);
501
502 assert!(
503 d_ab < d_ac,
504 "similar items should be closer: d(a,b)={d_ab:.4} should be < d(a,c)={d_ac:.4}"
505 );
506 }
507
508 #[test]
509 fn pca_magnitude_radial() {
510 let corpus = corpus_10d();
511 let pca = PcaProjection::fit(&corpus, RadialStrategy::Magnitude);
512
513 let short = emb(&[0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
514 let long = emb(&[10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
515
516 let ps = pca.project(&short);
517 let pl = pca.project(&long);
518
519 assert!(ps.r < pl.r, "longer vector should have larger radius");
520 assert!((ps.r - 0.1).abs() < 1e-10);
521 assert!((pl.r - 10.0).abs() < 1e-10);
522 }
523
524 #[test]
525 fn pca_transform_radial() {
526 let corpus = corpus_10d();
527 let pca = PcaProjection::fit(
528 &corpus,
529 RadialStrategy::MagnitudeTransform(Arc::new(|mag| mag.ln_1p())),
530 );
531
532 let e = emb(&[3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
533 let sp = pca.project(&e);
534 assert!((sp.r - 5.0_f64.ln_1p()).abs() < 1e-10);
535 }
536
537 #[test]
538 fn pca_single_embedding() {
539 let corpus = vec![emb(&[1.0, 0.0, 0.0, 0.0, 0.0])];
540 let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0));
541 let sp = pca.project(&corpus[0]);
542 assert!((sp.r - 1.0).abs() < 1e-12);
543 assert_valid_spherical(&sp);
544 }
545
546 #[test]
547 fn pca_dimensionality() {
548 let corpus = corpus_10d();
549 let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0));
550 assert_eq!(pca.dimensionality(), 10);
551 }
552
553 #[test]
554 #[should_panic(expected = "need at least one embedding")]
555 fn pca_empty_corpus_panics() {
556 PcaProjection::fit(&[], RadialStrategy::Fixed(1.0));
557 }
558
559 #[test]
560 #[should_panic(expected = "embedding dimension must be >= 3")]
561 fn pca_too_few_dimensions_panics() {
562 PcaProjection::fit(&[emb(&[1.0, 2.0])], RadialStrategy::Fixed(1.0));
563 }
564
565 #[test]
566 #[should_panic(expected = "expected dimension 10")]
567 fn pca_dimension_mismatch_panics() {
568 let corpus = corpus_10d();
569 let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0));
570 let _ = pca.project(&emb(&[1.0, 2.0, 3.0]));
571 }
572
573 #[test]
576 fn random_produces_valid_spherical_points() {
577 let rp = RandomProjection::new(10, RadialStrategy::Fixed(1.0), 42);
578 for i in 0..20 {
579 let e = emb(&[i as f64 * 0.1 + 0.01; 10]);
580 assert_valid_spherical(&rp.project(&e));
581 }
582 }
583
584 #[test]
585 fn random_deterministic_with_same_seed() {
586 let rp1 = RandomProjection::new(10, RadialStrategy::Fixed(1.0), 42);
587 let rp2 = RandomProjection::new(10, RadialStrategy::Fixed(1.0), 42);
588 let e = emb(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
589 let sp1 = rp1.project(&e);
590 let sp2 = rp2.project(&e);
591 assert!((sp1.theta - sp2.theta).abs() < 1e-12);
592 assert!((sp1.phi - sp2.phi).abs() < 1e-12);
593 }
594
595 #[test]
596 fn random_different_seeds_differ() {
597 let rp1 = RandomProjection::new(10, RadialStrategy::Fixed(1.0), 42);
598 let rp2 = RandomProjection::new(10, RadialStrategy::Fixed(1.0), 999);
599 let e = emb(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
600 let d = angular_distance(&rp1.project(&e), &rp2.project(&e));
601 assert!(
602 d > 1e-6,
603 "different seeds should produce different projections"
604 );
605 }
606
607 #[test]
608 fn random_dimensionality() {
609 let rp = RandomProjection::new(768, RadialStrategy::Fixed(1.0), 0);
610 assert_eq!(rp.dimensionality(), 768);
611 }
612
613 #[test]
614 #[should_panic(expected = "embedding dimension must be >= 3")]
615 fn random_too_few_dimensions_panics() {
616 RandomProjection::new(2, RadialStrategy::Fixed(1.0), 0);
617 }
618
619 #[test]
622 fn arc_projection_delegates() {
623 let rp = Arc::new(RandomProjection::new_default(10));
624 let e = emb(&[1.0; 10]);
625 let sp = rp.project(&e);
626 assert!(sp.r > 0.0);
627 assert_eq!(rp.dimensionality(), 10);
628 }
629
630 #[test]
633 fn prng_produces_distinct_values() {
634 let mut rng = SplitMix64::new(42);
635 let vals: Vec<f64> = (0..100).map(|_| rng.next_f64()).collect();
636 for i in 0..vals.len() {
637 for j in (i + 1)..vals.len() {
638 assert_ne!(vals[i].to_bits(), vals[j].to_bits());
639 }
640 }
641 }
642
643 #[test]
644 fn prng_normal_distribution_reasonable() {
645 let mut rng = SplitMix64::new(12345);
646 let samples: Vec<f64> = (0..10_000).map(|_| rng.normal()).collect();
647
648 let mean = samples.iter().sum::<f64>() / samples.len() as f64;
649 let variance =
650 samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / samples.len() as f64;
651
652 assert!(mean.abs() < 0.05, "mean should be near 0, got {mean}");
653 assert!(
654 (variance - 1.0).abs() < 0.1,
655 "variance should be near 1, got {variance}"
656 );
657 }
658}