1use std::collections::{BinaryHeap, HashMap};
2
3use sphereql_core::*;
4use sphereql_index::*;
5
6use crate::projection::Projection;
7use crate::types::{Embedding, ProjectedPoint};
8
9#[derive(Debug, Clone)]
10pub struct EmbeddingItem {
11 pub id: String,
12 pub position: SphericalPoint,
13 pub original_magnitude: f64,
14 pub projected: Option<ProjectedPoint>,
16}
17
18impl SpatialItem for EmbeddingItem {
19 type Id = String;
20 fn id(&self) -> &String {
21 &self.id
22 }
23 fn position(&self) -> &SphericalPoint {
24 &self.position
25 }
26}
27
28impl EmbeddingItem {
29 pub fn certainty(&self) -> f64 {
31 self.projected.map_or(1.0, |p| p.certainty)
32 }
33
34 pub fn intensity(&self) -> f64 {
36 self.projected
37 .map_or(self.original_magnitude, |p| p.intensity)
38 }
39
40 pub fn projection_magnitude(&self) -> f64 {
43 self.projected.map_or(1.0, |p| p.projection_magnitude)
44 }
45}
46
47pub struct EmbeddingIndexBuilder<P> {
48 projection: P,
49 inner: SpatialIndexBuilder,
50}
51
52impl<P: Projection> EmbeddingIndexBuilder<P> {
53 pub fn new(projection: P) -> Self {
54 Self {
55 projection,
56 inner: SpatialIndexBuilder::new(),
57 }
58 }
59
60 pub fn shell_boundary(mut self, r: f64) -> Self {
61 self.inner = self.inner.shell_boundary(r);
62 self
63 }
64
65 pub fn uniform_shells(mut self, count: usize, max_r: f64) -> Self {
66 self.inner = self.inner.uniform_shells(count, max_r);
67 self
68 }
69
70 pub fn theta_divisions(mut self, n: usize) -> Self {
71 self.inner = self.inner.theta_divisions(n);
72 self
73 }
74
75 pub fn phi_divisions(mut self, n: usize) -> Self {
76 self.inner = self.inner.phi_divisions(n);
77 self
78 }
79
80 pub fn build(self) -> EmbeddingIndex<P> {
81 EmbeddingIndex {
82 projection: self.projection,
83 index: self.inner.build(),
84 }
85 }
86}
87
88pub struct EmbeddingIndex<P> {
89 projection: P,
90 index: SpatialIndex<EmbeddingItem>,
91}
92
93impl<P: Projection> EmbeddingIndex<P> {
94 pub fn builder(projection: P) -> EmbeddingIndexBuilder<P> {
95 EmbeddingIndexBuilder::new(projection)
96 }
97
98 pub fn insert(&mut self, id: impl Into<String>, embedding: &Embedding) {
99 let rich = self.projection.project_rich(embedding);
100 self.index.insert(EmbeddingItem {
101 id: id.into(),
102 position: rich.position,
103 original_magnitude: embedding.magnitude(),
104 projected: Some(rich),
105 });
106 }
107
108 pub fn insert_with_radius(&mut self, id: impl Into<String>, embedding: &Embedding, r: f64) {
112 let rich = self.projection.project_rich(embedding);
113 let position = SphericalPoint::new_unchecked(r, rich.position.theta, rich.position.phi);
114 self.index.insert(EmbeddingItem {
115 id: id.into(),
116 position,
117 original_magnitude: embedding.magnitude(),
118 projected: Some(ProjectedPoint { position, ..rich }),
119 });
120 }
121
122 pub fn search_nearest(&self, query: &Embedding, k: usize) -> Vec<NearestResult<EmbeddingItem>> {
124 let projected = self.projection.project(query);
125 self.index.nearest(&projected, k)
126 }
127
128 pub fn search_similar(
133 &self,
134 query: &Embedding,
135 min_cosine_similarity: f64,
136 ) -> SpatialQueryResult<EmbeddingItem> {
137 let projected = self.projection.project(query);
138 let max_angle = min_cosine_similarity.clamp(-1.0, 1.0).acos();
139 self.index.within_distance(&projected, max_angle)
140 }
141
142 pub fn search_region(&self, region: &Region) -> SpatialQueryResult<EmbeddingItem> {
143 self.index.query_region(region)
144 }
145
146 pub fn remove(&mut self, id: &str) -> Option<EmbeddingItem> {
147 self.index.remove(&id.to_string())
148 }
149
150 pub fn get(&self, id: &str) -> Option<&EmbeddingItem> {
151 self.index.get(&id.to_string())
152 }
153
154 pub fn len(&self) -> usize {
155 self.index.len()
156 }
157
158 pub fn is_empty(&self) -> bool {
159 self.index.is_empty()
160 }
161
162 pub fn projection(&self) -> &P {
163 &self.projection
164 }
165
166 pub fn all_items(&self) -> Vec<&EmbeddingItem> {
167 self.index.all_items()
168 }
169
170 pub fn concept_path(&self, source_id: &str, target_id: &str, k: usize) -> Option<ConceptPath> {
181 let items = self.index.all_items();
182 let n = items.len();
183 if n < 2 {
184 return None;
185 }
186
187 let id_to_idx: HashMap<&str, usize> = items
188 .iter()
189 .enumerate()
190 .map(|(i, item)| (item.id.as_str(), i))
191 .collect();
192
193 let source_idx = *id_to_idx.get(source_id)?;
194 let target_idx = *id_to_idx.get(target_id)?;
195
196 let mut adj: Vec<Vec<(usize, f64)>> = vec![Vec::new(); n];
198 for (i, item) in items.iter().enumerate() {
199 let nearest = self.index.nearest(item.position(), k + 1);
200 for result in &nearest {
201 if let Some(&j) = id_to_idx.get(result.item.id.as_str())
202 && i != j
203 {
204 adj[i].push((j, result.distance));
205 }
206 }
207 }
208 let snapshot: Vec<Vec<(usize, f64)>> = adj.clone();
210 for (i, edges) in snapshot.iter().enumerate() {
211 for &(j, d) in edges {
212 if !adj[j].iter().any(|&(k, _)| k == i) {
213 adj[j].push((i, d));
214 }
215 }
216 }
217
218 let mut dist = vec![f64::INFINITY; n];
220 let mut prev: Vec<Option<usize>> = vec![None; n];
221 let mut heap = BinaryHeap::new();
222
223 dist[source_idx] = 0.0;
224 heap.push(DijkstraEntry {
225 dist: 0.0,
226 node: source_idx,
227 });
228
229 while let Some(entry) = heap.pop() {
230 let u = entry.node;
231 if entry.dist > dist[u] {
232 continue;
233 }
234 if u == target_idx {
235 break;
236 }
237 for &(v, w) in &adj[u] {
238 let nd = dist[u] + w;
239 if nd < dist[v] {
240 dist[v] = nd;
241 prev[v] = Some(u);
242 heap.push(DijkstraEntry { dist: nd, node: v });
243 }
244 }
245 }
246
247 if dist[target_idx].is_infinite() {
248 return None;
249 }
250
251 let mut path = Vec::new();
253 let mut cur = target_idx;
254 loop {
255 path.push(PathStep {
256 id: items[cur].id.clone(),
257 cumulative_distance: dist[cur],
258 });
259 match prev[cur] {
260 Some(p) => cur = p,
261 None => break,
262 }
263 }
264 path.reverse();
265
266 Some(ConceptPath {
267 total_distance: dist[target_idx],
268 steps: path,
269 })
270 }
271}
272
273#[derive(Debug, Clone)]
276pub struct ConceptPath {
277 pub steps: Vec<PathStep>,
278 pub total_distance: f64,
279}
280
281#[derive(Debug, Clone)]
282pub struct PathStep {
283 pub id: String,
284 pub cumulative_distance: f64,
285}
286
287#[derive(PartialEq)]
288struct DijkstraEntry {
289 dist: f64,
290 node: usize,
291}
292
293impl Eq for DijkstraEntry {}
296
297impl PartialOrd for DijkstraEntry {
298 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
299 Some(self.cmp(other))
300 }
301}
302
303impl Ord for DijkstraEntry {
304 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
305 other
307 .dist
308 .partial_cmp(&self.dist)
309 .unwrap_or(std::cmp::Ordering::Equal)
310 }
311}
312
313#[derive(Debug, Clone)]
324pub struct SlicingManifold {
325 pub centroid: [f64; 3],
326 pub normal: [f64; 3],
327 pub basis_u: [f64; 3],
328 pub basis_v: [f64; 3],
329 pub variance_ratio: f64,
330}
331
332impl SlicingManifold {
333 pub fn fit(points: &[[f64; 3]]) -> Self {
336 let n = points.len() as f64;
337 assert!(n >= 3.0, "need at least 3 points to fit a plane");
338
339 let mut c = [0.0; 3];
341 for p in points {
342 for i in 0..3 {
343 c[i] += p[i];
344 }
345 }
346 for ci in &mut c {
347 *ci /= n;
348 }
349
350 let mut cov = [[0.0f64; 3]; 3];
352 for p in points {
353 let d = [p[0] - c[0], p[1] - c[1], p[2] - c[2]];
354 for i in 0..3 {
355 for j in 0..3 {
356 cov[i][j] += d[i] * d[j];
357 }
358 }
359 }
360 for row in &mut cov {
361 for v in row.iter_mut() {
362 *v /= n;
363 }
364 }
365
366 let (eigenvalues, eigenvectors) = eigen_symmetric_3x3(&cov);
368
369 let total_var = eigenvalues[0] + eigenvalues[1] + eigenvalues[2];
372 let variance_ratio = if total_var > 0.0 {
373 (eigenvalues[0] + eigenvalues[1]) / total_var
374 } else {
375 1.0
376 };
377
378 Self {
379 centroid: c,
380 normal: eigenvectors[2],
381 basis_u: eigenvectors[0],
382 basis_v: eigenvectors[1],
383 variance_ratio,
384 }
385 }
386
387 pub fn project_2d(&self, point: &[f64; 3]) -> (f64, f64) {
389 let d = [
390 point[0] - self.centroid[0],
391 point[1] - self.centroid[1],
392 point[2] - self.centroid[2],
393 ];
394 let u = d[0] * self.basis_u[0] + d[1] * self.basis_u[1] + d[2] * self.basis_u[2];
395 let v = d[0] * self.basis_v[0] + d[1] * self.basis_v[1] + d[2] * self.basis_v[2];
396 (u, v)
397 }
398
399 pub fn distance(&self, point: &[f64; 3]) -> f64 {
401 let d = [
402 point[0] - self.centroid[0],
403 point[1] - self.centroid[1],
404 point[2] - self.centroid[2],
405 ];
406 d[0] * self.normal[0] + d[1] * self.normal[1] + d[2] * self.normal[2]
407 }
408
409 pub fn fit_local(query: &[f64; 3], all_points: &[[f64; 3]], k: usize) -> Self {
420 let mut dists: Vec<(usize, f64)> = all_points
421 .iter()
422 .enumerate()
423 .map(|(i, p)| (i, dist3(query, p)))
424 .collect();
425 dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
426
427 let neighborhood: Vec<[f64; 3]> = dists
428 .iter()
429 .take(k.max(3))
430 .map(|&(i, _)| all_points[i])
431 .collect();
432
433 Self::fit(&neighborhood)
434 }
435}
436
437fn eigen_symmetric_3x3(m: &[[f64; 3]; 3]) -> ([f64; 3], [[f64; 3]; 3]) {
440 let mut a = *m;
441 let mut v = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]; #[allow(clippy::needless_range_loop)]
444 for _ in 0..50 {
445 let mut p = 0;
447 let mut q = 1;
448 let mut max_val = a[0][1].abs();
449 for i in 0..3 {
450 for j in (i + 1)..3 {
451 if a[i][j].abs() > max_val {
452 max_val = a[i][j].abs();
453 p = i;
454 q = j;
455 }
456 }
457 }
458 if max_val < 1e-15 {
459 break;
460 }
461
462 let theta = if (a[p][p] - a[q][q]).abs() < 1e-30 {
464 std::f64::consts::FRAC_PI_4
465 } else {
466 0.5 * (2.0 * a[p][q] / (a[p][p] - a[q][q])).atan()
467 };
468 let c = theta.cos();
469 let s = theta.sin();
470
471 let mut new_a = a;
473 for i in 0..3 {
474 new_a[i][p] = c * a[i][p] + s * a[i][q];
475 new_a[i][q] = -s * a[i][p] + c * a[i][q];
476 }
477 let snapshot = new_a;
478 for j in 0..3 {
479 new_a[p][j] = c * snapshot[p][j] + s * snapshot[q][j];
480 new_a[q][j] = -s * snapshot[p][j] + c * snapshot[q][j];
481 }
482 new_a[p][q] = 0.0;
483 new_a[q][p] = 0.0;
484 a = new_a;
485
486 let mut new_v = v;
488 for i in 0..3 {
489 new_v[i][p] = c * v[i][p] + s * v[i][q];
490 new_v[i][q] = -s * v[i][p] + c * v[i][q];
491 }
492 v = new_v;
493 }
494
495 let eigenvalues = [a[0][0], a[1][1], a[2][2]];
496
497 let mut order = [0usize, 1, 2];
499 order.sort_by(|&a, &b| eigenvalues[b].partial_cmp(&eigenvalues[a]).unwrap());
500
501 let sorted_vals = [
502 eigenvalues[order[0]],
503 eigenvalues[order[1]],
504 eigenvalues[order[2]],
505 ];
506 let sorted_vecs = [
508 [v[0][order[0]], v[1][order[0]], v[2][order[0]]],
509 [v[0][order[1]], v[1][order[1]], v[2][order[1]]],
510 [v[0][order[2]], v[1][order[2]], v[2][order[2]]],
511 ];
512
513 (sorted_vals, sorted_vecs)
514}
515
516#[derive(Debug, Clone)]
520pub struct ConceptGlob {
521 pub id: usize,
522 pub centroid: [f64; 3],
523 pub member_ids: Vec<String>,
524 pub member_distances: Vec<f64>,
525 pub radius: f64,
526}
527
528#[derive(Debug, Clone)]
530pub struct GlobResult {
531 pub globs: Vec<ConceptGlob>,
532 pub k: usize,
533 pub silhouette: f64,
534}
535
536impl GlobResult {
537 pub fn detect(points: &[[f64; 3]], ids: &[String], k: Option<usize>, max_k: usize) -> Self {
542 let n = points.len();
543 assert_eq!(n, ids.len());
544 assert!(n >= 2, "need at least 2 points for clustering");
545
546 let max_k = max_k.min(n);
547
548 if let Some(k) = k {
549 let k = k.clamp(2, max_k);
550 let (assignments, silhouette) = kmeans_3d(points, k);
551 let globs = build_globs(points, ids, &assignments, k);
552 return Self {
553 globs,
554 k,
555 silhouette,
556 };
557 }
558
559 let mut best_k = 2;
561 let mut best_sil = f64::NEG_INFINITY;
562 let mut best_assignments = vec![0usize; n];
563
564 for trial_k in 2..=max_k {
565 let (assignments, sil) = kmeans_3d(points, trial_k);
566 if sil > best_sil {
567 best_sil = sil;
568 best_k = trial_k;
569 best_assignments = assignments;
570 }
571 }
572
573 let globs = build_globs(points, ids, &best_assignments, best_k);
574 Self {
575 globs,
576 k: best_k,
577 silhouette: best_sil,
578 }
579 }
580}
581
582fn kmeans_3d(points: &[[f64; 3]], k: usize) -> (Vec<usize>, f64) {
583 let n = points.len();
584 let max_iter = 50;
585
586 let mut centers: Vec<[f64; 3]> = (0..k).map(|i| points[i * n / k]).collect();
588
589 let mut assignments = vec![0usize; n];
590
591 for _ in 0..max_iter {
592 let mut changed = false;
593
594 for (i, p) in points.iter().enumerate() {
596 let mut best = 0;
597 let mut best_d = f64::MAX;
598 for (j, c) in centers.iter().enumerate() {
599 let d = angular_dist3(p, c);
600 if d < best_d {
601 best_d = d;
602 best = j;
603 }
604 }
605 if assignments[i] != best {
606 assignments[i] = best;
607 changed = true;
608 }
609 }
610
611 if !changed {
612 break;
613 }
614
615 let mut sums = vec![[0.0f64; 3]; k];
618 let mut counts = vec![0usize; k];
619 for (i, &a) in assignments.iter().enumerate() {
620 let norm_p = normalize3(&points[i]);
621 for (d, &np) in norm_p.iter().enumerate() {
622 sums[a][d] += np;
623 }
624 counts[a] += 1;
625 }
626 for j in 0..k {
627 if counts[j] > 0 {
628 centers[j] = normalize3(&sums[j]);
629 }
630 }
631 }
632
633 let sil = silhouette_score(points, &assignments, k);
634 (assignments, sil)
635}
636
637fn silhouette_score(points: &[[f64; 3]], assignments: &[usize], k: usize) -> f64 {
638 let n = points.len();
639 if n <= 1 || k <= 1 {
640 return 0.0;
641 }
642
643 let mut total = 0.0;
644 for i in 0..n {
645 let ci = assignments[i];
646
647 let mut a_sum = 0.0;
649 let mut a_cnt = 0;
650 for j in 0..n {
651 if j != i && assignments[j] == ci {
652 a_sum += angular_dist3(&points[i], &points[j]);
653 a_cnt += 1;
654 }
655 }
656 let a = if a_cnt > 0 { a_sum / a_cnt as f64 } else { 0.0 };
657
658 let mut b = f64::MAX;
660 for ck in 0..k {
661 if ck == ci {
662 continue;
663 }
664 let mut b_sum = 0.0;
665 let mut b_cnt = 0;
666 for j in 0..n {
667 if assignments[j] == ck {
668 b_sum += angular_dist3(&points[i], &points[j]);
669 b_cnt += 1;
670 }
671 }
672 if b_cnt > 0 {
673 b = b.min(b_sum / b_cnt as f64);
674 }
675 }
676 if b == f64::MAX {
677 b = 0.0;
678 }
679
680 let denom = a.max(b);
681 total += if denom > 0.0 { (b - a) / denom } else { 0.0 };
682 }
683
684 total / n as f64
685}
686
687fn build_globs(
688 points: &[[f64; 3]],
689 ids: &[String],
690 assignments: &[usize],
691 k: usize,
692) -> Vec<ConceptGlob> {
693 let mut globs = Vec::with_capacity(k);
694
695 for cluster_id in 0..k {
696 let member_indices: Vec<usize> = assignments
697 .iter()
698 .enumerate()
699 .filter(|&(_, &a)| a == cluster_id)
700 .map(|(i, _)| i)
701 .collect();
702
703 if member_indices.is_empty() {
704 continue;
705 }
706
707 let mut centroid = [0.0; 3];
709 for &i in &member_indices {
710 let norm_p = normalize3(&points[i]);
711 for (d, c) in centroid.iter_mut().enumerate() {
712 *c += norm_p[d];
713 }
714 }
715 centroid = normalize3(¢roid);
716
717 let member_distances: Vec<f64> = member_indices
719 .iter()
720 .map(|&i| angular_dist3(&points[i], ¢roid))
721 .collect();
722
723 let radius = member_distances.iter().cloned().fold(0.0f64, f64::max);
724
725 let member_ids: Vec<String> = member_indices.iter().map(|&i| ids[i].clone()).collect();
726
727 globs.push(ConceptGlob {
728 id: cluster_id,
729 centroid,
730 member_ids,
731 member_distances,
732 radius,
733 });
734 }
735
736 globs
737}
738
739fn dist3(a: &[f64; 3], b: &[f64; 3]) -> f64 {
740 let dx = a[0] - b[0];
741 let dy = a[1] - b[1];
742 let dz = a[2] - b[2];
743 (dx * dx + dy * dy + dz * dz).sqrt()
744}
745
746fn angular_dist3(a: &[f64; 3], b: &[f64; 3]) -> f64 {
749 let dot = a[0] * b[0] + a[1] * b[1] + a[2] * b[2];
750 let ma = (a[0] * a[0] + a[1] * a[1] + a[2] * a[2]).sqrt();
751 let mb = (b[0] * b[0] + b[1] * b[1] + b[2] * b[2]).sqrt();
752 let denom = ma * mb;
753 if denom < f64::EPSILON {
754 return 0.0;
755 }
756 (dot / denom).clamp(-1.0, 1.0).acos()
757}
758
759fn normalize3(v: &[f64; 3]) -> [f64; 3] {
761 let mag = (v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt();
762 if mag < f64::EPSILON {
763 return [0.0; 3];
764 }
765 [v[0] / mag, v[1] / mag, v[2] / mag]
766}
767
768pub struct SemanticQuery;
770
771impl SemanticQuery {
772 pub fn within_angle<P: Projection>(
774 query: &Embedding,
775 projection: &P,
776 max_angular_distance: f64,
777 ) -> Region {
778 let point = projection.project(query);
779 let half_angle = max_angular_distance.clamp(1e-10, std::f64::consts::PI);
780 Region::Cap(
781 Cap::new(
782 SphericalPoint::new_unchecked(1.0, point.theta, point.phi),
783 half_angle,
784 )
785 .unwrap(),
786 )
787 }
788
789 pub fn above_similarity<P: Projection>(
792 query: &Embedding,
793 projection: &P,
794 min_similarity: f64,
795 ) -> Region {
796 let half_angle = min_similarity.clamp(-1.0, 1.0).acos();
797 Self::within_angle(query, projection, half_angle)
798 }
799
800 pub fn in_shell(inner: f64, outer: f64) -> Region {
802 Region::Shell(Shell::new(inner, outer).expect("invalid shell bounds"))
803 }
804
805 pub fn similar_in_shell<P: Projection>(
808 query: &Embedding,
809 projection: &P,
810 min_similarity: f64,
811 shell_inner: f64,
812 shell_outer: f64,
813 ) -> Region {
814 Region::intersection(vec![
815 Self::above_similarity(query, projection, min_similarity),
816 Self::in_shell(shell_inner, shell_outer),
817 ])
818 }
819}
820
821#[cfg(test)]
822mod tests {
823 use super::*;
824 use crate::projection::{PcaProjection, RandomProjection};
825 use crate::types::RadialStrategy;
826 use sphereql_core::angular_distance;
827
828 fn emb(vals: &[f64]) -> Embedding {
829 Embedding::new(vals.to_vec())
830 }
831
832 fn test_corpus() -> Vec<Embedding> {
833 vec![
834 emb(&[1.0, 0.0, 0.0, 0.1, 0.0]),
835 emb(&[0.0, 1.0, 0.0, 0.0, 0.1]),
836 emb(&[0.0, 0.0, 1.0, 0.1, 0.0]),
837 emb(&[1.0, 1.0, 0.0, 0.05, 0.05]),
838 emb(&[-1.0, 0.0, 0.0, -0.1, 0.0]),
839 emb(&[0.0, -1.0, 0.0, 0.0, -0.1]),
840 ]
841 }
842
843 #[test]
846 fn insert_and_get() {
847 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
848 let mut idx = EmbeddingIndex::builder(rp)
849 .theta_divisions(4)
850 .phi_divisions(3)
851 .build();
852
853 idx.insert("a", &emb(&[1.0, 0.0, 0.0, 0.0, 0.0]));
854 idx.insert("b", &emb(&[0.0, 1.0, 0.0, 0.0, 0.0]));
855
856 assert_eq!(idx.len(), 2);
857 assert!(!idx.is_empty());
858 assert!(idx.get("a").is_some());
859 assert!(idx.get("b").is_some());
860 assert!(idx.get("c").is_none());
861 }
862
863 #[test]
864 fn remove() {
865 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
866 let mut idx = EmbeddingIndex::builder(rp).build();
867
868 idx.insert("a", &emb(&[1.0; 5]));
869 assert_eq!(idx.len(), 1);
870
871 let removed = idx.remove("a");
872 assert!(removed.is_some());
873 assert_eq!(removed.unwrap().id, "a");
874 assert_eq!(idx.len(), 0);
875 assert!(idx.get("a").is_none());
876 }
877
878 #[test]
879 fn remove_nonexistent() {
880 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
881 let mut idx = EmbeddingIndex::builder(rp).build();
882 assert!(idx.remove("nope").is_none());
883 }
884
885 #[test]
886 fn search_nearest_returns_sorted() {
887 let corpus = test_corpus();
888 let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0));
889 let mut idx = EmbeddingIndex::builder(pca)
890 .theta_divisions(4)
891 .phi_divisions(3)
892 .build();
893
894 for (i, e) in corpus.iter().enumerate() {
895 idx.insert(format!("item-{i}"), e);
896 }
897
898 let query = emb(&[0.95, 0.1, 0.0, 0.05, 0.0]);
899 let results = idx.search_nearest(&query, 3);
900
901 assert_eq!(results.len(), 3);
902 assert!(results[0].distance <= results[1].distance);
903 assert!(results[1].distance <= results[2].distance);
904 }
905
906 #[test]
907 fn search_similar_respects_threshold() {
908 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
909 let mut idx = EmbeddingIndex::builder(rp)
910 .theta_divisions(4)
911 .phi_divisions(3)
912 .build();
913
914 idx.insert("close_a", &emb(&[1.0, 0.1, 0.0, 0.0, 0.0]));
915 idx.insert("close_b", &emb(&[0.9, 0.2, 0.0, 0.0, 0.0]));
916 idx.insert("far", &emb(&[-1.0, 0.0, 0.0, 0.0, 0.0]));
917
918 let query = emb(&[1.0, 0.0, 0.0, 0.0, 0.0]);
919 let projected_query = idx.projection().project(&query);
920 let result = idx.search_similar(&query, 0.5);
921
922 let max_angle = 0.5_f64.acos();
923 for item in &result.items {
924 let d = angular_distance(&projected_query, item.position());
925 assert!(d <= max_angle + 1e-10, "item {} too far: {d}", item.id);
926 }
927 }
928
929 #[test]
930 fn insert_with_radius_overrides() {
931 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
932 let mut idx = EmbeddingIndex::builder(rp).build();
933
934 idx.insert_with_radius("custom", &emb(&[1.0, 0.0, 0.0, 0.0, 0.0]), 42.0);
935 let item = idx.get("custom").unwrap();
936 assert!((item.position.r - 42.0).abs() < 1e-12);
937 }
938
939 #[test]
940 fn original_magnitude_stored() {
941 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
942 let mut idx = EmbeddingIndex::builder(rp).build();
943
944 let e = emb(&[3.0, 4.0, 0.0, 0.0, 0.0]);
945 idx.insert("vec", &e);
946 let item = idx.get("vec").unwrap();
947 assert!((item.original_magnitude - 5.0).abs() < 1e-10);
948 }
949
950 #[test]
951 fn magnitude_radial_with_shell_query() {
952 let corpus = test_corpus();
953 let pca = PcaProjection::fit(&corpus, RadialStrategy::Magnitude);
954 let mut idx = EmbeddingIndex::builder(pca)
955 .uniform_shells(5, 10.0)
956 .theta_divisions(4)
957 .phi_divisions(3)
958 .build();
959
960 idx.insert("small", &emb(&[0.1, 0.0, 0.0, 0.0, 0.0]));
961 idx.insert("medium", &emb(&[1.0, 0.0, 0.0, 0.0, 0.0]));
962 idx.insert("large", &emb(&[5.0, 0.0, 0.0, 0.0, 0.0]));
963
964 let shell = Shell::new(0.5, 2.0).unwrap();
965 let result = idx.search_region(&Region::Shell(shell));
966
967 let ids: Vec<&str> = result.items.iter().map(|i| i.id.as_str()).collect();
968 assert!(
969 ids.contains(&"medium"),
970 "medium (mag=1.0) should be in [0.5, 2.0]"
971 );
972 assert!(
973 !ids.contains(&"large"),
974 "large (mag=5.0) should not be in [0.5, 2.0]"
975 );
976 }
977
978 #[test]
979 fn empty_index() {
980 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
981 let idx = EmbeddingIndex::builder(rp).build();
982
983 assert!(idx.is_empty());
984 assert_eq!(idx.len(), 0);
985 assert!(idx.get("x").is_none());
986
987 let results = idx.search_nearest(&emb(&[1.0; 5]), 5);
988 assert!(results.is_empty());
989 }
990
991 #[test]
992 fn projection_accessor() {
993 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
994 let idx = EmbeddingIndex::builder(rp).build();
995 assert_eq!(idx.projection().dimensionality(), 5);
996 }
997
998 #[test]
1001 fn above_similarity_creates_cap() {
1002 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1003 let region = SemanticQuery::above_similarity(&emb(&[1.0; 5]), &rp, 0.8);
1004 assert!(matches!(region, Region::Cap(_)));
1005 }
1006
1007 #[test]
1008 fn within_angle_creates_cap() {
1009 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1010 let region = SemanticQuery::within_angle(&emb(&[1.0; 5]), &rp, 0.5);
1011 assert!(matches!(region, Region::Cap(_)));
1012 }
1013
1014 #[test]
1015 fn in_shell_creates_shell() {
1016 let region = SemanticQuery::in_shell(1.0, 5.0);
1017 assert!(matches!(region, Region::Shell(_)));
1018 }
1019
1020 #[test]
1021 fn similar_in_shell_creates_intersection() {
1022 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1023 let region = SemanticQuery::similar_in_shell(&emb(&[1.0; 5]), &rp, 0.7, 1.0, 5.0);
1024
1025 match region {
1026 Region::Intersection(parts) => {
1027 assert_eq!(parts.len(), 2);
1028 assert!(matches!(parts[0], Region::Cap(_)));
1029 assert!(matches!(parts[1], Region::Shell(_)));
1030 }
1031 other => panic!("expected Intersection, got {other:?}"),
1032 }
1033 }
1034
1035 #[test]
1036 fn semantic_query_region_used_in_index() {
1037 let corpus = test_corpus();
1038 let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0));
1039 let projection_clone = pca.clone();
1040 let mut idx = EmbeddingIndex::builder(pca)
1041 .theta_divisions(4)
1042 .phi_divisions(3)
1043 .build();
1044
1045 for (i, e) in corpus.iter().enumerate() {
1046 idx.insert(format!("item-{i}"), e);
1047 }
1048
1049 let query = emb(&[1.0, 0.0, 0.0, 0.05, 0.0]);
1050 let region = SemanticQuery::above_similarity(&query, &projection_clone, 0.5);
1051 let result = idx.search_region(®ion);
1052
1053 for item in &result.items {
1054 assert!(region.contains(item.position()));
1055 }
1056 }
1057}