1use std::f64::consts::PI;
2use std::sync::Arc;
3
4use sphereql_core::{CartesianPoint, SphericalPoint, cartesian_to_spherical};
5
6use crate::types::{Embedding, ProjectedPoint, RadialStrategy};
7
8#[derive(Debug, Clone, PartialEq, thiserror::Error)]
15pub enum ProjectionError {
16 #[error("need at least one embedding to fit a projection")]
18 EmptyCorpus,
19
20 #[error("embedding dimension {got} is below the minimum {required} for this projection")]
23 DimensionTooLow { got: usize, required: usize },
24
25 #[error("embedding {index} has dimension {got}, expected {expected}")]
28 InconsistentDimension {
29 index: usize,
30 expected: usize,
31 got: usize,
32 },
33
34 #[error("need at least {required} embeddings, got {got}")]
37 TooFewEmbeddings { got: usize, required: usize },
38
39 #[error("kernel bandwidth σ must be positive, got {got}")]
41 InvalidSigma { got: f64 },
42
43 #[error("auxiliary slice has length {got}, expected {expected}")]
46 SliceLengthMismatch { expected: usize, got: usize },
47}
48
49pub trait Projection: Send + Sync {
55 fn project(&self, embedding: &Embedding) -> SphericalPoint;
56
57 fn project_rich(&self, embedding: &Embedding) -> ProjectedPoint {
59 let position = self.project(embedding);
60 ProjectedPoint::from_position(position, embedding.magnitude())
61 }
62
63 fn dimensionality(&self) -> usize;
64}
65
66impl<P: Projection> Projection for Arc<P> {
67 fn project(&self, embedding: &Embedding) -> SphericalPoint {
68 (**self).project(embedding)
69 }
70 fn project_rich(&self, embedding: &Embedding) -> ProjectedPoint {
71 (**self).project_rich(embedding)
72 }
73 fn dimensionality(&self) -> usize {
74 (**self).dimensionality()
75 }
76}
77
78#[derive(Clone)]
88pub struct PcaProjection {
89 components: [Vec<f64>; 3],
90 mean: Vec<f64>,
91 dim: usize,
92 radial: RadialStrategy,
93 volumetric: bool,
94 eigenvalues: [f64; 3],
97 total_variance: f64,
100}
101
102const PCA_MIN_DIM: usize = 3;
104
105impl PcaProjection {
106 fn validate_embeddings(embeddings: &[Embedding]) -> Result<usize, ProjectionError> {
110 if embeddings.is_empty() {
111 return Err(ProjectionError::EmptyCorpus);
112 }
113 let dim = embeddings[0].dimension();
114 if dim < PCA_MIN_DIM {
115 return Err(ProjectionError::DimensionTooLow {
116 got: dim,
117 required: PCA_MIN_DIM,
118 });
119 }
120 for (i, e) in embeddings.iter().enumerate() {
121 if e.dimension() != dim {
122 return Err(ProjectionError::InconsistentDimension {
123 index: i,
124 expected: dim,
125 got: e.dimension(),
126 });
127 }
128 }
129 Ok(dim)
130 }
131
132 fn from_eigendecomp(
137 components: Vec<Vec<f64>>,
138 eigenvalues: Vec<f64>,
139 mean: Vec<f64>,
140 dim: usize,
141 radial: RadialStrategy,
142 total_variance: f64,
143 ) -> Self {
144 Self {
145 components: [
146 components[0].clone(),
147 components[1].clone(),
148 components[2].clone(),
149 ],
150 mean,
151 dim,
152 radial,
153 volumetric: false,
154 eigenvalues: [
155 eigenvalues.first().copied().unwrap_or(0.0),
156 eigenvalues.get(1).copied().unwrap_or(0.0),
157 eigenvalues.get(2).copied().unwrap_or(0.0),
158 ],
159 total_variance,
160 }
161 }
162
163 pub fn fit(embeddings: &[Embedding], radial: RadialStrategy) -> Result<Self, ProjectionError> {
172 let dim = Self::validate_embeddings(embeddings)?;
173
174 let normalized: Vec<Vec<f64>> = embeddings.iter().map(|e| e.normalized()).collect();
175 let n = normalized.len();
176
177 let mut mean = vec![0.0; dim];
178 for v in &normalized {
179 for (i, &val) in v.iter().enumerate() {
180 mean[i] += val;
181 }
182 }
183 for m in &mut mean {
184 *m /= n as f64;
185 }
186
187 let centered: Vec<Vec<f64>> = normalized
188 .iter()
189 .map(|v| {
190 v.iter()
191 .zip(mean.iter())
192 .map(|(&val, &m)| val - m)
193 .collect()
194 })
195 .collect();
196
197 let (components, eigenvalues) = top_k_eigenvectors(¢ered, 3, dim);
198
199 let total_variance: f64 = centered
201 .iter()
202 .map(|row| row.iter().map(|x| x * x).sum::<f64>())
203 .sum::<f64>()
204 / centered.len() as f64;
205
206 Ok(Self::from_eigendecomp(
207 components,
208 eigenvalues,
209 mean,
210 dim,
211 radial,
212 total_variance,
213 ))
214 }
215
216 pub fn fit_default(embeddings: &[Embedding]) -> Result<Self, ProjectionError> {
217 Self::fit(embeddings, RadialStrategy::default())
218 }
219
220 pub fn fit_weighted(
240 embeddings: &[Embedding],
241 weights: &[f64],
242 radial: RadialStrategy,
243 ) -> Result<Self, ProjectionError> {
244 if weights.len() != embeddings.len() {
247 return Err(ProjectionError::SliceLengthMismatch {
248 expected: embeddings.len(),
249 got: weights.len(),
250 });
251 }
252 let dim = Self::validate_embeddings(embeddings)?;
253
254 let clamped: Vec<f64> = weights.iter().map(|&w| w.max(0.0)).collect();
255 let w_sum: f64 = clamped.iter().sum();
256 if w_sum < f64::EPSILON {
257 return Self::fit(embeddings, radial);
260 }
261
262 let normalized: Vec<Vec<f64>> = embeddings.iter().map(|e| e.normalized()).collect();
263
264 let mut mean = vec![0.0; dim];
266 for (v, &w) in normalized.iter().zip(clamped.iter()) {
267 for (i, &val) in v.iter().enumerate() {
268 mean[i] += w * val;
269 }
270 }
271 for m in &mut mean {
272 *m /= w_sum;
273 }
274
275 let scaled: Vec<Vec<f64>> = normalized
280 .iter()
281 .zip(clamped.iter())
282 .map(|(v, &w)| {
283 let s = w.sqrt();
284 v.iter()
285 .zip(mean.iter())
286 .map(|(&val, &m)| s * (val - m))
287 .collect()
288 })
289 .collect();
290
291 let (components, eigenvalues) = top_k_eigenvectors(&scaled, 3, dim);
292
293 let total_variance: f64 = scaled
297 .iter()
298 .map(|row| row.iter().map(|x| x * x).sum::<f64>())
299 .sum::<f64>()
300 / scaled.len() as f64;
301
302 Ok(Self::from_eigendecomp(
303 components,
304 eigenvalues,
305 mean,
306 dim,
307 radial,
308 total_variance,
309 ))
310 }
311
312 pub fn with_volumetric(mut self, enabled: bool) -> Self {
316 self.volumetric = enabled;
317 self
318 }
319
320 pub fn explained_variance_ratio(&self) -> f64 {
323 if self.total_variance < f64::EPSILON {
324 return 1.0;
325 }
326 let explained: f64 = self.eigenvalues.iter().sum();
327 (explained / self.total_variance).clamp(0.0, 1.0)
328 }
329
330 fn project_xyz_residual(&self, embedding: &Embedding) -> (f64, f64, f64, f64) {
342 let values = &embedding.values;
343 debug_assert_eq!(values.len(), self.dim);
344
345 let mag = embedding.magnitude();
346 let inv_mag = if mag < f64::EPSILON { 0.0 } else { 1.0 / mag };
347
348 let mut x = 0.0f64;
349 let mut y = 0.0f64;
350 let mut z = 0.0f64;
351 let mut total_sq = 0.0f64;
352 let c0 = &self.components[0];
353 let c1 = &self.components[1];
354 let c2 = &self.components[2];
355 for i in 0..self.dim {
356 let n = values[i] * inv_mag;
357 let c = n - self.mean[i];
358 x += c * c0[i];
359 y += c * c1[i];
360 z += c * c2[i];
361 total_sq += c * c;
362 }
363 let projected_sq = x * x + y * y + z * z;
364 let residual_sq = (total_sq - projected_sq).max(0.0);
365 (x, y, z, residual_sq)
366 }
367}
368
369impl Projection for PcaProjection {
370 fn project(&self, embedding: &Embedding) -> SphericalPoint {
371 assert_eq!(
375 embedding.dimension(),
376 self.dim,
377 "expected dimension {}, got {}",
378 self.dim,
379 embedding.dimension()
380 );
381
382 let (x, y, z, residual_sq) = self.project_xyz_residual(embedding);
383
384 if self.volumetric {
385 let sp = cartesian_to_spherical(&CartesianPoint::new(x, y, z));
386 if sp.r < f64::EPSILON {
387 return SphericalPoint::new_unchecked(0.0, 0.0, 0.0);
388 }
389 SphericalPoint::new_unchecked(sp.r, sp.theta, sp.phi)
390 } else {
391 let projection_magnitude = (x * x + y * y + z * z).sqrt();
392 let intensity = embedding.magnitude();
393 let certainty = pca_certainty(embedding, &self.mean, intensity, residual_sq);
394 let r = self.radial.compute_rich(&crate::types::RadialContext::full(
395 intensity,
396 projection_magnitude,
397 certainty,
398 ));
399 project_xyz_to_spherical(x, y, z, r)
400 }
401 }
402
403 fn project_rich(&self, embedding: &Embedding) -> ProjectedPoint {
404 assert_eq!(
406 embedding.dimension(),
407 self.dim,
408 "expected dimension {}, got {}",
409 self.dim,
410 embedding.dimension()
411 );
412
413 let intensity = embedding.magnitude();
414 let (x, y, z, residual_sq) = self.project_xyz_residual(embedding);
415 let projection_magnitude = (x * x + y * y + z * z).sqrt();
416 let certainty = pca_certainty(embedding, &self.mean, intensity, residual_sq);
417
418 let position = if self.volumetric {
419 let sp = cartesian_to_spherical(&CartesianPoint::new(x, y, z));
420 if sp.r < f64::EPSILON {
421 SphericalPoint::new_unchecked(0.0, 0.0, 0.0)
422 } else {
423 SphericalPoint::new_unchecked(sp.r, sp.theta, sp.phi)
424 }
425 } else {
426 let r = self.radial.compute_rich(&crate::types::RadialContext::full(
427 intensity,
428 projection_magnitude,
429 certainty,
430 ));
431 project_xyz_to_spherical(x, y, z, r)
432 };
433
434 ProjectedPoint::new(position, certainty, intensity, projection_magnitude)
435 }
436
437 fn dimensionality(&self) -> usize {
438 self.dim
439 }
440}
441
442#[derive(Clone)]
451pub struct RandomProjection {
452 matrix: [Vec<f64>; 3],
453 dim: usize,
454 radial: RadialStrategy,
455}
456
457impl RandomProjection {
458 pub fn new(dim: usize, radial: RadialStrategy, seed: u64) -> Self {
459 assert!(dim >= 3, "embedding dimension must be >= 3");
463 let mut rng = SplitMix64::new(seed);
464 let matrix = std::array::from_fn(|_| (0..dim).map(|_| rng.normal()).collect());
465 Self {
466 matrix,
467 dim,
468 radial,
469 }
470 }
471
472 pub fn new_default(dim: usize) -> Self {
473 Self::new(dim, RadialStrategy::default(), 42)
474 }
475}
476
477impl Projection for RandomProjection {
478 fn project(&self, embedding: &Embedding) -> SphericalPoint {
479 assert_eq!(
480 embedding.dimension(),
481 self.dim,
482 "expected dimension {}, got {}",
483 self.dim,
484 embedding.dimension()
485 );
486
487 let magnitude = embedding.magnitude();
488 let normalized = embedding.normalized();
489
490 let x = dot(&normalized, &self.matrix[0]);
491 let y = dot(&normalized, &self.matrix[1]);
492 let z = dot(&normalized, &self.matrix[2]);
493
494 let projection_magnitude = (x * x + y * y + z * z).sqrt();
498 let r = self.radial.compute_rich(&crate::types::RadialContext::full(
499 magnitude,
500 projection_magnitude,
501 1.0,
502 ));
503
504 project_xyz_to_spherical(x, y, z, r)
505 }
506
507 fn dimensionality(&self) -> usize {
508 self.dim
509 }
510}
511
512fn pca_certainty(embedding: &Embedding, mean: &[f64], intensity: f64, residual_sq: f64) -> f64 {
518 let zero_intensity = intensity < f64::EPSILON;
523 let inv_mag = if zero_intensity { 0.0 } else { 1.0 / intensity };
524 let total_sq: f64 = (0..mean.len())
525 .map(|i| {
526 let normalized_i = if zero_intensity {
527 if i == 0 { 1.0 } else { 0.0 }
528 } else {
529 embedding.values[i] * inv_mag
530 };
531 let c = normalized_i - mean[i];
532 c * c
533 })
534 .sum();
535 if total_sq < f64::EPSILON {
536 0.0
537 } else {
538 (1.0 - residual_sq / total_sq).clamp(0.0, 1.0)
539 }
540}
541
542pub(crate) fn project_xyz_to_spherical(x: f64, y: f64, z: f64, r: f64) -> SphericalPoint {
543 let cart = CartesianPoint::new(x, y, z).normalize();
544 if cart.magnitude() < f64::EPSILON {
545 return SphericalPoint::new_unchecked(r, 0.0, 0.0);
546 }
547 let sp = cartesian_to_spherical(&cart);
548 SphericalPoint::new_unchecked(r, sp.theta, sp.phi)
549}
550
551pub(crate) fn dot(a: &[f64], b: &[f64]) -> f64 {
554 a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
555}
556
557pub(crate) fn normalize_vec(v: &mut [f64]) -> f64 {
558 let mag = v.iter().map(|x| x * x).sum::<f64>().sqrt();
559 if mag > f64::EPSILON {
560 for x in v.iter_mut() {
561 *x /= mag;
562 }
563 }
564 mag
565}
566
567fn top_k_eigenvectors(data: &[Vec<f64>], k: usize, dim: usize) -> (Vec<Vec<f64>>, Vec<f64>) {
574 let max_iters = 200;
575 let tol = 1e-10;
576 let mut vectors: Vec<Vec<f64>> = Vec::with_capacity(k);
577 let mut values: Vec<f64> = Vec::with_capacity(k);
578 let mut rng = SplitMix64::new(0xDEAD_BEEF);
579 let n = data.len() as f64;
580
581 for _ in 0..k {
582 let mut v: Vec<f64> = (0..dim).map(|_| rng.normal()).collect();
583 normalize_vec(&mut v);
584 let mut eigenvalue = 0.0;
585
586 for _ in 0..max_iters {
587 let w: Vec<f64> = data.iter().map(|row| dot(row, &v)).collect();
589
590 let mut u = vec![0.0; dim];
592 for (row, &wi) in data.iter().zip(w.iter()) {
593 for (uj, &rj) in u.iter_mut().zip(row.iter()) {
594 *uj += wi * rj;
595 }
596 }
597
598 for prev in &vectors {
600 let proj = dot(&u, prev);
601 for (uj, &pj) in u.iter_mut().zip(prev.iter()) {
602 *uj -= proj * pj;
603 }
604 }
605
606 let mag = normalize_vec(&mut u);
607 if mag < f64::EPSILON {
608 break;
609 }
610
611 eigenvalue = mag / n;
613
614 let change = (1.0 - dot(&u, &v).abs()).max(0.0);
617 v = u;
618
619 if change < tol {
620 break;
621 }
622 }
623
624 vectors.push(v);
625 values.push(eigenvalue);
626 }
627
628 while vectors.len() < k {
630 let mut v: Vec<f64> = (0..dim).map(|_| rng.normal()).collect();
631 for prev in &vectors {
632 let proj = dot(&v, prev);
633 for (vj, &pj) in v.iter_mut().zip(prev.iter()) {
634 *vj -= proj * pj;
635 }
636 }
637 normalize_vec(&mut v);
638 vectors.push(v);
639 values.push(0.0);
640 }
641
642 (vectors, values)
643}
644
645pub(crate) struct SplitMix64 {
649 state: u64,
650}
651
652impl SplitMix64 {
653 pub(crate) fn new(seed: u64) -> Self {
654 Self { state: seed }
655 }
656
657 pub(crate) fn next_u64(&mut self) -> u64 {
658 self.state = self.state.wrapping_add(0x9e3779b97f4a7c15);
659 let mut z = self.state;
660 z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
661 z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb);
662 z ^ (z >> 31)
663 }
664
665 pub(crate) fn next_f64(&mut self) -> f64 {
666 (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
667 }
668
669 pub(crate) fn normal(&mut self) -> f64 {
670 let u1 = self.next_f64().max(f64::MIN_POSITIVE);
671 let u2 = self.next_f64();
672 (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos()
673 }
674}
675
676#[cfg(test)]
677mod tests {
678 use super::*;
679 use sphereql_core::angular_distance;
680 use std::f64::consts::TAU;
681
682 fn emb(vals: &[f64]) -> Embedding {
683 Embedding::new(vals.to_vec())
684 }
685
686 fn corpus_10d() -> Vec<Embedding> {
687 vec![
688 emb(&[1.0, 0.0, 0.0, 0.1, 0.05, -0.02, 0.03, -0.01, 0.04, 0.02]),
689 emb(&[0.0, 1.0, 0.0, -0.05, 0.1, 0.03, -0.02, 0.01, -0.03, 0.04]),
690 emb(&[0.0, 0.0, 1.0, 0.02, -0.03, 0.1, 0.05, 0.02, -0.01, -0.04]),
691 emb(&[1.0, 1.0, 0.0, 0.05, 0.08, 0.01, 0.01, -0.02, 0.02, 0.03]),
692 emb(&[0.0, 1.0, 1.0, -0.02, 0.07, 0.07, 0.01, 0.02, -0.02, 0.01]),
693 emb(&[1.0, 0.0, 1.0, 0.06, 0.01, 0.05, -0.03, -0.01, 0.03, -0.02]),
694 emb(&[-1.0, 0.0, 0.0, -0.08, 0.02, 0.01, 0.02, 0.03, -0.02, 0.01]),
695 emb(&[0.0, -1.0, 0.0, 0.03, -0.09, -0.02, 0.01, -0.01, 0.02, -0.03]),
696 ]
697 }
698
699 fn assert_valid_spherical(sp: &SphericalPoint) {
700 assert!(sp.r >= 0.0, "r must be >= 0, got {}", sp.r);
701 assert!(
702 sp.theta >= 0.0 && sp.theta < TAU,
703 "theta must be in [0, 2π), got {}",
704 sp.theta
705 );
706 assert!(
707 sp.phi >= 0.0 && sp.phi <= PI,
708 "phi must be in [0, π], got {}",
709 sp.phi
710 );
711 }
712
713 #[test]
716 fn pca_produces_valid_spherical_points() {
717 let corpus = corpus_10d();
718 let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
719 for e in &corpus {
720 assert_valid_spherical(&pca.project(e));
721 }
722 }
723
724 #[test]
725 fn pca_preserves_angular_ordering() {
726 let corpus = corpus_10d();
727 let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
728
729 let a = emb(&[1.0, 0.1, 0.0, 0.05, 0.02, -0.01, 0.01, 0.0, 0.02, 0.01]);
731 let b = emb(&[0.9, 0.2, 0.1, 0.04, 0.03, 0.0, 0.02, -0.01, 0.01, 0.02]);
732 let c = emb(&[-1.0, -0.1, 0.0, -0.04, 0.01, 0.02, 0.01, 0.02, -0.01, 0.01]);
733
734 let pa = pca.project(&a);
735 let pb = pca.project(&b);
736 let pc = pca.project(&c);
737
738 let d_ab = angular_distance(&pa, &pb);
739 let d_ac = angular_distance(&pa, &pc);
740
741 assert!(
742 d_ab < d_ac,
743 "similar items should be closer: d(a,b)={d_ab:.4} should be < d(a,c)={d_ac:.4}"
744 );
745 }
746
747 #[test]
748 fn pca_magnitude_radial() {
749 let corpus = corpus_10d();
750 let pca = PcaProjection::fit(&corpus, RadialStrategy::Magnitude).unwrap();
751
752 let short = emb(&[0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
753 let long = emb(&[10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
754
755 let ps = pca.project(&short);
756 let pl = pca.project(&long);
757
758 assert!(ps.r < pl.r, "longer vector should have larger radius");
759 assert!((ps.r - 0.1).abs() < 1e-10);
760 assert!((pl.r - 10.0).abs() < 1e-10);
761 }
762
763 #[test]
764 fn pca_transform_radial() {
765 let corpus = corpus_10d();
766 let pca = PcaProjection::fit(
767 &corpus,
768 RadialStrategy::MagnitudeTransform(Arc::new(|mag| mag.ln_1p())),
769 )
770 .unwrap();
771
772 let e = emb(&[3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
773 let sp = pca.project(&e);
774 assert!((sp.r - 5.0_f64.ln_1p()).abs() < 1e-10);
775 }
776
777 #[test]
778 fn pca_single_embedding() {
779 let corpus = vec![emb(&[1.0, 0.0, 0.0, 0.0, 0.0])];
780 let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
781 let sp = pca.project(&corpus[0]);
782 assert!((sp.r - 1.0).abs() < 1e-12);
783 assert_valid_spherical(&sp);
784 }
785
786 #[test]
787 fn pca_dimensionality() {
788 let corpus = corpus_10d();
789 let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
790 assert_eq!(pca.dimensionality(), 10);
791 }
792
793 #[test]
794 fn pca_empty_corpus_returns_err() {
795 assert!(matches!(
796 PcaProjection::fit(&[], RadialStrategy::Fixed(1.0)),
797 Err(ProjectionError::EmptyCorpus)
798 ));
799 }
800
801 #[test]
802 fn pca_too_few_dimensions_returns_err() {
803 assert!(matches!(
804 PcaProjection::fit(&[emb(&[1.0, 2.0])], RadialStrategy::Fixed(1.0)),
805 Err(ProjectionError::DimensionTooLow {
806 got: 2,
807 required: 3
808 })
809 ));
810 }
811
812 #[test]
813 #[should_panic(expected = "expected dimension 10")]
814 fn pca_dimension_mismatch_panics() {
815 let corpus = corpus_10d();
816 let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
817 let _ = pca.project(&emb(&[1.0, 2.0, 3.0]));
818 }
819
820 #[test]
823 fn fit_weighted_uniform_weights_matches_naive_fit() {
824 let corpus = corpus_10d();
825 let uniform: Vec<f64> = vec![1.0; corpus.len()];
826
827 let plain = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
828 let weighted =
829 PcaProjection::fit_weighted(&corpus, &uniform, RadialStrategy::Fixed(1.0)).unwrap();
830
831 for e in &corpus {
836 let a = plain.project(e);
837 let b = weighted.project(e);
838 assert!(
839 angular_distance(&a, &b) < 1e-9,
840 "uniform-weight fit should match naive fit"
841 );
842 }
843 assert!(
844 (plain.explained_variance_ratio() - weighted.explained_variance_ratio()).abs() < 1e-9
845 );
846 }
847
848 #[test]
849 fn fit_weighted_balances_imbalanced_corpus() {
850 let mut corpus: Vec<Embedding> = Vec::new();
855 let mut weights: Vec<f64> = Vec::new();
856 for i in 0..20 {
857 let mut v = vec![0.0; 8];
858 v[0] = 1.0 + (i as f64) * 0.001;
859 v[1] = 0.01;
860 corpus.push(emb(&v));
861 weights.push(1.0 / (20f64).sqrt());
862 }
863 corpus.push(emb(&[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]));
865 weights.push(1.0);
866
867 let weighted =
868 PcaProjection::fit_weighted(&corpus, &weights, RadialStrategy::Fixed(1.0)).unwrap();
869 assert!(
872 weighted.explained_variance_ratio() > 0.5,
873 "weighted EVR should be > 0.5, got {}",
874 weighted.explained_variance_ratio()
875 );
876 }
877
878 #[test]
879 fn fit_weighted_rejects_length_mismatch() {
880 let corpus = corpus_10d();
881 let bad_weights = vec![1.0; corpus.len() - 1];
882 let result = PcaProjection::fit_weighted(&corpus, &bad_weights, RadialStrategy::Fixed(1.0));
883 assert!(matches!(
884 result,
885 Err(ProjectionError::SliceLengthMismatch { .. })
886 ));
887 }
888
889 #[test]
890 fn fit_weighted_zero_weights_falls_back_to_unweighted() {
891 let corpus = corpus_10d();
892 let zeros = vec![0.0; corpus.len()];
893 let weighted =
894 PcaProjection::fit_weighted(&corpus, &zeros, RadialStrategy::Fixed(1.0)).unwrap();
895 let plain = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
896 for e in &corpus {
897 let a = plain.project(e);
898 let b = weighted.project(e);
899 assert!(angular_distance(&a, &b) < 1e-9);
900 }
901 }
902
903 #[test]
906 fn random_produces_valid_spherical_points() {
907 let rp = RandomProjection::new(10, RadialStrategy::Fixed(1.0), 42);
908 for i in 0..20 {
909 let e = emb(&[i as f64 * 0.1 + 0.01; 10]);
910 assert_valid_spherical(&rp.project(&e));
911 }
912 }
913
914 #[test]
915 fn random_deterministic_with_same_seed() {
916 let rp1 = RandomProjection::new(10, RadialStrategy::Fixed(1.0), 42);
917 let rp2 = RandomProjection::new(10, RadialStrategy::Fixed(1.0), 42);
918 let e = emb(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
919 let sp1 = rp1.project(&e);
920 let sp2 = rp2.project(&e);
921 assert!((sp1.theta - sp2.theta).abs() < 1e-12);
922 assert!((sp1.phi - sp2.phi).abs() < 1e-12);
923 }
924
925 #[test]
926 fn random_different_seeds_differ() {
927 let rp1 = RandomProjection::new(10, RadialStrategy::Fixed(1.0), 42);
928 let rp2 = RandomProjection::new(10, RadialStrategy::Fixed(1.0), 999);
929 let e = emb(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
930 let d = angular_distance(&rp1.project(&e), &rp2.project(&e));
931 assert!(
932 d > 1e-6,
933 "different seeds should produce different projections"
934 );
935 }
936
937 #[test]
938 fn random_dimensionality() {
939 let rp = RandomProjection::new(768, RadialStrategy::Fixed(1.0), 0);
940 assert_eq!(rp.dimensionality(), 768);
941 }
942
943 #[test]
944 #[should_panic(expected = "embedding dimension must be >= 3")]
945 fn random_too_few_dimensions_panics() {
946 RandomProjection::new(2, RadialStrategy::Fixed(1.0), 0);
947 }
948
949 #[test]
952 fn arc_projection_delegates() {
953 let rp = Arc::new(RandomProjection::new_default(10));
954 let e = emb(&[1.0; 10]);
955 let sp = rp.project(&e);
956 assert!(sp.r > 0.0);
957 assert_eq!(rp.dimensionality(), 10);
958 }
959
960 #[test]
963 fn prng_produces_distinct_values() {
964 let mut rng = SplitMix64::new(42);
965 let vals: Vec<f64> = (0..100).map(|_| rng.next_f64()).collect();
966 for i in 0..vals.len() {
967 for j in (i + 1)..vals.len() {
968 assert_ne!(vals[i].to_bits(), vals[j].to_bits());
969 }
970 }
971 }
972
973 #[test]
974 fn prng_normal_distribution_reasonable() {
975 let mut rng = SplitMix64::new(12345);
976 let samples: Vec<f64> = (0..10_000).map(|_| rng.normal()).collect();
977
978 let mean = samples.iter().sum::<f64>() / samples.len() as f64;
979 let variance =
980 samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / samples.len() as f64;
981
982 assert!(mean.abs() < 0.05, "mean should be near 0, got {mean}");
983 assert!(
984 (variance - 1.0).abs() < 0.1,
985 "variance should be near 1, got {variance}"
986 );
987 }
988}